diff --git a/cmake/external_libs/cmsis.cmake b/cmake/external_libs/cmsis.cmake deleted file mode 100644 index d0bed34152180e7e0ffea433bf1234f31c7575d9..0000000000000000000000000000000000000000 --- a/cmake/external_libs/cmsis.cmake +++ /dev/null @@ -1,37 +0,0 @@ -set(cmsis_pkg_name cmsis) - -if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. - set(REQ_URL "https://gitee.com/mirrors/CMSIS_5/repository/archive/5.7.0.tar.gz") - set(SHA256 "1b4aa6d47c7d3a5032555049b95f4962a700e2022405f863781010606fe7f8f1") -else() - set(REQ_URL "https://github.com/ARM-software/CMSIS_5/archive/5.7.0.tar.gz") - set(SHA256 "1b4aa6d47c7d3a5032555049b95f4962a700e2022405f863781010606fe7f8f1") -endif() - -set(INCLUDE "./") - -mindspore_add_pkg(${cmsis_pkg_name} - VER 5.7.0 - HEAD_ONLY ${INCLUDE} - URL ${REQ_URL} - SHA256 ${SHA256}) - -message("micro get ${cmsis_pkg_name} config hash: ${${cmsis_pkg_name}_CONFIG_HASH}") - -file(GLOB cmsic_children RELATIVE ${_MS_LIB_CACHE} ${_MS_LIB_CACHE}/*) - -foreach(child ${cmsic_children}) - string(FIND "${child}" "${cmsis_pkg_name}" position) - if(NOT "${position}" EQUAL "-1") - file(STRINGS ${_MS_LIB_CACHE}/${child}/options.txt cmsis_configs) - foreach(cmsis_config ${cmsis_configs}) - string(FIND "${cmsis_config}" "${SHA256}" position_sha256) - if(NOT "${position_sha256}" EQUAL "-1") - if(NOT IS_DIRECTORY ${CMAKE_BINARY_DIR}/${cmsis_pkg_name}) - MESSAGE("copy cmsis libaray: ${child} to ${CMAKE_BINARY_DIR}") - file(COPY ${_MS_LIB_CACHE}/${child}/CMSIS DESTINATION ${CMAKE_BINARY_DIR}/${cmsis_pkg_name}) - endif() - endif() - endforeach() - endif() -endforeach() diff --git a/cmake/external_libs/dirent.cmake b/cmake/external_libs/dirent.cmake deleted file mode 100644 index 4e58c46ca27d12047ce8a8befcd4594f27460f61..0000000000000000000000000000000000000000 --- a/cmake/external_libs/dirent.cmake +++ /dev/null @@ -1,20 +0,0 @@ -if(ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/dirent/repository/archive/1.24.zip") - set(SHA256 "d762a87c01b63a1bc8d13e573a38dd067a8c290f365398cbc1a9519c8fdc720a") -else() - set(REQ_URL "https://github.com/tronkko/dirent/archive/refs/tags/1.24.zip") - set(SHA256 "46fa2833610e60275e30949c9cb4268430f945ca11fdbfa80dfad68de967103a") -endif() - - -if(MSVC) - mindspore_add_pkg(dirent - VER 1.24 - HEAD_ONLY ./include - RELEASE on - URL ${REQ_URL} - SHA256 ${SHA256}) - include_directories(${dirent_INC}) -endif() - - diff --git a/cmake/external_libs/libevent.cmake b/cmake/external_libs/libevent.cmake deleted file mode 100644 index d652cf44aac54d9e1f60ba7836ad5ab2b6b678b4..0000000000000000000000000000000000000000 --- a/cmake/external_libs/libevent.cmake +++ /dev/null @@ -1,36 +0,0 @@ -set(openssl_USE_STATIC_LIBS ON) -set(libevent_CFLAGS "-fPIC -fvisibility=hidden -fstack-protector-all -D_FORTIFY_SOURCE=2 -O2") -if(NOT CMAKE_SYSTEM_NAME MATCHES "Darwin") - set(libevent_LDFLAGS "-Wl,-z,now") -endif() - -if(NOT MINDSPORE_PROJECT_DIR) -set(MINDSPORE_PROJECT_DIR ${CMAKE_SOURCE_DIR}) -endif() - -if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. - set(REQ_URL "https://gitee.com/mirrors/libevent/repository/archive/release-2.1.12-stable.tar.gz") - set(SHA256 "7180a979aaa7000e1264da484f712d403fcf7679b1e9212c4e3d09f5c93efc24") -else() - set(REQ_URL - "https://github.com/libevent/libevent/releases/download/release-2.1.12-stable/libevent-2.1.12-stable.tar.gz") - set(SHA256 "92e6de1be9ec176428fd2367677e61ceffc2ee1cb119035037a27d346b0403bb") -endif() - -message("libevent using openssl stub dir: " ${openssl_ROOT}) - -mindspore_add_pkg(libevent - VER 2.1.12 - LIBS event event_pthreads event_core event_openssl - URL ${REQ_URL} - SHA256 ${SHA256} - PATCHES ${MINDSPORE_PROJECT_DIR}/third_party/patch/libevent/libevent.patch001 - CMAKE_OPTION -DCMAKE_BUILD_TYPE:STRING=Release -DBUILD_TESTING=OFF -DOPENSSL_ROOT_DIR:PATH=${openssl_ROOT} - -DEVENT__LIBRARY_TYPE:STRING=STATIC) - -include_directories(${libevent_INC}) - -add_library(mindspore::event ALIAS libevent::event) -add_library(mindspore::event_pthreads ALIAS libevent::event_pthreads) -add_library(mindspore::event_core ALIAS libevent::event_core) -add_library(mindspore::event_openssl ALIAS libevent::event_openssl) diff --git a/cmake/external_libs/mkl_dnn.cmake b/cmake/external_libs/mkl_dnn.cmake deleted file mode 100644 index 072569485743a9ab3b789a0e7258a97048d12a22..0000000000000000000000000000000000000000 --- a/cmake/external_libs/mkl_dnn.cmake +++ /dev/null @@ -1,44 +0,0 @@ -set(onednn_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") -set(onednn_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") -set(onednn_LDFLAGS "-s") - -if(NOT MINDSPORE_PROJECT_DIR) -set(MINDSPORE_PROJECT_DIR ${CMAKE_SOURCE_DIR}) -endif() - -if(USE_MS_THREADPOOL_FOR_DNNL) - set(USE_MS_THREADPOOL "-DDNNL_CPU_RUNTIME=THREADPOOL") -else() - set(USE_MS_THREADPOOL "") -endif() -if(ENABLE_GITEE_EULER) - set(GIT_REPOSITORY "git@gitee.com:src-openeuler/onednn.git") - set(GIT_TAG "0d726f1") - set(SHA256 "4d655c0751ee6439584ef5e3d465953fe0c2f4ee2700bc02699bdc1d1572af0d") - __download_pkg_with_git(ONEDNN ${GIT_REPOSITORY} ${GIT_TAG} ${SHA256}) - set(ONE_DNN_SRC "${CMAKE_BINARY_DIR}/_deps/onednn-src") - execute_process(COMMAND tar -xf ${ONE_DNN_SRC}/v2.2.tar.gz --strip-components 1 -C ${ONE_DNN_SRC}) -endif() - -if(ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/MKL-DNN/repository/archive/v2.2.tar.gz") - set(SHA256 "2e809b11727af9d10784a5481b445a14387297161b5cc7f9c969c57fe40752bc") -else() - set(REQ_URL "https://github.com/oneapi-src/oneDNN/archive/v2.2.tar.gz") - set(SHA256 "4d655c0751ee6439584ef5e3d465953fe0c2f4ee2700bc02699bdc1d1572af0d") -endif() -mindspore_add_pkg(onednn - VER 2.2 - LIBS dnnl mkldnn - URL ${REQ_URL} - SHA256 ${SHA256} - PATCHES ${MINDSPORE_PROJECT_DIR}/third_party/patch/onednn/0001-fix-user-threadpool-bug.patch - PATCHES ${MINDSPORE_PROJECT_DIR}/third_party/patch/onednn/0002-fix-pool-nthr-bug.patch - PATCHES ${MINDSPORE_PROJECT_DIR}/third_party/patch/onednn/0003-fix-zero-threads-identified-on-AMD.patch - PATCHES ${MINDSPORE_PROJECT_DIR}/third_party/patch/onednn/0004-fix-dnnl-limits.patch - CMAKE_OPTION -DDNNL_ARCH_OPT_FLAGS='' -DDNNL_BUILD_EXAMPLES=OFF -DDNNL_BUILD_TESTS=OFF - ${USE_MS_THREADPOOL} -DDNNL_ENABLE_CONCURRENT_EXEC=ON) - -include_directories(${onednn_INC}) -add_library(mindspore::dnnl ALIAS onednn::dnnl) -add_library(mindspore::mkldnn ALIAS onednn::mkldnn) diff --git a/cmake/external_libs/robin.cmake b/cmake/external_libs/robin.cmake deleted file mode 100644 index ea1c9dd0c458532e25ec7bb719171d2024993d4e..0000000000000000000000000000000000000000 --- a/cmake/external_libs/robin.cmake +++ /dev/null @@ -1,19 +0,0 @@ -if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. - set(REQ_URL "https://gitee.com/mirrors/robin-hood-hashing/repository/archive/3.11.5.zip") - set(SHA256 "8d1f5d5ee447e5827032d1eb8b1609134618b1cc5c5bcadfcbfed99a2d3583d4") -else() - set(REQ_URL "https://github.com/martinus/robin-hood-hashing/archive/3.11.5.zip") - set(SHA256 "7aa183252527ded7f46186c1e2f4efe7d6139a3b7c0869c1b6051bd7260587ed") -endif() -set(INCLUDE "./src") - -mindspore_add_pkg(robin_hood_hashing - VER 3.11.5 - HEAD_ONLY ${INCLUDE} - URL ${REQ_URL} - SHA256 ${SHA256} - PATCHES ${TOP_DIR}/third_party/patch/robin_hood_hashing/0001-fix-unused-var-warning.patch - PATCHES ${TOP_DIR}/third_party/patch/robin_hood_hashing/0002-fix-string-isflat-symbol.patch - ) - -include_directories(${robin_hood_hashing_INC}) diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index 8e33fc3dcfbe6be0d1b34dc22d9434e2bac527a5..06cca6c53a34c6edc729d803ca3e12e00c8a1757 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -867,10 +867,6 @@ else() endif() install(FILES ${glog_LIBPATH}/${glog_name} DESTINATION ${RUNTIME_LIB_DIR} RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) - if(MSLITE_DEPS_MKLDNN) - install(FILES ${onednn_LIBPATH}/libdnnl.so.2.2 DESTINATION ${DNNL_DIR} - RENAME libdnnl.so.2 COMPONENT ${RUNTIME_COMPONENT_NAME}) - endif() install(TARGETS mindspore_core mindspore_ops DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${BUILD_DIR}/src/extendrt/convert/libruntime_convert_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) diff --git a/cmake/package_micro.cmake b/cmake/package_micro.cmake index 3f0c5ccad83e28b6bfe2d3e27f20a37f570a172f..ec07167019d3900bc760e634f04c8ea09dfa1cc1 100644 --- a/cmake/package_micro.cmake +++ b/cmake/package_micro.cmake @@ -38,8 +38,4 @@ function(__install_micro_codegen) COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") install(DIRECTORY ${MICRO_CMSIS_DIR}/NN/Include DESTINATION ${CODEGEN_ROOT_DIR}/third_party/include/CMSIS/NN COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") - if(MSLITE_DEPS_CMSIS) - install(TARGETS cmsis_nn ARCHIVE DESTINATION ${CODEGEN_ROOT_DIR}/third_party/lib - COMPONENT ${RUNTIME_COMPONENT_NAME}) - endif() endfunction() diff --git a/mindspore-lite/CMakeLists.txt b/mindspore-lite/CMakeLists.txt index c497ee12a09a0839cfdf1cdab3d2a72e29f94e52..dcec5cba1d834df1a1c31482fa5ec278efab418c 100644 --- a/mindspore-lite/CMakeLists.txt +++ b/mindspore-lite/CMakeLists.txt @@ -256,8 +256,6 @@ endif() if(MSLITE_TARGET_SITEAI) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/providers/siteai) -else() - set(MSLITE_DEPS_CMSIS on) endif() if(DEFINED ENV{MSLITE_ENABLE_MODEL_ENCRYPTION}) @@ -311,9 +309,6 @@ endif() if(DEFINED ENV{ENABLE_FAST_HASH_TABLE}) add_compile_definitions(ENABLE_FAST_HASH_TABLE) - if(NOT MSLITE_TARGET_SITEAI) - set(MSLITE_DEPS_ROBIN_HOOD_HASHING on) - endif() endif() if(DEFINED ENV{MSLITE_ENABLE_MODEL_OBF}) @@ -517,7 +512,6 @@ if(MSVC) if(MSLITE_ENABLE_RUNTIME_GLOG) add_definitions(-DNOMINMAX) add_definitions(-DNOGDI) - set(MSLITE_DEPS_DIRENT on) endif() endif() @@ -747,7 +741,6 @@ set(NNACL_DIR ${OPS_DIR}/kernel/cpu/nnacl) if(PLATFORM_MCU) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-incompatible-pointer-types") -# set(MSLITE_DEPS_CMSIS on) add_subdirectory(${NNACL_DIR} build/nnacl) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/cmake/cortex-m/ build) include(${TOP_DIR}/cmake/package_lite.cmake) @@ -941,14 +934,8 @@ if(MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENA set(MSLITE_DEPS_EIGEN on) endif() -if(NOT MSLITE_TARGET_SITEAI) - set(MSLITE_DEPS_CMSIS on) -endif() - if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) if(NOT MSLITE_TARGET_SITEAI) - set(MSLITE_DEPS_MKLDNN on) - set(MSLITE_DEPS_LIBEVENT on) set(MSLITE_DEPS_PYBIND11 on) endif() if(SUPPORT_TENSORRT) diff --git a/mindspore-lite/cmake/ccsrc_module.cmake b/mindspore-lite/cmake/ccsrc_module.cmake index 6283cf57f36c320dbef8169df72d2d7645403025..305387473a2f306d4edd44f7ba65009f9dd7dc60 100644 --- a/mindspore-lite/cmake/ccsrc_module.cmake +++ b/mindspore-lite/cmake/ccsrc_module.cmake @@ -16,9 +16,7 @@ message(${COMM_PROTO_IN}) ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN}) list(APPEND MINDSPORE_PROTO_LIST ${COMM_PROTO_SRCS}) -include(${TOP_DIR}/cmake/external_libs/robin.cmake) include(${TOP_DIR}/cmake/external_libs/eigen.cmake) -include(${TOP_DIR}/cmake/external_libs/mkl_dnn.cmake) find_package(Python3 COMPONENTS Interpreter Development) if(Python3_FOUND) @@ -34,4 +32,3 @@ if(Python3_FOUND) include(${TOP_DIR}/cmake/external_libs/pybind11.cmake) endif() endif() -include(${TOP_DIR}/cmake/external_libs/libevent.cmake) diff --git a/mindspore-lite/cmake/lite_dependences.cmake b/mindspore-lite/cmake/lite_dependences.cmake index 295a51cd781fd3234c2722e2f967b45e80cc3cb8..2b5f27cda562ef96194f493d8f75dcfa0fa59826 100644 --- a/mindspore-lite/cmake/lite_dependences.cmake +++ b/mindspore-lite/cmake/lite_dependences.cmake @@ -2,10 +2,6 @@ set(MINDSPORE_PROJECT_DIR ${TOP_DIR}) find_required_package(Patch) -if(MSLITE_DEPS_ROBIN_HOOD_HASHING) - include(${TOP_DIR}/cmake/external_libs/robin.cmake) -endif() - if(MSLITE_DEPS_FLATBUFFERS) include(${TOP_DIR}/cmake/external_libs/flatbuffers.cmake) endif() @@ -35,24 +31,6 @@ if(MSLITE_DEPS_OPENCV) include(${TOP_DIR}/cmake/external_libs/opencv.cmake) endif() -if(MSLITE_DEPS_FAST_TRANSFORMERS) - include(${TOP_DIR}/cmake/external_libs/fast_transformers.cmake) -endif() - -if(MSLITE_DEPS_MKLDNN) - if(CMAKE_SYSTEM_NAME MATCHES "Linux") - set(USE_MS_THREADPOOL_FOR_DNNL ON) - endif() - if(USE_MS_THREADPOOL_FOR_DNNL) - add_compile_definitions(USE_MS_THREADPOOL_FOR_DNNL) - endif() - include(${TOP_DIR}/cmake/external_libs/mkl_dnn.cmake) -endif() - -if(MSLITE_DEPS_LIBEVENT) - include(${TOP_DIR}/cmake/external_libs/libevent.cmake) -endif() - if(MSLITE_DEPS_PYBIND11) find_package(Python3 COMPONENTS Interpreter Development) set(PYTHON_LIBRARIES ${Python3_LIBRARIES}) @@ -76,8 +54,4 @@ endif() if(MSLITE_DEPS_OPENSSL) include(${TOP_DIR}/cmake/external_libs/openssl.cmake) -endif() - -if(MSLITE_DEPS_DIRENT) - include(${TOP_DIR}/cmake/external_libs/dirent.cmake) endif() \ No newline at end of file diff --git a/mindspore-lite/providers/siteai/CMakeLists.txt b/mindspore-lite/providers/siteai/CMakeLists.txt index dccb9942a6eaa4f05c78935da7b2f6d1fb55a301..ceb87c65b04b68424a2c57a1aaf47bf2f3d2f620 100644 --- a/mindspore-lite/providers/siteai/CMakeLists.txt +++ b/mindspore-lite/providers/siteai/CMakeLists.txt @@ -3,11 +3,7 @@ project(SiteAi) ##disable external libs set(MSLITE_DEPS_PYBIND11 on CACHE INTERNAL "setting MSLITE_DEPS_PYBIND11 value") -set(MSLITE_DEPS_ROBIN_HOOD_HASHING off CACHE INTERNAL "setting MSLITE_DEPS_ROBIN_HOOD_HASHING value") -set(MSLITE_DEPS_MKLDNN off CACHE INTERNAL "setting MSLITE_DEPS_MKLDNN value") -set(MSLITE_DEPS_LIBEVENT off CACHE INTERNAL "setting MSLITE_DEPS_LIBEVENT value") set(MSLITE_DEPS_OPENSSL off CACHE INTERNAL "setting MSLITE_DEPS_OPENSSL value") -set(MSLITE_DEPS_CMSIS off CACHE INTERNAL "setting MSLITE_DEPS_CMSIS value") ##enable prune simplest cloud inferenc set(MSLITE_SIMPLEST_CLOUD_INFERENCE on CACHE INTERNAL "setting MSLITE_SIMPLEST_CLOUD_INFERENCE value") diff --git a/mindspore-lite/src/extendrt/CMakeLists.txt b/mindspore-lite/src/extendrt/CMakeLists.txt index afd29018db28ab53f60ec831c20c114b49c7110e..239d1964e8e1f9da3355f4e087893828d173380a 100644 --- a/mindspore-lite/src/extendrt/CMakeLists.txt +++ b/mindspore-lite/src/extendrt/CMakeLists.txt @@ -177,22 +177,6 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) mindspore-lite-proto) target_link_libraries(mindspore-extendrt_static _mindspore_cpu_kernel_mod_depend_obj mindspore-lite-proto) - if(MSLITE_DEPS_MKLDNN) - add_dependencies(mindspore-extendrt mindspore::dnnl) - target_link_libraries(mindspore-extendrt mindspore::dnnl) - add_dependencies(mindspore-extendrt_static mindspore::dnnl) - target_link_libraries(mindspore-extendrt_static mindspore::dnnl) - endif() - - if(MSLITE_DEPS_MKLDNN) - set(CPU_KERNEL_OBJECT_COUNT 0) - add_subdirectory(${OPS_DIR}/kernel/cpu lite_kernel_mod) - foreach(number RANGE 1 ${CPU_KERNEL_OBJECT_COUNT}) - target_link_libraries(mindspore-extendrt _mindspore_ops_cpu_kernel_obj) - target_link_libraries(mindspore-extendrt_static _mindspore_ops_cpu_kernel_obj) - endforeach() - endif() - endif() if(NOT WIN32) diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/CMakeLists.txt b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/CMakeLists.txt index e1c1a89efbf4a59dfaa72ae52dcd99c046011312..a187f3aea96b840478f365e9ef87b5af383f56a8 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/CMakeLists.txt +++ b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/CMakeLists.txt @@ -186,9 +186,9 @@ endif() if(ENABLE_CPU) if(BUILD_LITE) - target_link_libraries(mslite_shared_lib PRIVATE mindspore::dnnl mindspore::mkldnn nnacl) + target_link_libraries(mslite_shared_lib PRIVATE mindspore::dnnl nnacl) else() - target_link_libraries(mindspore_shared_lib PRIVATE mindspore::dnnl mindspore::mkldnn nnacl) + target_link_libraries(mindspore_shared_lib PRIVATE mindspore::dnnl nnacl) endif() endif() diff --git a/mindspore-lite/tools/converter/micro/cmake/cortex-m/CMakeLists.txt b/mindspore-lite/tools/converter/micro/cmake/cortex-m/CMakeLists.txt index 07dbb8f673f87c2b14676ca4e09ae0450ef298e7..d5b9fe30459d04913a2228f6dd533ce23f4026b0 100644 --- a/mindspore-lite/tools/converter/micro/cmake/cortex-m/CMakeLists.txt +++ b/mindspore-lite/tools/converter/micro/cmake/cortex-m/CMakeLists.txt @@ -7,7 +7,6 @@ set(MICRO_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../..) include_directories(${NNACL_DIR}/..) include(${TOP_DIR}/cmake/utils.cmake) -include(${TOP_DIR}/cmake/external_libs/cmsis.cmake) set(CMSIS_DIR ${CMAKE_BINARY_DIR}/cmsis) message("build cmsis kernels") diff --git a/mindspore-lite/tools/converter/micro/coder/CMakeLists.txt b/mindspore-lite/tools/converter/micro/coder/CMakeLists.txt index 3076495cb0ecff35983eabc55cb7b784d6cbb422..5ce3ba1152a4172af21b2f4e548e4e959ee09bd5 100644 --- a/mindspore-lite/tools/converter/micro/coder/CMakeLists.txt +++ b/mindspore-lite/tools/converter/micro/coder/CMakeLists.txt @@ -20,10 +20,6 @@ include_directories(${LITE_DIR}) #include coder if(NOT MSVC OR NOT WIN32 OR NOT APPLE) - if(MSLITE_DEPS_CMSIS) - message("MSLITE_DEPS_CMSIS enabled") - include(${TOP_DIR}/cmake/external_libs/cmsis.cmake) - endif() include(${MICRO_DIR}/cmake/package_wrapper.cmake) add_subdirectory(wrapper) endif() diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/CMakeLists.txt b/mindspore-lite/tools/converter/micro/coder/wrapper/CMakeLists.txt index b6317f090db46964254e3a2091095b24bb80fe6a..e121aa4c2db02f51904167aab2a6cf3b3ed84d22 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/CMakeLists.txt +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/CMakeLists.txt @@ -16,29 +16,6 @@ if(PLATFORM_ARM64) add_compile_definitions(ENABLE_ARM64) elseif(PLATFORM_ARM32) add_compile_definitions(ENABLE_ARM32) -else() - if(MSLITE_DEPS_CMSIS) - message("MSLITE_DEPS_CMSIS enabled") - set(CMSIS_DIR ${CMAKE_BINARY_DIR}/cmsis) - message("build cmsis kernels") - include_directories(${CMSIS_DIR}/CMSIS/Core/Include) - include_directories(${CMSIS_DIR}/CMSIS/DSP/Include) - include_directories(${CMSIS_DIR}/CMSIS/NN/Include) - - file(REMOVE ${CMSIS_DIR}/CMSIS/NN/Source/NNSupportFunctions/arm_q7_to_q15_reordered_no_shift.c) - file(GLOB CMSIS_OPS - ${CMSIS_DIR}/CMSIS/NN/Source/BasicMathFunctions/*.c - ${CMSIS_DIR}/CMSIS/NN/Source/ActivationFunctions/*.c - ${CMSIS_DIR}/CMSIS/NN/Source/ConcatenationFunctions/*.c - ${CMSIS_DIR}/CMSIS/NN/Source/ConvolutionFunctions/*.c - ${CMSIS_DIR}/CMSIS/NN/Source/FullyConnectedFunctions/*.c - ${CMSIS_DIR}/CMSIS/NN/Source/NNSupportFunctions/*.c - ${CMSIS_DIR}/CMSIS/NN/Source/PoolingFunctions/*.c - ${CMSIS_DIR}/CMSIS/NN/Source/ReshapeFunctions/*.c - ${CMSIS_DIR}/CMSIS/NN/Source/SoftmaxFunctions/*.c - ) - add_library(cmsis_nn STATIC ${CMSIS_OPS}) - endif() endif() include(${MICRO_DIR}/cmake/package_wrapper.cmake) diff --git a/third_party/patch/fast_transformer/001-fast_transformer.patch b/third_party/patch/fast_transformer/001-fast_transformer.patch deleted file mode 100644 index e3a6542d2e1031d3ab8818551970c6624551e33e..0000000000000000000000000000000000000000 --- a/third_party/patch/fast_transformer/001-fast_transformer.patch +++ /dev/null @@ -1,615816 +0,0 @@ -diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml -new file mode 100644 -index 0000000..18054db ---- /dev/null -+++ b/.github/ISSUE_TEMPLATE/bug_report.yml -@@ -0,0 +1,32 @@ -+name: "Bug Report" -+description: Submit a bug report -+labels: [ "bug" ] -+body: -+ - type: textarea -+ id: description -+ attributes: -+ label: Description -+ description: Please share your system info with us. -+ render: shell -+ placeholder: branch, docker version, GPU type -+ validations: -+ required: true -+ -+ - type: textarea -+ id: reproduced-steps -+ attributes: -+ label: Reproduced Steps -+ description: Please provide the step to reproduce the bugs -+ render: shell -+ placeholder: | -+ Steps to reproduce your bugs: -+ -+ 1. docker run -ti --gpus all nvcr.io/nvidia/pytorch:22.03-py3 bash -+ 2. git clone https://github.com/NVIDIA/FasterTransformer.git -+ 3. cd FasterTransformer mkdir build && cd build -+ 4. cmake -DSM=80 -DCMAKE_BUILD_TYPE=Release .. && make -j12 -+ 5. ./bin/bert_example 32 12 32 12 64 0 0 -+ 6. What error you see. -+ -+ validations: -+ required: true -diff --git a/.gitignore b/.gitignore -index d15caca..0bca65e 100644 ---- a/.gitignore -+++ b/.gitignore -@@ -6,3 +6,4 @@ __pycache__/ - .vscode - ./translation - .cache -+ 001-fast_transformer.patch -diff --git a/.vscode/settings.json b/.vscode/settings.json -deleted file mode 100644 -index 6f535da..0000000 ---- a/.vscode/settings.json -+++ /dev/null -@@ -1,72 +0,0 @@ --{ -- "files.associations": { -- "*.cuh": "cpp", -- "stdexcept": "cpp", -- "chrono": "cpp", -- "cmath": "cpp", -- "type_traits": "cpp", -- "cctype": "cpp", -- "clocale": "cpp", -- "cstdarg": "cpp", -- "cstddef": "cpp", -- "cstdio": "cpp", -- "cstdlib": "cpp", -- "cstring": "cpp", -- "ctime": "cpp", -- "cwchar": "cpp", -- "cwctype": "cpp", -- "array": "cpp", -- "atomic": "cpp", -- "*.tcc": "cpp", -- "condition_variable": "cpp", -- "cstdint": "cpp", -- "deque": "cpp", -- "unordered_map": "cpp", -- "vector": "cpp", -- "exception": "cpp", -- "algorithm": "cpp", -- "functional": "cpp", -- "iterator": "cpp", -- "map": "cpp", -- "memory": "cpp", -- "memory_resource": "cpp", -- "numeric": "cpp", -- "optional": "cpp", -- "random": "cpp", -- "ratio": "cpp", -- "set": "cpp", -- "string": "cpp", -- "string_view": "cpp", -- "system_error": "cpp", -- "tuple": "cpp", -- "utility": "cpp", -- "fstream": "cpp", -- "initializer_list": "cpp", -- "iomanip": "cpp", -- "iosfwd": "cpp", -- "iostream": "cpp", -- "istream": "cpp", -- "limits": "cpp", -- "mutex": "cpp", -- "new": "cpp", -- "ostream": "cpp", -- "sstream": "cpp", -- "streambuf": "cpp", -- "thread": "cpp", -- "cinttypes": "cpp", -- "typeinfo": "cpp", -- "bitset": "cpp", -- "hash_map": "cpp", -- "hash_set": "cpp", -- "slist": "cpp", -- "regex": "cpp", -- "strstream": "cpp", -- "complex": "cpp", -- "forward_list": "cpp", -- "list": "cpp", -- "unordered_set": "cpp", -- "future": "cpp", -- "cfenv": "cpp", -- "typeindex": "cpp" -- } --} -\ No newline at end of file -diff --git a/3rdparty/cutlass/LICENSE.txt b/3rdparty/cutlass/LICENSE.txt -new file mode 100644 -index 0000000..2913ab8 ---- /dev/null -+++ b/3rdparty/cutlass/LICENSE.txt -@@ -0,0 +1,27 @@ -+Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+SPDX-License-Identifier: BSD-3-Clause -+ -+Redistribution and use in source and binary forms, with or without -+modification, are permitted provided that the following conditions are met: -+ -+1. Redistributions of source code must retain the above copyright notice, this -+list of conditions and the following disclaimer. -+ -+2. Redistributions in binary form must reproduce the above copyright notice, -+this list of conditions and the following disclaimer in the documentation -+and/or other materials provided with the distribution. -+ -+3. Neither the name of the copyright holder nor the names of its -+contributors may be used to endorse or promote products derived from -+this software without specific prior written permission. -+ -+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -diff --git a/3rdparty/cutlass/cmake/nop.cu b/3rdparty/cutlass/cmake/nop.cu -new file mode 100644 -index 0000000..efdb035 ---- /dev/null -+++ b/3rdparty/cutlass/cmake/nop.cu -@@ -0,0 +1,49 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Basic CUDA file for testing compiler flags. -+*/ -+ -+__device__ int inner() -+{ -+ return -1; -+} -+ -+__global__ void test() -+{ -+ inner(); -+} -+ -+int main() -+{ -+ test<<<1,1>>>(); -+ return 0; -+} -diff --git a/3rdparty/cutlass/examples/00_basic_gemm/basic_gemm.cu b/3rdparty/cutlass/examples/00_basic_gemm/basic_gemm.cu -new file mode 100644 -index 0000000..57df36b ---- /dev/null -+++ b/3rdparty/cutlass/examples/00_basic_gemm/basic_gemm.cu -@@ -0,0 +1,497 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example demonstrates how to call a CUTLASS GEMM kernel and provides a naive reference -+ matrix multiply kernel to verify its correctness. -+ -+ The CUTLASS Gemm template is instantiated in the function CutlassSgemmNN. This is kernel computes -+ the general matrix product (GEMM) using single-precision floating-point arithmetic and assumes -+ all matrices have column-major layout. -+ -+ The threadblock tile size is chosen as 128x128x8 which offers good performance for large matrices. -+ See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available -+ in CUTLASS. -+ -+ https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/ -+ -+ Aside from defining and launching the SGEMM kernel, this example does not use any other components -+ or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are -+ prevalent in the CUTLASS unit tests. -+ -+ This example has delibrately been kept similar to the basic_gemm example from cutass-1.3 to -+ highlight the minimum amount of differences needed to transition to cutlass-2.0. -+ -+ Cutlass-1.3 sgemm: https://github.com/NVIDIA/cutlass/blob/master/examples/00_basic_gemm/basic_gemm.cu -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Helper methods to check for errors -+#include "helper.h" -+ -+// -+// CUTLASS includes needed for single-precision GEMM kernel -+// -+ -+// Defines cutlass::gemm::device::Gemm, the generic Gemm computation template class. -+#include "cutlass/gemm/device/gemm.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// This function defines a CUTLASS GEMM kernel instantiation, constructs its parameters object, -+// and launches it on the CUDA device. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Define a CUTLASS GEMM template and launch a GEMM kernel. -+cudaError_t CutlassSgemmNN( -+ int M, -+ int N, -+ int K, -+ float alpha, -+ float const *A, -+ int lda, -+ float const *B, -+ int ldb, -+ float beta, -+ float *C, -+ int ldc) { -+ -+ // Define type definition for single-precision CUTLASS GEMM with column-major -+ // input matrices and 128x128x8 threadblock tile size (chosen by default). -+ // -+ // To keep the interface manageable, several helpers are defined for plausible compositions -+ // including the following example for single-precision GEMM. Typical values are used as -+ // default template arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for more details. -+ // -+ // To view the full gemm device API interface, see `cutlass/gemm/device/gemm.h` -+ -+ using ColumnMajor = cutlass::layout::ColumnMajor; -+ -+ using CutlassGemm = cutlass::gemm::device::Gemm; // Layout of C matrix -+ -+ // Define a CUTLASS GEMM type -+ CutlassGemm gemm_operator; -+ -+ // Construct the CUTLASS GEMM arguments object. -+ // -+ // One of CUTLASS's design patterns is to define gemm argument objects that are constructible -+ // in host code and passed to kernels by value. These may include pointers, strides, scalars, -+ // and other arguments needed by Gemm and its components. -+ // -+ // The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible -+ // arguments to kernels and (2.) minimized initialization overhead on kernel entry. -+ // -+ CutlassGemm::Arguments args({M , N, K}, // Gemm Problem dimensions -+ {A, lda}, // Tensor-ref for source matrix A -+ {B, ldb}, // Tensor-ref for source matrix B -+ {C, ldc}, // Tensor-ref for source matrix C -+ {C, ldc}, // Tensor-ref for destination matrix D (may be different memory than source C matrix) -+ {alpha, beta}); // Scalars used in the Epilogue -+ -+ // -+ // Launch the CUTLASS GEMM kernel. -+ // -+ -+ cutlass::Status status = gemm_operator(args); -+ -+ // -+ // Return a cudaError_t if the CUTLASS GEMM operator returned an error code. -+ // -+ -+ if (status != cutlass::Status::kSuccess) { -+ return cudaErrorUnknown; -+ } -+ -+ // Return success, if no errors were encountered. -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// The source code after this point in the file is generic CUDA using the CUDA Runtime API -+// and simple CUDA kernels to initialize matrices and compute the general matrix product. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to initialize a matrix with small integers. -+__global__ void InitializeMatrix_kernel( -+ float *matrix, -+ int rows, -+ int columns, -+ int seed = 0) { -+ -+ int i = threadIdx.x + blockIdx.x * blockDim.x; -+ int j = threadIdx.y + blockIdx.y * blockDim.y; -+ -+ if (i < rows && j < columns) { -+ int offset = i + j * rows; -+ -+ // Generate arbitrary elements. -+ int const k = 16807; -+ int const m = 16; -+ float value = float(((offset + seed) * k % m) - m / 2); -+ -+ matrix[offset] = value; -+ } -+} -+ -+/// Simple function to initialize a matrix to arbitrary small integers. -+cudaError_t InitializeMatrix(float *matrix, int rows, int columns, int seed = 0) { -+ -+ dim3 block(16, 16); -+ dim3 grid( -+ (rows + block.x - 1) / block.x, -+ (columns + block.y - 1) / block.y -+ ); -+ -+ InitializeMatrix_kernel<<< grid, block >>>(matrix, rows, columns, seed); -+ -+ return cudaGetLastError(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocates device memory for a matrix then fills with arbitrary small integers. -+cudaError_t AllocateMatrix(float **matrix, int rows, int columns, int seed = 0) { -+ cudaError_t result; -+ -+ size_t sizeof_matrix = sizeof(float) * rows * columns; -+ -+ // Allocate device memory. -+ result = cudaMalloc(reinterpret_cast(matrix), sizeof_matrix); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to allocate matrix: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ // Clear the allocation. -+ result = cudaMemset(*matrix, 0, sizeof_matrix); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to clear matrix device memory: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ // Initialize matrix elements to arbitrary small integers. -+ result = InitializeMatrix(*matrix, rows, columns, seed); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to initialize matrix: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ return result; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Naive reference GEMM computation. -+__global__ void ReferenceGemm_kernel( -+ int M, -+ int N, -+ int K, -+ float alpha, -+ float const *A, -+ int lda, -+ float const *B, -+ int ldb, -+ float beta, -+ float *C, -+ int ldc) { -+ -+ int i = threadIdx.x + blockIdx.x * blockDim.x; -+ int j = threadIdx.y + blockIdx.y * blockDim.y; -+ -+ if (i < M && j < N) { -+ float accumulator = 0; -+ -+ for (int k = 0; k < K; ++k) { -+ accumulator += A[i + k * lda] * B[k + j * ldb]; -+ } -+ -+ C[i + j * ldc] = alpha * accumulator + beta * C[i + j * ldc]; -+ } -+} -+ -+/// Reference GEMM computation. -+cudaError_t ReferenceGemm( -+ int M, -+ int N, -+ int K, -+ float alpha, -+ float const *A, -+ int lda, -+ float const *B, -+ int ldb, -+ float beta, -+ float *C, -+ int ldc) { -+ -+ dim3 block(16, 16); -+ dim3 grid( -+ (M + block.x - 1) / block.x, -+ (N + block.y - 1) / block.y -+ ); -+ -+ ReferenceGemm_kernel<<< grid, block >>>(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); -+ -+ return cudaGetLastError(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocate several matrices in GPU device memory and call a single-precision -+/// CUTLASS GEMM kernel. -+cudaError_t TestCutlassGemm(int M, int N, int K, float alpha, float beta) { -+ cudaError_t result; -+ -+ // -+ // Define several matrices to be used as operands to GEMM kernels. -+ // -+ -+ // Compute leading dimensions for each matrix. -+ int lda = M; -+ int ldb = K; -+ int ldc = M; -+ -+ // Compute size in bytes of the C matrix. -+ size_t sizeof_C = sizeof(float) * ldc * N; -+ -+ // Define pointers to matrices in GPU device memory. -+ float *A; -+ float *B; -+ float *C_cutlass; -+ float *C_reference; -+ -+ // -+ // Allocate matrices in GPU device memory with arbitrary seeds. -+ // -+ -+ result = AllocateMatrix(&A, M, K, 0); -+ -+ if (result != cudaSuccess) { -+ return result; -+ } -+ -+ result = AllocateMatrix(&B, K, N, 17); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ return result; -+ } -+ -+ result = AllocateMatrix(&C_cutlass, M, N, 101); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ cudaFree(B); -+ return result; -+ } -+ -+ result = AllocateMatrix(&C_reference, M, N, 101); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ cudaFree(B); -+ cudaFree(C_cutlass); -+ return result; -+ } -+ -+ result = cudaMemcpy(C_reference, C_cutlass, sizeof_C, cudaMemcpyDeviceToDevice); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy C_cutlass matrix to C_reference: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Launch CUTLASS GEMM. -+ // -+ -+ result = CutlassSgemmNN(M, N, K, alpha, A, lda, B, ldb, beta, C_cutlass, ldc); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "CUTLASS GEMM kernel failed: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Verify. -+ // -+ -+ // Launch reference GEMM -+ result = ReferenceGemm(M, N, K, alpha, A, lda, B, ldb, beta, C_reference, ldc); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Reference GEMM kernel failed: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // Copy to host and verify equivalence. -+ std::vector host_cutlass(ldc * N, 0); -+ std::vector host_reference(ldc * N, 0); -+ -+ result = cudaMemcpy(host_cutlass.data(), C_cutlass, sizeof_C, cudaMemcpyDeviceToHost); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy CUTLASS GEMM results: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ result = cudaMemcpy(host_reference.data(), C_reference, sizeof_C, cudaMemcpyDeviceToHost); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy Reference GEMM results: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Free device memory allocations. -+ // -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ // -+ // Test for bit equivalence of results. -+ // -+ -+ if (host_cutlass != host_reference) { -+ std::cerr << "CUTLASS results incorrect." << std::endl; -+ -+ return cudaErrorUnknown; -+ } -+ -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point to basic_gemm example. -+// -+// usage: -+// -+// 00_basic_gemm -+// -+int main(int argc, const char *arg[]) { -+ -+ // -+ // Parse the command line to obtain GEMM dimensions and scalar values. -+ // -+ -+ // GEMM problem dimensions. -+ int problem[3] = { 128, 128, 128 }; -+ -+ for (int i = 1; i < argc && i < 4; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> problem[i - 1]; -+ } -+ -+ // Scalars used for linear scaling the result of the matrix product. -+ float scalars[2] = { 1, 0 }; -+ -+ for (int i = 4; i < argc && i < 6; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> scalars[i - 4]; -+ } -+ -+ // -+ // Run the CUTLASS GEMM test. -+ // -+ -+ cudaError_t result = TestCutlassGemm( -+ problem[0], // GEMM M dimension -+ problem[1], // GEMM N dimension -+ problem[2], // GEMM K dimension -+ scalars[0], // alpha -+ scalars[1] // beta -+ ); -+ -+ if (result == cudaSuccess) { -+ std::cout << "Passed." << std::endl; -+ } -+ -+ // Exit. -+ return result == cudaSuccess ? 0 : -1; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/01_cutlass_utilities/cutlass_utilities.cu b/3rdparty/cutlass/examples/01_cutlass_utilities/cutlass_utilities.cu -new file mode 100644 -index 0000000..f4cc4d0 ---- /dev/null -+++ b/3rdparty/cutlass/examples/01_cutlass_utilities/cutlass_utilities.cu -@@ -0,0 +1,400 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example demonstrates several CUTLASS utilities in the context of a mixed-precision -+ floating-point matrix product computation. -+ -+ These utilities are intended to be useful supporting components for managing tensor and matrix -+ memory allocations, initializing and comparing results, and computing reference output. -+ -+ CUTLASS utilities are defined in the directory `tools/util`, and definitions appear -+ namespace `cutlass::` or an inner namespace therein. Operations in `cutlass::reference::` have -+ both host-side and device-side implementations, and the choice to use device-side initialization -+ and host-side verification in this example was arbitrary. -+ -+ -+ cutlass::half_t -+ -+ This is a numeric type implementing IEEE half-precision quantities. It is functional in host -+ and device code. In host-side code, CUTLASS_ENABLE_F16C optionally enables harware-accelerated -+ numeric conversion on x86-64 CPUs support F16C extensions. In device code, all available -+ hardware is used to implement conversion and numeric operations. -+ -+ -+ cutlass::HostTensor<> -+ -+ This template class simplifies the creation of tensors for all supported layouts. It simplifies -+ allocation and management of host- and device- memory allocations. -+ -+ This class offers methods device_view() and host_view() to provide TensorView objects for -+ device- and host-side memory allocations. -+ -+ -+ cutlass::reference::device::TensorFillRandomGaussian() -+ -+ This template function initializes elementsof a tensor to a random Gaussian distribution. It -+ uses cuRAND in device code to compute random numbers. -+ -+ -+ cutlass::reference::host::Gemm<> -+ -+ This template function computes the general matrix product. This template supports unique -+ data types for each matrix operand, the internal accumulation type, and the scalar parameters -+ alpha and beta. -+ -+ -+ cutlass::reference::host::TensorEquals() -+ -+ Compares two tensors of identical rank and returns true if values are bit equivalent. -+ -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+#include -+ -+// CUTLASS includes needed for half-precision GEMM kernel -+#include "cutlass/cutlass.h" -+#include "cutlass/core_io.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+// -+// CUTLASS utility includes -+// -+ -+// Defines operator<<() to write TensorView objects to std::ostream -+#include "cutlass/util/tensor_view_io.h" -+ -+// Defines cutlass::HostTensor<> -+#include "cutlass/util/host_tensor.h" -+ -+// Defines cutlass::half_t -+#include "cutlass/numeric_types.h" -+ -+// Defines device_memory::copy_device_to_device() -+#include "cutlass/util/device_memory.h" -+ -+// Defines cutlass::reference::device::TensorFillRandomGaussian() -+#include "cutlass/util/reference/device/tensor_fill.h" -+ -+// Defines cutlass::reference::host::TensorEquals() -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+// Defines cutlass::reference::host::Gemm() -+#include "cutlass/util/reference/host/gemm.h" -+ -+#pragma warning( disable : 4503) -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Define a CUTLASS GEMM template and launch a GEMM kernel. -+cudaError_t cutlass_hgemm_nn( -+ int M, -+ int N, -+ int K, -+ cutlass::half_t alpha, -+ cutlass::half_t const *A, -+ cutlass::layout::ColumnMajor::Stride::Index lda, -+ cutlass::half_t const *B, -+ cutlass::layout::ColumnMajor::Stride::Index ldb, -+ cutlass::half_t beta, -+ cutlass::half_t *C, -+ cutlass::layout::ColumnMajor::Stride::Index ldc) { -+ -+ // Define the GEMM operation -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, // ElementA -+ cutlass::layout::ColumnMajor, // LayoutA -+ cutlass::half_t, // ElementB -+ cutlass::layout::ColumnMajor, // LayoutB -+ cutlass::half_t, // ElementOutput -+ cutlass::layout::ColumnMajor // LayoutOutput -+ >; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op({ -+ {M, N, K}, -+ {A, lda}, -+ {B, ldb}, -+ {C, ldc}, -+ {C, ldc}, -+ {alpha, beta} -+ }); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return cudaErrorUnknown; -+ } -+ -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocate several matrices in GPU device memory and call a single-precision -+/// CUTLASS GEMM kernel. -+cudaError_t TestCutlassGemm(int M, int N, int K, cutlass::half_t alpha, cutlass::half_t beta) { -+ cudaError_t result; -+ -+ // -+ // Construct cutlass::HostTensor<> using the half-precision host-side type. -+ // -+ // cutlass::HostTensor<> allocates memory on both the host and device corresponding to rank=2 -+ // tensors in column-major layout. Explicit synchronization methods are offered to copy the -+ // tensor to the device or to the host. -+ // -+ -+ // M-by-K matrix of cutlass::half_t -+ cutlass::HostTensor A(cutlass::MatrixCoord(M, K)); -+ -+ // K-by-N matrix of cutlass::half_t -+ cutlass::HostTensor B(cutlass::MatrixCoord(K, N)); -+ -+ // M-by-N matrix of cutlass::half_t -+ cutlass::HostTensor C_cutlass(cutlass::MatrixCoord(M, N)); -+ -+ // M-by-N matrix of cutlass::half_t -+ cutlass::HostTensor C_reference(cutlass::MatrixCoord(M, N)); -+ -+ // -+ // Initialize matrices with small, random integers. -+ // -+ -+ // Arbitrary RNG seed value. Hard-coded for deterministic results. -+ uint64_t seed = 2080; -+ -+ // Gaussian random distribution -+ cutlass::half_t mean = 0.0_hf; -+ cutlass::half_t stddev = 5.0_hf; -+ -+ // Specify the number of bits right of the binary decimal that are permitted -+ // to be non-zero. A value of "0" here truncates random values to integers -+ int bits_less_than_one = 0; -+ -+ cutlass::reference::device::TensorFillRandomGaussian( -+ A.device_view(), -+ seed, -+ mean, -+ stddev, -+ bits_less_than_one -+ ); -+ -+ cutlass::reference::device::TensorFillRandomGaussian( -+ B.device_view(), -+ seed * 2019, -+ mean, -+ stddev, -+ bits_less_than_one -+ ); -+ -+ cutlass::reference::device::TensorFillRandomGaussian( -+ C_cutlass.device_view(), -+ seed * 1993, -+ mean, -+ stddev, -+ bits_less_than_one -+ ); -+ -+ -+ // Copy C_cutlass into C_reference so the GEMM is correct when beta != 0. -+ cutlass::device_memory::copy_device_to_device( -+ C_reference.device_data(), -+ C_cutlass.device_data(), -+ C_cutlass.capacity()); -+ -+ // Copy the device-side view into host memory -+ C_reference.sync_host(); -+ -+ // -+ // Launch the CUTLASS GEMM kernel -+ // -+ -+ result = cutlass_hgemm_nn( -+ M, -+ N, -+ K, -+ alpha, -+ A.device_data(), -+ A.stride(0), -+ B.device_data(), -+ B.stride(0), -+ beta, -+ C_cutlass.device_data(), -+ C_cutlass.stride(0) -+ ); -+ -+ if (result != cudaSuccess) { -+ return result; -+ } -+ -+ // -+ // Verify the result using a host-side reference -+ // -+ -+ // A and B were initialized using device-side procedures. The intent of this example is to -+ // use the host-side reference GEMM, so we must perform a device-to-host copy. -+ A.sync_host(); -+ B.sync_host(); -+ -+ // Copy CUTLASS's GEMM results into host memory. -+ C_cutlass.sync_host(); -+ -+ // Compute the reference result using the host-side GEMM reference implementation. -+ cutlass::reference::host::Gemm< -+ cutlass::half_t, // ElementA -+ cutlass::layout::ColumnMajor, // LayoutA -+ cutlass::half_t, // ElementB -+ cutlass::layout::ColumnMajor, // LayoutB -+ cutlass::half_t, // ElementOutput -+ cutlass::layout::ColumnMajor, // LayoutOutput -+ cutlass::half_t, -+ cutlass::half_t -+ > gemm_ref; -+ -+ gemm_ref( -+ {M, N, K}, // problem size (type: cutlass::gemm::GemmCoord) -+ alpha, // alpha (type: cutlass::half_t) -+ A.host_ref(), // A (type: TensorRef) -+ B.host_ref(), // B (type: TensorRef) -+ beta, // beta (type: cutlass::half_t) -+ C_reference.host_ref() // C (type: TensorRef) -+ ); -+ -+ // Compare reference to computed results. -+ if (!cutlass::reference::host::TensorEquals( -+ C_reference.host_view(), -+ C_cutlass.host_view())) { -+ -+ char const *filename = "errors_01_cutlass_utilities.csv"; -+ -+ std::cerr << "Error - CUTLASS GEMM kernel differs from reference. Wrote computed and reference results to '" << filename << "'" << std::endl; -+ -+ // -+ // On error, print C_cutlass and C_reference to std::cerr. -+ // -+ // Note, these are matrices of half-precision elements stored in host memory as -+ // arrays of type cutlass::half_t. -+ // -+ -+ std::ofstream file(filename); -+ -+ // Result of CUTLASS GEMM kernel -+ file << "\n\nCUTLASS =\n" << C_cutlass.host_view() << std::endl; -+ -+ // Result of reference computation -+ file << "\n\nReference =\n" << C_reference.host_view() << std::endl; -+ -+ // Return error code. -+ return cudaErrorUnknown; -+ } -+ -+ // Passed error check -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point to cutlass_utilities example. -+// -+// usage: -+// -+// 01_cutlass_utilities -+// -+int main(int argc, const char *arg[]) { -+ -+ // -+ // This example uses half-precision and is only suitable for devices with compute capabitliy 5.3 or greater. -+ // -+ -+ cudaDeviceProp prop; -+ cudaError_t result = cudaGetDeviceProperties(&prop, 0); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to query device properties with error " << cudaGetErrorString(result) << std::endl; -+ return -1; -+ } -+ -+ if (!(prop.major > 5 || (prop.major == 5 && prop.minor >= 3))) { -+ std::cerr << "This example uses half precision and is only suitable for devices with compute capability 5.3 or greater.\n"; -+ std::cerr << "You are using a CUDA device with compute capability " << prop.major << "." << prop.minor << std::endl; -+ return -1; -+ } -+ -+ // -+ // Parse the command line to obtain GEMM dimensions and scalar values. -+ // -+ -+ // GEMM problem dimensions: -+ int problem[3] = { 128, 128, 128 }; -+ -+ for (int i = 1; i < argc && i < 4; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> problem[i - 1]; -+ } -+ -+ // Linear scale factors in GEMM. Note, these are half-precision values stored as -+ // cutlass::half_t. -+ // -+ // Values outside the range of IEEE FP16 will overflow to infinity or underflow to zero. -+ // -+ cutlass::half_t scalars[2] = { 1.0_hf, 0.0_hf }; -+ -+ for (int i = 4; i < argc && i < 6; ++i) { -+ std::stringstream ss(arg[i]); -+ -+ ss >> scalars[i - 4]; // lexical cast to cutlass::half_t -+ } -+ -+ // -+ // Run the CUTLASS GEMM test. -+ // -+ -+ result = TestCutlassGemm( -+ problem[0], // GEMM M dimension -+ problem[1], // GEMM N dimension -+ problem[2], // GEMM K dimension -+ scalars[0], // alpha -+ scalars[1] // beta -+ ); -+ -+ if (result == cudaSuccess) { -+ std::cout << "Passed." << std::endl; -+ } -+ -+ // Exit. -+ return result == cudaSuccess ? 0 : -1; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/02_dump_reg_shmem/dump_reg_shmem.cu b/3rdparty/cutlass/examples/02_dump_reg_shmem/dump_reg_shmem.cu -new file mode 100644 -index 0000000..f70e721 ---- /dev/null -+++ b/3rdparty/cutlass/examples/02_dump_reg_shmem/dump_reg_shmem.cu -@@ -0,0 +1,186 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Demonstrate CUTLASS debugging tool for dumping fragments and shared -+ memory -+ */ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Standard Library includes -+ -+#include -+ -+// -+// CUTLASS includes -+// -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h" -+ -+#include "cutlass/util/debug.h" -+#include "cutlass/util/device_dump.h" -+ -+#define EXAMPLE_MATRIX_ROW 64 -+#define EXAMPLE_MATRIX_COL 32 -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_dump(typename GmemIterator::Params params, -+ typename GmemIterator::TensorRef ref) { -+ extern __shared__ Element shared_storage[]; -+ -+ // Construct the global iterator and load the data to the fragments. -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ GmemIterator gmem_iterator(params, ref.data(), -+ {EXAMPLE_MATRIX_ROW, EXAMPLE_MATRIX_COL}, -+ tb_thread_id); -+ -+ typename GmemIterator::Fragment frag; -+ -+ frag.clear(); -+ gmem_iterator.load(frag); -+ -+ // Call dump_fragment() with different parameters. -+ if (threadIdx.x == 0 && blockIdx.x == 0) -+ printf("\nAll threads dump all the elements:\n"); -+ cutlass::debug::dump_fragment(frag); -+ -+ if (threadIdx.x == 0 && blockIdx.x == 0) -+ printf("\nFirst thread dumps all the elements:\n"); -+ cutlass::debug::dump_fragment(frag, /*N = */ 1); -+ -+ if (threadIdx.x == 0 && blockIdx.x == 0) -+ printf("\nFirst thread dumps first 16 elements:\n"); -+ cutlass::debug::dump_fragment(frag, /*N = */ 1, /*M = */ 16); -+ -+ if (threadIdx.x == 0 && blockIdx.x == 0) -+ printf("\nFirst thread dumps first 16 elements with a stride of 8:\n"); -+ cutlass::debug::dump_fragment(frag, /*N = */ 1, /*M = */ 16, /*S = */ 8); -+ -+ // Construct the shared iterator and store the data to the shared memory. -+ SmemIterator smem_iterator( -+ typename SmemIterator::TensorRef( -+ {shared_storage, SmemIterator::Layout::packed( -+ {EXAMPLE_MATRIX_ROW, EXAMPLE_MATRIX_COL})}), -+ tb_thread_id); -+ -+ smem_iterator.store(frag); -+ -+ // Call dump_shmem() with different parameters. -+ if (threadIdx.x == 0 && blockIdx.x == 0) printf("\nDump all the elements:\n"); -+ cutlass::debug::dump_shmem(shared_storage, -+ EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL); -+ -+ if (threadIdx.x == 0 && blockIdx.x == 0) -+ printf("\nDump all the elements with a stride of 8:\n"); -+ cutlass::debug::dump_shmem( -+ shared_storage, EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL, /*S = */ 8); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point for dump_reg_shmem example. -+// -+// usage: -+// -+// 02_dump_reg_shmem -+// -+int main() { -+ // Initialize a 64x32 column major matrix with sequential data (1,2,3...). -+ using Element = cutlass::half_t; -+ using Layout = cutlass::layout::ColumnMajor; -+ -+ cutlass::HostTensor matrix( -+ {EXAMPLE_MATRIX_ROW, EXAMPLE_MATRIX_COL}); -+ cutlass::reference::host::BlockFillSequential(matrix.host_data(), -+ matrix.capacity()); -+ -+ // Dump the matrix. -+ std::cout << "Matrix:\n" << matrix.host_view() << "\n"; -+ -+ // Copy the matrix to the device. -+ matrix.sync_device(); -+ -+ // Define a global iterator, a shared iterator and their thread map. -+ using ThreadMap = cutlass::transform::PitchLinearWarpRakedThreadMap< -+ cutlass::layout::PitchLinearShape, -+ 32, cutlass::layout::PitchLinearShape<8, 4>, 8>; -+ -+ using GmemIterator = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, Element, -+ Layout, 1, ThreadMap>; -+ -+ typename GmemIterator::Params params(matrix.layout()); -+ -+ using SmemIterator = cutlass::transform::threadblock::RegularTileIterator< -+ cutlass::MatrixShape, Element, -+ cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous<16, 64>, 1, -+ ThreadMap>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ int smem_size = -+ int(sizeof(Element) * EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL); -+ -+ kernel_dump -+ <<>>(params, matrix.device_ref()); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cout << "Failed" << std::endl; -+ } -+ -+ return (result == cudaSuccess ? 0 : -1); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/03_visualize_layout/options.h b/3rdparty/cutlass/examples/03_visualize_layout/options.h -new file mode 100644 -index 0000000..fd99b1c ---- /dev/null -+++ b/3rdparty/cutlass/examples/03_visualize_layout/options.h -@@ -0,0 +1,121 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+ -+// Cutlass command line parser -+#include "cutlass/util/command_line.h" -+ -+class Options { -+public: -+ -+ bool help; -+ bool good; -+ std::vector extent; ///< extent of tile to fill -+ std::vector stride; ///< stride vector for layout function -+ std::vector output_shape; ///< output shape -+ int vectorize; ///< sequences of consecutive output elements are concatenated into a vector -+ /// if, and only if, they were consecutive in source memory -+ -+public: -+ -+ /// Options -+ Options(): -+ help(false), -+ good(true), -+ extent({32, 8}), -+ stride({32}), -+ output_shape({16, 8}), -+ vectorize(1) { -+ -+ } -+ -+ /// Constructs from command line parser -+ Options(cutlass::CommandLine const & cmd_line): help(false), good(true) { -+ -+ if (cmd_line.check_cmd_line_flag("help") || -+ cmd_line.check_cmd_line_flag("h")) { -+ -+ help = true; -+ } -+ -+ if (cmd_line.check_cmd_line_flag("extent")) { -+ cmd_line.get_cmd_line_arguments("extent", extent); -+ } -+ else { -+ extent = {32, 8}; -+ } -+ -+ if (cmd_line.check_cmd_line_flag("stride")) { -+ cmd_line.get_cmd_line_arguments("stride", stride); -+ } -+ -+ int default_output_shape[] = {16, 8}; -+ -+ if (cmd_line.check_cmd_line_flag("output-shape")) { -+ cmd_line.get_cmd_line_arguments("output-shape", output_shape); -+ } -+ -+ for (int i = int(output_shape.size()); i < 2; ++i) { -+ output_shape.push_back(default_output_shape[i]); -+ } -+ -+ if (cmd_line.check_cmd_line_flag("vectorize")) { -+ cmd_line.get_cmd_line_argument("vectorize", vectorize); -+ } -+ else { -+ vectorize = 1; -+ } -+ -+ if (output_shape.front() % vectorize) { -+ -+ std::cerr << "Error: --vectorize=" << vectorize -+ << " must divide contiguous elements in --output-shape=" -+ << output_shape.at(0) << "," << output_shape.at(1) << std::endl; -+ -+ good = false; -+ } -+ } -+ -+ /// Prints usage statement -+ static void print_usage(std::ostream &out) { -+ out -+ << " Options:\n" -+ << " --help Displays this help message.\n" -+ << " --extent= Specifies the layout-specific extent (as comma-delimited array).\n" -+ << " --stride= Specifies the layout-specific stride vector (comma-delimited array)\n" -+ << " --output-shape= Specifies the dimensions of a row-major output matrix. \n" -+ << " --vectorize= If possible, vectorizes the output into vectors of consecutive elements\n"; -+ } -+}; -diff --git a/3rdparty/cutlass/examples/03_visualize_layout/register_layout.cu b/3rdparty/cutlass/examples/03_visualize_layout/register_layout.cu -new file mode 100644 -index 0000000..423bfcc ---- /dev/null -+++ b/3rdparty/cutlass/examples/03_visualize_layout/register_layout.cu -@@ -0,0 +1,145 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS layout visualization example -+*/ -+ -+#include -+#include -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm70.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "visualize_layout.h" -+#include "register_layout.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void RegisterLayouts(std::map > &layouts) { -+ -+ struct { -+ char const *name; -+ VisualizeLayoutBase *ptr; -+ } layout_pairs[] = { -+ -+ {"PitchLinear", new VisualizeLayout}, -+ {"ColumnMajor", new VisualizeLayout}, -+ {"RowMajor", new VisualizeLayout}, -+ {"ColumnMajorInterleaved<4>", -+ new VisualizeLayout>}, -+ {"RowMajorInterleaved<4>", -+ new VisualizeLayout>}, -+ // All Ampere/Turing H/Integer matrix multiply tensor core kernels uses the same swizzling -+ // layout implementation with different templates. -+ // -+ // mma.sync.aligned.m8n8k128.s32.b1.b1.s32 Interleaved-256 -+ // mma.sync.aligned.m16n8k256.s32.b1.b1.s32 Interleaved-256 -+ {"TensorOpMultiplicand<1,256>", -+ new VisualizeLayout>}, -+ // mma.sync.aligned.m8n8k128.s32.b1.b1.s32 TN kblock512 -+ // mma.sync.aligned.m16n8k256.s32.b1.b1.s32 TN kblock512 -+ {"TensorOpMultiplicand<1,512>", -+ new VisualizeLayout>}, -+ // mma.sync.aligned.m16n8k256.s32.b1.b1.s32 TN kblock1024 -+ {"TensorOpMultiplicand<1,1024>", -+ new VisualizeLayout>}, -+ // Integer matrix multiply.int4 8832 Interleaved-64 -+ // Integer matrix multiply.int4 16864 Interleaved-64 -+ {"TensorOpMultiplicand<4,64>", -+ new VisualizeLayout>}, -+ // Integer matrix multiply.int4 8832 TN kblock128 -+ // Integer matrix multiply.int4 16864 TN kblock128 -+ {"TensorOpMultiplicand<4,128>", -+ new VisualizeLayout>}, -+ // Integer matrix multiply.int4 16864 TN kblock256 -+ {"TensorOpMultiplicand<4,256>", -+ new VisualizeLayout>}, -+ // Integer matrix multiply 8816 Interleaved-32 -+ // Integer matrix multiply 16832 Interleaved-32 -+ {"TensorOpMultiplicand<8,32>", -+ new VisualizeLayout>}, -+ // Integer matrix multiply 8816 TN kblock64 -+ // Integer matrix multiply 16832 TN kblock64 -+ {"TensorOpMultiplicand<8,64>", -+ new VisualizeLayout>}, -+ // Integer matrix multiply 16832 TN kblock128 -+ {"TensorOpMultiplicand<8,128>", -+ new VisualizeLayout>}, -+ // Matrix Multiply 1688 TN kblock32 -+ // Matrix multiply 16816 TN kblock32 -+ {"TensorOpMultiplicand<16,32>", -+ new VisualizeLayout>}, -+ // Matrix multiply 1688 NT -+ // Matrix multiply 16816 NT -+ // Matrix multiply 16816 TN kblock64 -+ {"TensorOpMultiplicand<16,64>", -+ new VisualizeLayout>}, -+ // Matrix multiply 1688.TF32 TN kblock16 -+ {"TensorOpMultiplicand<32,16>", -+ new VisualizeLayout>}, -+ // Matrix multiply 1688.TF32 TN kblock32 -+ {"TensorOpMultiplicand<32,32>", -+ new VisualizeLayout>}, -+ // Matrix multiply 1688 NT -+ {"TensorOpMultiplicandCongruous<32,32>", -+ new VisualizeLayout< -+ cutlass::layout::TensorOpMultiplicandCongruous<32, 32>>}, -+ // Matrix multiply 884 NT -+ {"TensorOpMultiplicandCongruous<64,16>", -+ new VisualizeLayout< -+ cutlass::layout::TensorOpMultiplicandCongruous<64, 16>>}, -+ // Matrix multiply 884 TN -+ {"TensorOpMultiplicand64bCrosswise", -+ new VisualizeLayout}, -+ {"TensorOpMultiplicandCongruous<128,4>", -+ new VisualizeLayout< -+ cutlass::layout::TensorOpMultiplicandCongruous<128, 4>>}, -+ {"TensorOpMultiplicandCrosswise<128,4>", -+ new VisualizeLayout< -+ cutlass::layout::TensorOpMultiplicandCrosswise<128, 4>>}, -+ {"VoltaTensorOpMultiplicandCongruous<16>", -+ new VisualizeLayout< -+ cutlass::layout::VoltaTensorOpMultiplicandCongruous<16>>}, -+ {"VoltaTensorOpMultiplicandCrosswise<16,32>", -+ new VisualizeLayout< -+ cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>>} -+ }; -+ -+ for (auto layout : layout_pairs) { -+ layouts.emplace(std::string(layout.name), std::unique_ptr(layout.ptr)); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/03_visualize_layout/register_layout.h b/3rdparty/cutlass/examples/03_visualize_layout/register_layout.h -new file mode 100644 -index 0000000..bb5f893 ---- /dev/null -+++ b/3rdparty/cutlass/examples/03_visualize_layout/register_layout.h -@@ -0,0 +1,59 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS layout visualization example -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "options.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct VisualizeLayoutBase { -+ virtual bool visualize(Options const &) = 0; -+ virtual bool verify(bool verbose, std::ostream &out) = 0; -+ virtual void print_csv(std::ostream &out, char delim = '|', char new_line = '\n') = 0; -+ virtual std::ostream &print_help(std::ostream &out) { -+ return out; -+ } -+ virtual ~VisualizeLayoutBase() { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void RegisterLayouts(std::map > &layouts); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/03_visualize_layout/visualize_layout.h b/3rdparty/cutlass/examples/03_visualize_layout/visualize_layout.h -new file mode 100644 -index 0000000..cef8579 ---- /dev/null -+++ b/3rdparty/cutlass/examples/03_visualize_layout/visualize_layout.h -@@ -0,0 +1,383 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS layout visualization example -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/coord.h" -+#include "cutlass/util/reference/host/tensor_foreach.h" -+ -+#include "register_layout.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Permits copying dynamic vectors into static-length vectors -+template -+struct vector_to_coord { -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ coord[Rank - 1] = vec.at(Rank - 1); -+ -+ if (Rank > 1) { -+ vector_to_coord(coord, vec); -+ } -+ } -+}; -+ -+/// Permits copying dynamic vectors into static-length vectors -+template -+struct vector_to_coord { -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ coord[0] = vec.at(0); -+ } -+}; -+ -+/// Permits copying dynamic vectors into static-length vectors -+template -+struct vector_to_coord { -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+std::ostream &operator<<(std::ostream &out, std::vector const &vec) { -+ auto it = vec.begin(); -+ if (it != vec.end()) { -+ out << *it; -+ for (++it; it != vec.end(); ++it) { -+ out << ", " << *it; -+ } -+ } -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Permits copying static-length vectors into dynamic vectors -+template -+struct coord_to_vector { -+ -+ coord_to_vector(std::vector &vec, TensorCoord const &coord) { -+ -+ vec.at(Rank - 1) = coord[Rank - 1]; -+ coord_to_vector(vec, coord); -+ } -+}; -+ -+/// Permits copying static-length vectors into dynamic vectors -+template -+struct coord_to_vector { -+ -+ coord_to_vector(std::vector &vec, TensorCoord const &coord) { -+ -+ vec.at(0) = coord[0]; -+ } -+}; -+ -+/// Permits copying static-length vectors into dynamic vectors -+template -+struct coord_to_vector { -+ -+ coord_to_vector(std::vector &vec, TensorCoord const &coord) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure representing an element in source memory -+struct Element { -+ -+ std::vector coord; ///< logical coordinate of element (as vector) -+ int offset; ///< linear offset from source memory -+ int color; ///< enables coloring each element to indicate -+ -+ /// Default ctor -+ inline Element(): offset(-1), color(0) { } -+ -+ /// Construct from logical coordinate and initial offset -+ inline Element( -+ std::vector const &coord_, -+ int offset_, -+ int color_ = 0 -+ ): -+ coord(coord_), offset(offset_), color(color_) { } -+ -+ /// Returns true if element is in a defined state -+ inline bool valid() const { -+ return offset >= 0; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Visualizes memory layouts by constructing a 'shape' -+template -+class VisualizeLayout : public VisualizeLayoutBase { -+public: -+ -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Stride = typename Layout::Stride; -+ -+public: -+ -+ Options options; -+ Layout layout; -+ TensorCoord extent; -+ std::vector elements; -+ -+public: -+ -+ /// Initializes the problem space -+ VisualizeLayout() { -+ -+ } -+ -+ /// visualization method -+ bool visualize(Options const &options_) { -+ -+ options = options_; -+ -+ if (options.extent.size() != TensorCoord::kRank) { -+ -+ std::cerr -+ << "--extent must have rank " << TensorCoord::kRank -+ << " (given: " << options.extent.size() << ")" << std::endl; -+ -+ return false; -+ } -+ -+ vector_to_coord(extent, options.extent); -+ -+ // Construct the layout for a packed tensor -+ if (options.stride.empty()) { -+ -+ layout = Layout::packed(extent); -+ } -+ else if (options.stride.size() != Stride::kRank) { -+ -+ std::cerr -+ << "--stride must have rank " << Stride::kRank -+ << " (given: " << options.stride.size() << ")" << std::endl; -+ -+ return false; -+ } -+ else { -+ // Stride from -+ Stride stride; -+ vector_to_coord(stride, options.stride); -+ -+ layout = Layout(stride); -+ } -+ -+ // Resize elements, setting elements to 'undefined' state -+ elements.resize(layout.capacity(extent)); -+ -+ // enumerate points in tensor space and assign -+ cutlass::reference::host::TensorForEachLambda( -+ extent, -+ [&](TensorCoord coord) { -+ -+ std::vector coord_vec(TensorCoord::kRank, 0); -+ coord_to_vector(coord_vec, coord); -+ -+ int offset = int(layout(coord)); -+ -+ if (offset >= int(elements.size())) { -+ std::cerr -+ << "Layout error - " << coord_vec -+ << " is out of range (computed offset: " << offset -+ << ", capacity: " << elements.size() << std::endl; -+ -+ throw std::out_of_range("(TensorForEach) layout error - coordinate out of range"); -+ } -+ -+ elements.at(offset) = Element(coord_vec, offset); -+ }); -+ -+ return true; -+ } -+ -+ /// Verifies the layout satisfies vectorization requirements -+ bool verify(bool verbose, std::ostream &out) { -+ return true; -+ } -+ -+private: -+ -+ /// returns a pair (is_vectorizable, one_changing_rank) to determine if a -+ /// vector exists (consecutive logical coordinates or uniformly invalid) -+ /// at the given location. -+ std::pair< bool, int > _is_vectorizable(int i) const { -+ // (all elements are invalid) or -+ // (all elements are valid AND -+ // exactly one rank is changing AND -+ // elements are consecutive) -+ -+ // Don't need vectorization. -+ if (options.vectorize <= 2) return std::make_pair(false, -1); -+ -+ // Boundary check. -+ if (i > elements.size() || (i + options.vectorize - 1) > elements.size()) -+ return std::make_pair(false, -1); -+ -+ // Check if either all elements are valid or invalid. -+ bool all_elements_invalid = std::all_of( -+ elements.begin() + i, elements.begin() + i + options.vectorize, -+ [](Element const &e) { return !e.valid(); }); -+ -+ bool all_elements_valid = std::all_of( -+ elements.begin() + i, elements.begin() + i + options.vectorize, -+ [](Element const &e) { return e.valid(); }); -+ -+ if (!all_elements_invalid && !all_elements_valid) -+ return std::make_pair(false, -1); -+ -+ // From here, it is vectorizable. -+ if (all_elements_invalid) return std::make_pair(true, -1); -+ -+ // Check if only exactly one rank is changing. -+ int one_changing_rank = -1; -+ for (int j = 0; j < options.vectorize; ++j) { -+ for (int r = 0; r < TensorCoord::kRank; ++r) { -+ if (elements.at(i + j).coord.at(r) != elements.at(i).coord.at(r)) { -+ if (one_changing_rank == -1) { -+ one_changing_rank = r; -+ } else if (one_changing_rank != r) { -+ return std::make_pair(false, -1); -+ } -+ } -+ } -+ } -+ -+ return std::make_pair(true, one_changing_rank); -+ } -+ -+ /// Prints a vector of elements -+ void _print_vector(std::ostream &out, int i, int one_changing_rank) { -+ Element const &base_element = elements.at(i); -+ if (base_element.valid()) { -+ out << "("; -+ for (int r = 0; r < TensorCoord::kRank; ++r) { -+ if (r) { -+ out << ", "; -+ } -+ -+ if (r == one_changing_rank) { -+ out -+ << base_element.coord.at(r) -+ << ".." -+ << (base_element.coord.at(r) + options.vectorize - 1); -+ } -+ else { -+ out << base_element.coord.at(r); -+ } -+ } -+ out << ")"; -+ } -+ else { -+ out << " "; -+ } -+ } -+ -+ /// Prints a single element -+ void _print_element(std::ostream &out, int k) { -+ Element const &element = elements.at(k); -+ if (element.valid()) { -+ out << "("; -+ for (int v = 0; v < TensorCoord::kRank; ++v) { -+ out << (v ? ", " : "") << element.coord.at(v); -+ } -+ out << ")"; -+ } -+ else { -+ out << " "; -+ } -+ } -+ -+public: -+ -+ /// Pretty-prints the layout to the console -+ void print_csv(std::ostream &out, char delim = '|', char new_line = '\n') { -+ int row = -1; -+ -+ for (int i = 0; i < int(elements.size()); i += options.vectorize) { -+ if (i % options.output_shape.at(0)) { -+ out << delim; -+ } -+ else { -+ if (row >= 0) { -+ out << new_line; -+ } -+ ++row; -+ if (row == options.output_shape.at(1)) { -+ out << new_line; -+ row = 0; -+ } -+ } -+ -+ auto is_vector = _is_vectorizable(i); -+ -+ if (is_vector.first) { -+ _print_vector(out, i, is_vector.second); // print a vector starting at element i -+ } -+ else { -+ for (int j = 0; j < options.vectorize; ++j) { // print individual elements [i..i+j) -+ _print_element(out, i + j); -+ } -+ } -+ } -+ -+ out << new_line << std::flush; -+ } -+ -+ /// Help message -+ virtual std::ostream &print_help(std::ostream &out) { -+ out << "TensorCoord rank " << TensorCoord::kRank << ", Stride rank: " << Stride::kRank; -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/04_tile_iterator/tile_iterator.cu b/3rdparty/cutlass/examples/04_tile_iterator/tile_iterator.cu -new file mode 100644 -index 0000000..8146a09 ---- /dev/null -+++ b/3rdparty/cutlass/examples/04_tile_iterator/tile_iterator.cu -@@ -0,0 +1,221 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example demonstrates how to use the PredicatedTileIterator in CUTLASS to load data from -+ addressable memory, and then store it back into addressable memory. -+ -+ TileIterator is a core concept in CUTLASS that enables efficient loading and storing of data to -+ and from addressable memory. The PredicateTileIterator accepts a ThreadMap type, which defines -+ the mapping of threads to a "tile" in memory. This separation of concerns enables user-defined -+ thread mappings to be specified. -+ -+ In this example, a PredicatedTileIterator is used to load elements from a tile in global memory, -+ stored in column-major layout, into a fragment and then back into global memory in the same -+ layout. -+ -+ This example uses CUTLASS utilities to ease the matrix operations. -+ -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// CUTLASS includes -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+ -+// -+// CUTLASS utility includes -+// -+ -+// Defines operator<<() to write TensorView objects to std::ostream -+#include "cutlass/util/tensor_view_io.h" -+ -+// Defines cutlass::HostTensor<> -+#include "cutlass/util/host_tensor.h" -+ -+// Defines cutlass::reference::host::TensorFill() and -+// cutlass::reference::host::TensorFillBlockSequential() -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#pragma warning( disable : 4503) -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Define PredicatedTileIterators to load and store a M-by-K tile, in column major layout. -+ -+template -+__global__ void copy( -+ typename Iterator::Params dst_params, -+ typename Iterator::Element *dst_pointer, -+ typename Iterator::Params src_params, -+ typename Iterator::Element *src_pointer, -+ cutlass::Coord<2> extent) { -+ -+ -+ Iterator dst_iterator(dst_params, dst_pointer, extent, threadIdx.x); -+ Iterator src_iterator(src_params, src_pointer, extent, threadIdx.x); -+ -+ // PredicatedTileIterator uses PitchLinear layout and therefore takes in a PitchLinearShape. -+ // The contiguous dimension can be accessed via Iterator::Shape::kContiguous and the strided -+ // dimension can be accessed via Iterator::Shape::kStrided -+ int iterations = (extent[1] + Iterator::Shape::kStrided - 1) / Iterator::Shape::kStrided; -+ -+ typename Iterator::Fragment fragment; -+ -+ for(int i = 0; i < fragment.size(); ++i) { -+ fragment[i] = 0; -+ } -+ -+ src_iterator.load(fragment); -+ dst_iterator.store(fragment); -+ -+ -+ ++src_iterator; -+ ++dst_iterator; -+ -+ for(; iterations > 1; --iterations) { -+ -+ src_iterator.load(fragment); -+ dst_iterator.store(fragment); -+ -+ ++src_iterator; -+ ++dst_iterator; -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Initializes the source tile with sequentially increasing values and performs the copy into -+// the destination tile using two PredicatedTileIterators, one to load the data from addressable -+// memory into a fragment (regiser-backed array of elements owned by each thread) and another to -+// store the data from the fragment back into the addressable memory of the destination tile. -+ -+cudaError_t TestTileIterator(int M, int K) { -+ -+ // For this example, we chose a <64, 4> tile shape. The PredicateTileIterator expects -+ // PitchLinearShape and PitchLinear layout. -+ using Shape = cutlass::layout::PitchLinearShape<64, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int; -+ int const kThreads = 32; -+ -+ // ThreadMaps define how threads are mapped to a given tile. The PitchLinearStripminedThreadMap -+ // stripmines a pitch-linear tile among a given number of threads, first along the contiguous -+ // dimension then along the strided dimension. -+ using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap; -+ -+ // Define the PredicateTileIterator, using TileShape, Element, Layout, and ThreadMap types -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator< -+ Shape, Element, Layout, 1, ThreadMap>; -+ -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(M, K); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(M, K); -+ -+ // Allocate source and destination tensors -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ -+ // Initialize destination tensor with all -1s -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ // Initialize source tensor with sequentially increasing values -+ cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ // Launch copy kernel to perform the copy -+ copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ if(result != cudaSuccess) { -+ std::cerr << "Error - kernel failed." << std::endl; -+ return result; -+ } -+ -+ dst_tensor.sync_host(); -+ -+ // Verify results -+ for(int s = 0; s < alloc_extent[1]; ++s) { -+ for(int c = 0; c < alloc_extent[0]; ++c) { -+ -+ Element expected = Element(0); -+ -+ if(c < copy_extent[0] && s < copy_extent[1]) { -+ expected = src_tensor.at({c, s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({c, s}); -+ bool equal = (expected == got); -+ -+ if(!equal) { -+ std::cerr << "Error - source tile differs from destination tile." << std::endl; -+ return cudaErrorUnknown; -+ } -+ } -+ } -+ -+ return cudaSuccess; -+} -+ -+int main(int argc, const char *arg[]) { -+ -+ cudaError_t result = TestTileIterator(57, 35); -+ -+ if(result == cudaSuccess) { -+ std::cout << "Passed." << std::endl; -+ } -+ -+ // Exit -+ return result == cudaSuccess ? 0 : -1; -+} -+ -diff --git a/3rdparty/cutlass/examples/05_batched_gemm/batched_gemm.cu b/3rdparty/cutlass/examples/05_batched_gemm/batched_gemm.cu -new file mode 100644 -index 0000000..ab85361 ---- /dev/null -+++ b/3rdparty/cutlass/examples/05_batched_gemm/batched_gemm.cu -@@ -0,0 +1,466 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/device/gemm_array.h" -+#include "cutlass/gemm/device/gemm_batched.h" -+ -+#pragma warning( disable : 4503) -+ -+/* -+This example demonstrates how to use cutlass to compute a batched strided gemm in two different ways: -+ 1. By specifying pointers to the first matrices of the batch and the stride between the consecutive -+ matrices of the batch (this is called a strided batched gemm). -+ 2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm). -+In this example, both A and B matrix are non-transpose and column major matrix -+batched_C = batched_A x batched_B -+As an example, matrix C can be seen as -+----------------------------------------------------------- -+(0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) | -+----------------------------------------------------------- -+(0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) | -+----------------------------------------------------------- -+(0,2,0) | (0,2,1) | (0,2,2) | (1,2,0) | (1,2,1) | (1,2,2) | -+----------------------------------------------------------- -+(0,3,0) | (0,3,1) | (0,3,2) | (1,3,0) | (1,3,1) | (1,3,2) | -+----------------------------------------------------------- -+(0,4,0) | (0,4,1) | (0,4,2) | (1,4,0) | (1,4,1) | (1,4,2) | -+----------------------------------------------------------- -+(0,5,0) | (0,5,1) | (0,5,2) | (1,5,0) | (1,5,1) | (1,5,2) | -+----------------------------------------------------------- -+ batch 0 | batch 1 -+where we denote each element with (batch_idx, row_idx, column_idx) -+In this example, batch size is 2, M is 6 and N is 3 -+The stride (batch_stride_C) between the first element of two batches is ldc * n -+ -+matrix A can be seen as -+--------------------------------------- -+(0,0,0) | (0,0,1) | (1,0,0) | (1,0,1) | -+--------------------------------------- -+(0,1,0) | (0,1,1) | (1,1,0) | (1,1,1) | -+--------------------------------------- -+(0,2,0) | (0,2,1) | (1,2,0) | (1,2,1) | -+--------------------------------------- -+(0,3,0) | (0,3,1) | (1,3,0) | (1,3,1) | -+--------------------------------------- -+(0,4,0) | (0,4,1) | (1,4,0) | (1,4,1) | -+--------------------------------------- -+(0,5,0) | (0,5,1) | (1,5,0) | (1,5,1) | -+--------------------------------------- -+ batch 0 | batch 1 -+, where batch size is 2, M is 6 and K is 2 -+The stride (batch_stride_A) between the first element of two batches is lda * k -+ -+matrix B can be seen as -+----------------------------- -+(0,0,0) | (0,0,1) | (0,0,2) | -+----------------------------- batch 0 -+(0,1,0) | (0,1,1) | (0,1,2) | -+------------------------------------- -+(1,0,0) | (1,0,1) | (1,0,2) | -+----------------------------- batch 1 -+(1,1,0) | (1,1,1) | (1,1,2) | -+----------------------------- -+, where the batch size is 2, N is 3 and K is 2 -+The stride (batch_stride_B) between the first element of two batches is k -+ -+ -+*/ -+ -+cudaError_t cutlass_array_sgemm( -+ int m, -+ int n, -+ int k, -+ float alpha, -+ float const * const *A, -+ int lda, -+ float const * const *B, -+ int ldb, -+ float * const *C, -+ int ldc, -+ float beta, -+ int batch_count) { -+ -+ using Gemm = cutlass::gemm::device::GemmArray< -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor -+ >; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, -+ A, lda, -+ B, ldb, -+ C, ldc, -+ C, ldc, -+ {alpha, beta}, -+ batch_count -+ }); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return cudaErrorUnknown; -+ } -+ -+ return cudaSuccess; -+} -+ -+cudaError_t cutlass_strided_batched_sgemm( -+ int m, -+ int n, -+ int k, -+ float alpha, -+ float const *A, -+ int lda, -+ long long int batch_stride_A, -+ float const *B, -+ int ldb, -+ long long int batch_stride_B, -+ float *C, -+ int ldc, -+ long long int batch_stride_C, -+ float beta, -+ int batch_count) { -+ -+ using Gemm = cutlass::gemm::device::GemmBatched< -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor -+ >; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, -+ {A, lda}, -+ batch_stride_A, -+ {B, ldb}, -+ batch_stride_B, -+ {C, ldc}, -+ batch_stride_C, -+ {C, ldc}, -+ batch_stride_C, -+ {alpha, beta}, -+ batch_count -+ }); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return cudaErrorUnknown; -+ } -+ -+ return cudaSuccess; -+} -+ -+template -+cudaError_t strided_batched_gemm_nn_reference( -+ int m, -+ int n, -+ int k, -+ T alpha, -+ std::vector const &A, -+ int lda, -+ long long int batch_stride_A, -+ std::vector const &B, -+ int ldb, -+ long long int batch_stride_B, -+ std::vector &C, -+ int ldc, -+ long long int batch_stride_C, -+ T beta, -+ int batch_count) { -+ /* -+ strided batched gemm NN -+ */ -+ -+ cudaError_t result = cudaSuccess; -+ -+ if (A.size() < lda * k * batch_count) { -+ std::cout << "the size of A is too small" << std::endl; -+ return cudaErrorInvalidValue; -+ } -+ if (B.size() < ldb * n) { -+ std::cout << "the size of B is too small" << std::endl; -+ return cudaErrorInvalidValue; -+ } -+ if (C.size() < ldc * n * batch_count) { -+ std::cout << "the size of C is too small" << std::endl; -+ return cudaErrorInvalidValue; -+ } -+ -+ for (int batch_idx = 0; batch_idx < batch_count; batch_idx++) { -+ for (int n_idx = 0; n_idx < n; n_idx++) { -+ for (int m_idx = 0; m_idx < m; m_idx++) { -+ T accum = beta * C[batch_idx * batch_stride_C + n_idx * ldc + m_idx]; -+ for (int k_idx = 0; k_idx < k; k_idx++) { -+ accum += alpha -+ * A[batch_idx * batch_stride_A + k_idx * lda + m_idx] -+ * B[batch_idx * batch_stride_B + n_idx * ldb + k_idx]; -+ } -+ C[batch_idx * batch_stride_C + n_idx * ldc + m_idx] = accum; -+ } -+ } -+ } -+ -+ return result; -+} -+ -+ -+cudaError_t run_batched_gemm(bool use_array) { -+ -+ const char* gemm_desc = use_array ? "array" : "strided batched"; -+ std::cout << "Running " << gemm_desc << " gemm" << std::endl; -+ -+ // Arbitrary problem size -+ int const m = 520; -+ int const n = 219; -+ int const k = 129; -+ int const batch_count = 17; -+ -+ // A, B are non-transpose, column major -+ int const lda = m; -+ int const ldb = k * batch_count; -+ int const ldc = m; -+ -+ int const count_A = batch_count * lda * k; -+ int const count_B = ldb * n; -+ int const count_C = batch_count * ldc * n; -+ -+ // the memory is batched along K dimension -+ long long int batch_stride_A = static_cast(lda) * static_cast(k); -+ long long int batch_stride_B = static_cast(k); -+ long long int batch_stride_C = static_cast(ldc) * static_cast(n); -+ -+ // alpha and beta -+ float alpha = 1.0f; -+ float beta = 2.0f; -+ -+ cudaError_t result = cudaSuccess; -+ -+ // allocate the host memory -+ std::vector host_A(count_A); -+ std::vector host_B(count_B); -+ std::vector host_C(count_C); -+ std::vector result_C(count_C); -+ -+ // allocate the device memory -+ float *A; -+ float *B; -+ float *C; -+ -+ result = cudaMalloc(&A, count_A * sizeof(float)); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMalloc result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMalloc(&B, count_B * sizeof(float)); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMalloc result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMalloc(&C, count_C * sizeof(float)); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMalloc result = " << result << std::endl; -+ return result; -+ } -+ -+ // Limit range to avoid floating-point errors -+ int const kRange = 8; -+ -+ // fill A -+ for (int b_idx = 0; b_idx < batch_count; b_idx++) { -+ for (int col_idx = 0; col_idx < k; col_idx++) { -+ for (int row_idx = 0; row_idx < m; row_idx++) { -+ host_A[row_idx + col_idx * lda + b_idx * lda * k] = static_cast((row_idx + col_idx * lda + b_idx * lda * k) % kRange); -+ } -+ } -+ } -+ // fill B -+ for (int b_idx = 0; b_idx < batch_count; b_idx++) { -+ for (int col_idx = 0; col_idx < n; col_idx++) { -+ for (int row_idx = 0; row_idx < k; row_idx++) { -+ host_B[row_idx + col_idx * ldb + b_idx * k] = static_cast(((n + k * ldb + batch_count * k) - (row_idx + col_idx * ldb + b_idx * k)) % kRange); -+ } -+ } -+ } -+ // fill C -+ for (int b_idx = 0; b_idx < batch_count; b_idx++) { -+ for (int col_idx = 0; col_idx < n; col_idx++) { -+ for (int row_idx = 0; row_idx < m; row_idx++) { -+ host_C[row_idx + col_idx * ldc + b_idx * ldc * n] = 1.f; -+ } -+ } -+ } -+ -+ // ref memory -+ std::vector ref_A(host_A); -+ std::vector ref_B(host_B); -+ std::vector ref_C(host_C); -+ // copy host memory to device -+ result = cudaMemcpy(A, host_A.data(), count_A * sizeof(float), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMemcpy(B, host_B.data(), count_B * sizeof(float), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMemcpy(C, host_C.data(), count_C * sizeof(float), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ -+ // run cutlass -+ if (use_array) { -+ // allocate the host memory for the pointers to the matrices of the batch -+ std::vector host_ptr_A(batch_count); -+ std::vector host_ptr_B(batch_count); -+ std::vector host_ptr_C(batch_count); -+ -+ // permute the batch elements to emphasize that GemmArray does not depend on matrices being separated by a fixed stride -+ std::vector permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5}; -+ for (size_t b_idx = 0; b_idx < batch_count; b_idx++) { -+ host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A; -+ host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B; -+ host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C; -+ } -+ -+ // allocate the corresponding device memory -+ float const **ptr_A; -+ float const **ptr_B; -+ float **ptr_C; -+ -+ result = cudaMalloc(&ptr_A, batch_count * sizeof(float*)); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMalloc result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMalloc(&ptr_B, batch_count * sizeof(float*)); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMalloc result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMalloc(&ptr_C, batch_count * sizeof(float*)); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMalloc result = " << result << std::endl; -+ return result; -+ } -+ -+ // copy the matrix pointers to the device -+ result = cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ -+ result = cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count); -+ -+ if (result != cudaSuccess) -+ return result; -+ } else { -+ result = cutlass_strided_batched_sgemm( -+ m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C, -+ beta, batch_count); -+ if (result != cudaSuccess) -+ return result; -+ } -+ -+ // copy device memory to host -+ result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ -+ //compare with reference code -+ result = strided_batched_gemm_nn_reference(m, n, k, alpha, ref_A, lda, batch_stride_A, ref_B, ldb, batch_stride_B, ref_C, ldc, batch_stride_C, -+ beta, batch_count); -+ if (result != 0) -+ return result; -+ -+ // Expect bit-level accuracy for this simple example -+ if (ref_C != result_C) { -+ std::cout << "CUTLASS " << gemm_desc << " gemm does not run correctly" << std::endl; -+ return cudaErrorUnknown; -+ } -+ -+ // free memory -+ result = cudaFree(A); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaFree result = " << result << std::endl; -+ return result; -+ } -+ result = cudaFree(B); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaFree result = " << result << std::endl; -+ return result; -+ } -+ result = cudaFree(C); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaFree result = " << result << std::endl; -+ return result; -+ } -+ -+ return result; -+} -+ -+int main() { -+ -+ cudaError_t result = cudaSuccess; -+ for (bool use_array : {false, true}) { -+ result = run_batched_gemm(use_array); -+ if (result == cudaSuccess) { -+ std::cout << "Passed." << std::endl; -+ } else { -+ break; -+ } -+ } -+ -+ // Exit. -+ return result == cudaSuccess ? 0 : -1; -+} -diff --git a/3rdparty/cutlass/examples/06_splitK_gemm/splitk_gemm.cu b/3rdparty/cutlass/examples/06_splitK_gemm/splitk_gemm.cu -new file mode 100644 -index 0000000..9c88851 ---- /dev/null -+++ b/3rdparty/cutlass/examples/06_splitK_gemm/splitk_gemm.cu -@@ -0,0 +1,340 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+This example shows how to use split-k version of matrix multiplication using functions and data -+structures provided by CUTLASS; which we run on a NVIDIA Volta GPU. -+ -+What is split-k? -+Consider a problem size of M = 128, N = 128, K = 4096. In this case, if my thread-block tile size (a -+tile can be viewed as a 2d matrix) is 128x128x4096, then we launch a singled a thread-block taking -+up a single SM of 84 SMs present on V100. Hence the efficiency of computation is really low. So, how -+to solve it? This is where split-k comes in. It is a way of partitioning K-dimension of matrix -+multiplication and distribute across multiple SMs and get better efficiency than single SM. In the -+above example, we can partition K-dimension with split-k factor of 16 i.e., thread-block tile size -+will be 128x128x256 and will be launching on 16 SMs. Once each thread-block computes their partial -+inner product (1/16th of output), they accumulate to single output matrix. -+ -+Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing -+high performance kernels at scale which works for multiple problem sizes with good abstractions is -+really hard. CUTLASS solves this problem by providing simplified abstractions to compose -+multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU -+easily. -+ -+CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp -+and thread-block level, they compute on their own tile-size with higher level of tile sizes being -+composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used -+to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute -+threadblock-tile (tile size computed by a threadblock). -+ -+In this example, we split variable initialization into -+1. Setting up data properties : describes how matrices are laid out in the memory and how the kernel -+can view them (logical to physical mapping) -+2. Setting up computation properties : describes how the above set matrices will be used to compute -+output of matrix multiplication. -+ -+First, we setup the data types of matrices A, B, C and D along with alpha, beta as the equation for -+GEMM is D = alpha * A * B + beta * C. In CUTLASS, the kernels first compute A * B and leaves the -+rest of the computation to end of the kernel as alpha * X + beta * C is a simple element-wise -+operation on X (A * B) and C. We call this as epilogue of kernel. Hence, we setup data types for -+alpha and beta to be equal to ElementComputeEpilogue = float. As we want to MMA instructions on -+Volta and they support only half-precision floating point (fp16 or half), we use data type for -+elements in input matrix A and B as cutlass::half_t. Volta also supports accumulation of partial dot -+product to fp32, which can store wider range of numbers, we use it as data type of output matrix -+elements and accumulation. We convey this to CUTLASS kernel by initializing template variables -+ElementAccumulator (float), ElementComputeEpilogue (float), ElementInputA (cutlass::half_t), -+ElementInputB (cutlass::half_t), ElementOutput (float). Communicating just the data type is not -+enough. As the data is laid out linearly in memory, we have to convey the layout of matrices. We do -+that by initializing template variable LayoutInputA to column major cutlass variable, LayoutInputB -+to row major and LayoutOutput to row major. Next, we setup rules to compute alpha * X + beta * C -+which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the -+data type of output ElementOutput (float), the number of elements per vector memory access (16), -+data type of accumulator (float) and data type of computation of linear combination (alpha * X + -+beta * C). -+ -+Now that we setup the properties of data, we have to setup properties of computation. -+ -+Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x32, -+64x64x4, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally -+deduce the amount of threads needed per thread-block, amount of shared memory, storing data in -+bank-conflict free manner, and ton of other variables required to compose, initialize and launch a -+high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from -+understanding and coding complicated hardware optimizations which can easily go wrong. -+ -+There are few more template variables initialized such as, which threadblock tile of output matrix -+is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on. -+ -+These are all put together to create a template variable which describes CUTLASS GEMM kernel using -+cutlass::gemm::device::GemmSplitKParallel template. -+ -+The next step is to initialize physical data, instantiate and initialize CUTLASS kernel and run it. -+We use CUTLASS utilities to initialize, fill, compare matrices as they are simple and doesn't come -+in the way of learning CUTLASS. -+ -+Once all the matrices are initialized and filled with data, create arguments tuple to launch CUTLASS -+kernel which takes problem size (M = 5120, N = 4096 and K = 4096), matrices, alpha, beta and the -+important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space -+memory required by the kernel we instantiated. If yes, we create it and pass it along with other -+arguments created to initialize CUTLASS kernel then, the kernel is launched. -+ -+In this example, we later on launch a reference gemm kernel (from CUTLASS utilities) to compare if -+the output from CUTLASS kernel is same as reference GEMM kernel. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_splitk_parallel.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "helper.h" -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = float; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A -+using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B -+using ElementOutput = float; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm70; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4 -+ -+// This code section describes ? -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- This is the number of elements per -+ // vectorized memory access. For half -+ // precision, it's 8 elements. This becomes -+ // the vector width of math instructions in -+ // epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function -+ -+// Put all the created template variables to create GemmSplitKParallel template variable -+using Gemm = cutlass::gemm::device::GemmSplitKParallel; -+ -+int run() { -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (props.major != 7) { -+ std::cerr << "Volta Tensor Ops must be run on a machine with compute capability of 70, 72, or 75." -+ << std::endl; -+ -+ // Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits. -+ return 0; -+ } -+ -+ // -+ // Define problem size -+ // -+ -+ const int length_m = 5120; -+ const int length_n = 4096; -+ const int length_k = 4096; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(0); -+ -+ // Split K dimension into 16 partitions -+ int split_k_slices = 16; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ tensor_c.device_ref(), // <- reference to matrix C on device -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ -+ // Create instantiation for device reference gemm kernel -+ cutlass::reference::device::Gemm -+ gemm_device; -+ -+ // Launch device reference gemm kernel -+ gemm_device(problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ beta, -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ std::cout << (passed ? "Passed" : "Failed") << std::endl; -+ -+ return (passed ? 0 : -1); -+} -+ -+int main() { -+ -+ // -+ // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1. -+ // -+ // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. -+ // -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { -+ std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; -+ -+ // Returning zero, so this test passes when built with older CUDA Toolkits. Its action are no-op. -+ return 0; -+ } -+ else { -+ return run(); -+ } -+} -+ -diff --git a/3rdparty/cutlass/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu b/3rdparty/cutlass/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu -new file mode 100644 -index 0000000..c38f040 ---- /dev/null -+++ b/3rdparty/cutlass/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu -@@ -0,0 +1,357 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+This example shows how to run matrix multiplication kernels using functions and data structures -+provided by CUTLASS using tensor cores; which we run on a NVIDIA Volta GPU. -+ -+Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing -+high performance kernels at scale which works for multiple problem sizes with good abstractions is -+really hard. CUTLASS solves this problem by providing simplified abstractions to compose -+multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU -+easily. -+ -+CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp -+and thread-block level, they compute on their own tile-size with higher level of tile sizes being -+composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used -+to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute -+threadblock-tile (tile size computed by a threadblock). -+ -+In thie example, we split variable initialization into -+1. Setting up data properties : describes how matrices are laid out in the memory and how the kernel -+can view them (logical to physical mapping) -+2. Setting up computation properties : describes how the above set matrices will be used to compute -+output of matrix multiplication. -+ -+First, we setup the data types of matrices A, B, C and D along with alpha, beta as the equation for -+GEMM is D = alpha * A * B + beta * C. In CUTLASS, the kernels first compute A * B and leaves the -+rest of the computation to end of the kernel as alpha * X + beta * C is a simple element-wise -+operation on X (A * B) and C. We call this as epilogue of kernel. Hence, we setup data types for -+alpha and beta to be equal to ElementComputeEpilogue = float. As we want to MMA instructions on -+Volta and they support only half-precision floating point (fp16 or half), we use data type for -+elements in input matrix A and B as cutlass::half_t. Volta also supports accumulation of partial dot -+product to fp32, which can store wider range of numbers, we use it as data type of output matrix -+elements and accumulation. We convey this to CUTLASS kernel by initializing template variables -+ElementAccumulator (float), ElementComputeEpilogue (float), ElementInputA (cutlass::half_t), -+ElementInputB (cutlass::half_t), ElementOutput (float). Communicating just the data type is not -+enough. As the data is laid out linearly in memory, we have to convey the layout of matrices. We do -+that by initializing template variable LayoutInputA to column major cutlass variable, LayoutInputB -+to row major and LayoutOutput to row major. Next, we setup rules to comptue alpha * X + beta * C -+which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the -+data type of output ElementOutput (int32_t), the number of elements per vector memory access (16), -+data type of accumulator (int32_t) and data type of computation of linear combination (alpha * X + -+beta * C). -+ -+Now that we setup the properties of data, we have to setup properties of computation. -+ -+Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x32, -+64x64x32, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally -+deduce the amount of threads needed per thread-block, amount of shared memory, storing data in -+bank-conflict free manner, and ton of other variables required to compose, intialize and launch a -+high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from -+understanding and coding complicated hardware optimizations which can easily go wrong. -+ -+CUTLASS also supports multiple MMA pipelines in a CTA. What are MMA pipelines? MMA pipelines -+constitute the whole process of loading input data from global memory to shared memory, loading data -+from shared memory to registers, doing matrix multiplication, store to global memory. The below flow -+sequence shows a typical mma pipeline. -+ -+matrix in global memory -> registers -> tile in shared memory -> registers -> mma -> registers -> -+output to global memory -+ -+The problem with single pipeline is, each stage is synchronous which means, each stage has to wait -+until the previous finished executing. There are stages in the pipeline which do not have fixed -+latency, for example, the loads from global memory and shared memory. Therefore, we can add one more -+pipeline with a phase shift in mma kernel to hide latency from global and shared memory loads. -+Finally, the pipeline in a kernel looks like -+ -+(1) matrix in global memory -> (2) registers -> (3) tile in shared memory -> (4) registers -> (5) -+mma -> (6) registers -> (7) output to global memory (1) -> (2) -> (3) matrix in global -+memory -> (4) registers -> (5) tile in shared memory -> (6) registers -> (7) mma -> (8) registers -> -+(9) output to global memory -+ -+This way, you can hide the second global memoroy load latency by doing computation on already loaded -+input data. -+ -+There are few more template variables initialized such as, which threadblock tile of output matrix -+is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on. -+ -+These are all put together to create a template variable which describes CUTLASS GEMM kernel using -+cutlass::gemm::device::Gemm template. -+ -+The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. -+We use CUTLASS utilities to initialize, fill, compare matrices as they are simple and doesn't come -+in the way of learning CUTLASS. -+ -+Once all the matrices are initialized and filled with data, create arguments tuple to launch CUTLASS -+kernel which takes problem size (M = 5120, N = 4096 and K = 4096), matrices, alpha, beta and the -+important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space -+memory required by the kernel we instantiated. If yes, we create it and pass it along with other -+arguments created to intialize CUTLASS kernel then, the kernel is launched. -+ -+In this example, we later on launch a reference gemm kernel (from CUTLASS utilities) to compare if -+the output from CUTLASS kernel is same as reference GEMM kernel. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "helper.h" -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = float; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A -+using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B -+using ElementOutput = float; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm70; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes ? -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- this is the number of elements per -+ // vectorized memory access. For half -+ // precision, it's 8 elements. This becomes -+ // the vector width of math instructions in -+ // epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 2; -+ -+using Gemm = cutlass::gemm::device::Gemm; -+ -+int run() { -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (props.major != 7) { -+ std::cerr << "Volta Tensor Ops must be run on a machine with compute capability of 70, 72, or 75." -+ << std::endl; -+ -+ // Return 0 so tests are considered passing if run on unsupported architectures or CUDA Toolkits. -+ return 0; -+ } -+ -+ const int length_m = 5120; -+ const int length_n = 4096; -+ const int length_k = 4096; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(0); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ tensor_c.device_ref(), // <- reference to matrix C on device -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ -+ // Create instantiation for device reference gemm kernel -+ cutlass::reference::device::Gemm -+ gemm_device; -+ -+ // Launch device reference gemm kernel -+ gemm_device(problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ beta, -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ std::cout << (passed ? "Passed" : "Failed") << std::endl; -+ -+ return (passed ? 0 : -1); -+} -+ -+int main() { -+ -+ // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1. -+ // -+ // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { -+ std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; -+ -+ // Returning zero when built on older Toolkits so tests pass. The actions of this SDK example are no-op. -+ return 0; -+ } -+ else { -+ return run(); -+ } -+} -+ -diff --git a/3rdparty/cutlass/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu b/3rdparty/cutlass/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu -new file mode 100644 -index 0000000..bcff579 ---- /dev/null -+++ b/3rdparty/cutlass/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu -@@ -0,0 +1,358 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+This example shows how to run matrix multiplication kernels using functions and data structures -+provided by CUTLASS using tensor cores; which we run on a NVIDIA Turing GPU. -+ -+Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing -+high performance kernels at scale which works for multiple problem sizes with good abstractions is -+really hard. CUTLASS solves this problem by providing simplified abstractions to compose -+multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU -+easily. -+ -+CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp -+and thread-block level, they compute on their own tile-size with higher level of tile sizes being -+composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used -+to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute -+threadblock-tile (tile size computed by a threadblock). -+ -+In thie example, we split variable initialization into -+1. Setting up data properties : describes how matrices are laid out in the memory and how the kernel -+can view them (logical to physical mapping) -+2. Setting up computation properties : describes how the above set matrices will be used to compute -+output of matrix multiplication. -+ -+First, we setup the data types of matrices A, B, C and D along with alpha, beta as the equation for -+GEMM is D = alpha * A * B + beta * C. In CUTLASS, the kernels first compute A * B and leaves the -+rest of the computation to end of the kernel as alpha * X + beta * C is a simple element-wise -+operation on X (A * B) and C. We call this as epilogue of kernel. Hence, we setup data types for -+alpha and beta to be equal to ElementComputeEpilogue = int32_t. As we want to use MMA instructions -+on Turing and they support 8-bit signed integer (int8_t), we use data type for elements in input -+matrix A and B as int8_t. Volta also supports accumulation of partial dot product to int32_t, which -+can store wider range of numbers, we use it as data type of output matrix elements and accumulation. -+We convey this to CUTLASS kernel by initializing template variables ElementAccumulator (int32_t), -+ElementComputeEpilogue (int32_t), ElementInputA (int8_t), ElementInputB (int8_t), ElementOutput -+(int32_t). Communicating just the data type is not enough. As the data is laid out linearly in -+memory, we have to convey the layout of matrices. We do that by initializing template variable -+LayoutInputA to column major cutlass variable, LayoutInputB to row major and LayoutOutput to row -+major. Next, we setup rules to comptue alpha * X + beta * C which is called epilogue of the kernel. -+We initialize template variable EpilogueOp, which takes the data type of output ElementOutput -+(int32_t), the number of elements per vector memory access (16), data type of accumulator (int32_t) -+and data type of computation of linear combination (alpha * X + beta * C). -+ -+Now that we setup the properties of data, we have to setup properties of computation. -+ -+Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x256x64, -+64x64x16, 8x8x16 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally -+deduce the amount of threads needed per thread-block, amount of shared memory, storing data in -+bank-conflict free manner, and ton of other variables required to compose, intialize and launch a -+high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from -+understanding and coding complicated hardware optimizations which can easily go wrong. -+ -+CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines -+constitute the whole process of loading input data from global memory to shared memory, loading data -+from shared memory to registers, doing matrix multiplication, store to global memory. The below flow -+sequence shows a typical mma pipeline. -+ -+matrix in global memory -> registers -> tile in shared memory -> registers -> mma -> registers -> -+output to global memory -+ -+The problem with single pipeline is, each stage is synchronous which means, each stage has to wait -+until the previous finished executing. There are stages in the pipeline which do not have fixed -+latency, for example, the loads from global memory and shared memory. Therefore, we can add one more -+pipeline with a phase shift in mma kernel to hide latency from global and shared memory loads. -+Finally, the pipeline in a kernel looks like -+ -+(1) matrix in global memory -> (2) registers -> (3) tile in shared memory -> (4) registers -> (5) -+mma -> (6) registers -> (7) output to global memory (1) -> (2) -> (3) matrix in global -+memory -> (4) registers -> (5) tile in shared memory -> (6) registers -> (7) mma -> (8) registers -> -+(9) output to global memory -+ -+This way, you can hide the second global memoroy load latency by doing computation on already loaded -+input data. -+ -+There are few more template variables initialized such as, which threadblock tile of output matrix -+is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on. -+ -+These are all put together to create a template variable which describes CUTLASS GEMM kernel using -+cutlass::gemm::device::Gemm template. -+ -+The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. -+We use CUTLASS utilities to initialize, fill, compare matrices as they are simple and doesn't come -+in the way of learning CUTLASS. -+ -+Once all the matrices are initialized and filled with data, create arguments tuple to launch CUTLASS -+kernel which takes problem size (M = 5120, N = 4096 and K = 4096), matrices, alpha, beta and the -+important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space -+memory required by the kernel we instantiated. If yes, we create it and pass it along with other -+arguments created to intialize CUTLASS kernel then, the kernel is launched. -+ -+In this example, we later on launch a reference gemm kernel (from CUTLASS utilities) to compare if -+the output from CUTLASS kernel is same as reference GEMM kernel. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "helper.h" -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = int32_t; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = int8_t; // <- data type of elements in input matrix A -+using ElementInputB = int8_t; // <- data type of elements in input matrix B -+using ElementOutput = int32_t; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::RowMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm75; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 256, 64>; // <- threadblock tile M = 128, N = 256, K = 64 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = 64, N = 64, K = 64 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 16 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 2; -+ -+using Gemm = cutlass::gemm::device::Gemm; -+ -+int run() { -+ -+ const int length_m = 5120; -+ const int length_n = 4096; -+ const int length_k = 4096; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(0); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ tensor_c.device_ref(), // <- reference to matrix C on device -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ -+ // Create instantiation for device reference gemm kernel -+ cutlass::reference::device::Gemm -+ gemm_device; -+ -+ // Launch device reference gemm kernel -+ gemm_device(problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ beta, -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ std::cout << (passed ? "Passed" : "Failed") << std::endl; -+ -+ return (passed ? 0 : -1); -+} -+ -+int main() { -+ bool notSupported = false; -+ -+ // Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 10.2. -+ // -+ // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { -+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 75)) { -+ std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 75." -+ << std::endl; -+ -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ return run(); -+} -+ -diff --git a/3rdparty/cutlass/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu b/3rdparty/cutlass/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu -new file mode 100644 -index 0000000..e39784e ---- /dev/null -+++ b/3rdparty/cutlass/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu -@@ -0,0 +1,771 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+ -+This example shows how to run convolution kernels using functions and data structures -+provided by CUTLASS using tensor cores; which we run on a NVIDIA Turing GPU. -+ -+Writing a single high performance convolution kernel is hard but do-able. Whereas writing -+high performance kernels at scale which works for multiple problem sizes with good abstractions is -+really hard. CUTLASS solves this problem by providing simplified abstractions to compose -+multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance -+of GPU easily. -+ -+CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp -+and thread-block level, they compute on their own tile-size with higher level of tile sizes being -+composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used -+to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute -+threadblock-tile (tile size computed by a threadblock). -+ -+In thie example, we split variable initialization into -+1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel -+can view them (logical to physical mapping) -+2. Setting up computation properties : describes how the above set tensors will be used to compute -+output of convolution. -+ -+First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along -+with alpha, beta as the equation for convolution is C = alpha * Conv(A, B) + beta * C. In CUTLASS, -+the kernels first compute Conv(A, B) and leave the rest of the computation to end of the kernel as -+alpha * X + beta * C is a simple element-wise operation on X (Conv(A, B)) and C. We call this as -+epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to -+ElementComputeEpilogue = float. We want to use MMA instructions on Turing and they support 4-bit -+signed integer. But int4b_t is not fully supported by Nvidia software stack, so CUTLASS introduces -+cutlass::int4b_t. We use the data type for elements in input tensor A and B as cutlass::int4b_t. We -+convey this to CUTLASS kernel by initializing template variables ElementAccumulator (int32_t), -+ElementComputeEpilogue (float), ElementInputA (cutlass::int4b_t), ElementInputB (cutlass::int4b_t), -+ElementOutput (int32_t). Communicating just the data type is not enough. As the data is laid out -+linearly in memory, we have to convey the layout of tensors. We do that by initializing template -+variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup -+rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template -+variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of -+elements per vector memory access (32), data type of accumulator (int32_t) and data type of -+computation of linear combination (alpha * X + beta * C). -+ -+Now that we setup the properties of data, we have to setup properties of computation. -+ -+Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x128, -+64x64x128, 8x8x32 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it -+internally deduces the amount of threads needed per thread-block, amount of shared memory, storing -+data in bank-conflict free manner, and ton of other variables required to compose, intialize and -+launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer -+from understanding and coding complicated hardware optimizations which can easily go wrong. -+ -+CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines -+constitute the whole process of loading input data from global memory to shared memory, loading data -+from shared memory to registers, doing matrix multiplication, store to global memory. The below flow -+sequence shows a typical mma pipeline. -+ -+tensor in global memory -> registers -> tile in shared memory -> registers -> mma -> registers -> -+output to global memory -+ -+The problem with single pipeline is, each stage is synchronous which means, each stage has to wait -+until the previous finished executing. There are stages in the pipeline which do not have fixed -+latency, for example, the loads from global memory and shared memory. Therefore, we can add one more -+pipeline with a phase shift in mma kernel to hide latency from global and shared memory loads. -+Finally, the pipeline in a kernel looks like -+ -+(1) tensor in global memory -> (2) registers -> (3) tile in shared memory -> (4) registers -> (5) -+mma -> (6) registers -> (7) output to global memory (1) -> (2) -> (3) tensor in global -+memory -> (4) registers -> (5) tile in shared memory -> (6) registers -> (7) mma -> (8) registers -> -+(9) output to global memory -+ -+This way, you can hide the second global memory load latency by doing computation on already loaded -+input data. -+ -+There are few more template variables initialized such as, which threadblock tile of output matrix -+is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on. -+ -+These are all put together to create a template variable which describes CUTLASS Implicit GEMM -+kernel using cutlass::conv::device::ImplicitGemm template. -+ -+The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. -+We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come -+in the way of learning CUTLASS. -+ -+Once all the tensors are initialized and filled with data, create arguments tuple to launch CUTLASS -+kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64, -+R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the -+important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space -+memory required by the kernel we instantiated. If yes, we create it and pass it along with other -+arguments created to intialize CUTLASS kernel then, the kernel is launched. -+ -+In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to -+compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = int32_t; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::int4b_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::int4b_t; // Data type of elements in input tensor -+using ElementOutput = cutlass::int4b_t; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm75; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 2; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, // Data type of output matrix. -+ 8, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+ -+using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+>::Kernel; -+ -+using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(false), -+ measure_performance(true), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of int4b_t elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 32 elements. -+ // -+ int const kAlignment = 32; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "09_turing_tensorop_conv2dfprop example\n\n" -+ << " This example uses Turing's Tensor Core operators on int4 data types to compute\n" -+ << " forward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/09_turing_tensorop_conv2dfprop/09_turing_tensorop_conv2dfprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" -+ << "$ ./examples/09_turing_tensorop_conv2dfprop/09_turing_tensorop_conv2dfprop --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_ref_c(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_c.host_view()); -+ -+ // Fill tensor C for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_c.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_ref_c.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ // mode (kCrossCorrelation or kConvolution) -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices); -+ -+ // Construct ImplicitGemm::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename ImplicitGemm::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_c.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm implicit_gemm_op; -+ -+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on host...\n"; -+ -+ // Compute with reference implementation -+ cutlass::reference::host::Conv2dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverterClamp -+ >( -+ problem_size, -+ tensor_a.host_ref(), -+ tensor_b.host_ref(), -+ tensor_c.host_ref(), -+ tensor_ref_c.host_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_c.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_c.host_view(), -+ tensor_ref_c.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "09_tensor_conv_workspace_conv2dfprop_" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. -+ // -+ // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { -+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; -+ return 0; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major > 7 || (props.major == 7 && props.minor >= 5))) { -+ std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75." -+ << std::endl; -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {1, 32, 64, 128, 256, 512}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1}, -+ {56, 56, 64, 64, 1, 1}, -+ {56, 56, 64, 64, 3, 3}, -+ {56, 56, 256, 64, 1, 1}, -+ {56, 56, 256, 512, 1, 1}, -+ {56, 56, 256, 128, 1, 1}, -+ {28, 28, 128, 128, 3, 3}, -+ {28, 28, 128, 512, 1, 1}, -+ {28, 28, 512, 128, 1, 1}, -+ {28, 28, 512, 1024, 1, 1}, -+ {28, 28, 512, 256, 1, 1}, -+ {14, 14, 256, 256, 3, 3}, -+ {14, 14, 256, 1024, 1, 1}, -+ {14, 14, 1024, 256, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1}, -+ {14, 14, 1024, 512, 1, 1}, -+ {7, 7, 512, 512, 3, 3}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ -+ options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+ -diff --git a/3rdparty/cutlass/examples/10_planar_complex/planar_complex.cu b/3rdparty/cutlass/examples/10_planar_complex/planar_complex.cu -new file mode 100644 -index 0000000..9e0915d ---- /dev/null -+++ b/3rdparty/cutlass/examples/10_planar_complex/planar_complex.cu -@@ -0,0 +1,567 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Planar Complex GEMM -+ -+ This example demonstrates the CUTLASS Library's exposure of planar complex GEMM kernels supporting -+ the batched strided mode. -+ -+ These kernels represent complex matrices by storing the real and imaginary parts of the matrix in -+ disjoint regions in memory. These real-valued matrices are stored using existing cuBLAS layouts -+ as either column-major or row-major layouts with a single leading dimension indicating the stride -+ between columns or rows. -+ -+ The CUTLASS Library collects multiple template instantiations in a data structure and offers -+ a BLAS-like dispatch API to invoke the appropriate kernel on the Volta or Turing architectures. -+ -+ CUTLASS decouples matrix layout from complex transformation, so four possible transformations -+ are possible on the A and B operands: -+ -+ n: column-major -+ c: column-major complex conjugate -+ t: row-major -+ h: row-major complex conjugate -+ -+ The CUTLASS Library contains many kernel instances specialized for architecture, data type, tile -+ size, and alignment. This can result in long compile times. -+ -+ To build strictly the planar complex kernels needed for general application, execute the following -+ CMake command in an empty build directory. -+ -+ $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \ -+ -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex -+ -+ This builds all planar complex GEMM variants for Volta and Turing architectures. -+ -+ To build strictly the kernels needed for this example, an even narrower filter string may be -+ specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for -+ the 'CN' layout configuration (conjugate A operand with both A and B as column-major). -+ -+ $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \ -+ -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_f16*cn -+ -+ $ make 10_planar_complex -+ -+ $ ./examples/10_planar_complex/10_planar_complex --m=2048 --n=1024 --k=512 --batch=10 -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor_planar_complex.h" -+ -+#include "cutlass/util/reference/device/tensor_fill.h" -+ -+#include "cutlass/util/reference/device/gemm_planar_complex.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+ -+#include "cutlass/library/handle.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ cutlass::complex alpha; -+ cutlass::complex beta; -+ -+ bool reference_check; -+ int iterations; -+ -+ Options(): -+ help(false), -+ problem_size({1024, 1024, 1024}), -+ batch_count(1), -+ reference_check(true), -+ iterations(20), -+ alpha(1), -+ beta() { } -+ -+ bool valid() { -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ cmd.get_cmd_line_argument("batch", batch_count); -+ -+ cmd.get_cmd_line_argument("alpha", alpha.real()); -+ cmd.get_cmd_line_argument("alpha_i", alpha.imag()); -+ cmd.get_cmd_line_argument("beta", beta.real()); -+ cmd.get_cmd_line_argument("beta_i", beta.imag()); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "10_planar_complex example\n\n" -+ << " This example uses the CUTLASS Library to execute Planar Complex GEMM computations.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --batch= Number of GEMM operations executed in one batch\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --alpha_i= Epilogue scalar alpha (imaginary part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n\n" -+ << " --beta_i= Epilogue scalar beta (imaginary part)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/10_planar_complex/10_planar_complex --batch=7 --m=1024 --n=512 --k=1024 \\\n" -+ << " --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product() * batch_count * 4; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Performance test environment for planar complex -+class TestbedPlanarComplex { -+public: -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementCompute = float; -+ using ElementAccumulator = float; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::library::Handle handle; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ cutlass::DeviceAllocation tensor_A; -+ cutlass::DeviceAllocation tensor_B; -+ cutlass::DeviceAllocation tensor_C; -+ cutlass::DeviceAllocation tensor_D; -+ cutlass::DeviceAllocation tensor_D_ref; -+ -+ // -+ // Methods -+ // -+ -+ TestbedPlanarComplex( -+ Options const &options -+ ): -+ problem_size(options.problem_size), batch_count(options.batch_count) { -+ -+ // Allocate device memory for batched strided GEMM -+ tensor_A.reset(int64_t(problem_size.m()) * problem_size.k() * batch_count * 2); -+ tensor_B.reset(int64_t(problem_size.k()) * problem_size.n() * batch_count * 2); -+ tensor_C.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); -+ tensor_D.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); -+ tensor_D_ref.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); -+ } -+ -+ void initialize() { -+ -+ uint64_t seed = 1073; -+ -+ // Use small integers to simplify correctness checking -+ int scope_max = 6; -+ int scope_min = -6; -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ tensor_A.get(), tensor_A.size(), seed, ElementA(scope_max), ElementA(scope_min), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ tensor_B.get(), tensor_B.size(), seed * 2019, ElementB(scope_max), ElementB(scope_min), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ tensor_C.get(), tensor_C.size(), seed * 2020, ElementC(scope_max), ElementC(scope_min), 0); -+ } -+ -+ Result profile(Options const &options) { -+ -+ Result result; -+ -+ initialize(); -+ -+ ElementA *ptr_A = tensor_A.get(); -+ ElementB *ptr_B = tensor_B.get(); -+ ElementC *ptr_C = tensor_C.get(); -+ ElementC *ptr_D = tensor_D.get(); -+ -+ int64_t batch_stride_A = int64_t(problem_size.m()) * problem_size.k() * 2; -+ int64_t batch_stride_B = int64_t(problem_size.k()) * problem_size.n() * 2; -+ int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2; -+ int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2; -+ -+ typename LayoutA::Stride::Index lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0); -+ typename LayoutB::Stride::Index ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0); -+ typename LayoutC::Stride::Index ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); -+ typename LayoutC::Stride::Index ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); -+ -+ int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k(); -+ int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n(); -+ int64_t imag_stride_C = int64_t(problem_size.m()) * problem_size.n(); -+ int64_t imag_stride_D = int64_t(problem_size.m()) * problem_size.n(); -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMMs -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ -+ // -+ // Execute the planar complex GEMM kernel via the CUTLASS Library's -+ // dispatch routines. -+ // -+ // Note, for planar complex GEMM kernels, all numeric type arguments -+ // specify the data type of the base real types. These are understood to -+ // apply to planar complex representations of matrices in memory and to complex -+ // structures for scalars. -+ // -+ // See tools/library/include/cutlass/library/handle.h for more details. -+ // -+ -+ result.status = handle.gemm_planar_complex( -+ problem_size.m(), // GEMM M dimension -+ problem_size.n(), // GEMM N dimension -+ problem_size.k(), // GEMM K dimension -+ -+ cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued accumulation -+ cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued alpha/beta scalars -+ -+ &options.alpha, // Pointer to alpha scalar, of type complex -+ -+ cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued A matrix -+ cutlass::library::LayoutTypeID::kColumnMajor, // Layout of A matrix -+ cutlass::library::ComplexTransform::kConjugate, // Complex transformation on A matrix operand -+ ptr_A, // Pointer to real part of A matrix -+ ptr_A + imag_stride_A, // Pointer to imaginary part of A matrix -+ lda, // Leading dimension of real part of A matrix -+ lda, // Leading dimension of imaginary part of A matrix -+ -+ cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued B matrix -+ cutlass::library::LayoutTypeID::kColumnMajor, // Layout of B matrix -+ cutlass::library::ComplexTransform::kNone, // Complex transformation on B matrix operand -+ ptr_B, // Pointer to real part of B matrix -+ ptr_B + imag_stride_B, // Pointer to imaginary part of B matrix -+ ldb, // Leading dimension of real part of B matrix -+ ldb, // Leading dimension of imaginary part of B matrix -+ -+ &options.beta, // Pointer to beta scalar, of type complex -+ -+ cutlass::library::NumericTypeID::kF16, // Base data type of complex valued C and D matrices -+ -+ ptr_C, // Pointer to real part of C matrix -+ ptr_C + imag_stride_C, // Pointer to imaginary part of C matrix -+ ldc, // Leading dimension of real part of C matrix -+ ldc, // Leading dimension of imaginary part of C matrix -+ -+ ptr_D, // Pointer to real part of D matrix -+ ptr_D + imag_stride_D, // Pointer to imaginary part of D matrix -+ ldd, // Leading dimension of real part of D matrix -+ ldd, // Leading dimension of imaginary part of D matrix -+ -+ batch_count, // Number of batched elements -+ -+ batch_stride_A, // Stride between batches of real parts of A matrix -+ batch_stride_A, // Stride between batches of imaginary parts of A matrix -+ -+ batch_stride_B, // Stride between batches of real parts of B matrix -+ batch_stride_B, // Stride between batches of imaginary parts of B matrix -+ -+ batch_stride_C, // Stride between batches of real parts of C matrix -+ batch_stride_C, // Stride between batches of imaginary parts of C matrix -+ -+ batch_stride_D, // Stride between batches of real parts of D matrix -+ batch_stride_D // Stride between batches of imaginary parts of D matrix -+ ); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS internal error - configuration not supported" << std::endl; -+ return result; -+ } -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMMs are complete -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ if (handle.get_last_operation()) { -+ std::cout << "Recently executed '" << handle.get_last_operation()->description().name << "'" << std::endl; -+ } -+ -+ // -+ // Compute reference in device code -+ // -+ -+ if (options.reference_check) { -+ -+ result.passed = true; -+ -+ for (int64_t idx = 0; result.passed && idx < int64_t(batch_count); ++idx) { -+ cutlass::reference::device::GemmPlanarComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator -+ >( -+ problem_size, -+ options.alpha, -+ {tensor_A.get() + idx * batch_stride_A, lda, imag_stride_A}, -+ cutlass::ComplexTransform::kConjugate, -+ {tensor_B.get() + idx * batch_stride_B, ldb, imag_stride_B}, -+ cutlass::ComplexTransform::kNone, -+ options.beta, -+ {tensor_C.get() + idx * batch_stride_C, ldc, imag_stride_C}, -+ {tensor_D_ref.get() + idx * batch_stride_D, ldd, imag_stride_D} -+ ); -+ -+ ElementC epsilon = 0.1_hf; -+ ElementC nonzero_floor = 0.1_hf; -+ -+ result.passed = cutlass::reference::device::BlockCompareRelativelyEqual( -+ tensor_D.get() + idx * batch_stride_D, -+ tensor_D_ref.get() + idx * batch_stride_D, -+ batch_stride_D, -+ epsilon, -+ nonzero_floor -+ ); -+ } -+ -+ if (result.passed) { -+ std::cout << "Reference check passed." << std::endl; -+ } -+ else { -+ std::cerr << "Error - reference check failed." << std::endl; -+ } -+ } -+ -+ std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " GFLOPs: " << result.gflops << std::endl; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // -+ // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. -+ // -+ // Volta Tensor Core operations are first available in CUDA 10.1 Toolkit. -+ // -+ // Turing Tensor Core operations are first available in CUDA 10.2 Toolkit. -+ // -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (props.major < 7) { -+ std::cerr << "Volta Tensor Core operations must be run on a machine with compute capability at least 70." -+ << std::endl; -+ -+ // Returning zero so this test passes on older architectures even though its actions are no-op. -+ return 0; -+ } -+ else if (props.major == 7 && props.minor <= 2) { -+ // -+ // If running on the Volta architecture, at least CUDA 10.1 Toolkit is required to run this example. -+ // -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { -+ std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; -+ -+ // Returning zero so this test passes on older Toolkits even though its actions are no-op. -+ return 0; -+ } -+ } -+ else if (props.major == 7 && props.minor >= 5) { -+ // -+ // If running on the Turing architecture, at least CUDA 10.2 Toolkit is required to run this example. -+ // -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { -+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; -+ -+ // Returning zero so this test passes on older Toolkits even though its actions are no-op. -+ return 0; -+ } -+ } -+ else { -+ // NVIDIA Ampere Architecture GPUs (SM80 and later) are fully supported on CUDA 11 Toolkit and beyond. -+ // -+ // fall through -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ TestbedPlanarComplex testbed(options); -+ -+ Result result = testbed.profile(options); -+ -+ return result.passed ? 0 : -1; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/11_planar_complex_array/planar_complex_array.cu b/3rdparty/cutlass/examples/11_planar_complex_array/planar_complex_array.cu -new file mode 100644 -index 0000000..e317731 ---- /dev/null -+++ b/3rdparty/cutlass/examples/11_planar_complex_array/planar_complex_array.cu -@@ -0,0 +1,628 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Planar Complex Array Example -+ -+ This example demonstrates the CUTLASS Library's exposure of planar complex GEMM kernels which -+ execute a batch of matrix products, loading problem sizes and matrix base pointers from arrays -+ in global memory. -+ -+ These kernels represent complex matrices by storing the real and imaginary parts of the matrix in -+ disjoint regions in memory. These real-valued matrices are stored using existing cuBLAS layouts -+ as either column-major or row-major layouts with a single leading dimension indicating the stride -+ between columns or rows. -+ -+ The CUTLASS Library collects multiple template instantiations in a data structure and offers -+ a BLAS-like dispatch API to invoke the appropriate kernel on the Volta or Turing architectures. -+ -+ CUTLASS decouples matrix layout from complex transformation, so four possible transformations -+ are possible on the A and B operands: -+ -+ n: column-major -+ c: column-major complex conjugate -+ t: row-major -+ h: row-major complex conjugate -+ -+ To build strictly the planar complex kernels needed for general application, execute the following -+ CMake command in an empty build directory. -+ -+ $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \ -+ -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex -+ -+ This builds all planar complex GEMM variants for Volta and Turing architectures. -+ -+ To build strictly the kernels needed for this example, an even narrower filter string may be -+ specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for -+ the 'CN' layout configuration (conjugate A operand with both A and B as column-major). -+ -+ $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \ -+ -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_array_f16*cn -+ -+ $ make 11_planar_complex_array -+ -+ $ ./examples/11_planar_complex_array/11_planar_complex_array --m=2048 --n=1024 --k=512 --batch=10 -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor_planar_complex.h" -+ -+#include "cutlass/util/reference/device/tensor_fill.h" -+ -+#include "cutlass/util/reference/device/gemm_planar_complex.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+ -+#include "cutlass/library/handle.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ cutlass::complex alpha; -+ cutlass::complex beta; -+ -+ bool reference_check; -+ int iterations; -+ -+ Options(): -+ help(false), -+ problem_size({1024, 1024, 1024}), -+ batch_count(1), -+ reference_check(true), -+ iterations(20), -+ alpha(1), -+ beta() { } -+ -+ bool valid() { -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ cmd.get_cmd_line_argument("batch", batch_count); -+ -+ cmd.get_cmd_line_argument("alpha", alpha.real()); -+ cmd.get_cmd_line_argument("alpha_i", alpha.imag()); -+ cmd.get_cmd_line_argument("beta", beta.real()); -+ cmd.get_cmd_line_argument("beta_i", beta.imag()); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "11_planar_complex_array example\n\n" -+ << " This example uses the CUTLASS Library to execute Planar Complex Array GEMM computations.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --batch= Number of GEMM operations executed in one batch\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --alpha_i= Epilogue scalar alpha (imaginary part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n\n" -+ << " --beta_i= Epilogue scalar beta (imaginary part)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/11_planar_complex_array/11_planar_complex_array\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product() * batch_count * 4; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Performance test environment for planar complex -+class TestbedPlanarComplex { -+public: -+ -+ // Half-precision input and output -+ using Element = cutlass::half_t; -+ -+ // Configurations for layouts and internal computation -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementCompute = float; -+ using ElementAccumulator = float; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::library::Handle handle; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ cutlass::DeviceAllocation tensor_A; -+ cutlass::DeviceAllocation tensor_B; -+ cutlass::DeviceAllocation tensor_C; -+ cutlass::DeviceAllocation tensor_D; -+ cutlass::DeviceAllocation tensor_D_ref; -+ -+ cutlass::DeviceAllocation ptr_A_real; -+ cutlass::DeviceAllocation ptr_A_imag; -+ cutlass::DeviceAllocation ptr_B_real; -+ cutlass::DeviceAllocation ptr_B_imag; -+ cutlass::DeviceAllocation ptr_C_real; -+ cutlass::DeviceAllocation ptr_C_imag; -+ cutlass::DeviceAllocation ptr_D_real; -+ cutlass::DeviceAllocation ptr_D_imag; -+ -+ // -+ // Methods -+ // -+ -+ TestbedPlanarComplex( -+ Options const &options -+ ): -+ problem_size(options.problem_size), batch_count(options.batch_count) { -+ -+ // Allocate device memory for batched planar complex GEMM -+ tensor_A.reset(int64_t(problem_size.m()) * problem_size.k() * batch_count * 2); -+ tensor_B.reset(int64_t(problem_size.k()) * problem_size.n() * batch_count * 2); -+ tensor_C.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); -+ tensor_D.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); -+ tensor_D_ref.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); -+ -+ ptr_A_real.reset(batch_count); -+ ptr_A_imag.reset(batch_count); -+ ptr_B_real.reset(batch_count); -+ ptr_B_imag.reset(batch_count); -+ ptr_C_real.reset(batch_count); -+ ptr_C_imag.reset(batch_count); -+ ptr_D_real.reset(batch_count); -+ ptr_D_imag.reset(batch_count); -+ -+ } -+ -+ void initialize() { -+ -+ uint64_t seed = 1073; -+ -+ // Use small integers to simplify correctness checking -+ int scope_max = 6; -+ int scope_min = -6; -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ tensor_A.get(), tensor_A.size(), seed, Element(scope_max), Element(scope_min), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ tensor_B.get(), tensor_B.size(), seed * 2019, Element(scope_max), Element(scope_min), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ tensor_C.get(), tensor_C.size(), seed * 2020, Element(scope_max), Element(scope_min), 0); -+ } -+ -+ Result profile(Options const &options) { -+ -+ Result result; -+ -+ initialize(); -+ -+ Element *ptr_A = tensor_A.get(); -+ Element *ptr_B = tensor_B.get(); -+ Element *ptr_C = tensor_C.get(); -+ Element *ptr_D = tensor_D.get(); -+ -+ int64_t batch_stride_A = int64_t(problem_size.m()) * problem_size.k() * 2; -+ int64_t batch_stride_B = int64_t(problem_size.k()) * problem_size.n() * 2; -+ int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2; -+ int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2; -+ -+ typename LayoutA::Stride::Index lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0); -+ typename LayoutB::Stride::Index ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0); -+ typename LayoutC::Stride::Index ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); -+ typename LayoutC::Stride::Index ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); -+ -+ -+ int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k(); -+ int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n(); -+ int64_t imag_stride_C = int64_t(problem_size.m()) * problem_size.n(); -+ int64_t imag_stride_D = int64_t(problem_size.m()) * problem_size.n(); -+ -+ // -+ // Configure pointers in global memory -+ // -+ -+ struct { -+ Element *base; -+ void **ptr_real; -+ void **ptr_imag; -+ int64_t batch_stride; -+ int64_t imag_stride; -+ } tensors[] = { -+ { tensor_A.get(), ptr_A_real.get(), ptr_A_imag.get(), batch_stride_A, imag_stride_A}, -+ { tensor_B.get(), ptr_B_real.get(), ptr_B_imag.get(), batch_stride_B, imag_stride_B}, -+ { tensor_C.get(), ptr_C_real.get(), ptr_C_imag.get(), batch_stride_C, imag_stride_C}, -+ { tensor_D.get(), ptr_D_real.get(), ptr_D_imag.get(), batch_stride_D, imag_stride_D} -+ }; -+ -+ for (auto const &tensor : tensors) { -+ for (int idx = 0; idx < batch_count; ++idx) { -+ -+ void *ptr_real = tensor.base + idx * tensor.batch_stride; -+ void *ptr_imag = tensor.base + idx * tensor.batch_stride + tensor.imag_stride; -+ -+ cudaError_t error = cudaMemcpy( -+ tensor.ptr_real + idx, -+ &ptr_real, -+ sizeof(void *), -+ cudaMemcpyHostToDevice); -+ -+ if (error != cudaSuccess) { -+ throw std::runtime_error("Failed to copy pointer to device memory"); -+ } -+ -+ error = cudaMemcpy( -+ tensor.ptr_imag + idx, -+ &ptr_imag, -+ sizeof(void *), -+ cudaMemcpyHostToDevice); -+ -+ if (error != cudaSuccess) { -+ throw std::runtime_error("Failed to copy pointer to device memory"); -+ } -+ } -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ -+ // -+ // Execute the planar complex array GEMM kernel via the CUTLASS Library's -+ // dispatch routines. -+ // -+ // Note, for planar complex array GEMM kernels, all numeric type arguments -+ // specify the data type of the base real types. These are understood to -+ // apply to planar complex representations of matrices in memory and to complex -+ // structures for scalars. -+ // -+ // See tools/library/include/cutlass/library/handle.h for more details. -+ // -+ -+ result.status = handle.gemm_planar_complex_array( -+ -+ problem_size.m(), // expected GEMM M dimension -+ problem_size.n(), // expected GEMM N dimension -+ problem_size.k(), // expected GEMM K dimension -+ batch_count, // Number of batched elements -+ -+ nullptr, -+ nullptr, -+ nullptr, -+ -+ cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued accumulation -+ cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued alpha/beta scalars -+ -+ &options.alpha, // Pointer to alpha scalar, of type complex -+ -+ cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued A matrix -+ cutlass::library::LayoutTypeID::kColumnMajor, // Layout of A matrix -+ cutlass::library::ComplexTransform::kConjugate, // Complex transformation on A matrix operand -+ -+ ptr_A_real.get(), // Pointer to array of pointers to real part of A matrix -+ ptr_A_imag.get(), // Pointer to array of pointers to imaginary part of A matrix -+ -+ lda, // Leading dimension of real part of A matrix -+ lda, // Leading dimension of imaginary part of A matrix -+ -+ cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued B matrix -+ cutlass::library::LayoutTypeID::kColumnMajor, // Layout of B matrix -+ cutlass::library::ComplexTransform::kNone, // Complex transformation on B matrix operand -+ -+ ptr_B_real.get(), // Pointer to array of pointers to real part of B matrix -+ ptr_B_imag.get(), // Pointer to array of pointers to imaginary part of B matrix -+ -+ ldb, // Leading dimension of real part of B matrix -+ ldb, // Leading dimension of imaginary part of B matrix -+ -+ &options.beta, // Pointer to beta scalar, of type complex -+ -+ cutlass::library::NumericTypeID::kF16, // Base data type of complex valued C and D matrices -+ -+ ptr_C_real.get(), // Pointer to array of pointers to real part of C matrix -+ ptr_C_imag.get(), // Pointer to array of pointers to imaginary part of C matrix -+ -+ ldc, // Leading dimension of real part of C matrix -+ ldc, // Leading dimension of imaginary part of C matrix -+ -+ ptr_D_real.get(), // Pointer to array of pointers to real part of D matrix -+ ptr_D_imag.get(), // Pointer to array of pointers to imaginary part of D matrix -+ -+ ldd, // Leading dimension of real part of D matrix -+ ldd // Leading dimension of imaginary part of D matrix -+ ); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS internal error - configuration not supported" << std::endl; -+ return result; -+ } -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ if (handle.get_last_operation()) { -+ std::cout << "Recently executed '" << handle.get_last_operation()->description().name << "'" << std::endl; -+ } -+ -+ // -+ // Compute reference in device code -+ // -+ -+ if (options.reference_check) { -+ -+ result.passed = true; -+ -+ for (int64_t idx = 0; result.passed && idx < int64_t(batch_count); ++idx) { -+ cutlass::reference::device::GemmPlanarComplex< -+ Element, LayoutA, -+ Element, LayoutB, -+ Element, LayoutC, -+ ElementAccumulator -+ >( -+ problem_size, -+ options.alpha, -+ {tensor_A.get() + idx * batch_stride_A, lda, imag_stride_A}, -+ cutlass::ComplexTransform::kConjugate, -+ {tensor_B.get() + idx * batch_stride_B, ldb, imag_stride_B}, -+ cutlass::ComplexTransform::kNone, -+ options.beta, -+ {tensor_C.get() + idx * batch_stride_C, ldc, imag_stride_C}, -+ {tensor_D_ref.get() + idx * batch_stride_D, ldd, imag_stride_D} -+ ); -+ -+ Element epsilon = 0.1_hf; -+ Element nonzero_floor = 0.1_hf; -+ -+ result.passed = cutlass::reference::device::BlockCompareRelativelyEqual( -+ tensor_D.get() + idx * batch_stride_D, -+ tensor_D_ref.get() + idx * batch_stride_D, -+ batch_stride_D, -+ epsilon, -+ nonzero_floor -+ ); -+ } -+ -+ if (result.passed) { -+ std::cout << "Reference check passed." << std::endl; -+ } -+ else { -+ std::cerr << "Error - reference check failed." << std::endl; -+ } -+ } -+ -+ std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " GFLOPs: " << result.gflops << std::endl; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // -+ // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. -+ // -+ // Volta Tensor Core operations are first available in CUDA 10.1 Toolkit. -+ // -+ // Turing Tensor Core operations are first available in CUDA 10.2 Toolkit. -+ // -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (props.major < 7) { -+ std::cerr << "Tensor Core operations must be run on a machine with compute capability at least 70." -+ << std::endl; -+ -+ // Returning zero so this passes on older architectures. Its actions are no-op. -+ return 0; -+ } -+ else if (props.major == 7 && props.minor <= 2) { -+ // -+ // If running on the Volta architecture, at least CUDA 10.1 Toolkit is required to run this example. -+ // -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { -+ std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; -+ -+ // Returning zero so this passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ } -+ else if (props.major == 7 && props.minor >= 5) { -+ // -+ // If running on the Turing architecture, at least CUDA 10.2 Toolkit is required to run this example. -+ // -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { -+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; -+ -+ // Returning zero so this passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ } -+ else { -+ // NVIDIA Ampere Architecture GPUs (SM80 and later) are fully supported on CUDA 11 Toolkit and beyond. -+ // -+ // fall through -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ TestbedPlanarComplex testbed(options); -+ -+ Result result = testbed.profile(options); -+ -+ return result.passed ? 0 : -1; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/12_gemm_bias_relu/gemm_bias_relu.cu b/3rdparty/cutlass/examples/12_gemm_bias_relu/gemm_bias_relu.cu -new file mode 100644 -index 0000000..418540f ---- /dev/null -+++ b/3rdparty/cutlass/examples/12_gemm_bias_relu/gemm_bias_relu.cu -@@ -0,0 +1,303 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "helper.h" -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = float; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A -+using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B -+using ElementOutput = float; // <- data type of elements in output matrix D -+ -+// Note that if the output is column major, the bias has to be per row. i.e. every row has different bias. -+// If the output is row major, the bias has to be per column, i.e. every column has different bias. -+// Below list some other notices: -+// -+// Note this example only works for ColumnMajor output because -+// 1) we only have row major epilogue. -+// 2) we swap A and B if the output is column major then we can still use the -+// row major epilogue. -+// 3) Mx1 bias vector becomes 1xM after the swapping/transposing. -+// 4) we can use the existing OutputIterator to load 1xM bias vector. -+ -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::ColumnMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm75; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 8, N = 8, K = 4 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to -+// -+// d_ij = max(0, alpha * sum_k(a_ik * b_kj) + c_ij ) -+// -+using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- this is the number of elements per -+ // vectorized memory access. For half -+ // precision, it's 8 elements. This becomes -+ // the vector width of math instructions in -+ // epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue, // <- data type for alpha in linear combination function -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling>; // <- alpha x C + bias -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 2; -+ -+using Gemm = cutlass::gemm::device::Gemm; -+ -+int run() { -+ -+ const int length_m = 5120; -+ const int length_n = 4096; -+ const int length_k = 4096; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ -+ cutlass::HostTensor tensor_c_bias( -+ {problem_size.m(), 1}); // <- Create matrix C with dimensions M x 1 -+ -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c_bias.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c_bias.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{ -+ problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ -+ {tensor_c_bias.device_data(), 0}, // <- the C matrix is treated as the bias vector. We can enable the GEMM -+ // to project away the N dimension by setting the stride to zero. -+ -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ {alpha}, // <- alpha -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ -+ // -+ // Create instantiation for device reference gemm kernel -+ // -+ -+ cutlass::reference::device::Gemm -+ gemm_device_reference; -+ -+ // Launch device reference to compute strictly the product A * B -+ gemm_device_reference( -+ problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ 0, -+ tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Compute bias + relu in host code -+ for (int i = 0; i < problem_size.m(); ++i) { -+ for (int j = 0; j < problem_size.n(); ++j) { -+ tensor_ref_d.at({i, j}) = std::max( -+ ElementOutput(0), -+ ElementOutput(tensor_ref_d.at({i, j}) + tensor_c_bias.at({i, 0})) -+ ); -+ } -+ } -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(), -+ tensor_ref_d.host_view()) -+ ? "Passed" -+ : "Failed") -+ << std::endl; -+ -+ CUTLASS_CHECK(status); -+ return 0; -+} -+ -+int main() { -+ -+ bool notSupported = false; -+ -+ // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. -+ // -+ // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { -+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (!(props.major * 10 + props.minor >= 75)) { -+ std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ return run(); -+} -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h -new file mode 100644 -index 0000000..9da0e66 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h -@@ -0,0 +1,719 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/reference/device/tensor_relu.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "reference/device/tensor_scale_bias.h" -+#include "helper.h" -+ -+#define CHECK_GT(val1, val2) \ -+ if((val1) <= (val2)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; -+#define CHECK_TRUE(val) \ -+ if(!(val)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; -+ -+ -+template -+class B2bNonFusedConv2dRun { -+public: -+ -+ using Conv2d0 = Conv2d0_; -+ using Conv2d1 = Conv2d1_; -+ using ElementAccumulator = typename Conv2d0::ElementAccumulator; -+ using ElementCompute = typename Conv2d0::ElementCompute; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator; -+ static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator, -+ "Fused convolution operators must be the same"); -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A0; -+ cutlass::HostTensor tensor_B0; -+ cutlass::HostTensor tensor_C0; -+ cutlass::HostTensor tensor_Bias0; -+ cutlass::HostTensor tensor_D0_computed; -+ cutlass::HostTensor tensor_D0_reference; -+ -+ cutlass::HostTensor tensor_B1; -+ cutlass::HostTensor tensor_C1; -+ cutlass::HostTensor tensor_Bias1; -+ cutlass::HostTensor tensor_D1_computed; -+ cutlass::HostTensor tensor_D1_reference; -+ -+ -+public: -+ -+ B2bNonFusedConv2dRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 16) { -+ scope = 2; -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ uint64_t seed = 2019) { -+ -+ tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_Bias0.resize({1, 1, 1, problem_size_0.K}); -+ tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); -+ tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ -+ initialize_tensor(tensor_A0.host_view(), init_A, seed); -+ initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); -+ initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); -+ initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); -+ initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); -+ initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_C0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0_computed.sync_device(); -+ tensor_D0_reference.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1_computed.sync_device(); -+ tensor_D1_reference.sync_device(); -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ initialize(problem_size_0, problem_size_1); -+ -+ // configure the operator -+ Conv2d0 conv2d_op_0; -+ Conv2d1 conv2d_op_1; -+ -+ typename Conv2d0::Arguments conv2d_args_0( -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ {tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)}, -+ tensor_D0_computed.device_ref(), -+ {alpha0, beta0}, -+ split_k_mode -+ ); -+ typename Conv2d1::Arguments conv2d_args_1( -+ problem_size_1, -+ tensor_D0_computed.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)}, -+ tensor_D1_computed.device_ref(), -+ {alpha1, beta1}, -+ split_k_mode -+ ); -+ -+ -+ cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0); -+ -+ CUTLASS_CHECK(status); -+ -+ status = conv2d_op_1.initialize(conv2d_args_1); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = conv2d_op_0(); -+ CUTLASS_CHECK(status); -+ status = conv2d_op_1(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run Conv2d -+ // -+ cudaEvent_t start, stop1, stop2; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop1); -+ cudaEventCreate(&stop2); -+ -+ cudaEventRecord(start); -+ -+ -+ for(int i = 0; i < runs; i++) { -+ // run conv2d operator -+ status = conv2d_op_0(); -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop1); -+ -+ for(int i = 0; i < runs; i++) { -+ // run conv2d operator -+ status = conv2d_op_1(); -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop2); -+ cudaDeviceSynchronize(); -+ float conv2d0Time, conv2d1Time, totalTime; -+ cudaEventElapsedTime(&conv2d0Time, start, stop1); -+ cudaEventElapsedTime(&conv2d1Time, stop1, stop2); -+ cudaEventElapsedTime(&totalTime, start, stop2); -+ std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n"; -+ std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n"; -+ std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; -+ -+ tensor_D0_computed.sync_host(); -+ tensor_D1_computed.sync_host(); -+ -+ bool passed = false; -+ -+ cutlass::reference::device::Conv2d< -+ typename Conv2d0::ElementA, -+ typename Conv2d0::LayoutA, -+ typename Conv2d0::ElementB, -+ typename Conv2d0::LayoutB, -+ typename Conv2d0::ElementC, -+ typename Conv2d0::LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ {tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)}, -+ tensor_D0_reference.device_ref(), -+ alpha0, -+ beta0); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); -+ } -+ -+ cutlass::reference::device::Conv2d< -+ typename Conv2d1::ElementA, -+ typename Conv2d1::LayoutA, -+ typename Conv2d1::ElementB, -+ typename Conv2d1::LayoutB, -+ typename Conv2d1::ElementC, -+ typename Conv2d1::LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size_1, -+ tensor_D0_reference.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)}, -+ tensor_D1_reference.device_ref(), -+ alpha1, -+ beta1); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ CHECK_TRUE(result == cudaSuccess); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D0_reference.sync_host(); -+ tensor_D1_reference.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D1_computed.host_view(), -+ tensor_D1_reference.host_view()); -+ -+ CHECK_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_B2bImplicitGemm_device_nonfused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size_0 << std::endl; -+ results << problem_size_1 << std::endl; -+ -+ results -+ << "\nA0:\n" << tensor_A0.host_view() << "\n" -+ << "\nB0:\n" << tensor_B0.host_view() << "\n" -+ << "\nC0:\n" << tensor_C0.host_view() << "\n" -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n" -+ << "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n" -+ << "\nB1:\n" << tensor_B1.host_view() << "\n" -+ << "\nC1:\n" << tensor_C1.host_view() << "\n" -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" -+ << "\nD1 computed:\n" << tensor_D1_computed.host_view(); -+ -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+template -+class B2bFusedConv2dRun { -+public: -+ -+ using B2bConv2d = B2bConv2d_; -+ using ElementAccumulator = typename B2bConv2d::ElementAccumulator; -+ using ElementCompute = typename B2bConv2d::ElementCompute; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Scale; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A0; -+ cutlass::HostTensor tensor_B0; -+ cutlass::HostTensor tensor_C0; -+ cutlass::HostTensor tensor_Scale0; -+ cutlass::HostTensor tensor_Bias0; -+ cutlass::HostTensor tensor_Z0_reference; -+ cutlass::HostTensor tensor_D0_reference; -+ -+ cutlass::HostTensor tensor_B1; -+ cutlass::HostTensor tensor_C1; -+ cutlass::HostTensor tensor_Bias1; -+ cutlass::HostTensor tensor_D1_computed; -+ cutlass::HostTensor tensor_D1_reference; -+ -+ -+public: -+ -+ B2bFusedConv2dRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), -+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 16) { -+ scope = 2; -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ ElementCompute alpha0, -+ ElementCompute alpha1, -+ uint64_t seed = 2019) { -+ -+ tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.resize({1, problem_size_0.K}); -+ tensor_Bias0.resize({1, problem_size_0.K}); -+ tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); -+ tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ -+ initialize_tensor(tensor_A0.host_view(), init_A, seed); -+ initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61); -+ initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); -+ initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); -+ initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); -+ initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_C0.sync_device(); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0_reference.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1_computed.sync_device(); -+ tensor_D1_reference.sync_device(); -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ initialize(problem_size_0, problem_size_1, alpha0, alpha1); -+ -+ // configure the operator -+ B2bConv2d b2b_conv2d_op; -+ -+ typename B2bConv2d::Arguments b2b_conv2d_args( -+ problem_size_0, -+ problem_size_1, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ tensor_C0.device_ref(), -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)}, -+ tensor_D1_computed.device_ref(), -+ {alpha0, beta0}, -+ {alpha1, beta1}, -+ split_k_mode -+ ); -+ -+ cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args); -+ -+ if(status != cutlass::Status::kSuccess) { -+ std::cout << "Problem sizes not supported.\n" -+ << "Requirments:\n" -+ << " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n" -+ << " problem_size_0.K = problem_size_1.C\n" -+ << " problem_size_1.R = problem_size_1.S = 1\n" -+ << " ThreadblockShape0::kN = problem_size_0.K\n" -+ << " ThreadblockShape1::kN = problem_size_1.K" << std::endl; -+ } -+ -+ CUTLASS_CHECK(status); -+ -+ status = b2b_conv2d_op.initialize(b2b_conv2d_args); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = b2b_conv2d_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run the Conv2d -+ // -+ -+ cudaEvent_t start, stop; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ -+ // run conv2d operator -+ status = b2b_conv2d_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop); -+ cudaDeviceSynchronize(); -+ float conv2dTime; -+ cudaEventElapsedTime(&conv2dTime, start, stop); -+ std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n"; -+ -+ tensor_D1_computed.sync_host(); -+ -+ bool passed = false; -+ -+ cutlass::reference::device::Conv2d< -+ typename B2bConv2d::ElementA, -+ typename B2bConv2d::LayoutA, -+ typename B2bConv2d::ElementB, -+ typename B2bConv2d::LayoutB, -+ ElementAccumulator, -+ typename B2bConv2d::LayoutC, -+ ElementAccumulator, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ tensor_Z0_reference.device_ref(), -+ tensor_Z0_reference.device_ref(), -+ ElementAccumulator(1), // intermediate alpha = 1 -+ ElementAccumulator(0) // beta = 0 -+ ); -+ -+ cutlass::reference::device::TensorScaleBiasConv2d< -+ ElementAccumulator, -+ typename B2bConv2d::ElementC, -+ typename B2bConv2d::LayoutC, -+ ElementCompute, -+ typename B2bConv2d::LayoutScaleBias -+ >( -+ problem_size_0, -+ tensor_Z0_reference.device_ref(), -+ tensor_D0_reference.device_ref(), -+ alpha0, -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); -+ } -+ -+ cutlass::reference::device::Conv2d< -+ typename B2bConv2d::ElementA, -+ typename B2bConv2d::LayoutA, -+ typename B2bConv2d::ElementB, -+ typename B2bConv2d::LayoutB, -+ typename B2bConv2d::ElementC, -+ typename B2bConv2d::LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size_1, -+ tensor_D0_reference.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)}, -+ tensor_D1_reference.device_ref(), -+ alpha1, -+ beta1); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ CHECK_TRUE(result == cudaSuccess); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D0_reference.sync_host(); -+ tensor_D1_reference.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D1_computed.host_view(), -+ tensor_D1_reference.host_view()); -+ -+ CHECK_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_B2bImplicitGemm_device_fused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size_0 << std::endl; -+ results << problem_size_1 << std::endl; -+ -+ results -+ << "\nA0:\n" << tensor_A0.host_view() << "\n" -+ << "\nB0:\n" << tensor_B0.host_view() << "\n" -+ << "\nC0:\n" << tensor_C0.host_view() << "\n" -+ << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nB1:\n" << tensor_B1.host_view() << "\n" -+ << "\nC1:\n" << tensor_C1.host_view() << "\n" -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" -+ << "\nD1 computed:\n" << tensor_D1_computed.host_view(); -+ -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h -new file mode 100644 -index 0000000..b8b080c ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h -@@ -0,0 +1,714 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/device/tensor_relu.h" -+ -+#include "reference/device/tensor_scale_bias.h" -+#include "helper.h" -+ -+#define CHECK_GT(val1, val2) \ -+ if((val1) <= (val2)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; -+#define CHECK_TRUE(val) \ -+ if(!(val)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct B2bNonFusedGemmRun -+{ -+ -+ using Gemm0 = Gemm0_; -+ using Gemm1 = Gemm1_; -+ using ElementAccumulator = typename Gemm0::ElementAccumulator; -+ using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ B2bNonFusedGemmRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size_0, -+ cutlass::gemm::GemmCoord problem_size_1, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementA, -+ typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementB, -+ typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ ElementCompute, -+ typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementB, -+ typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ ElementCompute, -+ typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); -+ -+ -+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); -+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); -+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); -+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); -+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ tensor_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D1.host_view()); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_C0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1.sync_device(); -+ reference_D0.sync_device(); -+ reference_D1.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm0::Arguments arguments_0{ -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, -+ tensor_D0.device_ref(), -+ {alpha0, beta0} -+ }; -+ -+ typename Gemm1::Arguments arguments_1{ -+ problem_size_1, -+ tensor_D0.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, -+ tensor_D1.device_ref(), -+ {alpha1, beta1} -+ }; -+ -+ -+ Gemm0 gemm_op_0; -+ Gemm1 gemm_op_1; -+ -+ cutlass::Status status = gemm_op_0.initialize(arguments_0); -+ -+ CUTLASS_CHECK(status); -+ -+ status = gemm_op_1.initialize(arguments_1); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = gemm_op_0(); -+ CUTLASS_CHECK(status); -+ status = gemm_op_1(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run the GEMM -+ // -+ cudaEvent_t start, stop1, stop2; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop1); -+ cudaEventCreate(&stop2); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ status = gemm_op_0(); -+ -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop1); -+ for(int i = 0; i < runs; i++) { -+ status = gemm_op_1(); -+ -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop2); -+ cudaDeviceSynchronize(); -+ float gemm0Time, gemm1Time, totalTime; -+ cudaEventElapsedTime(&gemm0Time, start, stop1); -+ cudaEventElapsedTime(&gemm1Time, stop1, stop2); -+ cudaEventElapsedTime(&totalTime, start, stop2); -+ std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; -+ std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; -+ std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; -+ -+ tensor_D0.sync_host(); -+ tensor_D1.sync_host(); -+ -+ // -+ // Verify -+ // -+ cutlass::reference::device::Gemm< -+ typename Gemm0::ElementA, typename Gemm0::LayoutA, -+ typename Gemm0::ElementB, typename Gemm0::LayoutB, -+ typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm0::Operator> -+ reference_gemm_0; -+ -+ cutlass::reference::device::Gemm< -+ typename Gemm1::ElementA, typename Gemm1::LayoutA, -+ typename Gemm1::ElementB, typename Gemm1::LayoutB, -+ typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm1::Operator> -+ reference_gemm_1; -+ -+ reference_gemm_0( -+ problem_size_0, -+ alpha0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ beta0, -+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, -+ reference_D0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D0.device_view()); -+ } -+ -+ reference_gemm_1( -+ problem_size_1, -+ alpha1, -+ reference_D0.device_ref(), -+ tensor_B1.device_ref(), -+ beta1, -+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, -+ reference_D1.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D1.device_view()); -+ } -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ reference_D0.sync_host(); -+ reference_D1.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ -+ CHECK_TRUE(passed); -+ if (!passed) { -+ -+ std::stringstream fname; -+ -+ fname << "error_B2bGemm_device_nonfused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A0 =\n" << tensor_A0.host_view() -+ << "\nB0 =\n" << tensor_B0.host_view() -+ << "\nC0 =\n" << tensor_C0.host_view() -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nD0 =\n" << tensor_D0.host_view() -+ << "\nB1 =\n" << tensor_B1.host_view() -+ << "\nC1 =\n" << tensor_C1.host_view() -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\n\nReference =\n" << reference_D1.host_view() -+ << "\nComputed =\n" << tensor_D1.host_view(); -+ } -+ return passed; -+ } -+}; -+ -+template -+struct B2bFusedGemmRun -+{ -+ -+ using B2bGemm = B2bGemm_; -+ using ElementAccumulator = typename B2bGemm::ElementAccumulator; -+ using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Scale; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ B2bFusedGemmRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), -+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size_0, -+ cutlass::gemm::GemmCoord problem_size_1, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementA, -+ typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementB, -+ typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementScaleBias, -+ typename B2bGemm::LayoutScaleBias> tensor_Scale0; -+ -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.resize({1, problem_size_0.n()}); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementScaleBias, -+ typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()}); -+ -+ cutlass::HostTensor< -+ ElementAccumulator, -+ typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementB, -+ typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ ElementCompute, -+ typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()}); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn()); -+ -+ -+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); -+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); -+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013)); -+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); -+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D1.host_view()); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_C0.sync_device(); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1.sync_device(); -+ reference_D0.sync_device(); -+ reference_D1.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename B2bGemm::Arguments arguments{ -+ problem_size_0, -+ problem_size_1, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ tensor_C0.device_ref(), -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, -+ tensor_D1.device_ref(), -+ {alpha0, beta0}, -+ {alpha1, beta1}, -+ }; -+ -+ B2bGemm b2b_gemm_op; -+ -+ cutlass::Status status = b2b_gemm_op.can_implement(arguments); -+ -+ if(status != cutlass::Status::kSuccess) { -+ std::cout << "Problem sizes not supported.\n" -+ << "Requirments:\n" -+ << " problem_size_0.M = problem_size_1.M\n" -+ << " problem_size_0.N = problem_size_1.K\n" -+ << " ThreadblockShape0::kN = problem_size_0.N\n" -+ << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; -+ } -+ -+ status = b2b_gemm_op.initialize(arguments); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = b2b_gemm_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run the GEMM -+ // -+ -+ cudaEvent_t start, stop; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ status = b2b_gemm_op(); -+ -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop); -+ cudaDeviceSynchronize(); -+ float gemmTime; -+ cudaEventElapsedTime(&gemmTime, start, stop); -+ std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; -+ -+ tensor_D1.sync_host(); -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::device::Gemm< -+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA, -+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB, -+ ElementAccumulator, typename B2bGemm::LayoutC, -+ ElementAccumulator, ElementAccumulator> -+ reference_gemm_0; -+ -+ cutlass::reference::device::Gemm< -+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA, -+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB, -+ typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename B2bGemm::Operator> -+ reference_gemm_1; -+ -+ reference_gemm_0( -+ problem_size_0, -+ ElementAccumulator(1), //intermediate alpha=1 -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ ElementAccumulator(0), //beta = 0 -+ reference_Z0.device_ref(), -+ reference_Z0.device_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ cutlass::reference::device::TensorScaleBiasGemm< -+ ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, -+ ElementCompute, typename B2bGemm::LayoutScaleBias -+ > ( -+ problem_size_0, -+ reference_Z0.device_ref(), -+ reference_D0.device_ref(), -+ alpha0, -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D0.device_view()); -+ } -+ -+ reference_gemm_1( -+ problem_size_1, -+ alpha1, -+ reference_D0.device_ref(), -+ tensor_B1.device_ref(), -+ beta1, -+ {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, -+ reference_D1.device_ref() -+ ); -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D1.device_view()); -+ } -+ cudaDeviceSynchronize(); -+ reference_D0.sync_host(); -+ reference_D1.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ -+ CHECK_TRUE(passed); -+ if (!passed) -+ { -+ -+ std::stringstream fname; -+ -+ fname << "error_B2bGemm_device_fused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A0 =\n" << tensor_A0.host_view() -+ << "\nB0 =\n" << tensor_B0.host_view() -+ << "\nC0 =\n" << tensor_C0.host_view() -+ << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nB1 =\n" << tensor_B1.host_view() -+ << "\nC1 =\n" << tensor_C1.host_view() -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\n\nReference =\n" << reference_D1.host_view() -+ << "\nComputed =\n" << tensor_D1.host_view(); -+ } -+ return passed; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h -new file mode 100644 -index 0000000..a6d0625 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h -@@ -0,0 +1,749 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/host_reorder.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/reference/device/tensor_relu.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "reference/device/tensor_scale_bias.h" -+#include "helper.h" -+ -+#define CHECK_GT(val1, val2) \ -+ if((val1) <= (val2)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; -+#define CHECK_TRUE(val) \ -+ if(!(val)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; -+ -+ -+template -+class B2bInterleavedNonFusedConv2dRun { -+public: -+ -+ using Conv2d0 = Conv2d0_; -+ using Conv2d1 = Conv2d1_; -+ using ElementAccumulator = typename Conv2d0::ElementAccumulator; -+ using ElementCompute = typename Conv2d0::ElementCompute; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator; -+ static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator, -+ "Fused convolution operators must be the same"); -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A0; -+ cutlass::HostTensor tensor_B0; -+ cutlass::HostTensor tensor_B0_reordered; -+ cutlass::HostTensor tensor_C0; -+ cutlass::HostTensor tensor_Bias0; -+ cutlass::HostTensor tensor_D0_computed; -+ cutlass::HostTensor tensor_D0_reference; -+ -+ cutlass::HostTensor tensor_B1; -+ cutlass::HostTensor tensor_B1_reordered; -+ cutlass::HostTensor tensor_C1; -+ cutlass::HostTensor tensor_Bias1; -+ cutlass::HostTensor tensor_D1_computed; -+ cutlass::HostTensor tensor_D1_reference; -+ -+ -+public: -+ -+ B2bInterleavedNonFusedConv2dRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 16) { -+ scope = 2; -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, uint64_t seed = 2019) { -+ -+ tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_Bias0.resize({1, 1, 1, problem_size_0.K}); -+ tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); -+ tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ -+ initialize_tensor(tensor_A0.host_view(), init_A, seed); -+ initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); -+ initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); -+ initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); -+ initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); -+ -+ //Reorder B0 and B1 -+ cutlass::reorder_convK( -+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0)); -+ cutlass::reorder_convK( -+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1)); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_B0_reordered.sync_device(); -+ tensor_C0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0_computed.sync_device(); -+ tensor_D0_reference.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_B1_reordered.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1_computed.sync_device(); -+ tensor_D1_reference.sync_device(); -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ initialize(problem_size_0, problem_size_1); -+ -+ // configure the operator -+ Conv2d0 conv2d_op_0; -+ Conv2d1 conv2d_op_1; -+ -+ typename Conv2d0::Arguments conv2d_args_0( -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0_reordered.device_ref(), -+ tensor_C0.device_ref(), -+ tensor_D0_computed.device_ref(), -+ {alpha0, beta0}, -+ split_k_mode -+ ); -+ typename Conv2d1::Arguments conv2d_args_1( -+ problem_size_1, -+ tensor_D0_computed.device_ref(), -+ tensor_B1_reordered.device_ref(), -+ tensor_C1.device_ref(), -+ tensor_D1_computed.device_ref(), -+ {alpha1, beta1}, -+ split_k_mode -+ ); -+ -+ -+ cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0); -+ -+ CUTLASS_CHECK(status); -+ -+ status = conv2d_op_1.initialize(conv2d_args_1); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = conv2d_op_0(); -+ CUTLASS_CHECK(status); -+ status = conv2d_op_1(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run Conv2d -+ // -+ cudaEvent_t start, stop1, stop2; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop1); -+ cudaEventCreate(&stop2); -+ -+ cudaEventRecord(start); -+ -+ -+ for(int i = 0; i < runs; i++) { -+ // run conv2d operator -+ status = conv2d_op_0(); -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop1); -+ -+ for(int i = 0; i < runs; i++) { -+ // run conv2d operator -+ status = conv2d_op_1(); -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop2); -+ cudaDeviceSynchronize(); -+ float conv2d0Time, conv2d1Time, totalTime; -+ cudaEventElapsedTime(&conv2d0Time, start, stop1); -+ cudaEventElapsedTime(&conv2d1Time, stop1, stop2); -+ cudaEventElapsedTime(&totalTime, start, stop2); -+ std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n"; -+ std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n"; -+ std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; -+ -+ tensor_D0_computed.sync_host(); -+ tensor_D1_computed.sync_host(); -+ -+ bool passed = false; -+ -+ cutlass::reference::device::Conv2d< -+ typename Conv2d0::ElementA, -+ typename Conv2d0::LayoutA, -+ typename Conv2d0::ElementB, -+ typename Conv2d0::LayoutB, -+ typename Conv2d0::ElementC, -+ typename Conv2d0::LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ cutlass::NumericConverterClamp -+ >( -+ kConvolutionalOperator, -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ tensor_C0.device_ref(), -+ tensor_D0_reference.device_ref(), -+ alpha0, -+ beta0); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); -+ } -+ -+ cutlass::reference::device::Conv2d< -+ typename Conv2d1::ElementA, -+ typename Conv2d1::LayoutA, -+ typename Conv2d1::ElementB, -+ typename Conv2d1::LayoutB, -+ typename Conv2d1::ElementC, -+ typename Conv2d1::LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ cutlass::NumericConverterClamp -+ >( -+ kConvolutionalOperator, -+ problem_size_1, -+ tensor_D0_reference.device_ref(), -+ tensor_B1.device_ref(), -+ tensor_C1.device_ref(), -+ tensor_D1_reference.device_ref(), -+ alpha1, -+ beta1); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ CHECK_TRUE(result == cudaSuccess); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D0_reference.sync_host(); -+ tensor_D1_reference.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D1_computed.host_view(), -+ tensor_D1_reference.host_view()); -+ -+ CHECK_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_B2bImplicitGemm_device_interleaved_nonfused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size_0 << std::endl; -+ results << problem_size_1 << std::endl; -+ -+ results -+ << "\nA0:\n" << tensor_A0.host_view() << "\n" -+ << "\nB0:\n" << tensor_B0.host_view() << "\n" -+ << "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n" -+ << "\nC0:\n" << tensor_C0.host_view() << "\n" -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n" -+ << "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n" -+ << "\nB1:\n" << tensor_B1.host_view() << "\n" -+ << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n" -+ << "\nC1:\n" << tensor_C1.host_view() << "\n" -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" -+ << "\nD1 computed:\n" << tensor_D1_computed.host_view(); -+ -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+template -+class B2bInterleavedFusedConv2dRun { -+public: -+ -+ using B2bConv2d = B2bConv2d_; -+ using ElementAccumulator = typename B2bConv2d::ElementAccumulator; -+ using ElementCompute = typename B2bConv2d::ElementCompute; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Scale; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A0; -+ cutlass::HostTensor tensor_B0; -+ cutlass::HostTensor tensor_B0_reordered; -+ cutlass::HostTensor tensor_C0; -+ cutlass::HostTensor tensor_Scale0; -+ cutlass::HostTensor tensor_Bias0; -+ cutlass::HostTensor tensor_Z0_reference; -+ cutlass::HostTensor tensor_D0_reference; -+ -+ cutlass::HostTensor tensor_B1; -+ cutlass::HostTensor tensor_B1_reordered; -+ cutlass::HostTensor tensor_C1; -+ cutlass::HostTensor tensor_Bias1; -+ cutlass::HostTensor tensor_D1_computed; -+ cutlass::HostTensor tensor_D1_reference; -+ -+ -+public: -+ -+ B2bInterleavedFusedConv2dRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), -+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 16) { -+ scope = 2; -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ ElementCompute alpha0, -+ ElementCompute alpha1, -+ uint64_t seed = 2019) { -+ -+ tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.resize({1, problem_size_0.K}); -+ tensor_Bias0.resize({1, problem_size_0.K}); -+ tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); -+ tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ -+ initialize_tensor(tensor_A0.host_view(), init_A, seed); -+ initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61); -+ initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); -+ initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); -+ initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); -+ initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84); -+ -+ //Reorder B0 and B1 -+ cutlass::reorder_convK<16, InterleavedK>( -+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0)); -+ cutlass::reorder_convK( -+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1)); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_B0_reordered.sync_device(); -+ tensor_C0.sync_device(); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0_reference.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_B1_reordered.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1_computed.sync_device(); -+ tensor_D1_reference.sync_device(); -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ initialize(problem_size_0, problem_size_1, alpha0, alpha1); -+ -+ // configure the operator -+ B2bConv2d b2b_conv2d_op; -+ -+ typename B2bConv2d::Arguments b2b_conv2d_args( -+ problem_size_0, -+ problem_size_1, -+ tensor_A0.device_ref(), -+ tensor_B0_reordered.device_ref(), -+ tensor_C0.device_ref(), -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref(), -+ tensor_B1_reordered.device_ref(), -+ tensor_C1.device_ref(), -+ tensor_D1_computed.device_ref(), -+ {alpha0, beta0}, -+ {alpha1, beta1}, -+ split_k_mode -+ ); -+ -+ cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args); -+ -+ if(status != cutlass::Status::kSuccess) { -+ std::cout << "Problem sizes not supported.\n" -+ << "Requirments:\n" -+ << " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n" -+ << " problem_size_0.K = problem_size_1.C\n" -+ << " problem_size_1.R = problem_size_1.S = 1\n" -+ << " ThreadblockShape0::kN = problem_size_0.K\n" -+ << " ThreadblockShape1::kN = problem_size_1.K" << std::endl; -+ } -+ -+ CUTLASS_CHECK(status); -+ -+ status = b2b_conv2d_op.initialize(b2b_conv2d_args); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = b2b_conv2d_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run the Conv2d -+ // -+ -+ cudaEvent_t start, stop; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ -+ // run conv2d operator -+ status = b2b_conv2d_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop); -+ cudaDeviceSynchronize(); -+ float conv2dTime; -+ cudaEventElapsedTime(&conv2dTime, start, stop); -+ std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n"; -+ -+ tensor_D1_computed.sync_host(); -+ -+ bool passed = false; -+ -+ cutlass::reference::device::Conv2d< -+ typename B2bConv2d::ElementA, -+ typename B2bConv2d::LayoutA, -+ typename B2bConv2d::ElementB, -+ typename B2bConv2d::LayoutB, -+ ElementAccumulator, -+ typename B2bConv2d::LayoutC, -+ ElementAccumulator, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ tensor_Z0_reference.device_ref(), -+ tensor_Z0_reference.device_ref(), -+ ElementAccumulator(1), // intermediate alpha = 1 -+ ElementAccumulator(0) // beta = 0 -+ ); -+ -+ cutlass::reference::device::TensorScaleBiasConv2d< -+ ElementAccumulator, -+ typename B2bConv2d::ElementC, -+ typename B2bConv2d::LayoutC, -+ ElementCompute, -+ typename B2bConv2d::LayoutScaleBias, -+ cutlass::NumericConverterClamp -+ >( -+ problem_size_0, -+ tensor_Z0_reference.device_ref(), -+ tensor_D0_reference.device_ref(), -+ alpha0, -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); -+ } -+ -+ cutlass::reference::device::Conv2d< -+ typename B2bConv2d::ElementA, -+ typename B2bConv2d::LayoutA, -+ typename B2bConv2d::ElementB, -+ typename B2bConv2d::LayoutB, -+ typename B2bConv2d::ElementC, -+ typename B2bConv2d::LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ cutlass::NumericConverterClamp -+ >( -+ kConvolutionalOperator, -+ problem_size_1, -+ tensor_D0_reference.device_ref(), -+ tensor_B1.device_ref(), -+ tensor_C1.device_ref(), -+ tensor_D1_reference.device_ref(), -+ alpha1, -+ beta1); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ CHECK_TRUE(result == cudaSuccess); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D0_reference.sync_host(); -+ tensor_D1_reference.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D1_computed.host_view(), -+ tensor_D1_reference.host_view()); -+ -+ CHECK_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_B2bImplicitGemm_device_interleaved_fused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size_0 << std::endl; -+ results << problem_size_1 << std::endl; -+ -+ results -+ << "\nA0:\n" << tensor_A0.host_view() << "\n" -+ << "\nB0:\n" << tensor_B0.host_view() << "\n" -+ << "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n" -+ << "\nC0:\n" << tensor_C0.host_view() << "\n" -+ << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nB1:\n" << tensor_B1.host_view() << "\n" -+ << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n" -+ << "\nC1:\n" << tensor_C1.host_view() << "\n" -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" -+ << "\nD1 computed:\n" << tensor_D1_computed.host_view(); -+ -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h -new file mode 100644 -index 0000000..51ff1bb ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h -@@ -0,0 +1,749 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/device/tensor_relu.h" -+ -+#include "reference/device/tensor_scale_bias.h" -+#include "helper.h" -+ -+#define CHECK_GT(val1, val2) \ -+ if((val1) <= (val2)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; -+#define CHECK_TRUE(val) \ -+ if(!(val)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; -+ -+template -+struct B2bInterleavedNonFusedGemmRun -+{ -+ -+ using Gemm0 = Gemm0_; -+ using Gemm1 = Gemm1_; -+ using ElementAccumulator = typename Gemm0::ElementAccumulator; -+ using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ B2bInterleavedNonFusedGemmRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size_0, -+ cutlass::gemm::GemmCoord problem_size_1, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementA, -+ typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementB, -+ typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementB, -+ typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementB, -+ typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementB, -+ typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); -+ -+ -+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); -+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); -+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); -+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); -+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); -+ -+ //Reorder B0 and B1 -+ cutlass::reorder_column( -+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0); -+ cutlass::reorder_column( -+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ tensor_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D1.host_view()); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_B0_reordered.sync_device(); -+ tensor_C0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_B1_reordered.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1.sync_device(); -+ reference_D0.sync_device(); -+ reference_D1.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm0::Arguments arguments_0{ -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0_reordered.device_ref(), -+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, -+ tensor_D0.device_ref(), -+ {alpha0, beta0} -+ }; -+ -+ typename Gemm1::Arguments arguments_1{ -+ problem_size_1, -+ tensor_D0.device_ref(), -+ tensor_B1_reordered.device_ref(), -+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, -+ tensor_D1.device_ref(), -+ {alpha1, beta1} -+ }; -+ -+ -+ Gemm0 gemm_op_0; -+ Gemm1 gemm_op_1; -+ -+ cutlass::Status status = gemm_op_0.initialize(arguments_0); -+ -+ CUTLASS_CHECK(status); -+ -+ status = gemm_op_1.initialize(arguments_1); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = gemm_op_0(); -+ CUTLASS_CHECK(status); -+ status = gemm_op_1(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run the GEMM -+ // -+ cudaEvent_t start, stop1, stop2; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop1); -+ cudaEventCreate(&stop2); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ status = gemm_op_0(); -+ -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop1); -+ for(int i = 0; i < runs; i++) { -+ status = gemm_op_1(); -+ -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop2); -+ cudaDeviceSynchronize(); -+ float gemm0Time, gemm1Time, totalTime; -+ cudaEventElapsedTime(&gemm0Time, start, stop1); -+ cudaEventElapsedTime(&gemm1Time, stop1, stop2); -+ cudaEventElapsedTime(&totalTime, start, stop2); -+ std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; -+ std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; -+ std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; -+ -+ tensor_D0.sync_host(); -+ tensor_D1.sync_host(); -+ -+ // -+ // Verify -+ // -+ cutlass::reference::device::Gemm< -+ typename Gemm0::ElementA, typename Gemm0::LayoutA, -+ typename Gemm0::ElementB, typename Gemm0::LayoutB, -+ typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm0::Operator> -+ reference_gemm_0; -+ -+ cutlass::reference::device::Gemm< -+ typename Gemm1::ElementA, typename Gemm1::LayoutA, -+ typename Gemm1::ElementB, typename Gemm1::LayoutB, -+ typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm1::Operator> -+ reference_gemm_1; -+ -+ reference_gemm_0( -+ problem_size_0, -+ alpha0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ beta0, -+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, -+ reference_D0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D0.device_view()); -+ } -+ -+ reference_gemm_1( -+ problem_size_1, -+ alpha1, -+ reference_D0.device_ref(), -+ tensor_B1.device_ref(), -+ beta1, -+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, -+ reference_D1.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D1.device_view()); -+ } -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ reference_D0.sync_host(); -+ reference_D1.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ -+ CHECK_TRUE(passed); -+ if (!passed) { -+ -+ std::stringstream fname; -+ -+ fname << "error_B2bGemm_device_interleaved_nonfused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A0 =\n" << tensor_A0.host_view() -+ << "\nB0 =\n" << tensor_B0.host_view() -+ << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() -+ << "\nC0 =\n" << tensor_C0.host_view() -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nD0 =\n" << tensor_D0.host_view() -+ << "\nB1 =\n" << tensor_B1.host_view() -+ << "\nB1_reordered =\n" << tensor_B1_reordered.host_view() -+ << "\nC1 =\n" << tensor_C1.host_view() -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\n\nReference =\n" << reference_D1.host_view() -+ << "\nComputed =\n" << tensor_D1.host_view(); -+ } -+ return passed; -+ } -+}; -+ -+template -+struct B2bInterleavedFusedGemmRun -+{ -+ -+ using B2bGemm = B2bGemm_; -+ using ElementAccumulator = typename B2bGemm::ElementAccumulator; -+ using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Scale; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ B2bInterleavedFusedGemmRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), -+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size_0, -+ cutlass::gemm::GemmCoord problem_size_1, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementA, -+ typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementB, -+ typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementB, -+ typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementScaleBias, -+ typename B2bGemm::LayoutScaleBias> tensor_Scale0; -+ -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.resize({1, problem_size_0.n()}); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementScaleBias, -+ typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()}); -+ -+ cutlass::HostTensor< -+ ElementAccumulator, -+ typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementB, -+ typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementB, -+ typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()}); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn()); -+ -+ -+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); -+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); -+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013)); -+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); -+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); -+ -+ //Reorder B0 -+ cutlass::reorder_column<16>( -+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0); -+ cutlass::reorder_column( -+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D1.host_view()); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_B0_reordered.sync_device(); -+ tensor_C0.sync_device(); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_B1_reordered.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1.sync_device(); -+ reference_D0.sync_device(); -+ reference_D1.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename B2bGemm::Arguments arguments{ -+ problem_size_0, -+ problem_size_1, -+ tensor_A0.device_ref(), -+ tensor_B0_reordered.device_ref(), -+ tensor_C0.device_ref(), -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref(), -+ tensor_B1_reordered.device_ref(), -+ {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, -+ tensor_D1.device_ref(), -+ {alpha0, beta0}, -+ {alpha1, beta1}, -+ }; -+ -+ B2bGemm b2b_gemm_op; -+ -+ cutlass::Status status = b2b_gemm_op.can_implement(arguments); -+ -+ if(status != cutlass::Status::kSuccess) { -+ std::cout << "Problem sizes not supported.\n" -+ << "Requirments:\n" -+ << " problem_size_0.M = problem_size_1.M\n" -+ << " problem_size_0.N = problem_size_1.K\n" -+ << " ThreadblockShape0::kN = problem_size_0.N\n" -+ << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; -+ } -+ -+ status = b2b_gemm_op.initialize(arguments); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = b2b_gemm_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run the GEMM -+ // -+ -+ cudaEvent_t start, stop; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ status = b2b_gemm_op(); -+ -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop); -+ cudaDeviceSynchronize(); -+ float gemmTime; -+ cudaEventElapsedTime(&gemmTime, start, stop); -+ std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; -+ -+ tensor_D1.sync_host(); -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::device::Gemm< -+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA, -+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB, -+ ElementAccumulator, typename B2bGemm::LayoutC, -+ ElementAccumulator, ElementAccumulator> -+ reference_gemm_0; -+ -+ cutlass::reference::device::Gemm< -+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA, -+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB, -+ typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename B2bGemm::Operator> -+ reference_gemm_1; -+ -+ reference_gemm_0( -+ problem_size_0, -+ ElementAccumulator(1), //intermediate alpha=1 -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ ElementAccumulator(0), //beta = 0 -+ reference_Z0.device_ref(), -+ reference_Z0.device_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ cutlass::reference::device::TensorScaleBiasGemm< -+ ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, -+ ElementCompute, typename B2bGemm::LayoutScaleBias -+ > ( -+ problem_size_0, -+ reference_Z0.device_ref(), -+ reference_D0.device_ref(), -+ alpha0, -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D0.device_view()); -+ } -+ -+ reference_gemm_1( -+ problem_size_1, -+ alpha1, -+ reference_D0.device_ref(), -+ tensor_B1.device_ref(), -+ beta1, -+ {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, -+ reference_D1.device_ref() -+ ); -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D1.device_view()); -+ } -+ cudaDeviceSynchronize(); -+ reference_D0.sync_host(); -+ reference_D1.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ -+ CHECK_TRUE(passed); -+ if (!passed) -+ { -+ -+ std::stringstream fname; -+ -+ fname << "error_B2bGemm_device_interleaved_fused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A0 =\n" << tensor_A0.host_view() -+ << "\nB0 =\n" << tensor_B0.host_view() -+ << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() -+ << "\nC0 =\n" << tensor_C0.host_view() -+ << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nB1 =\n" << tensor_B1.host_view() -+ << "\nB1_reordered =\n" << tensor_B1_reordered.host_view() -+ << "\nC1 =\n" << tensor_C1.host_view() -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\n\nReference =\n" << reference_D1.host_view() -+ << "\nComputed =\n" << tensor_D1.host_view(); -+ } -+ return passed; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h -new file mode 100644 -index 0000000..f365b23 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h -@@ -0,0 +1,451 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+ -+#include "kernel/b2b_gemm.h" -+#include "kernel/default_b2b_gemm.h" -+#include "kernel/default_b2b_gemm_smem_accumulator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Stage accumulator in shared memory -+ bool SmemAccumulator = false, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator> -+class B2bGemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape0 = ThreadblockShape0_; -+ using ThreadblockShape1 = ThreadblockShape1_; -+ using WarpShape0 = WarpShape0_; -+ using WarpShape1 = WarpShape1_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp0 = EpilogueOutputOp0_; -+ using EpilogueOutputOp1 = EpilogueOutputOp1_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp1::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Derived types -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; -+ -+ /// Define the kernel -+ using B2bGemmKernel = typename kernel::DefaultB2bGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator, -+ SmemAccumulator -+ >::B2bGemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size_0; -+ GemmCoord problem_size_1; -+ TensorRef ref_A0; -+ TensorRef ref_B0; -+ TensorRef ref_C0; -+ TensorRef ref_Scale0; -+ TensorRef ref_Bias0; -+ TensorRef ref_B1; -+ TensorRef ref_C1; -+ TensorRef ref_D1; -+ typename EpilogueOutputOp0::Params epilogue0; -+ typename EpilogueOutputOp1::Params epilogue1; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) { -+ -+ } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_0_, -+ GemmCoord problem_size_1_, -+ TensorRef ref_A0_, -+ TensorRef ref_B0_, -+ TensorRef ref_C0_, -+ TensorRef ref_Scale0_, -+ TensorRef ref_Bias0_, -+ TensorRef ref_B1_, -+ TensorRef ref_C1_, -+ TensorRef ref_D1_, -+ typename EpilogueOutputOp0::Params epilogue0_ = -+ typename EpilogueOutputOp0::Params(), -+ typename EpilogueOutputOp1::Params epilogue1_ = -+ typename EpilogueOutputOp1::Params(), -+ int split_k_slices_ = 1 -+ ): -+ problem_size_0(problem_size_0_), -+ problem_size_1(problem_size_1_), -+ ref_A0(ref_A0_), -+ ref_B0(ref_B0_), -+ ref_C0(ref_C0_), -+ ref_Scale0(ref_Scale0_), -+ ref_Bias0(ref_Bias0_), -+ ref_B1(ref_B1_), -+ ref_C1(ref_C1_), -+ ref_D1(ref_D1_), -+ epilogue0(epilogue0_), -+ epilogue1(epilogue1_), -+ split_k_slices(split_k_slices_) { -+ -+ } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename B2bGemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ B2bGemm() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = B2bGemmKernel::can_implement( -+ args.problem_size_0, -+ args.problem_size_1, -+ args.ref_A0.non_const_ref(), -+ args.ref_B0.non_const_ref(), -+ args.ref_C0.non_const_ref(), -+ args.ref_B1.non_const_ref(), -+ args.ref_C1.non_const_ref(), -+ args.ref_D1 -+ ); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size_0, -+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size_0, -+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, -+ args.split_k_slices); -+// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape( -+// args.problem_size_1, -+// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK}, -+// args.split_k_slices); -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Initialize the Params structure -+ params_ = typename B2bGemmKernel::Params{ -+ args.problem_size_0, -+ args.problem_size_1, -+ grid_shape, -+ args.ref_A0.non_const_ref(), -+ args.ref_B0.non_const_ref(), -+ args.ref_C0.non_const_ref(), -+ args.ref_Scale0.non_const_ref(), -+ args.ref_Bias0.non_const_ref(), -+ args.ref_B1.non_const_ref(), -+ args.ref_C1.non_const_ref(), -+ args.ref_D1, -+ args.epilogue0, -+ args.epilogue1, -+ static_cast(workspace), -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); -+ params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); -+ params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); -+ params_.ref_Scale0.reset(args.ref_Scale0.non_const_ref().data()); -+ params_.ref_Bias0.reset(args.ref_Bias0.non_const_ref().data()); -+ params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); -+ params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); -+ params_.ref_D1.reset(args.ref_D1.data()); -+ params_.output_op_0 = args.epilogue0; -+ params_.output_op_1 = args.epilogue1; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(B2bGemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h -new file mode 100644 -index 0000000..b52d058 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h -@@ -0,0 +1,300 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Template for device-level Implicit GEMM -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/device_kernel.h" -+#include "cutlass/conv/convolution.h" -+ -+#include "kernel/b2b_implicit_gemm_convolution.h" -+#include "kernel/default_b2b_conv2d_fprop.h" -+#include "kernel/default_b2b_conv2d_fprop_sm75.h" -+#include "kernel/default_b2b_conv2d_fprop_sm80.h" -+#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h" -+#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h" -+ -+namespace cutlass { -+namespace conv { -+namespace device { -+ -+template -+class B2bImplicitGemmConvolution { -+public: -+ -+ using B2bImplicitGemmKernel = B2bImplicitGemmKernel_; -+ -+ using ElementA = typename B2bImplicitGemmKernel::ElementA; -+ using LayoutA = typename B2bImplicitGemmKernel::LayoutA; -+ using ElementB = typename B2bImplicitGemmKernel::ElementB; -+ using LayoutB = typename B2bImplicitGemmKernel::LayoutB; -+ using ElementC = typename B2bImplicitGemmKernel::ElementC; -+ using LayoutC = typename B2bImplicitGemmKernel::LayoutC; -+ using ElementAccumulator = typename B2bImplicitGemmKernel::ElementAccumulator; -+ using ElementCompute = typename B2bImplicitGemmKernel::ElementCompute; -+ using ElementScaleBias = typename B2bImplicitGemmKernel::ElementScaleBias; -+ using LayoutScaleBias = typename B2bImplicitGemmKernel::LayoutScaleBias; -+ using OperatorClass = typename B2bImplicitGemmKernel::OperatorClass; -+ using ArchTag = typename B2bImplicitGemmKernel::ArchTag; -+ using ThreadblockShape0 = typename B2bImplicitGemmKernel::ThreadblockShape0; -+ using ThreadblockShape1 = typename B2bImplicitGemmKernel::ThreadblockShape1; -+ using WarpShape0 = typename B2bImplicitGemmKernel::WarpShape0; -+ using WarpShape1 = typename B2bImplicitGemmKernel::WarpShape1; -+ using InstructionShape = typename B2bImplicitGemmKernel::InstructionShape; -+ using ThreadblockSwizzle = typename B2bImplicitGemmKernel::ThreadblockSwizzle; -+ using EpilogueOutputOp0 = typename B2bImplicitGemmKernel::EpilogueOutputOp0; -+ using EpilogueOutputOp1 = typename B2bImplicitGemmKernel::EpilogueOutputOp1; -+ static int const kStages = B2bImplicitGemmKernel::kStages; -+ static int const kConvDim = B2bImplicitGemmKernel::kConvDim; -+ using WarpMmaOperator0 = typename B2bImplicitGemmKernel::WarpMmaOperator0; -+ using WarpMmaOperator1 = typename B2bImplicitGemmKernel::WarpMmaOperator1; -+ using ArchMmaOperator = typename B2bImplicitGemmKernel::ArchMmaOperator; -+ using MathOperator = typename B2bImplicitGemmKernel::MathOperator; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = B2bImplicitGemmKernel::kConvolutionalOperator; -+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = B2bImplicitGemmKernel::kIteratorAlgorithm; -+ -+ static int const kWarpCount = -+ (ThreadblockShape0::kM / WarpShape0::kM) * -+ (ThreadblockShape0::kN / WarpShape0::kN); -+ -+ /// Argument structure -+ using Arguments = typename B2bImplicitGemmKernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename B2bImplicitGemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs Implicit GEMM -+ B2bImplicitGemmConvolution() { } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ // dispatch to iterators -+ Status status = B2bImplicitGemmKernel::B2bMma::IteratorA0::can_implement(args.problem_size_0); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ status = B2bImplicitGemmKernel::B2bMma::IteratorB0::can_implement(args.problem_size_0); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ status = B2bImplicitGemmKernel::B2bMma::IteratorB1::can_implement(args.problem_size_1); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape( -+ threadblock_swizzle.get_tiled_shape( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0), -+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, -+ args.problem_size_0.split_k_slices)); -+ -+ if (!(grid.y <= std::numeric_limits::max() && -+ grid.z <= std::numeric_limits::max())) { -+ -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // Determine if fusion sizes are valid -+ -+ cutlass::gemm::GemmCoord problem_size_0 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0); -+ cutlass::gemm::GemmCoord problem_size_1 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1); -+ -+ if(problem_size_0.m() != problem_size_1.m()) -+ return Status::kErrorInvalidProblem; -+ -+ if(problem_size_0.n() != problem_size_1.k()) -+ return Status::kErrorInvalidProblem; -+ -+ if(args.problem_size_1.R != 1 || args.problem_size_1.S != 1) -+ return Status::kErrorInvalidProblem; -+ -+ if(problem_size_0.n() > ThreadblockShape0::kN) -+ return Status::kErrorInvalidProblem; -+ -+ if(problem_size_1.n() > ThreadblockShape1::kN) -+ return Status::kErrorInvalidProblem; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t workspace_bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0), -+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, -+ args.problem_size_0.split_k_slices); -+ -+ if(args.split_k_mode == SplitKMode::kParallel) { -+ -+ // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. -+ // The user needs to call a reduction operator to optain the final output tensor -+ workspace_bytes = -+ sizeof(ElementAccumulator) * -+ size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size_0)) * -+ size_t(grid_tiled_shape.k()); -+ } -+ -+ else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size_0.split_k_slices > 1) { -+ -+ // Split-K serial: The user workspace is used to store semaphore and serialize writing the -+ // final reduced output to user's output tensor -+ workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); -+ } -+ -+ return workspace_bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ if (args.problem_size_0.split_k_slices > 1) { -+ -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); -+ -+ if (status != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // initialize the params structure from the arguments -+ params_ = typename B2bImplicitGemmKernel::Params( -+ args, -+ static_cast(workspace) -+ ); -+ -+ int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ // update the params structure from the arguments -+ params_.ptr_A0 = args.ref_A0.data(); -+ params_.ptr_B0 = args.ref_B0.data(); -+ params_.ptr_C0 = args.ref_C0.data(); -+ params_.ptr_Scale0 = args.ref_Scale0.data(); -+ params_.ptr_Bias0 = args.ref_Bias0.data(); -+ params_.ptr_B1 = args.ref_B1.data(); -+ params_.ptr_C1 = args.ref_C1.data(); -+ params_.ptr_D1 = args.ref_D1.data(); -+ params_.output_op_0 = args.output_op_0; -+ params_.output_op_1 = args.output_op_1; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(32 * kWarpCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+} // namespace device -+} // namespace conv -+} // namespace cutlass -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu -new file mode 100644 -index 0000000..6f12608 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu -@@ -0,0 +1,234 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {128, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 128} // output size (NPQK) -+ ); -+ -+bool run_nonfused_conv2d_fprop_optimized_f16_sm75() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //use beta for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ const bool SmemAccumulator = false; -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ SmemAccumulator -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with RF Residency...\n"; -+ bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_f16_sm75, -+ &run_fused_conv2d_fprop_optimized_f16_sm75_rf_res -+ }; -+ -+ return testRun(75, funcs, "conv f16 RF residency"); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu -new file mode 100644 -index 0000000..86eb1f7 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu -@@ -0,0 +1,234 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {256, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 256} // output size (NPQK) -+ ); -+ -+bool run_nonfused_conv2d_fprop_optimized_f16_sm75() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_conv2d_fprop_optimized_f16_sm75_shmem() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ SmemAccumulator -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n"; -+ bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_f16_sm75, -+ &run_fused_conv2d_fprop_optimized_f16_sm75_shmem -+ }; -+ -+ return testRun(75, funcs, "conv f16 shmem staging"); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu -new file mode 100644 -index 0000000..14bef44 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu -@@ -0,0 +1,233 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {128, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 128} // output size (NPQK) -+ ); -+ -+ -+bool run_nonfused_conv2d_fprop_optimized_f16_sm80() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_conv2d_fprop_optimized_f16_sm80_rf_res() { -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with RF Residency...\n"; -+ bool pass = fusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+ return true; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_f16_sm80, -+ &run_fused_conv2d_fprop_optimized_f16_sm80_rf_res -+ }; -+ -+ return testRun(80, funcs, "conv f16 RF residency"); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu -new file mode 100644 -index 0000000..c4df985 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu -@@ -0,0 +1,236 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {256, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 256} // output size (NPQK) -+ ); -+ -+ -+bool run_nonfused_conv2d_fprop_optimized_f16_sm80() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_conv2d_fprop_optimized_f16_sm80_shmem() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ const bool SmemAccumulator = true; -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ SmemAccumulator -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n"; -+ bool pass = fusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_f16_sm80, -+ &run_fused_conv2d_fprop_optimized_f16_sm80_shmem -+ }; -+ -+ return testRun(80, funcs, "conv f16 shmem staging"); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu -new file mode 100644 -index 0000000..64955f8 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu -@@ -0,0 +1,238 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_interleaved_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {128, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 128} // output size (NPQK) -+ ); -+ -+bool run_nonfused_conv2d_fprop_optimized_s8_sm75() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bInterleavedNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ const bool SmemAccumulator = false; -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ SmemAccumulator -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bInterleavedFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with RF residency...\n"; -+ bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_s8_sm75, -+ &run_fused_conv2d_fprop_optimized_s8_sm75_rf_res -+ }; -+ -+ return testRun(75, funcs, "conv int8 RF residency"); -+ -+} -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu -new file mode 100644 -index 0000000..7f82518 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu -@@ -0,0 +1,238 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_interleaved_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {256, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 256} // output size (NPQK) -+ ); -+ -+bool run_nonfused_conv2d_fprop_optimized_s8_sm75() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bInterleavedNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ SmemAccumulator -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bInterleavedFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with shared memory staging...\n"; -+ bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_s8_sm75, -+ &run_fused_conv2d_fprop_optimized_s8_sm75_shmem -+ }; -+ -+ return testRun(75, funcs, "conv int8 shmem staging"); -+ -+} -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu -new file mode 100644 -index 0000000..c4e0e4c ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu -@@ -0,0 +1,236 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_interleaved_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {128, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 128} // output size (NPQK) -+ ); -+ -+bool run_nonfused_conv2d_fprop_optimized_s8_sm80() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bInterleavedNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+bool run_fused_conv2d_fprop_optimized_s8_sm80_rf_res() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 8 * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bInterleavedFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with RF residency...\n"; -+ bool pass = fusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_s8_sm80, -+ &run_fused_conv2d_fprop_optimized_s8_sm80_rf_res -+ }; -+ -+ return testRun(80, funcs, "conv int8 RF residency"); -+ -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu -new file mode 100644 -index 0000000..de15106 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu -@@ -0,0 +1,237 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_interleaved_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {256, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 256} // output size (NPQK) -+ ); -+ -+bool run_nonfused_conv2d_fprop_optimized_s8_sm80() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bInterleavedNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_conv2d_fprop_optimized_s8_sm80_shmem() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 8 * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ SmemAccumulator -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bInterleavedFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with shared memory staging...\n"; -+ bool pass = fusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_s8_sm80, -+ &run_fused_conv2d_fprop_optimized_s8_sm80_shmem -+ }; -+ -+ return testRun(80, funcs, "conv int8 shmem staging"); -+ -+ -+} -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu -new file mode 100644 -index 0000000..3a02096 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu -@@ -0,0 +1,210 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*640, 128, 64); -+ -+bool run_nonfused_gemm_f16() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ -+ B2bNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+bool run_fused_gemm_f16_rf_res() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ -+ B2bFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back FP16 TN GEMMs with RF Residency...\n"; -+ bool passed = fusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_f16, -+ &run_fused_gemm_f16_rf_res -+ }; -+ -+ return testRun(75, funcs, "gemm f16 RF residency"); -+ -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu -new file mode 100644 -index 0000000..3498b40 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu -@@ -0,0 +1,214 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*640, 256, 64); -+ -+bool run_nonfused_gemm_f16() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ -+ B2bNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_gemm_f16_shmem() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ SmemAccumulator -+ >; -+ -+ B2bFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back FP16 TN GEMMs with shared memory staging...\n"; -+ bool passed = fusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_f16, -+ &run_fused_gemm_f16_shmem -+ }; -+ -+ return testRun(75, funcs, "gemm f16 shmem staging"); -+ -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu -new file mode 100644 -index 0000000..feb22fa ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu -@@ -0,0 +1,213 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*640, 128, 64); -+ -+bool run_nonfused_gemm_f16_sm80() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3 -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3 -+ >; -+ -+ B2bNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_gemm_f16_sm80_rf_res() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3 -+ >; -+ -+ B2bFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back FP16 TN GEMMs with RF residency...\n"; -+ bool passed = fusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+ -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_f16_sm80, -+ &run_fused_gemm_f16_sm80_rf_res -+ }; -+ -+ return testRun(80, funcs, "gemm f16 RF residency"); -+ -+ -+} -+ -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu -new file mode 100644 -index 0000000..36c4819 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu -@@ -0,0 +1,217 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*640, 256, 64); -+ -+bool run_nonfused_gemm_f16_sm80() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3 -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3 -+ >; -+ -+ B2bNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_gemm_f16_sm80_shmem() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ SmemAccumulator -+ >; -+ -+ B2bFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back FP16 TN GEMMs with shared memory staging...\n"; -+ bool passed = fusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+ -+} -+ -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_f16_sm80, -+ &run_fused_gemm_f16_sm80_shmem -+ }; -+ -+ return testRun(80, funcs, "gemm f16 shmem staging"); -+ -+ -+} -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu -new file mode 100644 -index 0000000..565cca7 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu -@@ -0,0 +1,212 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_interleaved_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*640, 128, 64); -+ -+bool run_nonfused_gemm_s8() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ -+ B2bInterleavedNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+bool run_fused_gemm_s8_rf_res() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ -+ B2bInterleavedFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF Residency...\n"; -+ bool passed = fusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+ -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_s8, -+ &run_fused_gemm_s8_rf_res -+ }; -+ -+ return testRun(75, funcs, "gemm int8 RF residency"); -+ -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu -new file mode 100644 -index 0000000..8719d74 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu -@@ -0,0 +1,214 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_interleaved_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*640, 256, 64); -+ -+bool run_nonfused_gemm_s8() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ -+ B2bInterleavedNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_gemm_s8_shmem() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ SmemAccumulator -+ >; -+ -+ B2bInterleavedFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with shared memory staging...\n"; -+ bool passed = fusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+ -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_s8, -+ &run_fused_gemm_s8_shmem -+ }; -+ -+ return testRun(75, funcs, "gemm int8 shmem staing"); -+ -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu -new file mode 100644 -index 0000000..60f9adb ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_interleaved_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*640, 128, 64); -+ -+bool run_nonfused_gemm_s8_sm80() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ B2bInterleavedNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+bool run_fused_gemm_s8_sm80_rf_res() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 8 * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ const bool SmemAccumulator = false; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ SmemAccumulator, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ B2bInterleavedFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n"; -+ bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+} -+ -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_s8_sm80, -+ &run_fused_gemm_s8_sm80_rf_res -+ }; -+ -+ return testRun(80, funcs, "gemm int8 RF residency"); -+ -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu -new file mode 100644 -index 0000000..64788e0 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu -@@ -0,0 +1,226 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_interleaved_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*640, 256, 64); -+ -+bool run_nonfused_gemm_s8_sm80() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ B2bInterleavedNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_gemm_s8_sm80_shmem() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 8 * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ SmemAccumulator, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ B2bInterleavedFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with shared memory staging...\n"; -+ bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+} -+ -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_s8_sm80, -+ &run_fused_gemm_s8_sm80_shmem -+ }; -+ -+ return testRun(80, funcs, "gemm int8 shmem staging"); -+ -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h -new file mode 100644 -index 0000000..1ccf902 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h -@@ -0,0 +1,460 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/semaphore.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. -+> -+struct B2bGemm { -+ -+ using B2bMma = B2bMma_; -+ using Epilogue = Epilogue_; -+ using OutputOp0 = typename B2bMma::OutputOp; -+ using OutputOp1 = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount0 = typename B2bMma::WarpCount0; -+ static int const kThreadCount = 32 * WarpCount0::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size_0; -+ cutlass::gemm::GemmCoord problem_size_1; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename B2bMma::IteratorA0::Params params_A0; -+ typename B2bMma::IteratorA0::TensorRef ref_A0; -+ typename B2bMma::IteratorB0::Params params_B0; -+ typename B2bMma::IteratorB0::TensorRef ref_B0; -+ typename Epilogue::OutputTileIterator::Params params_C0; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C0; -+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0; -+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0; -+ typename B2bMma::IteratorB1::Params params_B1; -+ typename B2bMma::IteratorB1::TensorRef ref_B1; -+ typename Epilogue::OutputTileIterator::Params params_C1; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C1; -+ typename Epilogue::OutputTileIterator::Params params_D1; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D1; -+ typename OutputOp0::Params output_op_0; -+ typename OutputOp1::Params output_op_1; -+ int *semaphore; -+ int gemm_k_iterations_0; -+ int gemm_k_size_0; -+ int gemm_k_iterations_1; -+ int gemm_k_size_1; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0), -+ gemm_k_iterations_1(0), gemm_k_size_1(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size_0, -+ cutlass::gemm::GemmCoord const & problem_size_1, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ typename B2bMma::IteratorA0::TensorRef ref_A0, -+ typename B2bMma::IteratorB0::TensorRef ref_B0, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C0, -+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0, -+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0, -+ typename B2bMma::IteratorB1::TensorRef ref_B1, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C1, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D1, -+ typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), -+ typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), -+ int *workspace = nullptr -+ ): -+ problem_size_0(problem_size_0), -+ problem_size_1(problem_size_1), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A0(ref_A0.layout()), -+ ref_A0(ref_A0), -+ params_B0(ref_B0.layout()), -+ ref_B0(ref_B0), -+ params_C0(ref_C0.layout()), -+ ref_C0(ref_C0), -+ ref_Scale0(ref_Scale0), -+ ref_Bias0(ref_Bias0), -+ params_B1(ref_B1.layout()), -+ ref_B1(ref_B1), -+ params_C1(ref_C1.layout()), -+ ref_C1(ref_C1), -+ params_D1(ref_D1.layout()), -+ ref_D1(ref_D1), -+ output_op_0(output_op_0), -+ output_op_1(output_op_1) { -+ -+ int total_gemm_k_iterations_0 = (problem_size_0.k() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; -+ int gemm_k_iterations_0 = (total_gemm_k_iterations_0 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ gemm_k_size_0 = gemm_k_iterations_0 * B2bMma::Shape0::kK; -+ int total_gemm_k_iterations_1 = (problem_size_1.k() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; -+ int gemm_k_iterations_1 = (total_gemm_k_iterations_1 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ gemm_k_size_1 = gemm_k_iterations_1 * B2bMma::Shape1::kK; -+ -+ semaphore = workspace; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename B2bMma::B2bMmaSharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ B2bGemm() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size_0, -+ cutlass::gemm::GemmCoord const & problem_size_1, -+ typename B2bMma::IteratorA0::TensorRef ref_A0, -+ typename B2bMma::IteratorB0::TensorRef ref_B0, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C0, -+ typename B2bMma::IteratorB1::TensorRef ref_B1, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C1, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D1) { -+ -+ static int const kAlignmentA = B2bMma::IteratorA0::AccessType::kElements; -+ static int const kAlignmentB = B2bMma::IteratorB0::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if (!TensorRef_aligned(ref_A0, kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B0, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C0, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B1, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C1, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D1, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if ((problem_size_0.m() % kAlignmentA) || (problem_size_0.k() % kAlignmentA) || -+ (problem_size_0.n() % kAlignmentB) || (problem_size_0.k() % kAlignmentB) || -+ (problem_size_0.m() % kAlignmentC) || (problem_size_0.n() % kAlignmentC) || -+ (problem_size_1.m() % kAlignmentA) || (problem_size_1.k() % kAlignmentA) || -+ (problem_size_1.n() % kAlignmentB) || (problem_size_1.k() % kAlignmentB) || -+ (problem_size_1.m() % kAlignmentC) || (problem_size_1.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ // Determine if fusion sizes are valid -+ if(problem_size_0.m() != problem_size_1.m()) -+ return Status::kErrorInvalidProblem; -+ -+ if(problem_size_0.n() != problem_size_1.k()) -+ return Status::kErrorInvalidProblem; -+ -+ if(problem_size_0.n() > B2bMma::Shape0::kN) -+ return Status::kErrorInvalidProblem; -+ -+ if(problem_size_1.n() > B2bMma::Shape1::kN) -+ return Status::kErrorInvalidProblem; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A0{ -+ threadblock_tile_offset.m() * B2bMma::Shape0::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size_0, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B0{ -+ threadblock_tile_offset.k() * params.gemm_k_size_0, -+ threadblock_tile_offset.n() * B2bMma::Shape0::kN -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B1{ -+ threadblock_tile_offset.k() * params.gemm_k_size_1, -+ threadblock_tile_offset.n() * B2bMma::Shape1::kN -+ }; -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k_0 = min( -+ params.problem_size_0.k(), -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k_1 = min( -+ params.problem_size_1.k(), -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1); -+ -+ // Compute threadblock-scoped matrix multiply-add -+// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; -+ -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename B2bMma::IteratorA0 iterator_A0( -+ params.params_A0, -+ params.ref_A0.data(), -+ {params.problem_size_0.m(), problem_size_k_0}, -+ thread_idx, -+ tb_offset_A0); -+ -+ typename B2bMma::IteratorB0 iterator_B0( -+ params.params_B0, -+ params.ref_B0.data(), -+ {problem_size_k_0, params.problem_size_0.n()}, -+ thread_idx, -+ tb_offset_B0); -+ -+ typename B2bMma::IteratorB1 iterator_B1( -+ params.params_B1, -+ params.ref_B1.data(), -+ {problem_size_k_1, params.problem_size_1.n()}, -+ thread_idx, -+ tb_offset_B1); -+ -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); -+ int lane_idx = threadIdx.x % 32; -+ -+ // Construct iterators to accumulator scale/bias vector -+ typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( -+ params.ref_Scale0.data(), -+ {1, params.problem_size_0.n()}, -+ thread_idx, -+ warp_idx, -+ MatrixCoord( -+ 0, threadblock_tile_offset.n() * B2bMma::Shape0::kN -+ ) -+ ); -+ -+ typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( -+ params.ref_Bias0.data(), -+ {1, params.problem_size_0.n()}, -+ thread_idx, -+ warp_idx, -+ MatrixCoord( -+ 0, threadblock_tile_offset.n() * B2bMma::Shape0::kN -+ ) -+ ); -+ -+ -+ -+ // -+ // Main loop -+ // -+ -+ OutputOp0 output_op_0(params.output_op_0); -+ -+ // Construct thread-scoped matrix multiply -+ B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n()); -+ -+ typename B2bMma::FragmentC0 src_accum; -+ typename B2bMma::FragmentC1 accumulators; -+ -+ src_accum.clear(); -+ accumulators.clear(); -+ -+ if (!kSplitKSerial || gemm_k_iterations_0 > 0) { -+ // Compute threadblock-scoped matrix multiply-add -+ b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, -+ iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); -+ } -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp1 output_op_1(params.output_op_1); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * B2bMma::Shape1::kM, -+ threadblock_tile_offset.n() * B2bMma::Shape1::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C1( -+ params.params_C1, -+ params.ref_C1.data(), -+ params.problem_size_1.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D1( -+ params.params_D1, -+ params.ref_D1.data(), -+ params.problem_size_1.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C1 = iterator_D1; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ __threadfence(); -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h -new file mode 100644 -index 0000000..6c54087 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h -@@ -0,0 +1,521 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined Implicit GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem -+> -+struct B2bImplicitGemmConvolution { -+ -+ using B2bMma = B2bMma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp0 = typename B2bMma::OutputOp; -+ using EpilogueOutputOp1 = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ -+ using ElementA = typename B2bMma::IteratorA0::Element; -+ using LayoutA = typename B2bMma::IteratorA0::Layout; -+ using ElementB = typename B2bMma::IteratorB0::Element; -+ using LayoutB = typename B2bMma::IteratorB0::Layout; -+ using ElementC = typename EpilogueOutputOp1::ElementOutput; -+ -+ /// Set output tensor C layout -+ using LayoutC = LayoutA; -+ -+ using ElementAccumulator = typename EpilogueOutputOp0::ElementAccumulator; -+ using ElementCompute = typename EpilogueOutputOp0::ElementCompute; -+ -+ /// Scale and Bias -+ using ElementScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Element; -+ using LayoutScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Layout; -+ -+ using WarpMmaOperator0 = typename B2bMma::Policy0::Operator; -+ using WarpMmaOperator1 = typename B2bMma::Policy1::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator0::ArchMmaOperator; -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ using OperatorClass = typename WarpMmaOperator0::OperatorClass; -+ using ArchTag = typename WarpMmaOperator0::ArchTag; -+ -+ using ThreadblockShape0 = typename B2bMma::Shape0; -+ using ThreadblockShape1 = typename B2bMma::Shape1; -+ using WarpShape0 = typename WarpMmaOperator0::Shape; -+ using WarpShape1 = typename WarpMmaOperator1::Shape; -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ static int const kStages = B2bMma::kStages; -+ static IteratorAlgorithm const kIteratorAlgorithm = B2bMma::IteratorA0::kIteratorAlgorithm; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount0 = typename B2bMma::WarpCount0; -+ static int const kThreadCount = 32 * WarpCount0::kCount; -+ -+ using TensorRefA0 = typename B2bMma::IteratorA0::TensorRef; -+ using TensorRefB0 = typename B2bMma::IteratorB0::TensorRef; -+ using TensorRefScaleBias0 = typename B2bMma::IteratorAccumulatorScaleBias::TensorRef; -+ using TensorRefB1 = typename B2bMma::IteratorB1::TensorRef; -+ using TensorRefC = cutlass::TensorRef; -+ -+ /// Check iterator A and B convolution dimension are the same and -+ // set device::B2bImplicitGemmConvolution::kConvDim -+ static_assert(B2bMma::IteratorA0::kConvDim == B2bMma::IteratorB0::kConvDim, -+ "Convolution on different dimensions is not supported"); -+ static int const kConvDim = B2bMma::IteratorA0::kConvDim; -+ -+ /// Conv dimension and problem size structure (Conv2d or Conv3d) -+ using ConvProblemSize = ConvProblemSize_; -+ -+ /// Wgrad C stride idx for implicit gemm algorithm -+ // Conv2d row-major matrix C (KxRSC) -+ // Conv3d row-major matrix C (KxTRSC) -+ static int const kWgradCStrideIdx = -+ cutlass::platform::is_same::value ? 2 : 3; -+ -+ /// This chooses the appropriate stride element of the C tensor. -+ static int const kTensorCStrideIdx = -+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); -+ -+ // -+ // -+ // -+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< -+ LayoutC, -+ typename Epilogue::OutputTileIterator::Layout, -+ TensorRefC, -+ ConvOperator, -+ ConvProblemSize -+ >; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ ConvProblemSize problem_size_0; -+ ConvProblemSize problem_size_1; -+ TensorRefA0 ref_A0; -+ TensorRefB0 ref_B0; -+ TensorRefC ref_C0; -+ TensorRefScaleBias0 ref_Scale0; -+ TensorRefScaleBias0 ref_Bias0; -+ TensorRefB1 ref_B1; -+ TensorRefC ref_C1; -+ TensorRefC ref_D1; -+ typename EpilogueOutputOp0::Params output_op_0; -+ typename EpilogueOutputOp1::Params output_op_1; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size_0, -+ ConvProblemSize const & problem_size_1 -+ ): -+ problem_size_0(problem_size_0), -+ problem_size_1(problem_size_1) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size_0, -+ ConvProblemSize const & problem_size_1, -+ TensorRefA0 const & ref_A0, -+ TensorRefB0 const & ref_B0, -+ TensorRefC const & ref_C0, -+ TensorRefScaleBias0 const & ref_Scale0, -+ TensorRefScaleBias0 const & ref_Bias0, -+ TensorRefB1 const & ref_B1, -+ TensorRefC const & ref_C1, -+ TensorRefC const & ref_D1, -+ typename EpilogueOutputOp0::Params const & output_op_0, -+ typename EpilogueOutputOp1::Params const & output_op_1, -+ SplitKMode const & split_k_mode = SplitKMode::kSerial -+ ): -+ problem_size_0(problem_size_0), -+ problem_size_1(problem_size_1), -+ ref_A0(ref_A0), -+ ref_B0(ref_B0), -+ ref_C0(ref_C0), -+ ref_Scale0(ref_Scale0), -+ ref_Bias0(ref_Bias0), -+ ref_B1(ref_B1), -+ ref_C1(ref_C1), -+ ref_D1(ref_D1), -+ output_op_0(output_op_0), -+ output_op_1(output_op_1), -+ split_k_mode(split_k_mode) -+ { -+ -+ } -+ -+ }; -+ -+ /// Parameters structure -+ struct Params { -+ ConvProblemSize problem_size_0; -+ ConvProblemSize problem_size_1; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ gemm::GemmCoord implicit_gemm_problem_size_0; -+ gemm::GemmCoord implicit_gemm_problem_size_1; -+ int swizzle_log_tile; -+ int gemm_k_iterations_0; -+ int gemm_k_iterations_1; -+ typename B2bMma::IteratorA0::Params iterator_A0; -+ typename B2bMma::IteratorA0::Element const *ptr_A0; -+ typename B2bMma::IteratorB0::Params iterator_B0; -+ typename B2bMma::IteratorB0::Element const *ptr_B0; -+ typename Epilogue::OutputTileIterator::Params iterator_C0; -+ typename Epilogue::OutputTileIterator::Element *ptr_C0; -+ typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Scale0; -+ typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Bias0; -+ typename B2bMma::IteratorB1::Params iterator_B1; -+ typename B2bMma::IteratorB1::Element const *ptr_B1; -+ typename Epilogue::OutputTileIterator::Params iterator_C1; -+ typename Epilogue::OutputTileIterator::Element *ptr_C1; -+ typename Epilogue::OutputTileIterator::Params iterator_D1; -+ typename Epilogue::OutputTileIterator::Element *ptr_D1; -+ typename EpilogueOutputOp0::Params output_op_0; -+ typename EpilogueOutputOp1::Params output_op_1; -+ int *semaphore; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), gemm_k_iterations_0(0), gemm_k_iterations_1(0) { } -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ int *semaphore = nullptr -+ ): -+ problem_size_0(args.problem_size_0), -+ problem_size_1(args.problem_size_1), -+ implicit_gemm_problem_size_0(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0)), -+ implicit_gemm_problem_size_1(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1)), -+ iterator_A0(B2bMma::IteratorA0::getParams(args.problem_size_0, args.ref_A0.layout())), -+ ptr_A0(args.ref_A0.data()), -+ iterator_B0(args.problem_size_0, args.ref_B0.layout()), -+ ptr_B0(args.ref_B0.data()), -+ iterator_C0(ConvOutputIteratorParameter::layout(args.ref_C0)), -+ ptr_C0(args.ref_C0.data()), -+ ptr_Scale0(args.ref_Scale0.data()), -+ ptr_Bias0(args.ref_Bias0.data()), -+ iterator_B1(args.problem_size_1, args.ref_B1.layout()), -+ ptr_B1(args.ref_B1.data()), -+ iterator_C1(ConvOutputIteratorParameter::layout(args.ref_C1)), -+ ptr_C1(args.ref_C1.data()), -+ iterator_D1(ConvOutputIteratorParameter::layout(args.ref_D1)), -+ ptr_D1(args.ref_D1.data()), -+ output_op_0(args.output_op_0), -+ output_op_1(args.output_op_1), -+ semaphore(semaphore), -+ split_k_mode(args.split_k_mode) -+ { -+ gemm_k_iterations_0 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape0::kK, args.problem_size_0); -+ gemm_k_iterations_1 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape1::kK, args.problem_size_1); -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ implicit_gemm_problem_size_0, -+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, -+ args.problem_size_0.split_k_slices); -+ -+ swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename B2bMma::B2bMmaSharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ B2bImplicitGemmConvolution() { } -+ -+ /// Executes one ImplicitGEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { -+ -+ return; -+ } -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename B2bMma::IteratorA0 iterator_A0( -+ params.iterator_A0, -+ params.problem_size_0, -+ params.ptr_A0, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.m() * B2bMma::Shape0::kM, -+ threadblock_tile_idx.k() * B2bMma::Shape0::kK -+ ) -+ ); -+ -+ typename B2bMma::IteratorB0 iterator_B0( -+ params.iterator_B0, -+ params.problem_size_0, -+ params.ptr_B0, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.k() * B2bMma::Shape0::kK, -+ threadblock_tile_idx.n() * B2bMma::Shape0::kN -+ ) -+ ); -+ -+ typename B2bMma::IteratorB1 iterator_B1( -+ params.iterator_B1, -+ params.problem_size_1, -+ params.ptr_B1, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.k() * B2bMma::Shape1::kK, -+ threadblock_tile_idx.n() * B2bMma::Shape1::kN -+ ) -+ ); -+ -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ int lane_idx = threadIdx.x % 32; -+ -+ // Construct iterators to accumulator scale/bias vector -+ typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( -+ params.ptr_Scale0, -+ {1, params.problem_size_0.K}, -+ thread_idx, -+ warp_idx, -+ MatrixCoord( -+ 0, threadblock_tile_idx.n() * B2bMma::Shape0::kN -+ ) -+ ); -+ -+ typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( -+ params.ptr_Bias0, -+ {1, params.problem_size_0.K}, -+ thread_idx, -+ warp_idx, -+ MatrixCoord( -+ 0, threadblock_tile_idx.n() * B2bMma::Shape0::kN -+ ) -+ ); -+ -+ -+ // -+ // Main loop -+ // -+ -+ EpilogueOutputOp0 output_op_0(params.output_op_0); -+ -+ // Construct thread-scoped matrix multiply -+ B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename B2bMma::FragmentC0 src_accum; -+ typename B2bMma::FragmentC1 accumulators; -+ -+ src_accum.clear(); -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, -+ iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp1 output_op_1(params.output_op_1); -+ -+ // Construct the semaphore. -+ int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); -+ -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // Compute logical position within grid -+ threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op_1.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); -+ } -+ -+ MatrixCoord threadblock_offset( -+ threadblock_tile_idx.m() * B2bMma::Shape1::kM, -+ threadblock_tile_idx.n() * B2bMma::Shape1::kN -+ ); -+ -+ // Tile iterator writing to destination tensor -+ typename Epilogue::OutputTileIterator iterator_D1( -+ params.iterator_D1, -+ params.ptr_D1, -+ ConvOutputIteratorParameter::extent(params.problem_size_1), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator reading from source accumulator tensor -+ typename Epilogue::OutputTileIterator iterator_C1( -+ params.iterator_C1, -+ params.ptr_C1, -+ ConvOutputIteratorParameter::extent(params.problem_size_1), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_idx.k()) { -+ iterator_C1 = iterator_D1; -+ } -+ -+ semaphore.wait(threadblock_tile_idx.k()); -+ -+ __threadfence(); -+ } -+ // Each split-k-slice writes to a unique tensor location -+ else if (params.split_k_mode == SplitKMode::kParallel) { -+ iterator_D1.add_pointer_offset(threadblock_tile_idx.k() * -+ cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size_1)); -+ } -+ -+ // Run efficient epilogue -+ epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_idx.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h -new file mode 100644 -index 0000000..82e808d ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h -@@ -0,0 +1,94 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+ -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/warp/vector_fragment_iterator.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "kernel/b2b_implicit_gemm_convolution.h" -+#include "threadblock/b2b_implicit_gemm_pipelined.h" -+#include "threadblock/b2b_implicit_gemm_multistage.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dFprop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, -+ bool SmemAccumulator = false -+> struct DefaultB2bConv2dFprop; -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h -new file mode 100644 -index 0000000..d5792a8 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h -@@ -0,0 +1,749 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+ -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/warp/vector_fragment_iterator.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "kernel/default_b2b_conv2d_fprop.h" -+#include "kernel/b2b_implicit_gemm_convolution.h" -+#include "threadblock/b2b_implicit_gemm_pipelined.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm -+/// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::ColumnMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelined< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1 -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage -+/// pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ false -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelined< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm -+/// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::ColumnMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelined< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1 -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage -+/// pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelined< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h -new file mode 100644 -index 0000000..7261e7e ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h -@@ -0,0 +1,740 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+ -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/warp/vector_fragment_iterator.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "kernel/default_b2b_conv2d_fprop.h" -+#include "kernel/b2b_implicit_gemm_convolution.h" -+#include "threadblock/b2b_implicit_gemm_multistage.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::ColumnMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistage< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistage< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and -+/// multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::ColumnMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistage< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and -+// multistage pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistage< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h -new file mode 100644 -index 0000000..09d094f ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h -@@ -0,0 +1,817 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+ -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/warp/vector_fragment_iterator.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "kernel/default_b2b_conv2d_fprop.h" -+#include "kernel/b2b_implicit_gemm_convolution.h" -+#include "threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm -+/// and 2 stage pipeline. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1 -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage -+/// pipeline with interleaved layout. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; //For interleaved layout -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm -+/// and 2 stage pipeline. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1 -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage -+/// pipeline with interleaved layout. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; //For interleaved layout -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h -new file mode 100644 -index 0000000..7a5b380 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h -@@ -0,0 +1,804 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+ -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/warp/vector_fragment_iterator.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "kernel/default_b2b_conv2d_fprop.h" -+#include "kernel/b2b_implicit_gemm_convolution.h" -+#include "threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline with interleaved layout. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and -+/// multistage pipeline. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and -+// multistage pipeline with interleaved layout. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h -new file mode 100644 -index 0000000..05c3f4e ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h -@@ -0,0 +1,442 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#include "kernel/b2b_gemm.h" -+#include "threadblock/default_b2b_mma.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Stage accumulator in shared memory -+ bool SmemAccumulator = false -+> -+struct DefaultB2bGemm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultB2bGemm { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Turing Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator -+> -+struct DefaultB2bGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator -+> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ 2, -+ Operator, -+ EpilogueOutputOp0 -+ >::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, -+ typename B2bMma::Operator1, -+ kPartitionsK1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+ -+/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultB2bGemm< -+ ElementA, layout::ColumnMajorInterleaved, kAlignmentA, -+ ElementB, layout::RowMajorInterleaved, kAlignmentB, -+ ElementC, layout::ColumnMajorInterleaved, int32_t, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, -+ ThreadblockSwizzle, Stages, -+ SplitKSerial, Operator> { -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, Stages, Operator, EpilogueOutputOp0, -+ true>::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Partial specialization for Turing Integer Tensor Core Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultB2bGemm, -+ kAlignmentA, ElementB, -+ layout::RowMajorInterleaved, kAlignmentB, -+ ElementC, layout::ColumnMajorInterleaved, -+ int32_t, arch::OpClassTensorOp, arch::Sm75, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, -+ ThreadblockSwizzle, 2, SplitKSerial, Operator> { -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, -+ arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, -+ WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue for the 2nd Gemm -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h -new file mode 100644 -index 0000000..23717c6 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h -@@ -0,0 +1,397 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+ -+#include "kernel/b2b_gemm.h" -+#include "threadblock/default_b2b_mma.h" -+#include "threadblock/default_b2b_mma_smem_accumulator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultB2bGemm { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Turing Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator -+> -+struct DefaultB2bGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ true -+> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ 2, -+ Operator, -+ EpilogueOutputOp0, -+ false, -+ true -+ >::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, -+ typename B2bMma::Operator1, -+ kPartitionsK1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+ -+/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultB2bGemm< -+ ElementA, layout::ColumnMajorInterleaved, kAlignmentA, -+ ElementB, layout::RowMajorInterleaved, kAlignmentB, -+ ElementC, layout::ColumnMajorInterleaved, int32_t, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, -+ ThreadblockSwizzle, Stages, -+ SplitKSerial, Operator, true> { -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, Stages, Operator, EpilogueOutputOp0, -+ true, true>::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Partial specialization for Turing Integer Tensor Core Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultB2bGemm, -+ kAlignmentA, ElementB, -+ layout::RowMajorInterleaved, kAlignmentB, -+ ElementC, layout::ColumnMajorInterleaved, -+ int32_t, arch::OpClassTensorOp, arch::Sm75, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, -+ ThreadblockSwizzle, 2, SplitKSerial, Operator, true> { -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue for the 2nd Gemm -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h -new file mode 100644 -index 0000000..eef9d9a ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h -@@ -0,0 +1,275 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines device-side elementwise operations on TensorView. Note, the operations defined -+ in this header are not specialized for any particular data layout and are therefore not -+ intended to offer the best possible performance. Rather, they are intended to be generic -+ reference implementations to support the CUTLASS unit tests. -+*/ -+ -+#pragma once -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_view.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+template < -+ typename TensorRefIn, ///< Input TensorRef Type -+ typename TensorRefOut, ///< Output TensorRef Type -+ typename ScalarType, ///< alpha Type -+ typename TensorRefScalar, ///< Scale/Bias TensorRef Type -+ typename OutputTile, -+ typename ConvertOp = NumericConverter -+> -+__global__ void TensorScaleBiasGemm( -+ gemm::GemmCoord problem_size, -+ TensorRefIn tensor_in, ///< input tensor -+ TensorRefOut tensor_out, ///< output tensor -+ ScalarType alpha, ///< alpha -+ TensorRefScalar tensor_scale, ///< scale tensor -+ TensorRefScalar tensor_bias ///< bias tensor -+) { -+ -+ ConvertOp convert_op; -+ -+ MatrixCoord output_coord( -+ MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), -+ MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) -+ ); -+ -+ // Update the output tensor -+ for (int j = 0; j < OutputTile::kRow; ++j) { -+ for (int i = 0; i < OutputTile::kColumn; ++i) { -+ MatrixCoord coord = output_coord + MatrixCoord(i, j); -+ if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { -+ -+ ScalarType scale = alpha; -+ if(tensor_scale.good()) -+ scale = tensor_scale.at({0, coord.column()}); -+ -+ ScalarType bias = ScalarType(0); -+ -+ if(tensor_bias.good()) -+ bias = tensor_bias.at({0, coord.column()}); -+ -+ tensor_out.at(coord) = convert_op( -+ scale * ScalarType(tensor_in.at(coord)) + bias); -+ } -+ } -+ } -+} -+ -+template < -+ typename TensorRefIn, ///< Input TensorRef Type -+ typename TensorRefOut, ///< Output TensorRef Type -+ typename ScalarType, ///< alpha Type -+ typename TensorRefScalar, ///< Scale/Bias TensorRef Type -+ typename ConvertOp = NumericConverter, -+ int kThreadM = 4, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 16, // shape of a threadblock in units of threads -+ int kCtaShapeN = 8 // shape of a threadblock in units of threads -+> -+__global__ void TensorScaleBiasConv2d( -+ conv::Conv2dProblemSize problem_size, -+ TensorRefIn tensor_in, ///< input tensor -+ TensorRefOut tensor_out, ///< output tensor -+ ScalarType alpha, ///< alpha -+ TensorRefScalar tensor_scale, ///< scale tensor -+ TensorRefScalar tensor_bias ///< bias tensor -+) { -+ -+ ConvertOp convert_op; -+ -+ int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_n[kThreadM]; -+ int thread_p[kThreadM]; -+ int thread_q[kThreadM]; -+ -+ // Compute N, P, Q coordinates for each row of a thread's tile -+ int64_t PQ = int64_t(problem_size.P) * problem_size.Q; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int64_t npq = npq_start + m; -+ -+ thread_n[m] = int(npq / PQ); -+ -+ int64_t residual = npq % PQ; -+ thread_p[m] = int(residual / problem_size.Q); -+ thread_q[m] = int(residual % problem_size.Q); -+ } -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_k = k_start + n; -+ if (thread_k < problem_size.K) { -+ -+ ScalarType scale = alpha; -+ if(tensor_scale.good()) -+ scale = tensor_scale.at({0, thread_k}); -+ -+ ScalarType bias = ScalarType(0); -+ if(tensor_bias.good()) -+ bias = tensor_bias.at({0, thread_k}); -+ -+ tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( -+ scale * ScalarType( -+ tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) -+ ) + bias); -+ } -+ } -+ } -+ } -+ -+} -+ -+} -+ -+/// Apply scale and bias on a tensor -+template < -+ typename ElementIn, ///< Input Type -+ typename ElementOut, ///< Output Type -+ typename Layout, ///< Layout of input/output tensor -+ typename ScalarType, ///< alpha Type -+ typename LayoutScaleBias, ///< Layout of scale and bias -+ typename ConvertOp = NumericConverter -+> -+void TensorScaleBiasGemm( -+ gemm::GemmCoord problem_size, -+ TensorRef tensor_in, ///< input tensor -+ TensorRef tensor_out, ///< output tensor -+ ScalarType alpha, ///< alpha -+ TensorRef tensor_scale, ///< scale tensor -+ TensorRef tensor_bias ///< bias tensor -+) { -+ -+ using OutputTile = MatrixShape<4, 4>; -+ -+ dim3 block(16, 8); -+ -+ dim3 grid( -+ (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), -+ (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) -+ ); -+ -+ kernel::TensorScaleBiasGemm< -+ TensorRef, -+ TensorRef, -+ ScalarType, -+ TensorRef, -+ OutputTile, -+ ConvertOp -+ ><<< grid, block >>> ( -+ problem_size, -+ tensor_in, -+ tensor_out, -+ alpha, -+ tensor_scale, -+ tensor_bias -+ ); -+} -+ -+/// Apply scale and bias on a tensor -+template < -+ typename ElementIn, ///< Input Type -+ typename ElementOut, ///< Output Type -+ typename Layout, ///< Layout of input/output tensor -+ typename ScalarType, ///< alpha Type -+ typename LayoutScaleBias, ///< Layout of scale and bias -+ typename ConvertOp = NumericConverter -+> -+void TensorScaleBiasConv2d( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_in, ///< input tensor -+ TensorRef tensor_out, ///< output tensor -+ ScalarType alpha, ///< alpha -+ TensorRef tensor_scale, ///< scale tensor -+ TensorRef tensor_bias ///< bias tensor -+) { -+ -+ int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads -+ -+ int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; -+ int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); -+ -+ -+ kernel::TensorScaleBiasConv2d< -+ TensorRef, -+ TensorRef, -+ ScalarType, -+ TensorRef, -+ ConvertOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block >>> ( -+ problem_size, -+ tensor_in, -+ tensor_out, -+ alpha, -+ tensor_scale, -+ tensor_bias -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/test_run.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/test_run.h -new file mode 100644 -index 0000000..b64f31f ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/test_run.h -@@ -0,0 +1,95 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+ -+#include -+ -+// Run tests on GPUs -+ -+int testRun(int arch, std::vector & test_funcs, const std::string & test_name) { -+ -+ bool supported = false; -+ -+ int arch_major = arch / 10; -+ int arch_minor = arch - arch / 10 * 10; -+ -+ if(arch_major >= 8) { -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) { -+ supported = true; -+ } -+ } -+ else if(arch_major >= 7) { -+ // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. -+ // -+ // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. -+ if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) { -+ supported = true; -+ } -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (!(props.major == arch_major && props.minor == arch_minor)) { -+ supported = false; -+ } -+ -+ if (!supported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ std::cout << "This example isn't supported on current architecture" << std::endl; -+ return 0; -+ } -+ -+ bool pass = true; -+ -+ std::cout << "Device: " << props.name << std::endl; -+ std::cout << "Arch: SM" << arch << std::endl; -+ std::cout << "Test: " << test_name << std::endl; -+ for(auto func : test_funcs) { -+ pass &= func(); -+ } -+ -+ -+ if(pass) -+ return 0; -+ else -+ return -1; -+ -+} -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h -new file mode 100644 -index 0000000..4e154f5 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h -@@ -0,0 +1,831 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA0, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB0, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile -+ // (concept::MmaTensorOpFragmentIterator) -+ typename FragmentIteratorA1_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// WarpIterator to load Scale or Bias vector from threadblock fragment -+ typename FragmentIteratorA1ScaleBias_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB1, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class B2bImplicitGemmMultistage : -+ public gemm::threadblock::B2bMmaBase { -+public: -+ ///< Base class -+ using Base = gemm::threadblock::B2bMmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape0 = Shape0_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA0 = IteratorA0_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB0 = IteratorB0_; -+ ///< Policy describing tuning details -+ using Policy0 = Policy0_; -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape1 = Shape1_; -+ ///< Iterates over tiles of A operand in global memory -+ using FragmentIteratorA1 = FragmentIteratorA1_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; -+ ///< WarpIterator to load Scale or Bias vector from threadblock fragment -+ using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB1 = IteratorB1_; -+ ///< Policy describing tuning details -+ using Policy1 = Policy1_; -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ -+ ///< Epilogue after 1st Gemm -+ using OutputOp = OutputOp_; -+ -+ static const bool PerChannelScale = (OutputOp::kScale == -+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; -+ -+ // -+ // Dependent types -+ // -+ -+ using ElementC = typename Policy0::Operator::ElementC; -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations0 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ static_assert(Base::kWarpGemmIterations1 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA0 = -+ IteratorA0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB0 = -+ IteratorB0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB1 = -+ IteratorB1::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA0 = -+ (AsyncCopyIterationsPerStageA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB0 = -+ (AsyncCopyIterationsPerStageB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB1 = -+ (AsyncCopyIterationsPerStageB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA0 = typename Operator0::FragmentA; -+ using WarpLoadedFragmentB0 = typename Operator0::FragmentB; -+ /// Warp Fragment of operand A1 loaded from accmulator tile -+ using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment; -+ using WarpLoadedFragmentA1ScaleBias = -+ typename FragmentIteratorA1ScaleBias::Fragment; -+ using WarpLoadedFragmentB1 = typename Operator1::FragmentB; -+ using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; -+ using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; -+ using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; -+ using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A0_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bImplicitGemmMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::B2bMmaSharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n}); -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_0( -+ IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, -+ int group_start_A0 = 0, int group_start_B0 = 0) { -+ -+ iterator_A0.set_iteration_index(group_start_A0); -+ this->smem_iterator_A0_.set_iteration_index(group_start_A0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { -+ -+ if (group_start_A0 + j < Detail::AsyncCopyIterationsPerStageA0) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_A0.get(), iterator_A0.valid()); -+ -+ ++iterator_A0; -+ -+ ++this->smem_iterator_A0_; -+ } -+ } -+ -+ iterator_B0.set_iteration_index(group_start_B0); -+ -+ this->smem_iterator_B0_.set_iteration_index(group_start_B0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { -+ if (group_start_B0 + j < Detail::AsyncCopyIterationsPerStageB0) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ ++this->smem_iterator_B0_; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_1( -+ IteratorB1 &iterator_B1, -+ int group_start_B1 = 0) { -+ -+ iterator_B1.set_iteration_index(group_start_B1); -+ -+ this->smem_iterator_B1_.set_iteration_index(group_start_B1); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { -+ if (group_start_B1 + j < Detail::AsyncCopyIterationsPerStageB1) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ ++this->smem_iterator_B1_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations_0, -+ ///< destination accumulator tile -+ FragmentC1 &accum, -+ ///< iterator over A0 operand in global memory -+ IteratorA0 iterator_A0, -+ ///< iterator over B0 operand in global memory -+ IteratorB0 iterator_B0, -+ ///< iterator over A1 operand scale vector in global memory -+ IteratorAccumulatorScaleBias iterator_A1_scale, -+ ///< iterator over A1 operand bias vector in global memory -+ IteratorAccumulatorScaleBias iterator_A1_bias, -+ ///< iterator over B1 operand in global memory -+ IteratorB1 iterator_B1, -+ ///< initial value of accumulator -+ FragmentC0 const &src_accum, -+ ///< epilogue operation after 1st Gemm -+ OutputOp output_op_0, -+ ///< Imaginary strides used for planar-complex only - ignored here -+ int64_t imag_stride_A = 0, -+ int64_t imag_stride_B = 0) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_0) { -+ -+ iterator_A0.set_iteration_index(0); -+ this->smem_iterator_A0_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA0; ++j) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_A0.get(), iterator_A0.valid()); -+ -+ ++iterator_A0; -+ ++this->smem_iterator_A0_; -+ } -+ -+ iterator_B0.set_iteration_index(0); -+ this->smem_iterator_B0_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB0; ++j) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ ++this->smem_iterator_B0_; -+ } -+ -+ // Move to the next stage -+ iterator_A0.advance(); -+ iterator_B0.advance(); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; -+ WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; -+ WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; -+ WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; -+ -+ Operator0 warp_mma0; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], -+ warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { -+ -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k > 0) -+ warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ warp_loaded_frag_A0[warp_mma_k % 2], -+ warp_loaded_frag_B0[warp_mma_k % 2]); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A0, group_start_iteration_B0; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations0) { -+ group_start_iteration_A0 = 0; -+ group_start_iteration_B0 = 0; -+ } else { -+ group_start_iteration_A0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; -+ group_start_iteration_B0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; -+ } -+ -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, -+ group_start_iteration_B0); -+ -+ warp_mma0( -+ accum0, -+ warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ accum0 -+ ); -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations0) -+ warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A0.advance(); -+ iterator_B0.advance(); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations_0; -+ } -+ } -+ -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ -+ // 2nd Implicit Gemm -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ FragmentIteratorA1 warp_tile_iterator_A1_(accum0); -+ FragmentA1ScaleBias tb_frag_A1_scale; -+ FragmentA1ScaleBias tb_frag_A1_bias; -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); -+ -+ if(PerChannelScale) { -+ tb_frag_A1_scale.clear(); -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ ++iterator_A1_scale; -+ } -+ tb_frag_A1_bias.clear(); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ ++iterator_A1_bias; -+ -+ -+ // -+ // Prologue -+ // -+ int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_1) { -+ -+ iterator_B1.set_iteration_index(0); -+ this->smem_iterator_B1_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB1; ++j) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ ++this->smem_iterator_B1_; -+ } -+ -+ // Move to the next stage -+ iterator_B1.advance(); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; -+ WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2]; -+ WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2]; -+ WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; -+ WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; -+ WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; -+ -+ Operator1 warp_mma1; -+ -+ if(PerChannelScale) { -+ warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); -+ ++warp_tile_iterator_A1_scale_; -+ } -+ warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]); -+ ++warp_tile_iterator_A1_bias_; -+ -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], -+ warp_loaded_frag_A1_scale[0], -+ warp_loaded_frag_A1_bias[0], -+ output_op_0); -+ ++warp_tile_iterator_A1_; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); -+ ++this->warp_tile_iterator_B1_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance_1(iterator_B1); -+ -+ smem_write_stage_idx = Base::kStages - 1; -+ smem_read_stage_idx = 0; -+ -+ warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], -+ warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); -+ -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1); -+ gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; -+ ++warp_mma_k) { -+ -+ // Load threadblock-level scale/bias vector from global memory -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { -+ if(PerChannelScale) { -+ tb_frag_A1_scale.clear(); -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ ++iterator_A1_scale; -+ } -+ tb_frag_A1_bias.clear(); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ ++iterator_A1_bias; -+ } -+ -+ // Load warp-level scale bias fragment from threadblock scale/bias vector -+ if(PerChannelScale) { -+ warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); -+ ++warp_tile_iterator_A1_scale_; -+ } -+ warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]); -+ ++warp_tile_iterator_A1_bias_; -+ -+ // Load warp-level tile from accumulator fragment -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2], -+ output_op_0); -+ ++warp_tile_iterator_A1_; -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k > 0) -+ warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ warp_loaded_frag_A1[warp_mma_k % 2], -+ warp_loaded_frag_B1[warp_mma_k % 2]); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_B1; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { -+ group_start_iteration_B1 = 0; -+ } else { -+ group_start_iteration_B1 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; -+ } -+ -+ copy_tiles_and_advance_1(iterator_B1, -+ group_start_iteration_B1); -+ -+ warp_mma1( -+ accum, -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ accum -+ ); -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) -+ warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_B1.advance(); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ } -+ } -+ -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h -new file mode 100644 -index 0000000..7c6793a ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h -@@ -0,0 +1,816 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base_smem_accumulator.h" -+#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA0, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB0, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// Iterates over accumulator tile -+ typename FragmentIteratorAccumulator_, -+ /// Iterates over accumulator tile in shared memory -+ typename SmemIteratorD0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile -+ // (concept::MmaTensorOpFragmentIterator) -+ typename WarpIteratorA1_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB1, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class B2bImplicitGemmMultistageSmemAccumulator : -+ public gemm::threadblock::B2bMmaBaseSmemAccumulator { -+public: -+ ///< Base class -+ using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape0 = Shape0_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA0 = IteratorA0_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB0 = IteratorB0_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; -+ ///< Policy describing tuning details -+ using Policy0 = Policy0_; -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory -+ -+ using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape1 = Shape1_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB1 = IteratorB1_; -+ ///< Policy describing tuning details -+ using Policy1 = Policy1_; -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory -+ -+ ///< Epilogue after 1st Gemm -+ using OutputOp = OutputOp_; -+ -+ static const bool PerChannelScale = (OutputOp::kScale == -+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; -+ -+ // -+ // Dependent types -+ // -+ -+ using ElementC = typename Policy0::Operator::ElementC; -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Epilog in shared memory -+ using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< -+ SmemIteratorD0, ///< SmemTileIterator -+ FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator -+ IteratorAccumulatorScaleBias, ///< ScaleBiasIterator -+ OutputOp>; ///< Output operator -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations0 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ static_assert(Base::kWarpGemmIterations1 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA0 = -+ IteratorA0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB0 = -+ IteratorB0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB1 = -+ IteratorB1::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA0 = -+ (AsyncCopyIterationsPerStageA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB0 = -+ (AsyncCopyIterationsPerStageB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB1 = -+ (AsyncCopyIterationsPerStageB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA0 = typename Operator0::FragmentA; -+ using WarpLoadedFragmentB0 = typename Operator0::FragmentB; -+ using WarpLoadedFragmentA1 = typename Operator1::FragmentA; -+ using WarpLoadedFragmentB1 = typename Operator1::FragmentB; -+ using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; -+ using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; -+ using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; -+ using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A0_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Shared Memory Iterator to store accumulator tile -+ SmemIteratorD0 smem_iterator_D0_; -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ WarpIteratorA1 warp_tile_iterator_A1_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bImplicitGemmMultistageSmemAccumulator( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::B2bMmaSharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), -+ warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), -+ smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; -+ int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; -+ -+ int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ -+ int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; -+ int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0}); -+ warp_tile_iterator_A1_.add_tile_offset( -+ {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); -+ -+ // Add smem accumulator iterator warp offset -+ smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, -+ warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_0( -+ IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, -+ int group_start_A0 = 0, int group_start_B0 = 0) { -+ -+ iterator_A0.set_iteration_index(group_start_A0); -+ this->smem_iterator_A0_.set_iteration_index(group_start_A0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { -+ -+ if (group_start_A0 + j < Detail::AsyncCopyIterationsPerStageA0) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_A0.get(), iterator_A0.valid()); -+ -+ ++iterator_A0; -+ -+ ++this->smem_iterator_A0_; -+ } -+ } -+ -+ iterator_B0.set_iteration_index(group_start_B0); -+ -+ this->smem_iterator_B0_.set_iteration_index(group_start_B0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { -+ if (group_start_B0 + j < Detail::AsyncCopyIterationsPerStageB0) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ ++this->smem_iterator_B0_; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_1( -+ IteratorB1 &iterator_B1, -+ int group_start_B1 = 0) { -+ -+ iterator_B1.set_iteration_index(group_start_B1); -+ -+ this->smem_iterator_B1_.set_iteration_index(group_start_B1); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { -+ if (group_start_B1 + j < Detail::AsyncCopyIterationsPerStageB1) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ ++this->smem_iterator_B1_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations_0, -+ ///< destination accumulator tile -+ FragmentC1 &accum, -+ ///< iterator over A0 operand in global memory -+ IteratorA0 iterator_A0, -+ ///< iterator over B0 operand in global memory -+ IteratorB0 iterator_B0, -+ ///< iterator over A1 operand scale vector in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_scale, -+ ///< iterator over A1 operand bias vector in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_bias, -+ ///< iterator over B1 operand in global memory -+ IteratorB1 iterator_B1, -+ ///< initial value of accumulator -+ FragmentC0 const &src_accum, -+ ///< epilogue operation after 1st Gemm -+ OutputOp output_op_0, -+ ///< Imaginary strides used for planar-complex only - ignored here -+ int64_t imag_stride_A = 0, -+ int64_t imag_stride_B = 0) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_0) { -+ -+ iterator_A0.set_iteration_index(0); -+ this->smem_iterator_A0_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA0; ++j) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_A0.get(), iterator_A0.valid()); -+ -+ ++iterator_A0; -+ ++this->smem_iterator_A0_; -+ } -+ -+ iterator_B0.set_iteration_index(0); -+ this->smem_iterator_B0_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB0; ++j) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ ++this->smem_iterator_B0_; -+ } -+ -+ // Move to the next stage -+ iterator_A0.advance(); -+ iterator_B0.advance(); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; -+ WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; -+ WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; -+ WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; -+ -+ Operator0 warp_mma0; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], -+ warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { -+ -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k > 0) -+ warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ warp_loaded_frag_A0[warp_mma_k % 2], -+ warp_loaded_frag_B0[warp_mma_k % 2]); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A0, group_start_iteration_B0; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations0) { -+ group_start_iteration_A0 = 0; -+ group_start_iteration_B0 = 0; -+ } else { -+ group_start_iteration_A0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; -+ group_start_iteration_B0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; -+ } -+ -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, -+ group_start_iteration_B0); -+ -+ warp_mma0( -+ accum0, -+ warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ accum0 -+ ); -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations0) -+ warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A0.advance(); -+ iterator_B0.advance(); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations_0; -+ } -+ } -+ -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ /// Epilogue for the first Implicit Gemm -+ Epilogue0 epilogue0; -+ -+ epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); -+ -+ __syncthreads(); -+ -+ // 2nd Implicit Gemm -+ -+ // -+ // Prologue -+ // -+ int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_1) { -+ -+ iterator_B1.set_iteration_index(0); -+ this->smem_iterator_B1_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB1; ++j) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ ++this->smem_iterator_B1_; -+ } -+ -+ // Move to the next stage -+ iterator_B1.advance(); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; -+ WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; -+ WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; -+ WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; -+ -+ Operator1 warp_mma1; -+ -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); -+ ++warp_tile_iterator_A1_; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); -+ ++this->warp_tile_iterator_B1_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance_1(iterator_B1); -+ -+ smem_write_stage_idx = Base::kStages - 1; -+ smem_read_stage_idx = 0; -+ -+ warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], -+ warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); -+ -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1); -+ gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; -+ ++warp_mma_k) { -+ -+ // Load warp-level tile from accumulator fragment -+ // skip warp tile loading for the last kgroup -+ if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) { -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); -+ } -+ ++warp_tile_iterator_A1_; -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ ++this->warp_tile_iterator_B1_; -+ -+ -+ if (warp_mma_k > 0) -+ warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ warp_loaded_frag_A1[warp_mma_k % 2], -+ warp_loaded_frag_B1[warp_mma_k % 2]); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_B1; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { -+ group_start_iteration_B1 = 0; -+ } else { -+ group_start_iteration_B1 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; -+ } -+ -+ copy_tiles_and_advance_1(iterator_B1, -+ group_start_iteration_B1); -+ -+ warp_mma1( -+ accum, -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ accum -+ ); -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) -+ warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_B1.advance(); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ } -+ } -+ -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h -new file mode 100644 -index 0000000..36d4563 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h -@@ -0,0 +1,553 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile -+ // (concept::MmaTensorOpFragmentIterator) -+ typename FragmentIteratorA1_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// FragmentIterator to load Scale or Bias vector from threadblock fragment -+ typename FragmentIteratorA1ScaleBias_, -+ // (concept: VectorFragmentIterator) -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Transformation applied to A operand -+ typename TransformA0_ = NumericArrayConverter< -+ typename SmemIteratorA0_::Element, -+ typename IteratorA0_::Element, -+ IteratorA0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B operand -+ typename TransformB0_ = NumericArrayConverter< -+ typename SmemIteratorB0_::Element, -+ typename IteratorB0_::Element, -+ IteratorB0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B operand -+ typename TransformB1_ = NumericArrayConverter< -+ typename SmemIteratorB1_::Element, -+ typename IteratorB1_::Element, -+ IteratorB1_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class B2bImplicitGemmPipelined : -+ public gemm::threadblock::B2bMmaBase { -+public: -+ -+ ///< Base class -+ using Base = gemm::threadblock::B2bMmaBase; -+ -+ using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory -+ using Policy0 = Policy0_; ///< Policy0 describing tuning details -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ -+ using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over tiles of A1 operand from accumulator tile -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory -+ using FragmentIteratorA1ScaleBias = -+ FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment -+ using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory -+ using Policy1 = Policy1_; ///< Policy1 describing tuning details -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ -+ -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ -+ using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm -+ -+ static const bool PerChannelScale = (OutputOp::kScale == -+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); -+ -+ using TransformA0 = TransformA0_; -+ using TransformB0 = TransformB0_; -+ using TransformB1 = TransformB1_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA0 = typename IteratorA0::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB0 = typename IteratorB0::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB1 = typename IteratorB1::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy0::Operator::ArchTag; -+ -+ /// Complex transform on A0 operand -+ static ComplexTransform const kTransformA0 = Operator0::kTransformA; -+ -+ /// Complex transform on B0 operand -+ static ComplexTransform const kTransformB0 = Operator0::kTransformB; -+ -+ /// Complex transform on B1 operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+private: -+ -+ using WarpFragmentA0 = typename Operator0::FragmentA; -+ using WarpFragmentB0 = typename Operator0::FragmentB; -+ /// Warp Fragment of operand A1 loaded from accmulator tile -+ using WarpFragmentA1 = typename FragmentIteratorA1::Fragment; -+ /// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment -+ using WarpFragmentA1ScaleBias = -+ typename FragmentIteratorA1ScaleBias::Fragment; -+ using WarpFragmentB1 = typename Operator1::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B0 operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Iterator to write threadblock-scoped tile of B1 operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bImplicitGemmPipelined( -+ typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; -+ -+ //These may change across different GEMM layers -+ int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k; -+ int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0}); -+ this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n}); -+ this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n}); -+ -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations_0, ///< number of iterations of the mainloop -+ FragmentC1 &accum, ///< destination accumulator tile -+ IteratorA0 iterator_A, ///< iterator over A operand in global memory -+ IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory -+ IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory -+ IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory -+ IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory -+ FragmentC0 const &src_accum, ///< source accumulator tile -+ OutputOp output_op_0, ///< epilogue operation after 1st Gemm -+ TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment -+ TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment -+ TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ FragmentA0 tb_frag_A; -+ FragmentB0 tb_frag_B0; -+ -+ tb_frag_A.clear(); -+ tb_frag_B0.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ -+ ++iterator_A; -+ ++iterator_B0; -+ -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA0 warp_frag_A0[2]; -+ WarpFragmentB0 warp_frag_B0[2]; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ Operator0 warp_mma0; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ -+ ++iterator_A; -+ ++iterator_B0; -+ } -+ -+ warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], -+ warp_frag_B0[warp_mma_k % 2], accum0); -+ -+ } -+ } -+ -+ -+ //2nd Implicit Gemm -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ FragmentIteratorA1 warp_tile_iterator_A1_(accum0); -+ -+ -+ -+ // -+ // Prologue -+ // -+ -+ FragmentA1ScaleBias tb_frag_A1_scale; -+ FragmentA1ScaleBias tb_frag_A1_bias; -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); -+ FragmentB1 tb_frag_B1; -+ -+ if(PerChannelScale) -+ tb_frag_A1_scale.clear(); -+ tb_frag_A1_bias.clear(); -+ tb_frag_B1.clear(); -+ -+ // The last kblock is loaded in the prolog -+ if(PerChannelScale) -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ iterator_B1.load(tb_frag_B1); -+ -+ -+ if(PerChannelScale) -+ ++iterator_A1_scale; -+ ++iterator_A1_bias; -+ ++iterator_B1; -+ -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ ++this->smem_iterator_B1_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA1ScaleBias warp_frag_A1_scale[2]; -+ WarpFragmentA1ScaleBias warp_frag_A1_bias[2]; -+ WarpFragmentA1 warp_frag_A1[2]; -+ WarpFragmentB1 warp_frag_B1[2]; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ -+ if(PerChannelScale) -+ warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]); -+ warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]); -+ warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0], -+ warp_frag_A1_bias[0], output_op_0); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); -+ -+ ++warp_tile_iterator_A1_; -+ if(PerChannelScale) -+ ++warp_tile_iterator_A1_scale_; -+ ++warp_tile_iterator_A1_bias_; -+ ++this->warp_tile_iterator_B1_; -+ -+ Operator1 warp_mma1; -+ -+ smem_write_stage_idx = 1; -+ -+ int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_PRAGMA_UNROLL -+ for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { -+ -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_B1_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * Base::kWarpGemmIterations1, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ -+ if(PerChannelScale) { -+ tb_frag_A1_scale.clear(); -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ ++iterator_A1_scale; -+ } -+ tb_frag_A1_bias.clear(); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ ++iterator_A1_bias; -+ } -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ -+ if(PerChannelScale) -+ warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]); -+ warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]); -+ warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], -+ warp_frag_A1_scale[(warp_mma_k + 1) % 2], -+ warp_frag_A1_bias[(warp_mma_k + 1) % 2], -+ output_op_0); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ if(PerChannelScale) -+ ++warp_tile_iterator_A1_scale_; -+ ++warp_tile_iterator_A1_bias_; -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_B1.load(tb_frag_B1); -+ -+ ++iterator_B1; -+ } -+ -+ warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], -+ warp_frag_B1[warp_mma_k % 2], accum); -+ -+ } -+ } -+ -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h -new file mode 100644 -index 0000000..828426b ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h -@@ -0,0 +1,535 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base_smem_accumulator.h" -+#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// Iterates over accumulator tile -+ typename FragmentIteratorAccumulator_, -+ /// Iterates over accumulator tile in shared memory -+ typename SmemIteratorD0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile in shared memory -+ typename WarpIteratorA1_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Transformation applied to A0 operand -+ typename TransformA0_ = NumericArrayConverter< -+ typename SmemIteratorA0_::Element, -+ typename IteratorA0_::Element, -+ IteratorA0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B0 operand -+ typename TransformB0_ = NumericArrayConverter< -+ typename SmemIteratorB0_::Element, -+ typename IteratorB0_::Element, -+ IteratorB0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B1 operand -+ typename TransformB1_ = NumericArrayConverter< -+ typename SmemIteratorB1_::Element, -+ typename IteratorB1_::Element, -+ IteratorB1_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class B2bImplicitGemmPipelinedSmemAccumulator : -+ public gemm::threadblock::B2bMmaBaseSmemAccumulator { -+public: -+ -+ ///< Base class -+ using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator; -+ -+ using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory -+ using Policy0 = Policy0_; ///< Policy0 describing tuning details -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory -+ -+ using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile -+ -+ using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory -+ using Policy1 = Policy1_; ///< Policy1 describing tuning details -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory -+ -+ -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ -+ using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm -+ -+ using TransformA0 = TransformA0_; -+ using TransformB0 = TransformB0_; -+ using TransformB1 = TransformB1_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA0 = typename IteratorA0::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB0 = typename IteratorB0::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB1 = typename IteratorB1::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy0::Operator::ArchTag; -+ -+ /// Complex transform on A0 operand -+ static ComplexTransform const kTransformA0 = Operator0::kTransformA; -+ -+ /// Complex transform on B0 operand -+ static ComplexTransform const kTransformB0 = Operator0::kTransformB; -+ -+ /// Complex transform on B1 operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+ /// Epilog in shared memory -+ using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< -+ SmemIteratorD0, ///< SmemTileIterator -+ FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator -+ IteratorAccumulatorScaleBias, ///< ScaleBiasIterator -+ OutputOp>; ///< Output operator -+ -+ -+ -+private: -+ -+ using WarpFragmentA0 = typename Operator0::FragmentA; -+ using WarpFragmentB0 = typename Operator0::FragmentB; -+ using WarpFragmentA1 = typename Operator1::FragmentA; -+ using WarpFragmentB1 = typename Operator1::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B0 operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Shared Memory Iterator to store accumulator tile -+ SmemIteratorD0 smem_iterator_D0_; -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ WarpIteratorA1 warp_tile_iterator_A1_; -+ -+ /// Iterator to write threadblock-scoped tile of B1 operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bImplicitGemmPipelinedSmemAccumulator( -+ typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), -+ warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), -+ smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; -+ int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; -+ -+ int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0; -+ -+ int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ -+ int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; -+ int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; -+ -+ int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0}); -+ this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0}); -+ warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1}); -+ this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1}); -+ -+ // Add smem accumulator iterator warp offset -+ smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, -+ warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); -+ -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations_0, ///< number of iterations of the mainloop -+ FragmentC1 &accum, ///< destination accumulator tile -+ IteratorA0 iterator_A, ///< iterator over A operand in global memory -+ IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory -+ IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory -+ FragmentC0 const &src_accum, ///< source accumulator tile -+ OutputOp output_op_0, ///< epilogue operation after 1st Gemm -+ TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment -+ TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment -+ TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ FragmentA0 tb_frag_A; -+ FragmentB0 tb_frag_B0; -+ -+ tb_frag_A.clear(); -+ tb_frag_B0.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ -+ ++iterator_A; -+ ++iterator_B0; -+ -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA0 warp_frag_A0[2]; -+ WarpFragmentB0 warp_frag_B0[2]; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ Operator0 warp_mma0; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ ++iterator_A; -+ ++iterator_B0; -+ } -+ -+ warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], -+ warp_frag_B0[warp_mma_k % 2], accum0); -+ -+ } -+ } -+ -+ /// Epilogue for the first Implicit Gemm -+ Epilogue0 epilogue0; -+ -+ epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); -+ -+ __syncthreads(); -+ -+ /// 2nd Implicit Gemm -+ -+ -+ // -+ // Prologue -+ // -+ -+ FragmentB1 tb_frag_B1; -+ -+ tb_frag_B1.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_B1.load(tb_frag_B1); -+ -+ ++iterator_B1; -+ -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ ++this->smem_iterator_B1_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA1 warp_frag_A1[2]; -+ WarpFragmentB1 warp_frag_B1[2]; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ -+ warp_tile_iterator_A1_.load(warp_frag_A1[0]); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); -+ -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B1_; -+ -+ Operator1 warp_mma1; -+ -+ smem_write_stage_idx = 1; -+ -+ int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; -+ -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_PRAGMA_UNROLL -+ for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_B1_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ -+ } -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ -+ // skip warp tile loading for the last kgroup -+ if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1) -+ warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_B1.load(tb_frag_B1); -+ -+ ++iterator_B1; -+ } -+ -+ warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], -+ warp_frag_B1[warp_mma_k % 2], accum); -+ -+ } -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h -new file mode 100644 -index 0000000..660879c ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h -@@ -0,0 +1,236 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class B2bMmaBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape0 = Shape0_; -+ using Shape1 = Shape1_; -+ -+ ///< Policy describing tuning details -+ using Policy0 = Policy0_; -+ using Policy1 = Policy1_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm0 = typename Policy0::Operator::Shape; -+ using WarpGemm1 = typename Policy1::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount0 = GemmShape; -+ using WarpCount1 = GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations0 = -+ (WarpGemm0::kK / Operator0::Policy::MmaShape::kK); -+ static int const kWarpGemmIterations1 = -+ (WarpGemm1::kK / Operator1::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ template< -+ typename Shape_, -+ typename Policy_ -+ > -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ using Shape = Shape_; -+ using Policy = Policy_; -+ using Operator = typename Policy::Operator; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ }; -+ -+ using SharedStorage0 = SharedStorage; -+ using SharedStorage1 = SharedStorage; -+ union B2bMmaSharedStorage { -+ SharedStorage0 shared_storage0; -+ SharedStorage1 shared_storage1; -+ }; -+ -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A0 operand from shared memory -+ typename Operator0::IteratorA warp_tile_iterator_A0_; -+ -+ /// Iterator to load a warp-scoped tile of B0 operand from shared memory -+ typename Operator0::IteratorB warp_tile_iterator_B0_; -+ -+ /// Iterator to load a warp-scoped tile of B1 operand from shared memory -+ typename Operator1::IteratorB warp_tile_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bMmaBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ B2bMmaSharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ warp_tile_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), lane_idx), -+ warp_tile_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), lane_idx) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h -new file mode 100644 -index 0000000..fc8058a ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h -@@ -0,0 +1,179 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "threadblock/b2b_mma_base.h" -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Shared Memory Accumulator Iterator -+ typename SmemAccumulatorIterator0_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class B2bMmaBaseSmemAccumulator : -+ public B2bMmaBase { -+ -+ public: -+ ///< Base class -+ using Base = B2bMmaBase; -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape0 = Shape0_; -+ using Shape1 = Shape1_; -+ -+ ///< Policy describing tuning details -+ using Policy0 = Policy0_; -+ using Policy1 = Policy1_; -+ -+ -+ using SmemAccumulatorIterator0 = SmemAccumulatorIterator0_; -+ -+ // -+ // Nested structs -+ // -+ /// Shared storage object needed by accumulator -+ template< -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename Padding_ -+ > -+ class AccumulatorSharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using Padding = Padding_; -+ -+ /// Tensor reference to the accumulator -+ using TensorRefAccum = TensorRef; -+ -+ /// Shape of the accumulator matrix in shared memory -+ using ShapeAccum = MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for accumulator -+ AlignedBuffer accum; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the Accum matrix -+ CUTLASS_DEVICE -+ static Layout LayoutAccum() { -+ return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the Accumulator -+ CUTLASS_HOST_DEVICE -+ TensorRefAccum accum_ref() { -+ return TensorRefAccum{accum.data(), LayoutAccum()}; -+ } -+ -+ }; -+ -+ using AccumulatorSharedStorage0 = AccumulatorSharedStorage< -+ Shape0, typename SmemAccumulatorIterator0::Element, -+ typename SmemAccumulatorIterator0::TensorLayout, -+ typename SmemAccumulatorIterator0::Padding>; -+ -+ struct B2bMmaSharedStorage { -+ typename Base::B2bMmaSharedStorage b2b_mma_shared_storage; -+ AccumulatorSharedStorage0 accumulator_shared_storage0; -+ }; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bMmaBaseSmemAccumulator( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ B2bMmaSharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage.b2b_mma_shared_storage, thread_idx, warp_idx, lane_idx) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h -new file mode 100644 -index 0000000..4ec718b ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h -@@ -0,0 +1,885 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA0, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB0, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile -+ // (concept::MmaTensorOpFragmentIterator) -+ typename FragmentIteratorA1_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// WarpIterator to load Scale or Bias vector from threadblock fragment -+ typename FragmentIteratorA1ScaleBias_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB1, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class B2bMmaMultistage : -+ public B2bMmaBase { -+public: -+ ///< Base class -+ using Base = B2bMmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape0 = Shape0_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA0 = IteratorA0_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB0 = IteratorB0_; -+ ///< Policy describing tuning details -+ using Policy0 = Policy0_; -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape1 = Shape1_; -+ ///< Iterates over intermediate accumulator tile -+ using FragmentIteratorA1 = FragmentIteratorA1_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; -+ ///< WarpIterator to load Scale or Bias vector from threadblock fragment -+ using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB1 = IteratorB1_; -+ ///< Policy describing tuning details -+ using Policy1 = Policy1_; -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ ///< Epilogue after 1st Gemm -+ using OutputOp = OutputOp_; -+ -+ static const bool PerChannelScale = (OutputOp::kScale == -+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA0 = Operator0::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB0 = Operator0::kTransformB; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations0 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ static_assert(Base::kWarpGemmIterations1 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const TBLoadIterationsA0 = -+ IteratorA0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const TBLoadIterationsB0 = -+ IteratorB0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const TBLoadIterationsB1 = -+ IteratorB1::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA0 = -+ (TBLoadIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB0 = -+ (TBLoadIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB1 = -+ (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA0 = typename Operator0::FragmentA; -+ using WarpLoadedFragmentB0 = typename Operator0::FragmentB; -+ /// Warp Fragment of operand A1 loaded from accmulator tile -+ using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment; -+ using WarpLoadedFragmentA1ScaleBias = -+ typename FragmentIteratorA1ScaleBias::Fragment; -+ using WarpLoadedFragmentB1 = typename Operator1::FragmentB; -+ using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; -+ using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; -+ using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; -+ using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A0_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bMmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::B2bMmaSharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx, -+ ///< GEMM0 N is used for accumulator extent -+ int problem_size_0_n -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n}); -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, -+ int group_start_A0 = 0, int group_start_B0 = 0) { -+ iterator_A0.set_iteration_index(group_start_A0 * -+ IteratorA0::kAccessesPerVector); -+ this->smem_iterator_A0_.set_iteration_index(group_start_A0); -+ -+ // Load for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { -+ if (group_start_A0 + j < Detail::TBLoadIterationsA0) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / -+ IteratorA0::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A0.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A0.valid()); -+ -+ ++iterator_A0; -+ } -+ -+ ++this->smem_iterator_A0_; -+ } -+ } -+ -+ iterator_B0.set_iteration_index(group_start_B0 * -+ IteratorB0::kAccessesPerVector); -+ this->smem_iterator_B0_.set_iteration_index(group_start_B0); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { -+ if (group_start_B0 + j < Detail::TBLoadIterationsB0) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / -+ IteratorB0::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B0.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B0.valid()); -+ -+ ++iterator_B0; -+ } -+ ++this->smem_iterator_B0_; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_1(IteratorB1 &iterator_B1, -+ int group_start_B1 = 0) { -+ iterator_B1.set_iteration_index(group_start_B1 * -+ IteratorB1::kAccessesPerVector); -+ this->smem_iterator_B1_.set_iteration_index(group_start_B1); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { -+ if (group_start_B1 + j < Detail::TBLoadIterationsB1) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / -+ IteratorB1::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B1.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ ++this->smem_iterator_B1_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations_0, -+ ///< destination accumulator tile -+ FragmentC1 &accum, -+ ///< iterator over A0 operand in global memory -+ IteratorA0 iterator_A0, -+ ///< iterator over B0 operand in global memory -+ IteratorB0 iterator_B0, -+ ///< iterator over A1 operand scale vector in global memory -+ IteratorAccumulatorScaleBias iterator_A1_scale, -+ ///< iterator over A1 operand bias vector in global memory -+ IteratorAccumulatorScaleBias iterator_A1_bias, -+ ///< iterator over B1 operand in global memory -+ IteratorB1 iterator_B1, -+ ///< initial value of accumulator -+ FragmentC0 const &src_accum, -+ ///< epilogue operation after 1st Gemm -+ OutputOp output_op_0) -+ { -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_0) { -+ -+ iterator_A0.clear_mask(gemm_k_iterations_0 == 0); -+ iterator_B0.clear_mask(gemm_k_iterations_0 == 0); -+ -+ iterator_A0.set_iteration_index(0); -+ this->smem_iterator_A0_.set_iteration_index(0); -+ -+ // Load for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsA0; ++j) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / -+ IteratorA0::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A0.get(), iterator_A0.valid()); -+ -+ ++iterator_A0; -+ } -+ -+ ++this->smem_iterator_A0_; -+ } -+ -+ iterator_B0.set_iteration_index(0); -+ this->smem_iterator_B0_.set_iteration_index(0); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB0; ++j) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / -+ IteratorB0::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ } -+ -+ ++this->smem_iterator_B0_; -+ } -+ -+ // Move to the next stage -+ iterator_A0.add_tile_offset({0, 1}); -+ iterator_B0.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ // DEPBAR+SYNC -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; -+ WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; -+ WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; -+ WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; -+ -+ Operator0 warp_mma0; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ iterator_A0.clear_mask(gemm_k_iterations_0 == 0); -+ iterator_B0.clear_mask(gemm_k_iterations_0 == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], -+ warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k > 0) -+ warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ warp_loaded_frag_A0[warp_mma_k % 2], -+ warp_loaded_frag_B0[warp_mma_k % 2]); -+ -+ warp_mma0( -+ accum0, -+ warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ accum0 -+ ); -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations0 - 1) { -+ int group_start_iteration_A0, group_start_iteration_B0; -+ -+ group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0; -+ group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0; -+ -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, -+ group_start_iteration_B0); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { -+ int group_start_iteration_A0, group_start_iteration_B0; -+ group_start_iteration_A0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; -+ group_start_iteration_B0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; -+ -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, -+ group_start_iteration_B0); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A0.add_tile_offset({0, 1}); -+ iterator_B0.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations_0; -+ iterator_A0.clear_mask(gemm_k_iterations_0 == 0); -+ iterator_B0.clear_mask(gemm_k_iterations_0 == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations0) -+ warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ // 2nd Gemm -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ FragmentIteratorA1 warp_tile_iterator_A1_(accum0); -+ FragmentA1ScaleBias tb_frag_A1_scale; -+ FragmentA1ScaleBias tb_frag_A1_bias; -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); -+ -+ if(PerChannelScale) { -+ tb_frag_A1_scale.clear(); -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ ++iterator_A1_scale; -+ } -+ tb_frag_A1_bias.clear(); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ ++iterator_A1_bias; -+ -+ // -+ // Prologue -+ // -+ int gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_1) { -+ -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 0); -+ -+ iterator_B1.set_iteration_index(0); -+ this->smem_iterator_B1_.set_iteration_index(0); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / -+ IteratorB1::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ -+ ++this->smem_iterator_B1_; -+ } -+ -+ // Move to the next stage -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // DEPBAR+SYNC -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; -+ WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2]; -+ WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2]; -+ WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; -+ WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; -+ WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; -+ -+ Operator1 warp_mma1; -+ -+ if(PerChannelScale) { -+ warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); -+ ++warp_tile_iterator_A1_scale_; -+ } -+ warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]); -+ ++warp_tile_iterator_A1_bias_; -+ -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], -+ warp_loaded_frag_A1_scale[0], -+ warp_loaded_frag_A1_bias[0], -+ output_op_0); -+ ++warp_tile_iterator_A1_; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); -+ ++this->warp_tile_iterator_B1_; -+ -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 0); -+ -+ smem_write_stage_idx = Base::kStages - 1; -+ smem_read_stage_idx = 0; -+ -+ warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], -+ warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1 - (Base::kStages - 1); -+ CUTLASS_PRAGMA_UNROLL -+ for (; gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; -+ ++warp_mma_k) { -+ -+ // Load threadblock-level scale/bias vector from global memory -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { -+ if(PerChannelScale) { -+ tb_frag_A1_scale.clear(); -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ ++iterator_A1_scale; -+ } -+ tb_frag_A1_bias.clear(); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ ++iterator_A1_bias; -+ } -+ -+ // Load warp-level scale bias fragment from threadblock scale/bias vector -+ if(PerChannelScale) { -+ warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); -+ ++warp_tile_iterator_A1_scale_; -+ } -+ warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]); -+ ++warp_tile_iterator_A1_bias_; -+ -+ // Load warp-level tile from accumulator fragment -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2], -+ output_op_0); -+ ++warp_tile_iterator_A1_; -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k > 0) -+ warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ warp_loaded_frag_A1[warp_mma_k % 2], -+ warp_loaded_frag_B1[warp_mma_k % 2]); -+ -+ -+ warp_mma1( -+ accum, -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ accum -+ ); -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { -+ int group_start_iteration_B1; -+ -+ group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; -+ -+ copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { -+ int group_start_iteration_B1; -+ group_start_iteration_B1 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; -+ -+ copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 1); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) -+ warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h -new file mode 100644 -index 0000000..7f42d52 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h -@@ -0,0 +1,869 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base_smem_accumulator.h" -+#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA0, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB0, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// Iterates over accumulator tile -+ typename FragmentIteratorAccumulator_, -+ /// Iterates over accumulator tile in shared memory -+ typename SmemIteratorD0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile in shared memory -+ typename WarpIteratorA1_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB1, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class B2bMmaMultistageSmemAccumulator : -+ public gemm::threadblock::B2bMmaBaseSmemAccumulator { -+public: -+ ///< Base class -+ using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape0 = Shape0_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA0 = IteratorA0_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB0 = IteratorB0_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; -+ ///< Policy describing tuning details -+ using Policy0 = Policy0_; -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory -+ -+ using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape1 = Shape1_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB1 = IteratorB1_; -+ ///< Policy describing tuning details -+ using Policy1 = Policy1_; -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ ///< Epilogue after 1st Gemm -+ using OutputOp = OutputOp_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Epilog in shared memory -+ using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< -+ SmemIteratorD0, ///< SmemTileIterator -+ FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator -+ IteratorAccumulatorScaleBias, ///< ScaleBiasIterator -+ OutputOp>; ///< Output operator -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA0 = Operator0::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB0 = Operator0::kTransformB; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations0 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ static_assert(Base::kWarpGemmIterations1 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const TBLoadIterationsA0 = -+ IteratorA0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const TBLoadIterationsB0 = -+ IteratorB0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const TBLoadIterationsB1 = -+ IteratorB1::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA0 = -+ (TBLoadIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB0 = -+ (TBLoadIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB1 = -+ (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA0 = typename Operator0::FragmentA; -+ using WarpLoadedFragmentB0 = typename Operator0::FragmentB; -+ using WarpLoadedFragmentA1 = typename Operator1::FragmentA; -+ using WarpLoadedFragmentB1 = typename Operator1::FragmentB; -+ using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; -+ using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; -+ using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; -+ using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A0_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Shared Memory Iterator to store accumulator tile -+ SmemIteratorD0 smem_iterator_D0_; -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ WarpIteratorA1 warp_tile_iterator_A1_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bMmaMultistageSmemAccumulator( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::B2bMmaSharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx, -+ ///< GEMM0 N is used for accumulator extent -+ int problem_size_0_n -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), -+ warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx ), -+ smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; -+ int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; -+ -+ int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ -+ int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; -+ int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0}); -+ warp_tile_iterator_A1_.add_tile_offset( -+ {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); -+ -+ // Add smem accumulator iterator warp offset -+ smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, -+ warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, -+ int group_start_A0 = 0, int group_start_B0 = 0) { -+ iterator_A0.set_iteration_index(group_start_A0 * -+ IteratorA0::kAccessesPerVector); -+ this->smem_iterator_A0_.set_iteration_index(group_start_A0); -+ -+ // cp.async for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { -+ if (group_start_A0 + j < Detail::TBLoadIterationsA0) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / -+ IteratorA0::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A0.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A0.valid()); -+ -+ ++iterator_A0; -+ } -+ -+ ++this->smem_iterator_A0_; -+ } -+ } -+ -+ iterator_B0.set_iteration_index(group_start_B0 * -+ IteratorB0::kAccessesPerVector); -+ this->smem_iterator_B0_.set_iteration_index(group_start_B0); -+ -+ // cp.async for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { -+ if (group_start_B0 + j < Detail::TBLoadIterationsB0) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / -+ IteratorB0::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B0.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B0.valid()); -+ -+ ++iterator_B0; -+ } -+ ++this->smem_iterator_B0_; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_1(IteratorB1 &iterator_B1, -+ int group_start_B1 = 0) { -+ iterator_B1.set_iteration_index(group_start_B1 * -+ IteratorB1::kAccessesPerVector); -+ this->smem_iterator_B1_.set_iteration_index(group_start_B1); -+ -+ // cp.async for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { -+ if (group_start_B1 + j < Detail::TBLoadIterationsB1) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / -+ IteratorB1::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B1.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ ++this->smem_iterator_B1_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations_0, -+ ///< destination accumulator tile -+ FragmentC1 &accum, -+ ///< iterator over A0 operand in global memory -+ IteratorA0 iterator_A0, -+ ///< iterator over B0 operand in global memory -+ IteratorB0 iterator_B0, -+ ///< iterator over A1 operand scale vector in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_scale, -+ ///< iterator over A1 operand bias vector in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_bias, -+ ///< iterator over B1 operand in global memory -+ IteratorB1 iterator_B1, -+ ///< initial value of accumulator -+ FragmentC0 const &src_accum, -+ ///< epilogue operation after 1st Gemm -+ OutputOp output_op_0) -+ { -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_0) { -+ -+ iterator_A0.clear_mask(gemm_k_iterations_0 == 0); -+ iterator_B0.clear_mask(gemm_k_iterations_0 == 0); -+ -+ iterator_A0.set_iteration_index(0); -+ this->smem_iterator_A0_.set_iteration_index(0); -+ -+ // cp.async for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsA0; ++j) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / -+ IteratorA0::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A0.get(), iterator_A0.valid()); -+ -+ ++iterator_A0; -+ } -+ -+ ++this->smem_iterator_A0_; -+ } -+ -+ iterator_B0.set_iteration_index(0); -+ this->smem_iterator_B0_.set_iteration_index(0); -+ -+ // cp.async for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB0; ++j) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / -+ IteratorB0::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ } -+ -+ ++this->smem_iterator_B0_; -+ } -+ -+ // Move to the next stage -+ iterator_A0.add_tile_offset({0, 1}); -+ iterator_B0.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ // DEPBAR+SYNC -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; -+ WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; -+ WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; -+ WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; -+ -+ Operator0 warp_mma0; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ iterator_A0.clear_mask(gemm_k_iterations_0 == 0); -+ iterator_B0.clear_mask(gemm_k_iterations_0 == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], -+ warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k > 0) -+ warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ warp_loaded_frag_A0[warp_mma_k % 2], -+ warp_loaded_frag_B0[warp_mma_k % 2]); -+ -+ warp_mma0( -+ accum0, -+ warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ accum0 -+ ); -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations0 - 1) { -+ int group_start_iteration_A0, group_start_iteration_B0; -+ -+ group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0; -+ group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0; -+ -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, -+ group_start_iteration_B0); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { -+ int group_start_iteration_A0, group_start_iteration_B0; -+ group_start_iteration_A0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; -+ group_start_iteration_B0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; -+ -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, -+ group_start_iteration_B0); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A0.add_tile_offset({0, 1}); -+ iterator_B0.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations_0; -+ iterator_A0.clear_mask(gemm_k_iterations_0 == 0); -+ iterator_B0.clear_mask(gemm_k_iterations_0 == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations0) -+ warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ /// Epilogue for the first Implicit Gemm -+ Epilogue0 epilogue0; -+ -+ epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); -+ -+ __syncthreads(); -+ -+ -+ // 2nd Gemm -+ -+ // -+ // Prologue -+ // -+ int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_1) { -+ -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 0); -+ -+ iterator_B1.set_iteration_index(0); -+ this->smem_iterator_B1_.set_iteration_index(0); -+ -+ // cp.async for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / -+ IteratorB1::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ -+ ++this->smem_iterator_B1_; -+ } -+ -+ // Move to the next stage -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // DEPBAR+SYNC -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; -+ WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; -+ WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; -+ WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; -+ -+ Operator1 warp_mma1; -+ -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); -+ ++warp_tile_iterator_A1_; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); -+ ++this->warp_tile_iterator_B1_; -+ -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 0); -+ -+ smem_write_stage_idx = Base::kStages - 1; -+ smem_read_stage_idx = 0; -+ -+ warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], -+ warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1); -+ gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; -+ ++warp_mma_k) { -+ -+ // Load warp-level tile from accumulator fragment -+ // skip warp tile loading for the last kgroup -+ if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) { -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); -+ } -+ ++warp_tile_iterator_A1_; -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ ++this->warp_tile_iterator_B1_; -+ -+ -+ if (warp_mma_k > 0) -+ warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ warp_loaded_frag_A1[warp_mma_k % 2], -+ warp_loaded_frag_B1[warp_mma_k % 2]); -+ -+ -+ warp_mma1( -+ accum, -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ accum -+ ); -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { -+ int group_start_iteration_B1; -+ -+ group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; -+ -+ copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { -+ int group_start_iteration_B1; -+ group_start_iteration_B1 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; -+ -+ copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 1); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) -+ warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h -new file mode 100644 -index 0000000..c36d133 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h -@@ -0,0 +1,554 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile -+ // (concept::MmaTensorOpFragmentIterator) -+ typename FragmentIteratorA1_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// FragmentIterator to load Scale or Bias vector from threadblock fragment -+ typename FragmentIteratorA1ScaleBias_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPipelinedPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPipelinedPolicy) -+ typename Policy1_, -+ /// Transformation applied to A0 operand -+ typename TransformA0_ = NumericArrayConverter< -+ typename SmemIteratorA0_::Element, -+ typename IteratorA0_::Element, -+ IteratorA0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B0 operand -+ typename TransformB0_ = NumericArrayConverter< -+ typename SmemIteratorB0_::Element, -+ typename IteratorB0_::Element, -+ IteratorB0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B1 operand -+ typename TransformB1_ = NumericArrayConverter< -+ typename SmemIteratorB1_::Element, -+ typename IteratorB1_::Element, -+ IteratorB1_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class B2bMmaPipelined : -+ public B2bMmaBase { -+public: -+ -+ ///< Base class -+ using Base = B2bMmaBase; -+ -+ using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory -+ using Policy0 = Policy0_; ///< Policy describing tuning details -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ -+ using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory -+ using FragmentIteratorA1ScaleBias = -+ FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment -+ using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory -+ using Policy1 = Policy1_; ///< Policy describing tuning details -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ -+ -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ -+ using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm -+ -+ static const bool PerChannelScale = (OutputOp::kScale == -+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); -+ -+ using TransformA0 = TransformA0_; -+ using TransformB0 = TransformB0_; -+ using TransformB1 = TransformB1_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA0 = typename IteratorA0::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB0 = typename IteratorB0::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB1 = typename IteratorB1::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy0::Operator::ArchTag; -+ -+ /// Complex transform on A0 operand -+ static ComplexTransform const kTransformA0 = Operator0::kTransformA; -+ -+ /// Complex transform on B0 operand -+ static ComplexTransform const kTransformB0 = Operator0::kTransformB; -+ -+ /// Complex transform on B1 operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+private: -+ -+ using WarpFragmentA0 = typename Operator0::FragmentA; -+ using WarpFragmentB0 = typename Operator0::FragmentB; -+ /// Warp Fragment of operand A1 loaded from accmulator tile -+ using WarpFragmentA1 = typename FragmentIteratorA1::Fragment; -+ /// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment -+ using WarpFragmentA1ScaleBias = -+ typename FragmentIteratorA1ScaleBias::Fragment; -+ using WarpFragmentB1 = typename Operator1::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B0 operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Iterator to write threadblock-scoped tile of B1 operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bMmaPipelined( -+ typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx, ///< ID of each thread within a warp -+ int problem_size_0_n ///< GEMM0 N is used for accumulator extent -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) { -+ -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ //These should stay the same across different GEMM layers -+ int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; -+ -+ //These may change across different GEMM layers -+ int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k; -+ int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0}); -+ this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n}); -+ this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n}); -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations_0, ///< number of iterations of the mainloop -+ FragmentC1 &accum, ///< destination accumulator tile -+ IteratorA0 iterator_A, ///< iterator over A operand in global memory -+ IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory -+ IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory -+ IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory -+ IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory -+ FragmentC0 const &src_accum, ///< source accumualtor tile -+ OutputOp output_op_0, ///< epilogue operation after 1st Gemm -+ TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment -+ TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment -+ TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ FragmentA0 tb_frag_A; -+ FragmentB0 tb_frag_B0; -+ -+ tb_frag_A.clear(); -+ tb_frag_B0.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ -+ ++iterator_A; -+ ++iterator_B0; -+ -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA0 warp_frag_A0[2]; -+ WarpFragmentB0 warp_frag_B0[2]; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ Operator0 warp_mma0; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Avoid reading out of bounds -+ iterator_A.clear_mask(gemm_k_iterations_0 <= 1); -+ iterator_B0.clear_mask(gemm_k_iterations_0 <= 1); -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ ++iterator_A; -+ ++iterator_B0; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A.clear_mask(gemm_k_iterations_0 <= 2); -+ iterator_B0.clear_mask(gemm_k_iterations_0 <= 2); -+ } -+ -+ warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], -+ warp_frag_B0[warp_mma_k % 2], accum0); -+ } -+ } -+ -+ //2nd Gemm -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ FragmentIteratorA1 warp_tile_iterator_A1_(accum0); -+ -+ // -+ // Prologue -+ // -+ -+ FragmentA1ScaleBias tb_frag_A1_scale; -+ FragmentA1ScaleBias tb_frag_A1_bias; -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); -+ FragmentB1 tb_frag_B1; -+ -+ if(PerChannelScale) -+ tb_frag_A1_scale.clear(); -+ tb_frag_A1_bias.clear(); -+ tb_frag_B1.clear(); -+ -+ // The last kblock is loaded in the prolog -+ if(PerChannelScale) -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ iterator_B1.load(tb_frag_B1); -+ -+ if(PerChannelScale) -+ ++iterator_A1_scale; -+ ++iterator_A1_bias; -+ ++iterator_B1; -+ -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ ++this->smem_iterator_B1_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA1ScaleBias warp_frag_A1_scale[2]; -+ WarpFragmentA1ScaleBias warp_frag_A1_bias[2]; -+ WarpFragmentA1 warp_frag_A1[2]; -+ WarpFragmentB1 warp_frag_B1[2]; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ -+ if(PerChannelScale) -+ warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]); -+ warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]); -+ warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0], -+ warp_frag_A1_bias[0], output_op_0); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); -+ -+ ++warp_tile_iterator_A1_; -+ if(PerChannelScale) -+ ++warp_tile_iterator_A1_scale_; -+ ++warp_tile_iterator_A1_bias_; -+ ++this->warp_tile_iterator_B1_; -+ -+ Operator1 warp_mma1; -+ -+ smem_write_stage_idx = 1; -+ -+ int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; -+ -+ // Avoid reading out of bounds -+ iterator_B1.clear_mask(gemm_k_iterations_1 <= 1); -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::WarpGemmIterations == 2. -+ CUTLASS_PRAGMA_UNROLL -+ for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { -+ -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ __syncthreads(); -+ ++this->smem_iterator_B1_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ -+ if(PerChannelScale) { -+ tb_frag_A1_scale.clear(); -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ ++iterator_A1_scale; -+ } -+ tb_frag_A1_bias.clear(); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ ++iterator_A1_bias; -+ } -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ -+ if(PerChannelScale) -+ warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]); -+ warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]); -+ warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], -+ warp_frag_A1_scale[(warp_mma_k + 1) % 2], -+ warp_frag_A1_bias[(warp_mma_k + 1) % 2], -+ output_op_0); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ if(PerChannelScale) -+ ++warp_tile_iterator_A1_scale_; -+ ++warp_tile_iterator_A1_bias_; -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_B1.load(tb_frag_B1); -+ ++iterator_B1; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_B1.clear_mask(gemm_k_iterations_1 <= 2); -+ } -+ -+ warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], -+ warp_frag_B1[warp_mma_k % 2], accum); -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h -new file mode 100644 -index 0000000..351fae3 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h -@@ -0,0 +1,544 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base_smem_accumulator.h" -+#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// Iterates over accumulator tile -+ typename FragmentIteratorAccumulator_, -+ /// Iterates over accumulator tile in shared memory -+ typename SmemIteratorD0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile in shared memory -+ typename WarpIteratorA1_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPipelinedPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPipelinedPolicy) -+ typename Policy1_, -+ /// Transformation applied to A0 operand -+ typename TransformA0_ = NumericArrayConverter< -+ typename SmemIteratorA0_::Element, -+ typename IteratorA0_::Element, -+ IteratorA0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B0 operand -+ typename TransformB0_ = NumericArrayConverter< -+ typename SmemIteratorB0_::Element, -+ typename IteratorB0_::Element, -+ IteratorB0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B1 operand -+ typename TransformB1_ = NumericArrayConverter< -+ typename SmemIteratorB1_::Element, -+ typename IteratorB1_::Element, -+ IteratorB1_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class B2bMmaPipelinedSmemAccumulator : -+ public B2bMmaBaseSmemAccumulator { -+public: -+ -+ ///< Base class -+ using Base = B2bMmaBaseSmemAccumulator; -+ -+ using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory -+ using Policy0 = Policy0_; ///< Policy0 describing tuning details -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory -+ -+ using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile -+ -+ using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory -+ using Policy1 = Policy1_; ///< Policy1 describing tuning details -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory -+ -+ -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ -+ using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm -+ -+ using TransformA0 = TransformA0_; -+ using TransformB0 = TransformB0_; -+ using TransformB1 = TransformB1_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA0 = typename IteratorA0::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB0 = typename IteratorB0::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB1 = typename IteratorB1::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy0::Operator::ArchTag; -+ -+ /// Complex transform on A0 operand -+ static ComplexTransform const kTransformA0 = Operator0::kTransformA; -+ -+ /// Complex transform on B0 operand -+ static ComplexTransform const kTransformB0 = Operator0::kTransformB; -+ -+ /// Complex transform on B1 operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+ /// Epilog in shared memory -+ using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< -+ SmemIteratorD0, ///< SmemTileIterator -+ FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator -+ IteratorAccumulatorScaleBias, ///< ScaleBiasIterator -+ OutputOp>; ///< Output operator -+ -+ -+ -+private: -+ -+ using WarpFragmentA0 = typename Operator0::FragmentA; -+ using WarpFragmentB0 = typename Operator0::FragmentB; -+ using WarpFragmentA1 = typename Operator1::FragmentA; -+ using WarpFragmentB1 = typename Operator1::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B0 operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Shared Memory Iterator to store accumulator tile -+ SmemIteratorD0 smem_iterator_D0_; -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ WarpIteratorA1 warp_tile_iterator_A1_; -+ -+ /// Iterator to write threadblock-scoped tile of B1 operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bMmaPipelinedSmemAccumulator( -+ typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx, ///< ID of each thread within a warp -+ int problem_size_0_n ///< GEMM0 N is used for accumulator extent -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), -+ warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx), -+ smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; -+ int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; -+ -+ int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0; -+ -+ int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ -+ int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; -+ int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; -+ -+ int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0}); -+ this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0}); -+ warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1}); -+ this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1}); -+ -+ // Add smem accumulator iterator warp offset -+ smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, -+ warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); -+ -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations_0, ///< number of iterations of the mainloop -+ FragmentC1 &accum, ///< destination accumulator tile -+ IteratorA0 iterator_A, ///< iterator over A operand in global memory -+ IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory -+ IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory -+ FragmentC0 const &src_accum, ///< source accumualtor tile -+ OutputOp output_op_0, ///< epilogue operation after 1st Gemm -+ TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment -+ TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment -+ TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ FragmentA0 tb_frag_A; -+ FragmentB0 tb_frag_B0; -+ -+ tb_frag_A.clear(); -+ tb_frag_B0.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ -+ ++iterator_A; -+ ++iterator_B0; -+ -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA0 warp_frag_A0[2]; -+ WarpFragmentB0 warp_frag_B0[2]; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ Operator0 warp_mma0; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Avoid reading out of bounds -+ iterator_A.clear_mask(gemm_k_iterations_0 <= 1); -+ iterator_B0.clear_mask(gemm_k_iterations_0 <= 1); -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ ++iterator_A; -+ ++iterator_B0; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A.clear_mask(gemm_k_iterations_0 <= 2); -+ iterator_B0.clear_mask(gemm_k_iterations_0 <= 2); -+ } -+ -+ warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], -+ warp_frag_B0[warp_mma_k % 2], accum0); -+ } -+ } -+ -+ /// Epilogue for the first Implicit Gemm -+ Epilogue0 epilogue0; -+ -+ epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); -+ -+ __syncthreads(); -+ -+ //2nd Gemm -+ -+ // -+ // Prologue -+ // -+ -+ FragmentB1 tb_frag_B1; -+ -+ tb_frag_B1.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_B1.load(tb_frag_B1); -+ -+ ++iterator_B1; -+ -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ ++this->smem_iterator_B1_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA1 warp_frag_A1[2]; -+ WarpFragmentB1 warp_frag_B1[2]; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ -+ warp_tile_iterator_A1_.load(warp_frag_A1[0]); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); -+ -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B1_; -+ -+ Operator1 warp_mma1; -+ -+ smem_write_stage_idx = 1; -+ -+ int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; -+ -+ // Avoid reading out of bounds -+ iterator_B1.clear_mask(gemm_k_iterations_1 <= 1); -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_PRAGMA_UNROLL -+ for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_B1_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ -+ } -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ -+ // skip warp tile loading for the last kgroup -+ if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1) -+ warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_B1.load(tb_frag_B1); -+ -+ ++iterator_B1; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_B1.clear_mask(gemm_k_iterations_1 <= 2); -+ } -+ -+ warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], -+ warp_frag_B1[warp_mma_k % 2], accum); -+ -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h -new file mode 100644 -index 0000000..d1842f6 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h -@@ -0,0 +1,584 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/warp/vector_fragment_iterator.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_pipelined.h" -+#include "threadblock/b2b_mma_multistage.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Staging the accumulators in shared memory. -+ bool SmemAccumulator = false> -+struct DefaultB2bMma; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for row-major output with 2-stage pipeline -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp> -+struct DefaultB2bMma { -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, 2, Operator>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA0 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB0 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::ColumnMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>; -+ -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, -+ typename MmaCore1::Shape, FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, -+ IteratorB1, typename MmaCore1::SmemIteratorB, -+ ElementAccumulator, layout::RowMajor, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for row-major output for multi-stage -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp> -+struct DefaultB2bMma { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using AccessTypeA0 = cutlass::Array; -+ using IteratorA0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA0>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using AccessTypeB0 = cutlass::Array; -+ using IteratorB0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB0>; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::ColumnMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using AccessTypeB1 = cutlass::Array; -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB1>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ MmaCore0::kCacheOpA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, -+ typename MmaCore1::Shape, FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, -+ IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, -+ ElementAccumulator, layout::RowMajor, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; -+ -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output with 2-stage pipeline -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultB2bMma, OperatorClass, arch::Sm75, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, 2, Operator, EpilogueOutputOp, true> { -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, -+ true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, -+ true>; -+ -+ static_assert(kAlignmentA == 128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ static_assert(kAlignmentB ==128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementA, -+ LayoutA, 1, typename MmaCore0::IteratorThreadMapA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementB, -+ LayoutB, 0, typename MmaCore0::IteratorThreadMapB>; -+ -+ // Use fragment iterator for A1 operand -+ using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, -+ InstructionShape, EpilogueOutputOp>; -+ -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>; -+ -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, -+ typename MmaCore1::Shape, FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, -+ IteratorB1, typename MmaCore1::SmemIteratorB, -+ ElementAccumulator, layout::ColumnMajorInterleaved, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output with multi-stage -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultB2bMma, OperatorClass, ArchTag, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, Stages, Operator, EpilogueOutputOp, true> { -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, Stages, -+ Operator, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, Stages, -+ Operator, true>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>; -+ -+ // Use fragment iterator for A1 operand -+ using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, -+ InstructionShape, EpilogueOutputOp>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>; -+ -+ -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ MmaCore0::kCacheOpA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, -+ typename MmaCore1::Shape, FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, -+ IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, -+ ElementAccumulator, layout::ColumnMajorInterleaved, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h -new file mode 100644 -index 0000000..1ef7e50 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h -@@ -0,0 +1,605 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" -+ -+#include "threadblock/b2b_mma_pipelined_smem_accumulator.h" -+#include "threadblock/b2b_mma_multistage_smem_accumulator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for row-major output with 2-stage pipeline -+/// Accumulator will be staged in shared memory. -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp> -+struct DefaultB2bMma { -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, 2, Operator>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA0 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB0 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ typename EpilogueOutputOp::ElementOutput, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, SmemIteratorD0, -+ typename MmaCore1::Shape, WarpIteratorA1, -+ IteratorB1, typename MmaCore1::SmemIteratorB, -+ ElementAccumulator, layout::RowMajor, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for row-major output for multi-stage -+/// Accumulator will be staged in shared memory. -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp> -+struct DefaultB2bMma { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using AccessTypeA0 = cutlass::Array; -+ using IteratorA0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA0>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using AccessTypeB0 = cutlass::Array; -+ using IteratorB0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB0>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using AccessTypeB1 = cutlass::Array; -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB1>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ typename EpilogueOutputOp::ElementOutput, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ MmaCore0::kCacheOpA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, SmemIteratorD0, -+ typename MmaCore1::Shape, WarpIteratorA1, -+ IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, -+ ElementAccumulator, layout::RowMajor, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output with 2-stage pipeline -+/// Accumulator will be staged in shared memory. -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultB2bMma, OperatorClass, arch::Sm75, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, 2, Operator, EpilogueOutputOp, true, true> { -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, -+ true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, -+ true>; -+ -+ static_assert(kAlignmentA == 128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ static_assert(kAlignmentB ==128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementA, -+ LayoutA, 1, typename MmaCore0::IteratorThreadMapA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementB, -+ LayoutB, 0, typename MmaCore0::IteratorThreadMapB>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; //For interleaved layout -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ typename EpilogueOutputOp::ElementOutput, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, SmemIteratorD0, -+ typename MmaCore1::Shape, WarpIteratorA1, -+ IteratorB1, typename MmaCore1::SmemIteratorB, -+ ElementAccumulator, layout::ColumnMajorInterleaved, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for column-major-interleaved output with multi-stage -+/// Accumulator will be staged in shared memory. -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultB2bMma, OperatorClass, ArchTag, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, Stages, Operator, EpilogueOutputOp, true, true> { -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, Stages, -+ Operator, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, Stages, -+ Operator, true>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ typename EpilogueOutputOp::ElementOutput, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true >; -+ -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ MmaCore0::kCacheOpA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, SmemIteratorD0, -+ typename MmaCore1::Shape, WarpIteratorA1, -+ IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, -+ ElementAccumulator, layout::ColumnMajorInterleaved, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu b/3rdparty/cutlass/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu -new file mode 100644 -index 0000000..bc2185d ---- /dev/null -+++ b/3rdparty/cutlass/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu -@@ -0,0 +1,472 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+Please check example 07 and 08 for the basics of tensor op gemm kernels. On NVIDIA Ampere -+architecture, most concept still holds. The two main differences are -+ -+1. NVIDIA Ampere architecture introduces a new series of tensor core instructions (see -+ include/cutlass/arch/mma_sm80.h) which are more efficient on Ampere. -+ -+2. NVIDIA Ampere architecture uses cp_async() to build multistage software pipeline to better hide -+ latency (see include/cutlass/gemm/threadblock/mma_multistage.h) -+ -+Moreover, NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h) -+data types in tensor cores. One big advantage is that we can load in fp32 data and convert them -+implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate traditional -+fp32 data by using NVIDIA Ampere architecture. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ float alpha; -+ float beta; -+ -+ bool reference_check; -+ int iterations; -+ -+ Options(): -+ help(false), -+ problem_size({5120, 4096, 4096}), -+ batch_count(1), -+ reference_check(true), -+ iterations(20), -+ alpha(1), -+ beta() { } -+ -+ bool valid() { -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "14_ampere_tf32_tensorop_gemm example\n\n" -+ << " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/14_ampere_tf32_tensorop_gemm/14_ampere_tf32_tensorop_gemm --m=1024 --n=512 --k=1024 \\\n" -+ << " --alpha=2 --beta=0.707 \n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product() * batch_count; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = float; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = float; // <- data type of elements in input matrix A -+using ElementInputB = float; // <- data type of elements in input matrix B -+using ElementOutput = float; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::RowMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 4; -+ -+using Gemm = cutlass::gemm::device::Gemm; -+ -+int run(Options &options) { -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size = options.problem_size; -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ tensor_c.device_ref(), // <- reference to matrix C on device -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Result structure -+ Result result; -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMMs -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMMs are complete -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ // Create instantiation for device reference gemm kernel -+ cutlass::reference::device::Gemm -+ gemm_device; -+ -+ // Launch device reference gemm kernel -+ gemm_device(problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ beta, -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ if (passed) { -+ std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " GFLOPs: " << result.gflops << std::endl; -+ } -+ -+ std::cout << (passed ? "Passed" : "Failed") << std::endl; -+ -+ return (passed ? 0 : -1); -+} -+ -+int main(int argc, const char **argv) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ printf("%d x %d x %d TF32 tensor op Matrix Multiply\n", \ -+ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ return run(options); -+} -diff --git a/3rdparty/cutlass/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu b/3rdparty/cutlass/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu -new file mode 100644 -index 0000000..dc87fff ---- /dev/null -+++ b/3rdparty/cutlass/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu -@@ -0,0 +1,317 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere -+architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4. -+ -+Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of -+meta data is different for every data types. CUTLASS templates can automatically infer it based on -+input A and B. Check code below. -+ -+Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers -+efficiently. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/util/host_uncompress.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "helper.h" -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = int32_t; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = cutlass::int4b_t; // <- data type of elements in input matrix A -+using ElementInputB = cutlass::int4b_t; // <- data type of elements in input matrix B -+using ElementOutput = int32_t; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. Row Major for -+// Matrix A, Column Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::RowMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+using Gemm = cutlass::gemm::device::SparseGemm; -+ -+// Data type and layout of meta data matrix E can be inferred from template Gemm. -+using ElementInputE = typename Gemm::ElementE; -+using LayoutInputE = cutlass::layout::RowMajor; -+using ReorderedLayoutInputE = typename Gemm::LayoutE; -+ -+// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h -+// 50% Sparsity on Ampere -+constexpr int kSparse = Gemm::kSparse; -+// How many elements of A are covered per ElementE -+constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; -+// The size of individual meta data -+constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; -+ -+int run() { -+ -+ const int length_m = 512; -+ const int length_n = 512; -+ const int length_k = 1024; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2) -+ cutlass::HostTensor tensor_a_uncompressed( -+ problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing -+ -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing. -+ cutlass::HostTensor tensor_e( -+ cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); -+ // Same size as the above. The above one needs to be reordered and stored in this one. -+ cutlass::HostTensor tensor_e_reordered( -+ cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(2), -+ ElementInputA(-2), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(2), -+ ElementInputB(-2), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(2), -+ ElementOutput(-2), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomSparseMeta( -+ tensor_e.host_view(), -+ 1, -+ kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core -+ // instructions. -+ cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(), -+ {problem_size.m(), problem_size.n(), -+ problem_size.k() / kSparse / kElementsPerElementE}); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_e_reordered.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(0); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ tensor_c.device_ref(), // <- reference to matrix C on device -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ tensor_e_reordered.device_ref(), // <- reference to matrix E on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ -+ // uncompress tensor_a based on meta data tensor_e. We need it for reference computing. -+ cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(), -+ tensor_e.host_ref(), problem_size.m(), problem_size.k()); -+ -+ // Create instantiation for host reference gemm kernel -+ cutlass::reference::host::Gemm -+ gemm_host; -+ -+ // Launch host reference gemm kernel -+ gemm_host(problem_size, -+ alpha, -+ tensor_a_uncompressed.host_ref(), -+ tensor_b.host_ref(), -+ beta, -+ tensor_c.host_ref(), -+ tensor_ref_d.host_ref()); -+ -+ // Copy output data from CUTLASS host for comparison -+ tensor_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ std::cout << (passed ? "Passed" : "Failed") << std::endl; -+ -+ return (passed ? 0 : -1); -+} -+ -+int main() { -+ -+ bool notSupported = false; -+ -+ // Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.1. -+ // -+ // CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples. -+ -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (props.major * 10 + props.minor < 80) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ return run(); -+} -diff --git a/3rdparty/cutlass/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu b/3rdparty/cutlass/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu -new file mode 100644 -index 0000000..378b489 ---- /dev/null -+++ b/3rdparty/cutlass/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu -@@ -0,0 +1,772 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+This example shows how to run convolution kernels using functions and data structures -+provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU. -+ -+Writing a single high performance convolution kernel is hard but do-able. Whereas writing -+high performance kernels at scale which works for multiple problem sizes with good abstractions is -+really hard. CUTLASS solves this problem by providing simplified abstractions to compose -+multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance -+of GPU easily. -+ -+CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp -+and thread-block level, they compute on their own tile-size with higher level of tile sizes being -+composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used -+to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute -+threadblock-tile (tile size computed by a threadblock). -+ -+In thie example, we split variable initialization into -+1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel -+can view them (logical to physical mapping) -+2. Setting up computation properties : describes how the above set tensors will be used to compute -+output of convolution. -+ -+First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along -+with alpha, beta as the equation for convolution is C = alpha * Conv2dFprop(A, B) + beta * C. In CUTLASS, -+the kernels first compute Conv2dFprop(A, B) and leave the rest of the computation to end of the kernel as -+alpha * X + beta * C is a simple element-wise operation on X (Conv2dFprop(A, B)) and C. We call this as -+epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to -+ElementComputeEpilogue = float. We use the data type for elements in input tensor A and B as -+cutlass::half_t. We convey this to CUTLASS kernel by initializing template variables ElementAccumulator (float), -+ElementComputeEpilogue (float), ElementInputA (cutlass::half_t), ElementInputB (cutlass::half_t), -+ElementOutput (float). Communicating just the data type is not enough. As the data is laid out -+linearly in memory, we have to convey the layout of tensors. We do that by initializing template -+variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup -+rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template -+variable EpilogueOp, which takes the data type of output ElementOutput (float), the number of -+elements per vector memory access (8), data type of accumulator (float) and data type of -+computation of linear combination (alpha * X + beta * C). -+ -+Now that we setup the properties of data, we have to setup properties of computation. -+ -+Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x64, -+64x64x64, 16x8x16 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it -+internally deduces the amount of threads needed per thread-block, amount of shared memory, storing -+data in bank-conflict free manner, and ton of other variables required to compose, intialize and -+launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer -+from understanding and coding complicated hardware optimizations which can easily go wrong. -+ -+CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines -+constitute the whole process of loading input data from global memory to shared memory, loading data -+from shared memory to registers, doing matrix multiplication, store to global memory. The below flow -+sequence shows a typical mma multistage pipeline. -+(see include/cutlass/conv/threadblock/implicit_gemm_multistage.h) -+ -+tensor in global memory --cp_async--> tile in shared memory --smem loads--> registers -+--mma--> registers --global stores--> output to global memory -+ -+NVIDIA Ampere uses `cp_async` to build multistage software pipeline to better hide latencies. -+ -+ -+There are few more template variables initialized such as, which threadblock tile of output matrix -+is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on. -+ -+These are all put together to create a template variable which describes CUTLASS Implicit GEMM -+kernel using cutlass::conv::device::ImplicitGemm template. -+ -+The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. -+We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come -+in the way of learning CUTLASS. -+ -+Once all the tensors are initialized and filled with data, create arguments tuple to launch CUTLASS -+kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64, -+R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the -+important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space -+memory required by the kernel we instantiated. If yes, we create it and pass it along with other -+arguments created to intialize CUTLASS kernel then, the kernel is launched. -+ -+In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to -+compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(false), -+ measure_performance(true), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "16_ampere_tensorop_conv2dfprop example\n\n" -+ << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" -+ << " forward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/16_ampere_tensorop_conv2dfprop/16_ampere_tensorop_conv2dfprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" -+ << "$ ./examples/16_ampere_tensorop_conv2dfprop/16_ampere_tensorop_conv2dfprop --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_d(options.output_size()); -+ cutlass::HostTensor tensor_ref_d(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(7), -+ ElementOutput(-8), -+ 0); -+ -+ // Fill tensor D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ // Construct ImplicitGemm::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename ImplicitGemm::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm implicit_gemm_op; -+ -+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on host...\n"; -+ -+ // Compute with reference implementation -+ cutlass::reference::host::Conv2dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_a.host_ref(), -+ tensor_b.host_ref(), -+ tensor_c.host_ref(), -+ tensor_ref_d.host_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_d.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "16_ampere_workspace_conv2dfprop_" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {1, 32, 64, 128, 256, 512}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1}, -+ {56, 56, 64, 64, 1, 1}, -+ {56, 56, 64, 64, 3, 3}, -+ {56, 56, 256, 64, 1, 1}, -+ {56, 56, 256, 512, 1, 1}, -+ {56, 56, 256, 128, 1, 1}, -+ {28, 28, 128, 128, 3, 3}, -+ {28, 28, 128, 512, 1, 1}, -+ {28, 28, 512, 128, 1, 1}, -+ {28, 28, 512, 1024, 1, 1}, -+ {28, 28, 512, 256, 1, 1}, -+ {14, 14, 256, 256, 3, 3}, -+ {14, 14, 256, 1024, 1, 1}, -+ {14, 14, 1024, 256, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1}, -+ {14, 14, 1024, 512, 1, 1}, -+ {7, 7, 512, 512, 3, 3}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ -+ options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu b/3rdparty/cutlass/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu -new file mode 100644 -index 0000000..a334511 ---- /dev/null -+++ b/3rdparty/cutlass/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu -@@ -0,0 +1,306 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+The convolution version of 12_gemm_bias_relu. Similarly, we put bias vector in Operand C and the -+rest is the same as normal convolution. -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 4; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha in linear combination -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling>; // alpha X C + per channel bias -+ -+ -+using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int run() { -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ {1, 7, 7, 512}, // activation -+ {512, 3, 3, 512}, // filter -+ {1, 1, 1, 1}, // padding -+ {1, 1}, // striding -+ {1, 1}, // dilation -+ cutlass::conv::Mode::kCrossCorrelation, // mode (convolution or cross-correlation) -+ 1 // split-k slices -+ ); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a(problem_size.activation_extent()); -+ cutlass::HostTensor tensor_b(problem_size.filter_extent()); -+ -+ // Create tensor C with dimensions 1x1x1xk which is the bias vector -+ cutlass::HostTensor tensor_c_bias({1, 1, 1, problem_size.K}); -+ -+ // Create tensor D used to store output from CUTLASS kernel -+ cutlass::HostTensor tensor_d(problem_size.output_extent()); -+ // Create matrix D with dimensions M x N used to store output from reference -+ // kernel -+ cutlass::HostTensor tensor_ref_d(problem_size.output_extent()); -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c_bias.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c_bias.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename ImplicitGemm::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), // <- reference to tensor A on device -+ tensor_b.device_ref(), // <- reference to tensor B on device -+ // tensor C is treated as the bias vector. We can enable the CONV -+ // to project away the N, H, W dimension by setting the stride to zero. -+ {tensor_c_bias.device_data(), LayoutOutput::Stride(0)}, -+ tensor_d.device_ref(), // <- reference to tensor D on device -+ {alpha} }; -+ -+ // Instantiate CUTLASS kernel depending on templates -+ ImplicitGemm implicit_gemm_op; -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = implicit_gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = implicit_gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = implicit_gemm_op(); -+ -+ CUTLASS_CHECK(status); -+ -+ // -+ // Create instantiation for device reference conv kernel -+ // -+ -+ // Launch device reference to compute strictly the product A * B -+ cutlass::reference::device::Conv2d< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter> -+ ( -+ cutlass::conv::Operator::kFprop, -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c_bias.device_ref(), -+ tensor_ref_d.device_ref(), -+ alpha, ElementComputeEpilogue(0) -+ ); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Compute bias + relu in host code -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ for (int k = 0; k < problem_size.K; ++k) { -+ -+ tensor_ref_d.at({n, p, q, k}) = -+ std::max(ElementOutput(0), -+ ElementOutput(tensor_ref_d.at({n, p, q, k}) + -+ tensor_c_bias.at({0, 0, 0, k}))); -+ } -+ } -+ } -+ } -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(), -+ tensor_ref_d.host_view()) -+ ? "Passed" -+ : "Failed") -+ << std::endl; -+ -+ CUTLASS_CHECK(status); -+ return 0; -+} -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ return run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu b/3rdparty/cutlass/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu -new file mode 100644 -index 0000000..d1044a2 ---- /dev/null -+++ b/3rdparty/cutlass/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu -@@ -0,0 +1,342 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+In the normal GEMM, the fast changing dimension of a matrix always has stride -+equals to 1, e.g. ColumnMajor and RowMajor matrix. Affine2 matrix can have -+larger than 1 stride in both dimensions. To support such layout, we need to -+change to method to visit the global memory: -+ -+ 1. We can only visit 1 element a time because elements are not stored -+ consecutively anymore. Vectorized load/store is not possible. -+ 2. One extra multiplication is needed in calculating the global memory -+ address -+ addr = base_pointer + coord1 * stride1 + coord2 * stride2 -+ -+The rest part of GEMM which includes shared memory load/store, mma comutation -+is the same. -+ -+This example uses Ampere fp64 tensore core Affine2 GEMM as an example. SIMT -+(e.g. sgemm, dgemm) has support Affine2 layout. -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/kernel/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+#include "cutlass/matrix_coord.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = double; // Data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation -+using ElementInputA = double; // Data type of elements in input tensor -+using ElementInputB = double; // Data type of elements in input tensor -+using ElementOutput = double; // Data type of elements in output tensor -+ -+// Since Affine2 explicitly lists the strides of both dimensions, it does not really matter if -+// it is columnmajor and rowmajor. However, it helps CUTLASS to improve the load locality if -+// CUTLASS can know which dimension of A/B operand has smaller stride or more dense. -+// -+// Affine2 ColumnMajor means the row stride is smaller and Affine2 RowMajor means the column -+// stride is smaller. -+// -+// The Affine2 epilogue reuses AffineN epilogue so it does not need to specify column majore -+// or row major. -+using LayoutInputA = cutlass::layout::AffineRank2ColumnMajor; -+using LayoutInputB = cutlass::layout::AffineRank2RowMajor; -+using LayoutOutput = cutlass::layout::AffineRankN<2>; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 1, // The number of elements per memory -+ // access has. It has to be 1 for -+ // affine2. -+ ElementAccumulator, -+ ElementComputeEpilogue>; -+ -+using Gemm = typename cutlass::gemm::device::GemmUniversal< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int run() { -+ -+ // Construct Gemm ProblemSize with user defined output size -+ cutlass::gemm::GemmCoord problem_size = {1024, 512, 1024}; -+ -+ // Stride factor shows the distance between two elements in the differnet dimensions. The -+ // first data is the logical distance between two rows, the second is between two columns. -+ // CUTLASS has a utility tool cutlass::layout::Affine2Layout_Factory::layout_factory -+ // to help to convert stride_factor to the two strides. -+ // -+ // It is also totally fine to compute the strides directly without using the utility to -+ // construct the affine2 layout. -+ typename LayoutInputA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutInputB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutOutput::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a(problem_size.mk(), -+ cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mk(), -+ stride_factor_A)); -+ cutlass::HostTensor tensor_b(problem_size.kn(), -+ cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.kn(), -+ stride_factor_B)); -+ -+ // Create matrix C used to load for bias addition. -+ cutlass::HostTensor tensor_c(problem_size.mn(), -+ cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), -+ stride_factor_C)); -+ -+ // Create matrix D used to store output from CUTLASS kernel -+ cutlass::HostTensor tensor_d(problem_size.mn(), -+ cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), -+ stride_factor_C)); -+ -+ // Create matrix D with dimensions M x N used to store output from reference -+ // kernel -+ cutlass::HostTensor tensor_ref_d(problem_size.mn(), -+ cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), -+ stride_factor_C)); -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(1); -+ -+ cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; -+ -+ int batch_count = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_a.device_ref().data(), // <- reference to matrix A on device -+ tensor_b.device_ref().data(), // <- reference to matrix B on device -+ tensor_c.device_ref().data(), // <- reference to matrix C on device -+ tensor_d.device_ref().data(), // <- reference to matrix D on device -+ tensor_a.layout().capacity(problem_size.mk()), -+ tensor_b.layout().capacity(problem_size.kn()), -+ tensor_c.layout().capacity(problem_size.mn()), -+ tensor_d.layout().capacity(problem_size.mn()), -+ tensor_a.layout().stride(), -+ tensor_b.layout().stride(), -+ tensor_c.layout().stride(), -+ tensor_d.layout().stride() -+ }; -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ -+ CUTLASS_CHECK(status); -+ -+ // -+ // Create instantiation for device reference gemm kernel -+ // -+ -+ // Launch device reference to compute strictly the product A * B -+ cutlass::reference::device::Gemm< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator> gemm_device; -+ -+ gemm_device -+ ( -+ problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ beta, -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref() -+ ); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ std::cout << (pass -+ ? "Passed" -+ : "Failed") -+ << std::endl; -+ -+ CUTLASS_CHECK(status); -+ -+ return (pass ? 0 : -1); -+} -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ return run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/19_tensorop_canonical/tensorop_canonical.cu b/3rdparty/cutlass/examples/19_tensorop_canonical/tensorop_canonical.cu -new file mode 100644 -index 0000000..2a16936 ---- /dev/null -+++ b/3rdparty/cutlass/examples/19_tensorop_canonical/tensorop_canonical.cu -@@ -0,0 +1,438 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example requires NVIDIA Ampere GPU or later. -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// CUTLASS Includes -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+ -+// CUTLASS Utility Includes -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Define the overal warp-level problem shape -+int const kM = 27; -+int const kN = 31; -+int const kK = 17; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Define a warp-level GEMM operator. -+// -+// This template could be part of the CUTLASS Template Library or implemented internally. This -+// wraps the matrix multiply operation and epilogue with a GEMM-like interface that can be -+// instantiated in device code. -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+template < -+ typename Shape, -+ typename InstructionShape, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementScalar -+> -+class GemmTensorOp { -+public: -+ -+ using WarpShape = GemmShape< -+ ((Shape::kM + InstructionShape::kM - 1) / InstructionShape::kM) * InstructionShape::kM, -+ ((Shape::kN + InstructionShape::kN - 1) / InstructionShape::kN) * InstructionShape::kN, -+ InstructionShape::kK -+ >; -+ -+ using MmaWarp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, -+ InstructionShape, -+ double, // Data type of A elements -+ cutlass::layout::RowMajor, // Layout of A matrix -+ double, // Data type of B elements -+ cutlass::layout::ColumnMajor, // Layout of B matrix -+ double, // Data type of C elements -+ cutlass::layout::RowMajor // Layout of C matrix -+ >::Type; -+ -+ // Number of 'K groups' -+ int const kKgroups = (Shape::kK + InstructionShape::kK - 1) / InstructionShape::kK; -+ -+ // Define a 'FragmentIterator' to iterate over slices of accumulators -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename MmaWarp::Shape, -+ InstructionShape, -+ double, -+ typename MmaWarp::Policy::Operator::FragmentC, -+ cutlass::layout::RowMajor -+ >; -+ -+ // Define an epilogue 'Tile Iteterator' to iterate over slices of elements in Shared Memory -+ using AccumulatorTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpCanonical< -+ typename MmaWarp::Shape, -+ InstructionShape, -+ double, -+ cutlass::layout::RowMajor -+ >; -+ -+ using TensorRefA = typename MmaWarp::IteratorA::TensorRef; -+ using TensorRefB = typename MmaWarp::IteratorB::TensorRef; -+ using TensorRefC = typename AccumulatorTileIterator::TensorRef; -+ -+public: -+ CUTLASS_HOST_DEVICE -+ GemmTensorOp() { } -+ -+ CUTLASS_DEVICE -+ void operator()( -+ ElementScalar alpha, -+ TensorRefA ref_A, -+ TensorRefB ref_B, -+ ElementScalar beta, -+ TensorRefC ref_C, -+ TensorRefC ref_D, -+ int lane_id) const { -+ -+ // Instantiate iterators pointing to slices of the A and B matrices in shared memory -+ typename MmaWarp::IteratorA iter_A(ref_A, {Shape::kM, Shape::kK}, lane_id); -+ typename MmaWarp::IteratorB iter_B(ref_B, {Shape::kK, Shape::kN}, lane_id); -+ -+ // Instantiate and clear accumulator tile holding the C matrix -+ typename MmaWarp::FragmentC accum; -+ accum.clear(); -+ -+ // Instantiate the warp-level matrix multiply operator -+ MmaWarp mma_op; -+ -+ // Instantiate fragments holding the slice of the matrix held by each warp -+ typename MmaWarp::FragmentA frag_A[2]; -+ typename MmaWarp::FragmentB frag_B[2]; -+ -+ // Load fragments from shared memory -+ iter_A.load(frag_A[0]); -+ iter_B.load(frag_B[0]); -+ -+ ++iter_A; -+ ++iter_B; -+ -+ // Load fragments from shared memory -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < kKgroups; ++k) { -+ -+ // Load fragments from shared memory -+ iter_A.load(frag_A[(k + 1) % 2]); -+ iter_B.load(frag_B[(k + 1) % 2]); -+ -+ ++iter_A; -+ ++iter_B; -+ -+ // Compute the matrix multiply -+ mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum); -+ } -+ -+ // Instantiate iterators -+ FragmentIterator accum_frag_it(accum); -+ AccumulatorTileIterator source_tile_it(ref_C, {Shape::kM, Shape::kN}, lane_id); -+ AccumulatorTileIterator dest_tile_it(ref_D, {Shape::kM, Shape::kN}, lane_id); -+ -+ // Define function objects for linear scaling operation -+ cutlass::multiplies mul_source; -+ cutlass::multiply_add mul_add_accumulator; -+ -+ // Iterate over the epilogue components -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < FragmentIterator::kIterations; ++idx) { -+ -+ // Define storage for slices of the accumulators -+ typename FragmentIterator::Fragment accum_fragment; -+ typename FragmentIterator::Fragment source_fragment; -+ -+ // Select a slice of accumulators from the accumulator tile -+ accum_frag_it.load(accum_fragment); -+ ++accum_frag_it; -+ -+ // Load a corresponding slice from Shared memory -+ source_tile_it.load(source_fragment); -+ ++source_tile_it; -+ -+ // Compute linear scaling - alpha * AB + beta * C -+ source_fragment = mul_source(beta, source_fragment); -+ accum_fragment = mul_add_accumulator(alpha, accum_fragment, source_fragment); -+ -+ // Store the result to shared memory -+ dest_tile_it.store(accum_fragment); -+ ++dest_tile_it; -+ } -+ } -+}; -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Sample kernel demonstrating a collective GEMM operation by a warp on arbitrary matrices held -+// in Shared Memory. -+__global__ void kernel( -+ double *D_gmem, -+ double alpha, -+ double const *A_gmem, -+ double const *B_gmem, -+ double beta, -+ double const *C_gmem) { -+ -+ // Define several matrices in shared memory -+ __shared__ double A[kM][kK]; -+ __shared__ double B[kN][kK]; -+ __shared__ double C[kM][kN]; -+ -+ // Copy data into SMEM -+ if (threadIdx.x == 0) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int m = 0; m < kM; ++m) { -+ for (int k = 0; k < kK; ++k) { -+ A[m][k] = A_gmem[m * kK + k]; -+ } -+ } -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int n = 0; n < kN; ++n) { -+ for (int k = 0; k < kK; ++k) { -+ B[n][k] = B_gmem[n * kK + k]; -+ } -+ } -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int m = 0; m < kM; ++m) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int n = 0; n < kN; ++n) { -+ C[m][n] = C_gmem[m * kN + n]; -+ } -+ } -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Instantiate a warp-level matrix multiply operator given the fundamental instruction shape (8x8x4), -+ // overall shape, data type of each operand, and layout of each operand. -+ // -+ -+ using GemmTensorOp = cutlass::gemm::warp::GemmTensorOp< -+ cutlass::gemm::GemmShape, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ double, // Data type of A elements -+ cutlass::layout::RowMajor, // Layout of A matrix -+ double, // Data type of B elements -+ cutlass::layout::ColumnMajor, // Layout of B matrix -+ double, // Data type of C elements -+ cutlass::layout::RowMajor, // Layout of C matrix -+ double // Scalar type of alpha and beta -+ >; -+ -+ // Instantiate the GEMM operator -+ GemmTensorOp gemm; -+ -+ // Execute the warp-level GEMM operation -+ gemm( -+ alpha, -+ {&A[0][0], kK}, -+ {&B[0][0], kK}, -+ beta, -+ {&C[0][0], kN}, -+ {&C[0][0], kN}, -+ threadIdx.x); -+ -+ __syncthreads(); -+ -+ // Copy data into SMEM -+ if (threadIdx.x == 0) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int m = 0; m < kM; ++m) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int n = 0; n < kN; ++n) { -+ D_gmem[m * kN + n] = C[m][n]; -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point to canonical warp-level GEMM operation -+int main(int argc, const char *arg[]) { -+ -+ bool notSupported = false; -+ -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "NVIDIA Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ std::cerr << "This example requires compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Return 0 so tests are considered passing if run on unsupported platforms. -+ return 0; -+ } -+ -+ cutlass::HostTensor A({kM, kK}); -+ cutlass::HostTensor B({kK, kN}); -+ cutlass::HostTensor C({kM, kN}); -+ cutlass::HostTensor D({kM, kN}); -+ -+ uint64_t seed = 2020; -+ double max = 8; -+ double min = -8; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ A.host_view(), -+ seed, -+ max, -+ min, -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ B.host_view(), -+ seed + 17, -+ max, -+ min, -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ C.host_view(), -+ seed + 31, -+ max, -+ min, -+ 0 -+ ); -+ -+ A.sync_device(); -+ B.sync_device(); -+ C.sync_device(); -+ D.sync_device(); -+ -+ dim3 grid(1,1); -+ dim3 block(32, 1, 1); -+ -+ double alpha = 2.25; -+ double beta = 1.24; -+ -+ kernel<<< grid, block >>>( -+ D.device_data(), -+ alpha, -+ A.device_data(), -+ B.device_data(), -+ beta, -+ C.device_data() -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to synchronize device after kernel launch." << std::endl; -+ return -1; -+ } -+ -+ D.sync_host(); -+ -+ // Compute reference on host -+ cutlass::HostTensor D_ref({kM, kN}, false); -+ -+ cutlass::reference::host::GemmComplex( -+ {kM, kN, kK}, -+ alpha, -+ A.host_ref(), -+ cutlass::ComplexTransform::kNone, -+ B.host_ref(), -+ cutlass::ComplexTransform::kNone, -+ beta, -+ C.host_ref(), -+ D_ref.host_ref(), -+ double() -+ ); -+ -+ // Verify reference matches computed -+ if (!cutlass::reference::host::TensorEquals( -+ D.host_view(), -+ D_ref.host_view())) { -+ -+ std::cerr -+ << "A =\n" << A.host_view() -+ << "\n\nB = \n" << B.host_view() -+ << "\n\nC = " << C.host_view() -+ << "\n\nRef =\n" << D_ref.host_view() -+ << "\n\nD =\n" << D.host_view() << "\n\n"; -+ -+ std::cerr << "Error - device results mismatch host reference." << std::endl; -+ -+ return -1; -+ } -+ -+ std::cout << "Passed" << std::endl; -+ -+ return 0; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/20_simt_canonical/simt_canonical.cu b/3rdparty/cutlass/examples/20_simt_canonical/simt_canonical.cu -new file mode 100644 -index 0000000..632cd22 ---- /dev/null -+++ b/3rdparty/cutlass/examples/20_simt_canonical/simt_canonical.cu -@@ -0,0 +1,425 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example requires NVIDIA Maxwell GPU or beyond. -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// CUTLASS Includes -+#include "cutlass/cutlass.h" -+#include "cutlass/core_io.h" -+#include "cutlass/functional.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/epilogue/warp/fragment_iterator_simt.h" -+#include "cutlass/epilogue/warp/tile_iterator_simt.h" -+ -+// CUTLASS Utility Includes -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Define the overal warp-level problem shape -+int const kM = 14; -+int const kN = 27; -+int const kK = 17; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Define a warp-level GEMM operator. -+// -+// This template could be part of the CUTLASS Template Library or implemented internally. This -+// wraps the matrix multiply operation and epilogue with a GEMM-like interface that can be -+// instantiated in device code. -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+template < -+ typename Shape, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementScalar -+> -+class GemmSimt { -+public: -+ -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using MmaWarp = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ // Number of 'K groups' -+ int const kKgroups = Shape::kK; -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< -+ typename MmaWarp::Shape, -+ typename MmaWarp::ThreadMma, -+ layout::RowMajor, // SMEM layout -+ typename MmaWarp::Policy -+ >; -+ -+ using AccumulatorTileIterator = cutlass::epilogue::warp::TileIteratorSimtCanonical< -+ typename MmaWarp::Shape, -+ typename MmaWarp::ThreadMma, -+ float, // ElementAccumulator -+ layout::RowMajor, // SMEM layout -+ typename MmaWarp::Policy -+ >; -+ -+ using TensorRefA = typename MmaWarp::IteratorA::TensorRef; -+ using TensorRefB = typename MmaWarp::IteratorB::TensorRef; -+ using TensorRefC = typename AccumulatorTileIterator::TensorRef; -+ -+public: -+ CUTLASS_HOST_DEVICE -+ GemmSimt() { } -+ -+ CUTLASS_DEVICE -+ void operator()( -+ ElementScalar alpha, -+ TensorRefA ref_A, -+ TensorRefB ref_B, -+ ElementScalar beta, -+ TensorRefC ref_C, -+ TensorRefC ref_D, -+ int lane_id) const { -+ -+ // Instantiate iterators pointing to slices of the A and B matrices in shared memory -+ typename MmaWarp::IteratorA iter_A(ref_A, {Shape::kM, Shape::kK}, lane_id); -+ typename MmaWarp::IteratorB iter_B(ref_B, {Shape::kK, Shape::kN}, lane_id); -+ -+ // Instantiate and clear accumulator tile holding the C matrix -+ typename MmaWarp::FragmentC accum; -+ accum.clear(); -+ -+ // Instantiate the warp-level matrix multiply operator -+ MmaWarp mma_op; -+ -+ // Instantiate fragments holding the slice of the matrix held by each warp -+ typename MmaWarp::FragmentA frag_A[2]; -+ typename MmaWarp::FragmentB frag_B[2]; -+ -+ // Load fragments from shared memory -+ iter_A.load(frag_A[0]); -+ iter_B.load(frag_B[0]); -+ -+ ++iter_A; -+ ++iter_B; -+ -+ // Load fragments from shared memory -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < kKgroups; ++k) { -+ -+ // Load fragments from shared memory -+ iter_A.load(frag_A[(k + 1) % 2]); -+ iter_B.load(frag_B[(k + 1) % 2]); -+ -+ ++iter_A; -+ ++iter_B; -+ -+ // Compute the matrix multiply -+ mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum); -+ } -+ -+ // Instantiate iterators -+ FragmentIterator accum_frag_it(accum); -+ AccumulatorTileIterator source_tile_it(ref_C, {Shape::kM, Shape::kN}, lane_id); -+ AccumulatorTileIterator dest_tile_it(ref_D, {Shape::kM, Shape::kN}, lane_id); -+ -+ // Define function objects for linear scaling operation -+ cutlass::multiplies mul_source; -+ cutlass::multiply_add mul_add_accumulator; -+ -+ // Iterate over the epilogue components -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < FragmentIterator::kIterations; ++idx) { -+ -+ // Define storage for slices of the accumulators -+ typename FragmentIterator::Fragment accum_fragment; -+ typename FragmentIterator::Fragment source_fragment; -+ -+ // Select a slice of accumulators from the accumulator tile -+ accum_frag_it.load(accum_fragment); -+ ++accum_frag_it; -+ -+ // Load a corresponding slice from Shared memory -+ source_tile_it.load(source_fragment); -+ ++source_tile_it; -+ -+ // Compute linear scaling - alpha * AB + beta * C -+ source_fragment = mul_source(beta, source_fragment); -+ accum_fragment = mul_add_accumulator(alpha, accum_fragment, source_fragment); -+ -+ // Store the result to shared memory -+ dest_tile_it.store(accum_fragment); -+ ++dest_tile_it; -+ } -+ -+ } -+ -+}; -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Sample kernel demonstrating a collective GEMM operation by a warp on arbitrary matrices held -+// in Shared Memory. -+__global__ void kernel( -+ float *D_gmem, -+ float alpha, -+ float const *A_gmem, -+ float const *B_gmem, -+ float beta, -+ float const *C_gmem) { -+ -+ // Define several matrices in shared memory -+ __shared__ float A[kM][kK]; -+ __shared__ float B[kN][kK]; -+ __shared__ float C[kM][kN]; -+ -+ // Copy data into SMEM -+ if (threadIdx.x == 0) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int m = 0; m < kM; ++m) { -+ for (int k = 0; k < kK; ++k) { -+ A[m][k] = A_gmem[m * kK + k]; -+ } -+ } -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int n = 0; n < kN; ++n) { -+ for (int k = 0; k < kK; ++k) { -+ B[n][k] = B_gmem[n * kK + k]; -+ } -+ } -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int m = 0; m < kM; ++m) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int n = 0; n < kN; ++n) { -+ C[m][n] = C_gmem[m * kN + n]; -+ } -+ } -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Instantiate a warp-level matrix multiply operator given the fundamental instruction shape (8x8x4), -+ // overall shape, data type of each operand, and layout of each operand. -+ // -+ -+ using GemmSimt = cutlass::gemm::warp::GemmSimt< -+ cutlass::gemm::GemmShape, -+ float, // Data type of A elements -+ cutlass::layout::RowMajor, // Layout of A matrix -+ float, // Data type of B elements -+ cutlass::layout::ColumnMajor, // Layout of B matrix -+ float, // Data type of C elements -+ cutlass::layout::RowMajor, // Layout of C matrix -+ float // Scalar type of alpha and beta -+ >; -+ -+ // Instantiate the GEMM operator -+ GemmSimt gemm; -+ -+ // Execute the warp-level GEMM operation -+ gemm( -+ alpha, -+ {&A[0][0], kK}, -+ {&B[0][0], kK}, -+ beta, -+ {&C[0][0], kN}, -+ {&C[0][0], kN}, -+ threadIdx.x); -+ -+ __syncthreads(); -+ -+ // Copy data into SMEM -+ if (threadIdx.x == 0) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int m = 0; m < kM; ++m) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int n = 0; n < kN; ++n) { -+ D_gmem[m * kN + n] = C[m][n]; -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, const char *arg[]) { -+ -+ cutlass::HostTensor A({kM, kK}); -+ cutlass::HostTensor B({kK, kN}); -+ cutlass::HostTensor C({kM, kN}); -+ cutlass::HostTensor D({kM, kN}); -+ -+ uint64_t seed = 2020; -+ float max = 8; -+ float min = -8; -+ -+ std::cout << "Simt canonical GEMM problem size = (" << cutlass::gemm::GemmShape() <<")" << std::endl; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ A.host_view(), -+ seed, -+ max, -+ min, -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ B.host_view(), -+ seed + 17, -+ max, -+ min, -+ 0 -+ ); -+ -+#if 0 // Debug: fill A sequentially and B as Identity matrix for debugging -+ cutlass::reference::host::BlockFillSequential( -+ A.host_view().data(), A.host_view().capacity()); -+ -+ cutlass::reference::host::TensorFillIdentity(B.host_view()); -+#endif -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ C.host_view(), -+ seed + 31, -+ max, -+ min, -+ 0 -+ ); -+ -+ A.sync_device(); -+ B.sync_device(); -+ C.sync_device(); -+ D.sync_device(); -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ float alpha = 1.0f; -+ float beta = 0.0f; -+ -+ kernel<<< grid, block >>>( -+ D.device_data(), -+ alpha, -+ A.device_data(), -+ B.device_data(), -+ beta, -+ C.device_data() -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to synchronize device after kernel launch." << std::endl; -+ return -1; -+ } -+ -+ D.sync_host(); -+ -+ // Compute reference on host -+ cutlass::HostTensor D_ref({kM, kN}, false); -+ cutlass::reference::host::TensorCopy(D_ref.host_view(), C.host_view()); -+ -+ cutlass::reference::host::Gemm< -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float, float> reference_gemm; -+ -+ reference_gemm( -+ {kM, kN, kK}, -+ alpha, -+ A.host_ref(), -+ B.host_ref(), -+ beta, -+ D_ref.host_ref(), -+ float() -+ ); -+ -+ // Verify reference matches computed -+ if (!cutlass::reference::host::TensorEquals( -+ D.host_view(), -+ D_ref.host_view())) { -+ -+ std::cerr -+ << "A =\n" << A.host_view() -+ << "\n\nB = \n" << B.host_view() -+ << "\n\nC = " << C.host_view() -+ << "\n\nRef =\n" << D_ref.host_view() -+ << "\n\nD =\n" << D.host_view() << "\n\n"; -+ -+ std::cerr << "Error - device results mismatch host reference." << std::endl; -+ -+ return -1; -+ } -+ -+ std::cout << "Passed" << std::endl; -+ -+ return 0; -+ -+} -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/21_quaternion_gemm/quaternion_gemm.cu b/3rdparty/cutlass/examples/21_quaternion_gemm/quaternion_gemm.cu -new file mode 100644 -index 0000000..02d7b53 ---- /dev/null -+++ b/3rdparty/cutlass/examples/21_quaternion_gemm/quaternion_gemm.cu -@@ -0,0 +1,454 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ cutlass::Quaternion alpha; -+ cutlass::Quaternion beta; -+ -+ bool reference_check; -+ int iterations; -+ -+ Options(): -+ help(false), -+ problem_size({1024, 1024, 1024}), -+ batch_count(1), -+ reference_check(true), -+ iterations(20), -+ alpha(1), -+ beta() { } -+ -+ bool valid() { -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ cmd.get_cmd_line_argument("batch", batch_count); -+ -+ cmd.get_cmd_line_argument("alpha", alpha.w()); -+ cmd.get_cmd_line_argument("alpha_i", alpha.x()); -+ cmd.get_cmd_line_argument("alpha_j", alpha.y()); -+ cmd.get_cmd_line_argument("alpha_k", alpha.z()); -+ -+ cmd.get_cmd_line_argument("beta", beta.w()); -+ cmd.get_cmd_line_argument("beta_i", beta.x()); -+ cmd.get_cmd_line_argument("beta_j", beta.y()); -+ cmd.get_cmd_line_argument("beta_k", beta.z()); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "21_quaternion_gemm example\n\n" -+ << " This example uses the CUTLASS Library to execute Quaternion GEMM computations.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --batch= Number of GEMM operations executed in one batch\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --alpha_i= Epilogue scalar alpha_i (imaginary part)\n" -+ << " --alpha_j= Epilogue scalar alpha_j (imaginary part)\n" -+ << " --alpha_k= Epilogue scalar alpha_k (imaginary part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n\n" -+ << " --beta_i= Epilogue scalar beta_i (imaginary part)\n\n" -+ << " --beta_j= Epilogue scalar beta_j (imaginary part)\n\n" -+ << " --beta_k= Epilogue scalar beta_k (imaginary part)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/21_quaternion_gemm/21_quaternion_gemm --batch=7 --m=1024 --n=512 --k=1024 \\\n" -+ << " --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product() * batch_count * 16; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using precision = float; -+using Element = cutlass::Quaternion; -+using ElementComputeEpilogue = Element; // <- data type of epilogue operations -+using ElementAccumulator = Element; // <- data type of accumulator -+using ElementInputA = Element; // <- data type of elements in input matrix A -+using ElementInputB = Element; // <- data type of elements in input matrix B -+using ElementOutput = Element; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::RowMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassSimt; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm50; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<64, 64, 4>; // <- threadblock tile M = 64, N = 64, K = 8 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 16, 4>; // <- warp tile M = 32, N = 16, K = 8 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<1, 1, 1>; // <- MMA Op tile M = 1, N = 1, K = 1 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- Defaults -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 2; -+ -+using Gemm = cutlass::gemm::device::Gemm; -+ -+int run(Options options) { -+ -+ // PASS/FAIL status -+ bool passed = true; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size = options.problem_size; -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ 4, -+ -4, -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ 4, -+ -4, -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ 4, -+ -4, -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(0); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ tensor_c.device_ref(), // <- reference to matrix C on device -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Result structure -+ Result result; -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMMs -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMMs are complete -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ if (options.reference_check) { -+ -+ // Create instantiation for device reference gemm kernel -+ cutlass::reference::device::Gemm gemm_device; -+ -+ // Launch device reference gemm kernel -+ gemm_device(problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ beta, -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ passed &= cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ } -+ -+ if (passed) { -+ std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " GFLOPs: " << result.gflops << std::endl; -+ } -+ -+ std::cout << (passed ? "Passed" : "Failed") << std::endl; -+ return (passed ? 0 : -1); -+} -+ -+int main(int argc, char const** argv) { -+ -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ printf("%d x %d x %d Single Precision Quaternion Matrix Multiply\n", \ -+ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ return run(options); -+} -+ -diff --git a/3rdparty/cutlass/examples/22_quaternion_conv/quaternion_conv.cu b/3rdparty/cutlass/examples/22_quaternion_conv/quaternion_conv.cu -new file mode 100644 -index 0000000..57df73f ---- /dev/null -+++ b/3rdparty/cutlass/examples/22_quaternion_conv/quaternion_conv.cu -@@ -0,0 +1,667 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using Element = cutlass::Quaternion; -+using ElementAccumulator = Element; // Data type of accumulator -+using ElementComputeEpilogue = Element; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = Element; // Data type of elements in input tensor -+using ElementInputB = Element; // Data type of elements in input tensor -+using ElementOutput = Element; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassSimt; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm50; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; // SIMT instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 2; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+ -+using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(false), -+ measure_performance(true), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha_w", alpha.w()); -+ cmd.get_cmd_line_argument("alpha_x", alpha.x()); -+ cmd.get_cmd_line_argument("alpha_y", alpha.y()); -+ cmd.get_cmd_line_argument("alpha_z", alpha.z()); -+ -+ cmd.get_cmd_line_argument("beta_w", beta.w()); -+ cmd.get_cmd_line_argument("beta_x", beta.x()); -+ cmd.get_cmd_line_argument("beta_y", beta.y()); -+ cmd.get_cmd_line_argument("beta_z", beta.z()); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "22_quaternion_conv example\n\n" -+ << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" -+ << " forward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/22_quaternion_conv/22_quaternion_conv --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" -+ << "$ ./examples/22_quaternion_conv/22_quaternion_conv --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()) * 16; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_ref_c(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ 7, -+ -8, -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ 7, -+ -8, -+ 0); -+ -+ // Fill tensor C on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_c.host_view()); -+ -+ // Fill tensor C for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_c.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_ref_c.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ // Construct ImplicitGemm::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename ImplicitGemm::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_c.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm implicit_gemm_op; -+ -+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on host...\n"; -+ -+ // Compute with reference implementation -+ cutlass::reference::host::Conv2dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_a.host_ref(), -+ tensor_b.host_ref(), -+ tensor_c.host_ref(), -+ tensor_ref_c.host_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_c.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_c.host_view(), -+ tensor_ref_c.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "22_quaternion_conv_" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {1, 32, 64, 128, 256, 512}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1}, -+ {56, 56, 64, 64, 1, 1}, -+ {56, 56, 64, 64, 3, 3}, -+ {56, 56, 256, 64, 1, 1}, -+ {56, 56, 256, 512, 1, 1}, -+ {56, 56, 256, 128, 1, 1}, -+ {28, 28, 128, 128, 3, 3}, -+ {28, 28, 128, 512, 1, 1}, -+ {28, 28, 512, 128, 1, 1}, -+ {28, 28, 512, 1024, 1, 1}, -+ {28, 28, 512, 256, 1, 1}, -+ {14, 14, 256, 256, 3, 3}, -+ {14, 14, 256, 1024, 1, 1}, -+ {14, 14, 1024, 256, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1}, -+ {14, 14, 1024, 512, 1, 1}, -+ {7, 7, 512, 512, 3, 3}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ -+ options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu b/3rdparty/cutlass/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu -new file mode 100644 -index 0000000..81a3e15 ---- /dev/null -+++ b/3rdparty/cutlass/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu -@@ -0,0 +1,766 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+The example demenstrates how to reduce one of the operands of the GEMM along the k-dimension when -+computing GEMM. So the output also contains either a Mx1 or 1XN vector. It only works with Ampere -+16x8x16 FP16/BF16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor -+core instructions. -+ -+Most of the reduction is done in gemm/warp level, see gemm/warp/mma_with_reduction_tensor_op.h -+A few bit of reduction is done in the epilouge before storing the vector, see -+epilogue/threadblock/epilogue_gemm_k_reduction.h -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_with_k_reduction.h" -+#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/kernel/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+#include "cutlass/matrix_coord.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation -+using ElementInputA = cutlass::bfloat16_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::bfloat16_t; // Data type of elements in input tensor -+using ElementOutput = cutlass::bfloat16_t; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::ColumnMajor; -+// Layout of the output vector -+using LayoutGemmKReduction = cutlass::layout::PitchLinear; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 4; -+ -+// Reduce A or B operand along the K dimension -+constexpr bool ReduceKForA = true; -+ -+// Alignment of A operand -+constexpr int AlignmentA = 8; -+ -+// Alignment of B operand -+constexpr int AlignmentB = 8; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; -+ -+using Gemm = typename cutlass::gemm::device::GemmWithKReduction< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ ReduceKForA, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ AlignmentA, -+ AlignmentB, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone -+>; -+ -+// Below is the reduction kernel used in the case of parallel split-k -+using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;; -+ -+using ReduceOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ ElementOutput, -+ EpilogueOp::kCount -+ >; -+ -+using ReduceGemmSplitKKernel = cutlass::reduction::kernel::ReduceSplitK< -+ ReduceGemmSplitKShape, -+ EpilogueOp, -+ ReduceOp -+ >; -+ -+using ReduceGemmSplitK = cutlass::reduction::device::ReduceSplitK; -+ -+using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using DummyEpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, -+ cutlass::epilogue::thread::ScaleType::Nothing>; -+ -+using ReduceVectorSplitKKernel = cutlass::reduction::kernel::ReduceSplitK< -+ ReduceVectorSplitKShape, -+ DummyEpilogueOp, -+ ReduceOp -+ >; -+ -+using ReduceVectorSplitK = cutlass::reduction::device::ReduceSplitK; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::gemm::GemmCoord problem_size; -+ int split_k_slices; -+ bool parallel_split_k; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ problem_size(1024, 1024, 1024), -+ split_k_slices(1), -+ parallel_split_k(false), -+ reference_check(true), -+ measure_performance(false), -+ iterations(20), -+ save_workspace(false), -+ alpha(-1), -+ beta(-1), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((problem_size.m() % kAlignment) || -+ (problem_size.n() % kAlignment) || -+ (problem_size.k() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::gemm::GemmCoord problem_size, -+ int split_k_slices, -+ bool parallel_split_k) { -+ -+ this->problem_size = problem_size; -+ this->split_k_slices = split_k_slices; -+ this->parallel_split_k = parallel_split_k; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("parallel-split-k")) { -+ parallel_split_k = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ cmd.get_cmd_line_argument("split-k-slices", split_k_slices); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "28_ampere_gemm_bias_fusion example\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M\n" -+ << " --n= GEMM N\n" -+ << " --k= GEMM K\n" -+ << " --split-k-slices= Split K Slices\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --parallel-split-k If set (true), use parallel split K\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several problem sizes.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/23_ampere_gemm_bias_fusion_example/ampere_gemm_bias_fusion --m=1024 --n=1024 --k=1024 \n\n"; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "ID,M,N,K,SplitK-Slices,Parallel-SplitK,Runtime"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "gemm_" << idx << "," -+ << options.problem_size.m() << "," -+ << options.problem_size.n() << "," -+ << options.problem_size.k() << "," -+ << options.split_k_slices << "," -+ << options.parallel_split_k << "," -+ << runtime_ms ; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile(Options const &options) { -+ -+ Result result; -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a(options.problem_size.mk()); -+ cutlass::HostTensor tensor_b(options.problem_size.kn()); -+ -+ -+ // Create tensor C with dimensions 1x1x1xk which is the bias vector -+ cutlass::HostTensor tensor_c(options.problem_size.mn()); -+ -+ // Create tensor D used to store output from CUTLASS kernel -+ cutlass::HostTensor tensor_d(options.problem_size.mn()); -+ // Create matrix D with dimensions M x N used to store output from reference -+ // kernel -+ cutlass::HostTensor tensor_ref_d(options.problem_size.mn()); -+ -+ int reduce_vector_length = ReduceKForA ? options.problem_size.m() : options.problem_size.n(); -+ -+ cutlass::HostTensor tensor_reduction({reduce_vector_length, 1}); -+ cutlass::HostTensor tensor_ref_reduction({reduce_vector_length, 1}); -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1997, -+ ElementInputA(2), -+ ElementInputA(-2), -+ 0); // <- Fill tensor A on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 2003, -+ ElementInputB(2), -+ ElementInputB(-2), -+ 0); // <- Fill tensor B on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 2017, -+ ElementOutput(2), -+ ElementOutput(-2), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ cutlass::reference::host::TensorFill( -+ tensor_reduction.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_reduction.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ tensor_reduction.sync_device(); -+ -+ // Initialize alpha for dot product computation -+ ElementComputeEpilogue alpha = options.parallel_split_k ? ElementComputeEpilogue(1) -+ : ElementComputeEpilogue(options.alpha); -+ ElementComputeEpilogue beta = options.parallel_split_k ? ElementComputeEpilogue(0) -+ : ElementComputeEpilogue(options.beta); -+ -+ cutlass::gemm::GemmUniversalMode mode = options.parallel_split_k ? -+ cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel : -+ cutlass::gemm::GemmUniversalMode::kGemm; -+ -+ int batch_count = options.split_k_slices; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments( -+ mode, -+ options.problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_a.device_ref().data(), // <- reference to tensor A on device -+ tensor_b.device_ref().data(), // <- reference to tensor B on device -+ tensor_c.device_ref().data(), // <- reference to matrix C on device -+ tensor_d.device_ref().data(), // <- reference to matrix D on device -+ tensor_reduction.device_ref().data(), // <- reference to reduction tensor on device -+ options.problem_size.m() * options.problem_size.k(), -+ options.problem_size.n() * options.problem_size.k(), -+ options.problem_size.m() * options.problem_size.n(), -+ options.problem_size.m() * options.problem_size.n(), -+ reduce_vector_length, -+ tensor_a.layout().stride(0), -+ tensor_b.layout().stride(0), -+ tensor_c.layout().stride(0), -+ tensor_d.layout().stride(0), -+ tensor_reduction.layout().stride(0)); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Check the problem size is supported or not -+ result.status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ result.status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // Launch initialized CUTLASS kernel -+ result.status = gemm_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ if (options.parallel_split_k && batch_count > 1) { -+ // reduce gemm -+ -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta); -+ -+ int splitk_gemm_stride = options.problem_size.m(); -+ -+ cutlass::layout::RowMajor splitk_gemm_layout(splitk_gemm_stride); -+ -+ void * workspace_gemm_ptr = workspace.get(); -+ cutlass::TensorRef workspace_gemm_tensorref(static_cast(workspace_gemm_ptr), splitk_gemm_layout); -+ -+ cutlass::TensorRef tensor_d_tensorref(tensor_d.device_ref().data(), splitk_gemm_layout); -+ -+ cutlass::TensorRef tensor_c_tensorref(tensor_c.device_ref().data(), splitk_gemm_layout); -+ -+ typename ReduceGemmSplitK::Arguments reduce_gemm_splitk_arguments{ -+ cutlass::MatrixCoord(options.problem_size.n(), options.problem_size.m()), -+ batch_count, -+ size_t(options.problem_size.m() * options.problem_size.n()), -+ workspace_gemm_tensorref, -+ tensor_d_tensorref, -+ tensor_c_tensorref, -+ {alpha, beta} -+ }; -+ -+ ReduceGemmSplitK reduce_gemm_splitk_op; -+ -+ result.status = reduce_gemm_splitk_op.initialize(reduce_gemm_splitk_arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = reduce_gemm_splitk_op(); -+ CUTLASS_CHECK(result.status); -+ -+ // reduce k vector -+ cutlass::layout::RowMajor splitk_vector_layout(reduce_vector_length); -+ -+ ElementOutput *workspace_vector_ptr = static_cast(workspace_gemm_ptr) + batch_count * options.problem_size.m() * options.problem_size.n(); -+ cutlass::TensorRef workspace_vector_tensorref(workspace_vector_ptr, splitk_vector_layout); -+ -+ cutlass::TensorRef tensor_reduction_tensorref(tensor_reduction.device_ref().data(), splitk_vector_layout); -+ -+ cutlass::TensorRef tensor_nullptr_tensorref(nullptr, splitk_vector_layout); -+ -+ typename ReduceVectorSplitK::Arguments reduce_vector_splitk_arguments( -+ cutlass::MatrixCoord(1, reduce_vector_length), -+ batch_count, -+ size_t(reduce_vector_length), -+ workspace_vector_tensorref, -+ tensor_reduction_tensorref, -+ tensor_nullptr_tensorref, -+ {1.0f, 0.0f}); -+ -+ ReduceVectorSplitK reduce_vector_splitk_op; -+ -+ result.status = reduce_vector_splitk_op.initialize(reduce_vector_splitk_arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = reduce_vector_splitk_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // -+ // Create instantiation for device reference conv kernel -+ // -+ if (options.reference_check) { -+ // Launch device reference to compute strictly the product A * B -+ cutlass::reference::device::Gemm< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator> gemm_device; -+ -+ gemm_device -+ ( -+ options.problem_size, -+ ElementComputeEpilogue(options.alpha), -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ ElementComputeEpilogue(options.beta), -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref() -+ ); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ tensor_reduction.sync_host(); -+ -+ // Reduce K in host code -+ if (ReduceKForA) { -+ for (int m = 0; m < options.problem_size.m(); ++m) { -+ for (int k = 0; k < options.problem_size.k(); ++k) { -+ tensor_ref_reduction.at({m, 0}) += -+ tensor_a.at(cutlass::MatrixCoord(m, k)); -+ } -+ } -+ } else { -+ for (int k = 0; k < options.problem_size.k(); ++k) { -+ for (int n = 0; n < options.problem_size.n(); ++n) { -+ tensor_ref_reduction.at({n, 0}) += -+ tensor_b.at(cutlass::MatrixCoord(k, n)); -+ } -+ } -+ } -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ pass &= cutlass::reference::host::TensorEquals(tensor_ref_reduction.host_view(), -+ tensor_reduction.host_view()); -+ -+ if (!pass) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "23_ampere_gemm_operand_reduction_fusion" -+ << options.problem_size.m() << "x" << options.problem_size.n() << "x" << options.problem_size.k() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "A = \n" << tensor_a.host_view() << "\n\n" -+ << "B = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference D = \n" << tensor_ref_d.host_view() << "\n\n"; -+ output_workspace << "Reference reduction vector = \n" << tensor_ref_reduction.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed D = \n" << tensor_d.host_view() << std::endl; -+ output_workspace << "Computed reduction vector = \n" << tensor_reduction.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = gemm_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ struct Benchmark { -+ int m, n, k, split_k_slices, parallel_split_k; -+ } problem_sizes[] = { -+ {4096, 6144, 4096, 1, false}, -+ }; -+ -+ Result::print_header(std::cout, options) << "\n"; -+ -+ int idx = 1; -+ -+ for (auto const &problem_size : problem_sizes) { -+ options.update({problem_size.m, problem_size.n, problem_size.k}, -+ problem_size.split_k_slices, problem_size.parallel_split_k); -+ -+ Result result = profile(options); -+ result.print(std::cout, idx, options) << "\n"; -+ -+ ++idx; -+ } -+ } else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << "\n"; -+ return -1; -+ } -+ -+ Result result = profile(options); -+ -+ Result::print_header(std::cout, options) << "\n"; -+ result.print(std::cout, 1, options) << "\n"; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/24_gemm_grouped/gemm_grouped.cu b/3rdparty/cutlass/examples/24_gemm_grouped/gemm_grouped.cu -new file mode 100644 -index 0000000..4b080fc ---- /dev/null -+++ b/3rdparty/cutlass/examples/24_gemm_grouped/gemm_grouped.cu -@@ -0,0 +1,1578 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief GEMM Grouped Example. -+ -+ This workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices -+ in Global Memory are passed to the kernel in array (also held in Global Memory). Similarly, -+ leading dimensions and problem sizes are stored in arrays in GMEM. -+ -+ This differs from "Batched Array" GEMM because the size of each GEMM problem in the Grouped GEMM -+ concept may be distinct. -+ -+ This benchmark program initializes a workspace with random problem sizes for a given number of -+ groups. Command line options enable overriding M, N, and/or K dimensions with uniform values to -+ model problems more similar to the traditional batched GEMM. -+ -+ Additionally, problem sizes are collected and binned to compute the same problem as a series of -+ conventional batched GEMMs (setup for this problem is not timed). This demonstrates the performance -+ enhancement achieved by implementing a specialized grouped GEMM kernel. -+ -+ Examples: -+ -+ # Runs a grouped GEMM with 100 random problem sizes -+ $ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100 -+ -+ # Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024) -+ $ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100 --k=1024 --verbose=true -+ -+ # Runs a grouped GEMM that is equivalent to a batched GEMM -+ $ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true -+ -+ # Execute Grouped GEMM and profile with NSight -+ $ nv-nsight-cu-cli ./examples/24_gemm_grouped/24_gemm_grouped --m=256 --n=256 --k=256 --verbose=true \ -+ --iterations=1 --reference-check=false -+ -+*/ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double initialization_time_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double initialization_time_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), initialization_time_ms(initialization_time_ms), gflops(gflops), -+ status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Hash function for cutlass::gemm::GemmCoord -+struct HashGemmCoord { -+ size_t operator()(cutlass::gemm::GemmCoord const &problem) const { -+ std::hash hasher; -+ return (hasher(problem.m() * 3)) ^ (hasher(1 + problem.n() * 5)) ^ (hasher(2 + problem.k() * 7)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool error; -+ bool reference_check; -+ bool profile_initialization; -+ bool sort_problems; -+ -+ std::vector problem_sizes; -+ -+ // problem size bins -+ std::unordered_map< -+ cutlass::gemm::GemmCoord, -+ std::vector, -+ HashGemmCoord> problem_bins; -+ -+ int alignment; -+ int problem_count; -+ int iterations; -+ int cuda_streams; -+ bool verbose; -+ float alpha; -+ float beta; -+ std::string benchmark_path; -+ -+ std::string output_tag; -+ std::ofstream output_file; -+ -+ using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; -+ std::vector scheduler_modes; -+ -+ std::unordered_map -+ str_to_scheduler_mode = { -+ {"kDeviceOnly", GroupScheduleMode::kDeviceOnly}, -+ {"kHostPrecompute", GroupScheduleMode::kHostPrecompute} -+ }; -+ -+ struct GroupScheduleModeHash { -+ size_t operator()(GroupScheduleMode m) const { -+ return static_cast(m); -+ } -+ }; -+ -+ std::unordered_map -+ scheduler_mode_to_str = { -+ {GroupScheduleMode::kDeviceOnly, "kDeviceOnly"}, -+ {GroupScheduleMode::kHostPrecompute, "kHostPrecompute"} -+ }; -+ -+ std::vector all_scheduler_modes = {GroupScheduleMode::kDeviceOnly, GroupScheduleMode::kHostPrecompute}; -+ -+ // -+ // Methods -+ // -+ -+ Options(): -+ help(false), -+ error(false), -+ alignment(8), -+ reference_check(true), -+ profile_initialization(false), -+ sort_problems(false), -+ problem_count(15), -+ iterations(20), -+ cuda_streams(0), -+ verbose(false), -+ alpha(1), -+ beta(), -+ scheduler_modes({GroupScheduleMode::kDeviceOnly}) -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("alignment", alignment, 16); -+ cmd.get_cmd_line_argument("groups", problem_count, 8); -+ cmd.get_cmd_line_argument("alpha", alpha, 1.0f); -+ cmd.get_cmd_line_argument("beta", beta, 0.0f); -+ cmd.get_cmd_line_argument("iterations", iterations, 20); -+ cmd.get_cmd_line_argument("streams", cuda_streams, 0); -+ cmd.get_cmd_line_argument("verbose", verbose, true); -+ cmd.get_cmd_line_argument("reference-check", reference_check, true); -+ cmd.get_cmd_line_argument("profile-initialization", profile_initialization, false); -+ cmd.get_cmd_line_argument("sort-problems", sort_problems, false); -+ cmd.get_cmd_line_argument("benchmark", benchmark_path); -+ -+ std::vector scheduler_mode_strs; -+ cmd.get_cmd_line_arguments("scheduler-modes", scheduler_mode_strs); -+ -+ if (!scheduler_mode_strs.empty()) { -+ scheduler_modes.clear(); -+ if (scheduler_mode_strs.size() == 1 && scheduler_mode_strs[0] == "all") { -+ scheduler_modes = all_scheduler_modes; -+ } else { -+ for (std::string precomp_str : scheduler_mode_strs) { -+ auto it = str_to_scheduler_mode.find(precomp_str); -+ if (it != str_to_scheduler_mode.end()) { -+ scheduler_modes.push_back(it->second); -+ } else if (precomp_str == "all") { -+ std::cerr << "Flag --scheduler-modes=all must not contain other scheduler modes in list." << std::endl; -+ error = true; -+ return; -+ } else { -+ std::cerr << "Unrecognized scheduler mode '" << precomp_str << "'" << std::endl; -+ error = true; -+ return; -+ } -+ } -+ } -+ } -+ -+ std::string output_path; -+ cmd.get_cmd_line_argument("tag", output_tag); -+ cmd.get_cmd_line_argument("output_file", output_path); -+ -+ if (!output_path.empty()) { -+ -+ std::ios_base::openmode open_mode = std::ios_base::out; -+ -+ std::ifstream input_file(output_path.c_str()); -+ -+ if (input_file.good()) { -+ open_mode = std::ios_base::app; -+ input_file.close(); -+ } -+ -+ output_file.open(output_path.c_str(), open_mode); -+ -+ if (output_file.good() && open_mode != std::ios_base::app) { -+ output_file << "Tag,Provider,Kind,Groups,Runtime,GFLOPs\n"; -+ } -+ } -+ -+ // Decide how to initialize the problems -+ if (!benchmark_path.empty()) { -+ if (!benchmark_problems()) { -+ error = true; -+ problem_sizes.clear(); -+ return; -+ } -+ } -+ else { -+ randomize_problems(cmd); -+ } -+ -+ // Post-process the problem sizes -+ bin_problems(); -+ } -+ -+ void randomize_problems(cutlass::CommandLine &cmd) { -+ -+ // -+ // For now, randomly choose the problem sizes. -+ // -+ -+ int cmd_line_m = -1; -+ int cmd_line_n = -1; -+ int cmd_line_k = -1; -+ -+ cmd.get_cmd_line_argument("m", cmd_line_m,128); -+ cmd.get_cmd_line_argument("n", cmd_line_n,128); -+ cmd.get_cmd_line_argument("k", cmd_line_k,64); -+ -+ problem_sizes.reserve(problem_count); -+ -+ for (int i = 0; i < problem_count; ++i) { -+ -+ int m = cmd_line_m; -+ int n = cmd_line_n; -+ int k = cmd_line_k; -+ -+ if (m < 1) { -+ m = alignment * ((rand() % 256) + 1); -+ } -+ -+ if (n < 1) { -+ n = alignment * ((rand() % 256) + 1); -+ } -+ -+ if (k < 1) { -+ k = alignment * ((rand() % 256) + 1); -+ } -+ -+ cutlass::gemm::GemmCoord problem(m, n, k); -+ -+ problem_sizes.push_back(problem); -+ } -+ } -+ -+ /// Load a benchmark -+ bool benchmark_problems() { -+ std::ifstream file(benchmark_path); -+ if (!file.good()) { -+ return false; -+ } -+ -+ while (file.good()) { -+ -+ int idx = -1; -+ std::string extent_str; -+ -+ file >> idx >> extent_str; -+ -+ if (idx < 0 || extent_str.empty()) { -+ break; -+ } -+ -+ cutlass::gemm::GemmCoord extent; -+ std::vector tokens; -+ -+ cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); -+ -+ for (int i = 0; i < int(tokens.size()); ++i) { -+ int x = std::atoi(tokens.at(i).c_str()); -+ -+ // round up -+ if (x % alignment) { -+ x += (alignment - (x % alignment)); -+ } -+ -+ extent.at(i) = x; -+ } -+ -+ if (extent.product()) { -+ problem_sizes.push_back(extent); -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Post processes the problems -+ void bin_problems() { -+ -+ problem_bins.clear(); -+ -+ problem_count = int(problem_sizes.size()); -+ -+ // -+ // Insert the problem sizes into a sorted container class. This is *NOT* necessary -+ // to run the CUTLASS kernel, but it enables the execution of cublas's batched GEMM. -+ // -+ for (int i = 0; i < int(problem_sizes.size()); ++i) { -+ auto it = problem_bins.find(problem_sizes.at(i)); -+ if (it == problem_bins.end()) { -+ problem_bins.insert({problem_sizes.at(i), std::vector({i}) }); -+ } -+ else { -+ it->second.push_back(i); -+ } -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "24_gemm_grouped\n\n" -+ << " This example profiles the performance of a 'grouped' GEMM kernel. This is similar to batched GEMM\n" -+ << " in that multiple, independent GEMMs are computed by one grid launch. It differs in that each\n" -+ << " 'group' may compute a unique problem size. Problem sizes and pointers to matrices are both stored\n" -+ << " in device Global Memory and loaded by the kernel.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --benchmark= Executes a benchmark problem size.\n" -+ << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" -+ << " --tag= String tag to prepend to the CSV file.\n" -+ << " --groups= Number of individual GEMM problems (default: --groups=15)\n" -+ << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" -+ << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" -+ << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n" -+ << " --scheduler-modes= List of scheduler modes to be profile for grouped GEMM scheduler (default: --scheduler_modes=kDeviceOnly)\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --reference-check= If true, performs reference check.\n" -+ << " --verbose= If true, prints problem sizes and batching structure.\n" -+ << " --profile-initialization= If true, profiles the device-level kernel's initialization.\n" -+ << " --sort-problems= If true, sorts problem sizes in descending order of GEMM-K dimension.\n"; -+ -+ out << "\n\nExamples:\n\n" -+ -+ << "# Runs a grouped GEMM with 100 random problem sizes\n" -+ << "$ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100\n\n" -+ -+ << "# Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024)\n" -+ << "$ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100 --k=1024 --verbose=true\n\n" -+ -+ << "# Runs a grouped GEMM that is equivalent to a batched GEMM\n" -+ << "$ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true\n\n" -+ -+ << "# Runs a grouped GEMM with each different scheduler mode\n" -+ << "$ ./examples/24_gemm_grouped/24_gemm_grouped --scheduler-modes=all\n\n" -+ -+ << "# Runs a grouped GEMM with each different scheduler mode and profiles host-side initialization time\n" -+ << "$ ./examples/24_gemm_grouped/24_gemm_grouped --scheduler-modes=all --profile-initialization=true\n\n" -+ -+ << "# Runs a grouped GEMM problem given an externally supplied benchmark file. This is a text file in which\n" -+ << "# Each line contains a unique group index and an MxNxK triple indicating problemsize.\n" -+ << "#\n" -+ << "# For example, assume the following are the contents of 'problems.txt'\n" -+ << "#\n" -+ << "# 0 1024x256x520\n" -+ << "# 1 520x264x1024\n" -+ << "# 2 96x48x1024\n" -+ << "#\n" -+ << "$ ./examples/24_gemm_grouped/24_gemm_grouped --benchmark=problems.txt\n\n" -+ -+ << "# Execute Grouped GEMM and profile with NSight\n" -+ << "$ nv-nsight-cu-cli ./examples/24_gemm_grouped/24_gemm_grouped --m=256 --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = int64_t(); -+ -+ for (auto const & problem : problem_sizes) { -+ fmas += problem.product(); -+ } -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class BaseTestbed { -+public: -+ // -+ // Type definitions -+ // -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using LayoutA = typename Gemm::LayoutA; -+ using LayoutB = typename Gemm::LayoutB; -+ using LayoutC = typename Gemm::LayoutC; -+ -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+ // -+ // Data members -+ // -+ -+ Options & options; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint32_t seed; -+ -+ cutlass::DeviceAllocation problem_sizes_device; -+ -+ std::vector offset_A; -+ std::vector offset_B; -+ std::vector offset_C; -+ std::vector offset_D; -+ -+ std::vector lda_host; -+ std::vector ldb_host; -+ std::vector ldc_host; -+ std::vector ldd_host; -+ -+ cutlass::DeviceAllocation lda; -+ cutlass::DeviceAllocation ldb; -+ cutlass::DeviceAllocation ldc; -+ cutlass::DeviceAllocation ldd; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ -+ cutlass::DeviceAllocation ptr_A; -+ cutlass::DeviceAllocation ptr_B; -+ cutlass::DeviceAllocation ptr_C; -+ cutlass::DeviceAllocation ptr_D; -+ -+ BaseTestbed( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ int problem_count() const { -+ return options.problem_count; -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ Element *ptr, -+ size_t capacity, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ ptr, capacity, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::device::BlockFillRandomGaussian( -+ ptr, capacity, seed, Element(), Element(0.5f)); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ // Fill with increasing elements -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(1), Element()); -+ } -+ else { -+ -+ // Fill with all 1s -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(), Element(1)); -+ } -+ } -+ -+ /// Allocates device-side data -+ void allocate() { -+ int64_t total_elements_A = 0; -+ int64_t total_elements_B = 0; -+ int64_t total_elements_C = 0; -+ int64_t total_elements_D = 0; -+ -+ lda_host.resize(problem_count()); -+ ldb_host.resize(problem_count()); -+ ldc_host.resize(problem_count()); -+ ldd_host.resize(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ -+ auto problem = options.problem_sizes.at(i); -+ -+ lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); -+ ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); -+ ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ -+ offset_A.push_back(total_elements_A); -+ offset_B.push_back(total_elements_B); -+ offset_C.push_back(total_elements_C); -+ offset_D.push_back(total_elements_D); -+ -+ int64_t elements_A = problem.m() * problem.k(); -+ int64_t elements_B = problem.k() * problem.n(); -+ int64_t elements_C = problem.m() * problem.n(); -+ int64_t elements_D = problem.m() * problem.n(); -+ -+ total_elements_A += elements_A; -+ total_elements_B += elements_B; -+ total_elements_C += elements_C; -+ total_elements_D += elements_D; -+ } -+ -+ lda.reset(problem_count()); -+ ldb.reset(problem_count()); -+ ldc.reset(problem_count()); -+ ldd.reset(problem_count()); -+ -+ block_A.reset(total_elements_A); -+ block_B.reset(total_elements_B); -+ block_C.reset(total_elements_C); -+ block_D.reset(total_elements_D); -+ } -+ -+ /// Initializes device-side data -+ void initialize() { -+ problem_sizes_device.reset(problem_count()); -+ problem_sizes_device.copy_from_host(options.problem_sizes.data()); -+ -+ lda.copy_from_host(lda_host.data()); -+ ldb.copy_from_host(ldb_host.data()); -+ ldc.copy_from_host(ldc_host.data()); -+ ldd.copy_from_host(ldd_host.data()); -+ -+ // -+ // Assign pointers -+ // -+ -+ std::vector ptr_A_host(problem_count()); -+ std::vector ptr_B_host(problem_count()); -+ std::vector ptr_C_host(problem_count()); -+ std::vector ptr_D_host(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ ptr_A_host.at(i) = block_A.get() + offset_A.at(i); -+ ptr_B_host.at(i) = block_B.get() + offset_B.at(i); -+ ptr_C_host.at(i) = block_C.get() + offset_C.at(i); -+ ptr_D_host.at(i) = block_D.get() + offset_D.at(i); -+ } -+ -+ ptr_A.reset(problem_count()); -+ ptr_A.copy_from_host(ptr_A_host.data()); -+ -+ ptr_B.reset(problem_count()); -+ ptr_B.copy_from_host(ptr_B_host.data()); -+ -+ ptr_C.reset(problem_count()); -+ ptr_C.copy_from_host(ptr_C_host.data()); -+ -+ ptr_D.reset(problem_count()); -+ ptr_D.copy_from_host(ptr_D_host.data()); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ initialize_tensor(block_A.get(), block_A.size(), init_A, seed * 2021); -+ initialize_tensor(block_B.get(), block_B.size(), init_B, seed * 2022); -+ initialize_tensor(block_C.get(), block_C.size(), init_C, seed * 2023); -+ -+ cutlass::reference::device::BlockFillSequential( -+ block_D.get(), block_D.size(), ElementC(), ElementC()); -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify() { -+ -+ bool passed = true; -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ cutlass::gemm::GemmCoord problem = options.problem_sizes.at(i); -+ -+ LayoutA layout_A(lda_host.at(i)); -+ LayoutB layout_B(ldb_host.at(i)); -+ LayoutC layout_C(ldc_host.at(i)); -+ LayoutC layout_D(ldd_host.at(i)); -+ -+ MatrixCoord extent_A{problem.m(), problem.k()}; -+ MatrixCoord extent_B{problem.k(), problem.n()}; -+ MatrixCoord extent_C{problem.m(), problem.n()}; -+ -+ cutlass::TensorView view_A(block_A.get() + offset_A.at(i), layout_A, extent_A); -+ cutlass::TensorView view_B(block_B.get() + offset_B.at(i), layout_B, extent_B); -+ cutlass::TensorView view_C(block_C.get() + offset_C.at(i), layout_C, extent_C); -+ -+ cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C)); -+ cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem, -+ options.alpha, -+ view_A, -+ Gemm::kTransformA, -+ view_B, -+ Gemm::kTransformB, -+ options.beta, -+ view_C, -+ view_Ref_device, -+ ElementAccumulator(0) -+ ); -+ -+ // Copy to host memory -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ std::vector matrix_Ref(layout_D.capacity(extent_C)); -+ -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); -+ cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); -+ -+ cutlass::TensorView view_D( matrix_D.data(), layout_D, extent_C); -+ cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); -+ -+ // Reference check -+ passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; -+ return passed; -+ } -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+template -+class TestbedBatched : BaseTestbed { -+public: -+ TestbedBatched( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} -+ -+ void print_problem_sizes() { -+ std::cout << std::endl; -+ size_t bin_idx = 0; -+ size_t problem_count_check = 0; -+ std::cout << "Conventionally executed as " << this->options.problem_bins.size() << " batched GEMMs:\n"; -+ for (auto const & bin : this->options.problem_bins) { -+ -+ std::cout << " [" << bin_idx << "]: " -+ << bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k() -+ << ", batch count: " << bin.second.size() << "\n"; -+ -+ ++bin_idx; -+ problem_count_check += bin.second.size(); -+ } -+ -+ if (problem_count_check != this->problem_count()) { -+ std::cout << "\n***\nERROR in BINNING LOGIC!\n***\n" << std::endl; -+ } -+ -+ std::cout << std::endl; -+ } -+ -+ /// Executes a batched kernel and measures runtime -+ Result profile() { -+ std::cout << "Batched GEMM:\n" -+ << "====================================================" << std::endl; -+ -+ Result result; -+ result.passed = false; -+ -+ // Initialize the problem -+ this->allocate(); -+ this->initialize(); -+ -+ if (this->options.verbose) { -+ print_problem_sizes(); -+ } -+ -+ // -+ // Prepare batched GEMM environment -+ // -+ -+ int32_t effective_streams = (this->options.cuda_streams ? this->options.cuda_streams : 1); -+ -+ // Array of leading dimensions used by batched GEMM calls -+ std::vector bin_problem_sizes; -+ std::vector bin_count; -+ std::vector bin_ldm_A; -+ std::vector bin_ldm_B; -+ std::vector bin_ldm_C; -+ std::vector bin_start; -+ -+ std::vector ptr_A_batched_host; -+ std::vector ptr_B_batched_host; -+ std::vector ptr_C_batched_host; -+ -+ for (auto const & bin : this->options.problem_bins) { -+ int first_idx = bin.second.front(); -+ -+ bin_problem_sizes.push_back(this->options.problem_sizes.at(first_idx)); -+ bin_count.push_back(int32_t(bin.second.size())); -+ -+ bin_ldm_A.push_back(static_cast(this->lda_host.at(first_idx))); -+ bin_ldm_B.push_back(static_cast(this->ldb_host.at(first_idx))); -+ bin_ldm_C.push_back(static_cast(this->ldc_host.at(first_idx))); -+ -+ if (ptr_A_batched_host.size() % 2) { -+ ptr_A_batched_host.push_back(nullptr); -+ ptr_B_batched_host.push_back(nullptr); -+ ptr_C_batched_host.push_back(nullptr); -+ } -+ -+ bin_start.push_back(int32_t(ptr_A_batched_host.size())); -+ -+ for (int idx : bin.second) { -+ -+ if (bin_problem_sizes.back() != this->options.problem_sizes.at(idx)) { -+ std::cerr << "Error - failed to group problems.\n"; -+ return result; -+ } -+ -+ if (bin_ldm_A.back() != this->lda_host.at(idx)) { -+ std::cerr << "Error - failed to group problems.\n"; -+ return result; -+ } -+ -+ if (bin_ldm_B.back() != this->ldb_host.at(idx)) { -+ std::cerr << "Error - failed to group problems.\n"; -+ return result; -+ } -+ -+ if (bin_ldm_C.back() != this->ldc_host.at(idx)) { -+ std::cerr << "Error - failed to group problems.\n"; -+ return result; -+ } -+ -+ ptr_A_batched_host.push_back(this->block_A.get() + this->offset_A.at(idx)); -+ ptr_B_batched_host.push_back(this->block_B.get() + this->offset_B.at(idx)); -+ ptr_C_batched_host.push_back(this->block_D.get() + this->offset_C.at(idx)); -+ } -+ } -+ -+ // Array of GMEM pointers used by batched array GEMM calls -+ cutlass::DeviceAllocation ptr_A_batched; -+ cutlass::DeviceAllocation ptr_B_batched; -+ cutlass::DeviceAllocation ptr_C_batched; -+ -+ ptr_A_batched.reset(ptr_A_batched_host.size()); -+ ptr_B_batched.reset(ptr_A_batched_host.size()); -+ ptr_C_batched.reset(ptr_A_batched_host.size()); -+ -+ ptr_A_batched.copy_from_host(ptr_A_batched_host.data()); -+ ptr_B_batched.copy_from_host(ptr_B_batched_host.data()); -+ ptr_C_batched.copy_from_host(ptr_C_batched_host.data()); -+ -+ // -+ // Create CUDA streams to maximize concurrency of batched-array GEMM kernels -+ // -+ std::vector cuda_streams; -+ -+ // -+ // Warmup run -+ // -+ -+ -+ if (this->options.cuda_streams) { -+ for (int i = 0; i < this->options.cuda_streams; ++i) { -+ cudaStream_t stream; -+ -+ result.error = cudaStreamCreate(&stream); -+ if (result.error != cudaSuccess) { -+ std::cerr << "Failed to create CUDA stream." << std::endl; -+ return result; -+ } -+ cuda_streams.push_back(stream); -+ -+ } -+ } -+ else { -+ cuda_streams.push_back(nullptr); -+ -+ } -+ -+ // Use 'D' for the in/out workspace -+ this->block_D.copy_from_device(this->block_C.get()); -+ -+ for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) { -+ -+ cutlass::gemm::GemmCoord const & problem = bin_problem_sizes[bin_idx]; -+ int32_t batch_count = bin_count[bin_idx]; -+ int32_t bin_start_idx = bin_start[bin_idx]; -+ int32_t lda = bin_ldm_A[bin_idx]; -+ int32_t ldb = bin_ldm_B[bin_idx]; -+ int32_t ldc = bin_ldm_C[bin_idx]; -+ -+ void const ** ptr_A_array = ptr_A_batched.get() + bin_start[bin_idx]; -+ void const ** ptr_B_array = ptr_B_batched.get() + bin_start[bin_idx]; -+ void ** ptr_C_array = ptr_C_batched.get() + bin_start[bin_idx]; -+ -+ // -+ // Initialize the CUTLASS GEMM operator -+ // -+ -+ // Configure the GEMM arguments -+ typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); -+ -+ typename Gemm::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kArray, -+ problem, -+ batch_count, -+ epilogue_op, -+ (void const *)ptr_A_array, -+ (void const *)ptr_B_array, -+ (void const *)ptr_C_array, -+ (void *)ptr_C_array, -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(lda), -+ int64_t(ldb), -+ int64_t(ldc), -+ int64_t(ldc) -+ }; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op.initialize(arguments); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ -+ status = gemm_op(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ int last_stream_idx = 0; -+ -+ for (int iter = 0; iter < this->options.iterations; ++iter) { -+ -+ for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) { -+ -+ cutlass::gemm::GemmCoord const & problem = bin_problem_sizes[bin_idx]; -+ int32_t batch_count = bin_count[bin_idx]; -+ int32_t bin_start_idx = bin_start[bin_idx]; -+ int32_t lda = bin_ldm_A[bin_idx]; -+ int32_t ldb = bin_ldm_B[bin_idx]; -+ int32_t ldc = bin_ldm_C[bin_idx]; -+ -+ void const ** ptr_A_array = ptr_A_batched.get() + bin_start[bin_idx]; -+ void const ** ptr_B_array = ptr_B_batched.get() + bin_start[bin_idx]; -+ void ** ptr_C_array = ptr_C_batched.get() + bin_start[bin_idx]; -+ -+ last_stream_idx = (bin_idx % effective_streams); -+ -+ // -+ // Initialize the CUTLASS GEMM operator -+ // -+ -+ // Configure the GEMM arguments -+ typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); -+ -+ typename Gemm::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kArray, -+ problem, -+ batch_count, -+ epilogue_op, -+ (void const *)ptr_A_array, -+ (void const *)ptr_B_array, -+ (void const *)ptr_C_array, -+ (void *)ptr_C_array, -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(lda), -+ int64_t(ldb), -+ int64_t(ldc), -+ int64_t(ldc) -+ }; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op.initialize(arguments); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ -+ status = gemm_op(cuda_streams[last_stream_idx]); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ -+ } -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Wait for work to be completed -+ // -+ -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(this->options.iterations); -+ result.gflops = this->options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ for (auto stream : cuda_streams) { -+ if (stream) { -+ (void)cudaStreamDestroy(stream); -+ } -+ } -+ -+ std::cout << " " << this->options.problem_bins.size() << " batched GEMMs launched" << std::endl; -+ std::cout << std::endl; -+ std::cout << " " << "Batched Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "Batched GFLOPs: " << result.gflops << std::endl; -+ -+ std::string provider = "CUTLASS"; -+ -+ if (this->options.output_file.good()) { -+ this->options.output_file << this->options.output_tag << "," << provider << ",batched," -+ << this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl; -+ } -+ -+ result.passed = true; -+ return result; -+ } -+}; -+ -+template -+class TestbedGrouped : BaseTestbed { -+public: -+ TestbedGrouped( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} -+ -+ // Redefine GEMM with different GroupScheduleMode_ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ typename Gemm_::ElementA, -+ typename Gemm_::LayoutA, -+ Gemm_::kTransformA, -+ Gemm_::kAlignmentA, -+ typename Gemm_::ElementB, -+ typename Gemm_::LayoutB, -+ Gemm_::kTransformB, -+ Gemm_::kAlignmentB, -+ typename Gemm_::ElementC, -+ typename Gemm_::LayoutC, -+ typename Gemm_::ElementAccumulator, -+ typename Gemm_::OperatorClass, -+ typename Gemm_::ArchTag, -+ typename Gemm_::ThreadblockShape, -+ typename Gemm_::WarpShape, -+ typename Gemm_::InstructionShape, -+ typename Gemm_::EpilogueOutputOp, -+ typename Gemm_::ThreadblockSwizzle, -+ Gemm_::kStages, -+ GroupScheduleMode_>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ /// Verbose printing of problem sizes -+ void print_problem_sizes() { -+ std::cout << std::endl; -+ -+ // Print groups -+ std::cout << this->problem_count() << " groups:\n"; -+ -+ int32_t idx = 0; -+ int64_t total_tiles = 0; -+ -+ for (auto const & problem : this->options.problem_sizes) { -+ int tiles = Gemm::problem_tile_count(problem); -+ total_tiles += tiles; -+ -+ std::cout << " [" << idx << "]: " -+ << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() -+ << " (" << tiles << " threadblock tiles)" << "\n"; -+ -+ ++idx; -+ } -+ std::cout << std::endl; -+ } -+ -+ /// Sort problems in descending order of problem-K dimension -+ void sort_problems() { -+ Gemm::sort_problems(this->options.problem_count, -+ this->options.problem_sizes.data(), -+ this->lda_host.data(), -+ this->ldb_host.data(), -+ this->ldc_host.data(), -+ this->ldd_host.data(), -+ this->offset_A.data(), -+ this->offset_B.data(), -+ this->offset_C.data(), -+ this->offset_D.data()); -+ } -+ -+ /// Executes a grouped kernel and measures runtime -+ Result profile() { -+ std::string sched_mode = this->options.scheduler_mode_to_str.find(GroupScheduleMode_)->second; -+ -+ std::cout << std::endl; -+ std::cout << "Grouped GEMM (CUTLASS) with mode " << sched_mode << ":\n" -+ << "==================================================== *********" << std::endl; -+ -+ Result result; -+ -+ int threadblock_count = Gemm::sufficient(this->options.problem_sizes.data(), this->options.problem_count); -+ -+ // Early exit -+ if (!threadblock_count) { -+ std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ result.passed = false; -+ -+ // Initialize the problem -+ this->allocate(); -+ if (this->options.sort_problems) { -+ sort_problems(); -+ } -+ this->initialize(); -+ -+ if (this->options.verbose) { -+ print_problem_sizes(); -+ } -+ -+ // Configure the GEMM arguments -+ typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); -+ -+ // Configure GEMM arguments -+ typename Gemm::Arguments args( -+ this->problem_sizes_device.get(), -+ this->problem_count(), -+ threadblock_count, -+ epilogue_op, -+ this->ptr_A.get(), -+ this->ptr_B.get(), -+ this->ptr_C.get(), -+ this->ptr_D.get(), -+ this->lda.get(), -+ this->ldb.get(), -+ this->ldc.get(), -+ this->ldd.get(), -+ this->options.problem_sizes.data() -+ ); -+ -+ // Initialize the GEMM object -+ Gemm gemm; -+ -+ size_t workspace_size = gemm.get_workspace_size(args); -+ cutlass::DeviceAllocation workspace(workspace_size); -+ -+ result.status = gemm.initialize(args, workspace.get()); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize CUTLASS Grouped GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Run the grouped GEMM object -+ result.status = gemm.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ -+ if (this->options.reference_check) { -+ result.passed = this->verify(); -+ } -+ -+ // -+ // Warm-up run of the grouped GEMM object -+ // -+ result.status = gemm.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < this->options.iterations; ++iter) { -+ gemm(); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(this->options.iterations); -+ result.gflops = this->options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ // Optionally profile initialization -+ if (this->options.profile_initialization) { -+ // Warm up -+ gemm.initialize(args, workspace.get()); -+ -+ auto start_time = std::chrono::high_resolution_clock::now(); -+ for (int32_t i = 0; i < this->options.iterations; ++i) { -+ gemm.initialize(args, workspace.get()); -+ } -+ auto end_time = std::chrono::high_resolution_clock::now(); -+ -+ std::chrono::duration duration = end_time - start_time; -+ duration /= double(this->options.iterations); -+ result.initialization_time_ms = duration.count(); -+ } -+ -+ int64_t total_tiles = Gemm::group_tile_count(args); -+ std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; -+ -+ std::cout << std::endl; -+ std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; -+ if (this->options.profile_initialization) { -+ std::cout << " " << "Init Runtime: " << result.initialization_time_ms << " ms" << std::endl; -+ } -+ -+ if (this->options.output_file.good()) { -+ this->options.output_file << this->options.output_tag << ",CUTLASS,grouped-" << sched_mode << "," -+ << this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl; -+ } -+ -+ std::cout << "\nPassed\n"; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { -+ -+ // -+ // This example requires an NVIDIA Ampere-architecture GPU. -+ // -+ -+ std::cout -+ << "CUTLASS's Grouped GEMM example requires a GPU of NVIDIA's Ampere Architecture or " -+ << "later (compute capability 80 or greater).\n"; -+ -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.error) { -+ std::cerr << "Aborting execution." << std::endl; -+ return -1; -+ } -+ -+ // -+ // Define the Grouped and Batched GEMM types -+ // -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 -+ using GemmBatched = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, LayoutA, -+ cutlass::half_t, LayoutB, -+ ElementOutput, LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 4 -+ >; -+ -+ // Define a grouped GEMM kernel with all template parameters set except -+ // for scheduling mode. This will be used as the template for all scheduling -+ // modes executed. -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ cutlass::half_t, -+ LayoutA, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ LayoutB, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ ElementOutput, LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. -+ // This parameter is passed in at present to match the APIs of other kernels. The parameter -+ // is unused within the kernel. -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 4>::GemmKernel; -+ -+ using GemmGrouped = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Profile it -+ // -+ -+ TestbedBatched testbed_batched(options); -+ Result result = testbed_batched.profile(); -+ if (result.error) { -+ return 1; -+ } -+ -+ using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; -+ for (GroupScheduleMode mode : options.scheduler_modes) { -+ Result result; -+ switch (mode) { -+ case GroupScheduleMode::kDeviceOnly: -+ { -+ TestbedGrouped runner(options); -+ result = runner.profile(); -+ break; -+ } -+ case GroupScheduleMode::kHostPrecompute: -+ { -+ TestbedGrouped runner(options); -+ result = runner.profile(); -+ break; -+ } -+ } -+ -+ if (result.error != cudaSuccess) { -+ return 1; -+ } -+ -+ // Override verbose flag to avoid printing duplicate information for each scheduling mode -+ options.verbose = false; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/25_ampere_fprop_mainloop_fusion/ampere_3d_fprop_mainloop_fusion.cu b/3rdparty/cutlass/examples/25_ampere_fprop_mainloop_fusion/ampere_3d_fprop_mainloop_fusion.cu -new file mode 100644 -index 0000000..5964028 ---- /dev/null -+++ b/3rdparty/cutlass/examples/25_ampere_fprop_mainloop_fusion/ampere_3d_fprop_mainloop_fusion.cu -@@ -0,0 +1,776 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+This example shows how to fuse per channel scale+bias+relu of the activations -+into the 3D fprop mainloop. -+ -+Compared with original 3D fprop kernel, this example has two more vectors, one for -+the scale and one for the bias. The length of the vectors is the same as the -+activation channel number. This kernel loads the vectors when the associated -+activation channels are loaded in the mainloop. Between reading the -+activations and scale/bias data from the shared memory and calling tensor core -+instructions, scale+bias+relu is computed in the register file. -+ -+This example is customized for Ampere 16816 fp16 tensor core instruction. -+Changing to different data types or different tensor core instruction require -+source code changing. See -+include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h for more -+technical details. -+ -+This example is modified based on 25_ampere_fprop_mainloop_fusion. The command -+line is the same. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv3d_fprop_fusion.h" -+#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNDHWC; -+using LayoutInputB = cutlass::layout::TensorNDHWC; -+using LayoutInputScaleBias = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::TensorNDHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 4; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+using Conv3dFpropFusionKernel = typename cutlass::conv::kernel::DefaultConv3dFpropFusion< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementInputScaleBias, LayoutInputScaleBias, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor5DCoord input_size; -+ cutlass::Tensor5DCoord filter_size; -+ cutlass::Coord<3> padding; -+ cutlass::Coord<3> conv_stride; -+ cutlass::Coord<3> dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32, 32), -+ filter_size(32, 3, 3, 3, 32), -+ padding(cutlass::make_Coord(1, 1, 1)), -+ conv_stride(cutlass::make_Coord(1, 1, 1)), -+ dilation(cutlass::make_Coord(1, 1, 1)), -+ reference_check(true), -+ measure_performance(false), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding[0] != filter_size.d() / 2) || -+ (padding[1] != filter_size.h() / 2) || -+ (padding[2] != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor5DCoord input_size, -+ cutlass::Tensor5DCoord filter_size, -+ cutlass::Coord<3> stride) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ conv_stride = stride; -+ -+ padding[0] = filter_size.d() / 2; -+ padding[1] = filter_size.h() / 2; -+ padding[2] = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("d", input_size.d()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("t", filter_size.d()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.d() == 3 && filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = cutlass::make_Coord(1, 1, 1); -+ } -+ else { -+ filter_size.d() = 1; -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = cutlass::make_Coord(0, 0, 0); -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "25_ampere_3d_fprop_mainloop_fusion example\n\n" -+ << " This example fuses scale+bias+relu of the activations into Ampere's\n" -+ << " Tensor Core operators on F16 data types to compute\n" -+ << " forward convolution on tensors of layout NDHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n Input tensor extent N\n" -+ << " --d Input tensor extent D\n" -+ << " --h Input tensor extent H\n" -+ << " --w Input tensor extent W\n" -+ << " --c Input tensor extent C\n" -+ << " --k Filter extent K\n" -+ << " --t Filter extent T\n" -+ << " --r Filter extent R\n" -+ << " --s Filter extent S\n\n" -+ << " --alpha Epilogue scalar alpha\n" -+ << " --beta Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./25_ampere_3d_fprop_mainloop_fusion --n=32 --d=96 --h=96 --w=96 --c=64 --k=64 --t=1 --r=1 --s=1\n\n" -+ << "$ ./25_ampere_3d_fprop_mainloop_fusion --n=1 --d=224 --h=224 --w=224 --c=32 --k=32 --t=3 --r=3 --s=3 --ref-check\n\n" -+ << "$ ./25_ampere_3d_fprop_mainloop_fusion --n=19 --d=94 --h=96 --w=96 --c=128 --k=128 --t=1 --r=1 --s=1\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor5DCoord output_size() const { -+ return cutlass::Tensor5DCoord( -+ input_size.n(), -+ (input_size.d() + padding[0] + padding[0] - filter_size.d()) / conv_stride[0] + 1, -+ (input_size.h() + padding[1] + padding[1] - filter_size.h()) / conv_stride[1] + 1, -+ (input_size.w() + padding[2] + padding[2] - filter_size.w()) / conv_stride[2] + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.d() * filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,D,H,W,C,K,T,R,S,Stride_D,Stride_H,Stride_W,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.d() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.d() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << options.conv_stride[0] << "," -+ << options.conv_stride[1] << "," -+ << options.conv_stride[2] << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_transformed_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor -+ tensor_a_scale({1, options.input_size.c()}); -+ cutlass::HostTensor -+ tensor_a_bias({1, options.input_size.c()}); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_d(options.output_size()); -+ cutlass::HostTensor tensor_ref_d(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill scale vector for tensor A on host with uniform-distribution random -+ // data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_scale.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill bias vector for tensor A on host with uniform-distribution random -+ // data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_bias.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(7), -+ ElementOutput(-8), -+ 0); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_a_scale.sync_device(); -+ tensor_a_bias.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv3dProblemSize with user defined output size -+ cutlass::conv::Conv3dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ typename ImplicitGemmFusion::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_a_scale.device_ref(), -+ tensor_a_bias.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemmFusion implicit_gemm_fusion_op; -+ -+ size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_fusion_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_fusion_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on device...\n"; -+ -+ // Compute scale + bias + relu in host code -+ for (int n = 0; n < options.input_size.n(); ++n) { -+ for (int d = 0; d < options.input_size.d(); ++d) { -+ for (int h = 0; h < options.input_size.h(); ++h) { -+ for (int w = 0; w < options.input_size.w(); ++w) { -+ for (int c = 0; c < options.input_size.c(); ++c) { -+ tensor_transformed_a.at({n, d, h, w, c}) = std::max( -+ ElementOutput(0), ElementOutput(tensor_a.at({n, d, h, w, c}) * -+ tensor_a_scale.at({0, c}) + -+ tensor_a_bias.at({0, c}))); -+ } -+ } -+ } -+ } -+ } -+ -+ tensor_transformed_a.sync_device(); -+ -+ // Compute with reference implementation -+ cutlass::reference::device::Conv3dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_transformed_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "25_ampere_3d_fprop_mainloop_fusion" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_fusion_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv3dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "This test must run on SM80 or above.\n"; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {34, 18}; -+ -+ struct Benchmark { -+ int d, h, w, c, k, t, r, s, stride_d, stride_h, stride_w; -+ } layers[] = { -+ {56, 56, 56, 64, 256, 1, 1, 1, 1, 1, 1}, -+ {56, 56, 56, 64, 64, 1, 1, 1, 1, 1, 1}, -+ {56, 56, 56, 64, 64, 3, 3, 3, 1, 1, 1}, -+ {56, 56, 56, 256, 64, 1, 1, 1, 1, 1, 1}, -+ {56, 56, 56, 256, 512, 1, 1, 1, 2, 2, 2}, -+ {56, 56, 56, 256, 128, 1, 1, 1, 1, 1, 1}, -+ {56, 56, 56, 128, 128, 3, 3, 3, 2, 2, 2}, -+ {28, 28, 28, 128, 512, 1, 1, 1, 1, 1, 1}, -+ {28, 28, 28, 512, 128, 1, 1, 1, 1, 1, 1}, -+ {28, 28, 28, 128, 128, 3, 3, 3, 1, 1, 1}, -+ {28, 28, 28, 512, 1024, 1, 1, 1, 2, 2, 2}, -+ {28, 28, 28, 512, 256, 1, 1, 1, 1, 1, 1}, -+ {28, 28, 28, 256, 256, 3, 3, 3, 2, 2, 2}, -+ {14, 14, 14, 256, 1024, 1, 1, 1, 1, 1, 1}, -+ {14, 14, 14, 1024, 256, 1, 1, 1, 1, 1, 1}, -+ {14, 14, 14, 256, 256, 3, 3, 3, 1, 1, 1}, -+ {14, 14, 14, 1024, 2048, 1, 1, 1, 2, 2, 2}, -+ {14, 14, 14, 1024, 512, 1, 1, 1, 1, 1, 1}, -+ {14, 14, 14, 512, 512, 3, 3, 3, 2, 2, 2}, -+ { 7, 7, 7, 512, 2048, 1, 1, 1, 1, 1, 1}, -+ { 7, 7, 7, 2048, 512, 1, 1, 1, 1, 1, 1}, -+ { 7, 7, 7, 512, 512, 3, 3, 3, 1, 1, 1}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ options.update({N, layer.d, layer.h, layer.w, layer.c}, -+ {layer.k, layer.t, layer.r, layer.s, layer.c}, -+ cutlass::make_Coord(layer.stride_d, layer.stride_h, layer.stride_w)); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu b/3rdparty/cutlass/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu -new file mode 100644 -index 0000000..71f5040 ---- /dev/null -+++ b/3rdparty/cutlass/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu -@@ -0,0 +1,768 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+This example shows how to fuse per channel scale+bias+relu of the activations -+into the fprop mainloop. -+ -+Compared with original fprop kernel, this example has two more vectors, one for -+the scale and one for the bias. The length of the vectors are the same as the -+activation channel number. This kernels loads the vectors when the associated -+activation channels are loaded in the mainloop. Between reading the -+activations and scale/bias data from the shared memory and calling tensor core -+instructions, scale+bias+relu is computed in the register file. -+ -+This example is customized for Ampere 16816 fp16 tensor core instruction. -+Changing to different data types or different tensor core instruction require -+source code changing. See -+include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h for more -+technical details. -+ -+This example is modified based on 16_ampere_tensorop_conv2dfprop. The command -+line is the same. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop_fusion.h" -+#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutInputScaleBias = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 4; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+using Conv2dFpropFusionKernel = typename cutlass::conv::kernel::DefaultConv2dFpropFusion< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementInputScaleBias, LayoutInputScaleBias, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(true), -+ measure_performance(false), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size, -+ cutlass::MatrixCoord stride) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ conv_stride = stride; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "25_ampere_fprop_mainloop_fusion example\n\n" -+ << " This example fuses scale+bias+relu of the activations into Ampere's\n" -+ << " Tensor Core operators on F16 data types to compute\n" -+ << " forward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/25_ampere_fprop_mainloop_fusion/25_ampere_fprop_mainloop_fusion --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" -+ << "$ ./examples/25_ampere_fprop_mainloop_fusion/25_ampere_fprop_mainloop_fusion --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << options.conv_stride.row() << "," -+ << options.conv_stride.column() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_transformed_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor -+ tensor_a_scale({1, options.input_size.c()}); -+ cutlass::HostTensor -+ tensor_a_bias({1, options.input_size.c()}); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_d(options.output_size()); -+ cutlass::HostTensor tensor_ref_d(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill scale vector for tensor A on host with uniform-distribution random -+ // data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_scale.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill bias vector for tensor A on host with uniform-distribution random -+ // data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_bias.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(7), -+ ElementOutput(-8), -+ 0); -+ -+ // Fill tensor D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_a_scale.sync_device(); -+ tensor_a_bias.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ typename ImplicitGemmFusion::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_a_scale.device_ref(), -+ tensor_a_bias.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemmFusion implicit_gemm_fusion_op; -+ -+ size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_fusion_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_fusion_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on device...\n"; -+ -+ // Compute scale + bias + relu in host code -+ for (int n = 0; n < options.input_size.n(); ++n) { -+ for (int h = 0; h < options.input_size.h(); ++h) { -+ for (int w = 0; w < options.input_size.w(); ++w) { -+ for (int c = 0; c < options.input_size.c(); ++c) { -+ tensor_transformed_a.at({n, h, w, c}) = std::max( -+ ElementOutput(0), ElementOutput(tensor_a.at({n, h, w, c}) * -+ tensor_a_scale.at({0, c}) + -+ tensor_a_bias.at({0, c}))); -+ } -+ } -+ } -+ } -+ -+ tensor_transformed_a.sync_device(); -+ -+ // Compute with reference implementation -+ cutlass::reference::device::Conv2dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_transformed_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "25_ampere_fprop_mainloop_fusion" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_fusion_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "This test must run on SM80 or above.\n"; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {34, 408}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s, stride_h, stride_w; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1, 1, 1}, -+ {56, 56, 64, 64, 1, 1, 1, 1}, -+ {56, 56, 64, 64, 3, 3, 1, 1}, -+ {56, 56, 256, 64, 1, 1, 1, 1}, -+ {56, 56, 256, 512, 1, 1, 2, 2}, -+ {56, 56, 256, 128, 1, 1, 1, 1}, -+ {56, 56, 128, 128, 3, 3, 2, 2}, -+ {28, 28, 128, 512, 1, 1, 1, 1}, -+ {28, 28, 512, 128, 1, 1, 1, 1}, -+ {28, 28, 128, 128, 3, 3, 1, 1}, -+ {28, 28, 512, 1024, 1, 1, 2, 2}, -+ {28, 28, 512, 256, 1, 1, 1, 1}, -+ {28, 28, 256, 256, 3, 3, 2, 2}, -+ {14, 14, 256, 1024, 1, 1, 1, 1}, -+ {14, 14, 1024, 256, 1, 1, 1, 1}, -+ {14, 14, 256, 256, 3, 3, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1, 2, 2}, -+ {14, 14, 1024, 512, 1, 1, 1, 1}, -+ {14, 14, 512, 512, 3, 3, 2, 2}, -+ { 7, 7, 512, 2048, 1, 1, 1, 1}, -+ { 7, 7, 2048, 512, 1, 1, 1, 1}, -+ { 7, 7, 512, 512, 3, 3, 1, 1}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ options.update({N, layer.h, layer.w, layer.c}, -+ {layer.k, layer.r, layer.s, layer.c}, -+ {layer.stride_h, layer.stride_w}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu b/3rdparty/cutlass/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu -new file mode 100644 -index 0000000..48e2b77 ---- /dev/null -+++ b/3rdparty/cutlass/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu -@@ -0,0 +1,766 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+This example shows how to fuse activation's per channel scale+bias+relu -+into the wgrad mainloop. -+ -+Compared with original fprop kernel, this example has two more vectors, one for -+the scale and one for the bias. The length of the vectors are the same as the -+activation channel number. This kernels loads the vectors when the associated -+activation channels are loaded in the mainloop. Between reading the -+activations and scale/bias data from the shared memory and calling tensor core -+instructions, scale+bias+relu is computed in the register file. -+ -+This example is customized for Ampere 16816 fp16 tensor core instruction. -+Changing to different data types or different tensor core instruction require -+source code changing. See -+include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h for more -+technical details. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_wgrad_fusion.h" -+#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutInputScaleBias = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 5; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+using Conv2dWgradFusionKernel = typename cutlass::conv::kernel::DefaultConv2dWgradFusion< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementInputScaleBias, LayoutInputScaleBias, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(true), -+ measure_performance(false), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size, -+ cutlass::MatrixCoord stride) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ conv_stride = stride; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "26_ampere_wgrad_mainloop_fusion example\n\n" -+ << " This example fuses scale+bias+relu of the activation into Ampere's\n" -+ << " Tensor Core operators on F16 data types to compute\n" -+ << " backward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/26_ampere_wgrad_mainloop_fusion/26_ampere_wgrad_mainloop_fusion --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" -+ << "$ ./examples/26_ampere_wgrad_mainloop_fusion/26_ampere_wgrad_mainloop_fusion --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << options.conv_stride.row() << "," -+ << options.conv_stride.column() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.output_size()); -+ cutlass::HostTensor tensor_b(options.input_size); -+ cutlass::HostTensor tensor_transformed_b(options.input_size); -+ cutlass::HostTensor -+ tensor_b_scale({1, options.input_size.c()}); -+ cutlass::HostTensor -+ tensor_b_bias({1, options.input_size.c()}); -+ -+ cutlass::HostTensor tensor_c(options.filter_size); -+ cutlass::HostTensor tensor_d(options.filter_size); -+ cutlass::HostTensor tensor_ref_d(options.filter_size); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill scale vector for tensor B on host with uniform-distribution random -+ // data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b_scale.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill bias vector for tensor B on host with uniform-distribution random -+ // data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b_bias.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(7), -+ ElementOutput(-8), -+ 0); -+ -+ // Fill tensor D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_b_scale.sync_device(); -+ tensor_b_bias.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ typename ImplicitGemmFusion::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_b_scale.device_ref(), -+ tensor_b_bias.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemmFusion implicit_gemm_fusion_op; -+ -+ size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_fusion_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_fusion_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on device...\n"; -+ -+ // Compute scale + bias + relu in host code -+ for (int n = 0; n < options.input_size.n(); ++n) { -+ for (int h = 0; h < options.input_size.h(); ++h) { -+ for (int w = 0; w < options.input_size.w(); ++w) { -+ for (int c = 0; c < options.input_size.c(); ++c) { -+ tensor_transformed_b.at({n, h, w, c}) = std::max( -+ ElementOutput(0), ElementOutput(tensor_b.at({n, h, w, c}) * -+ tensor_b_scale.at({0, c}) + -+ tensor_b_bias.at({0, c}))); -+ } -+ } -+ } -+ } -+ -+ tensor_transformed_b.sync_device(); -+ -+ // Compute with reference implementation -+ cutlass::reference::device::Conv2dWgrad< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_transformed_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "26_ampere_wgrad_mainloop_fusion_" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_fusion_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major == 8 && props.minor == 0)) { -+ std::cerr << "This test must run on SM80 A100.\n"; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {34, 408}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s, stride_h, stride_w; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1, 1, 1}, -+ {56, 56, 64, 64, 1, 1, 1, 1}, -+ {56, 56, 64, 64, 3, 3, 1, 1}, -+ {56, 56, 256, 64, 1, 1, 1, 1}, -+ {56, 56, 256, 512, 1, 1, 2, 2}, -+ {56, 56, 256, 128, 1, 1, 1, 1}, -+ {56, 56, 128, 128, 3, 3, 2, 2}, -+ {28, 28, 128, 512, 1, 1, 1, 1}, -+ {28, 28, 512, 128, 1, 1, 1, 1}, -+ {28, 28, 128, 128, 3, 3, 1, 1}, -+ {28, 28, 512, 1024, 1, 1, 2, 2}, -+ {28, 28, 512, 256, 1, 1, 1, 1}, -+ {28, 28, 256, 256, 3, 3, 2, 2}, -+ {14, 14, 256, 1024, 1, 1, 1, 1}, -+ {14, 14, 1024, 256, 1, 1, 1, 1}, -+ {14, 14, 256, 256, 3, 3, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1, 2, 2}, -+ {14, 14, 1024, 512, 1, 1, 1, 1}, -+ {14, 14, 512, 512, 3, 3, 2, 2}, -+ { 7, 7, 512, 2048, 1, 1, 1, 1}, -+ { 7, 7, 2048, 512, 1, 1, 1, 1}, -+ { 7, 7, 512, 512, 3, 3, 1, 1}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ options.update({N, layer.h, layer.w, layer.c}, -+ {layer.k, layer.r, layer.s, layer.c}, -+ {layer.stride_h, layer.stride_w}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu b/3rdparty/cutlass/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu -new file mode 100644 -index 0000000..e9d0287 ---- /dev/null -+++ b/3rdparty/cutlass/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu -@@ -0,0 +1,750 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h) -+data types in tensor cores. One big advantage is that we can load in fp32 data and convert them -+implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate traditional -+fp32 data by using NVIDIA Ampere architecture. -+ -+We can use the tf32 mode of tensor core to emulate a fast accurate SGEMM kernel which is accelerated -+using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). -+ -+The trick is very simple -+ a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big -+ big = convert_to_tf32(fp32) -+ small = convert_to_tf32(fp32 - big) -+ -+a_small x b_small is discarded because they are too small. -+ -+This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32 -+results (SGEMM using SIMT) and against FP64 results (DGEMM) -+ -+To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to -+OpMultiplyAddFastF32. -+ -+Now, we have several different flavors of sgemm now in the profiler for Ampere. Here are the difference -+ -+ sgemm // CUDA core SIMT kernel. FP32 in, accumulated in FP32, FP32 out. -+ s1688gemm // Use 3xTF32 to emulate FP32. FP32 in, converted in TF32-big and TF32-small internally, -+ // accumulated in FP32, FP32 out. -+ s1688tf32gemm // Use 1xTF32. FP32 in, converted to one TF32 internally, accumulated in FP32, FP32 out. -+ s1688gemm_tf32 // TF32 in, accumulated in FP32, FP32 out. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_reduce.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ -+ int m, n, k; -+ double l2_norm_3xtf32_vs_fp64; -+ double l2_norm_1xtf32_vs_fp64; -+ double l2_norm_fp32_vs_fp64; -+ -+ // ctor -+ Result( -+ int m, int n, int k, -+ double runtime_ms, double gflops, -+ double l2_norm_3xtf32_vs_fp64, -+ double l2_norm_1xtf32_vs_fp64, -+ double l2_norm_fp32_vs_fp64) : -+ m(m), n(n), k(k), -+ runtime_ms(runtime_ms), gflops(gflops), -+ l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64), -+ l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64), -+ l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {} -+ -+ Result() {} -+ -+ // -+ // Methods -+ // -+ static void print_csv_header() { -+ std::cout << "M,N,K,Runtime(ms),GFLOPS,3xTF32_vs_FP64,1xTF32_vs_FP64,FP32_vs_FP64" << std::endl; -+ } -+ -+ void print_csv_row() { -+ std::cout << m << "," -+ << n << "," -+ << k << "," -+ << runtime_ms << "," -+ << gflops << "," -+ << l2_norm_3xtf32_vs_fp64 << "," -+ << l2_norm_1xtf32_vs_fp64 << "," -+ << l2_norm_fp32_vs_fp64 << std::endl; -+ } -+}; -+ -+std::vector results; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha; -+ float beta; -+ std::string rand_mode; -+ -+ int iterations; -+ int seed; -+ bool benchmark; -+ -+ Options(): -+ help(false), -+ problem_size({3456, 4096, 4096}), -+ iterations(20), -+ seed(1), -+ alpha(1), -+ beta(), -+ rand_mode("uniform"), -+ benchmark(false) { } -+ -+ bool valid() { -+ // -+ // CUTLASS attempts to load 128b vectors of F32 elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 4 elements. -+ // -+ int const kAlignment = 4; -+ -+ if ((problem_size.m() % kAlignment) || -+ (problem_size.n() % kAlignment) || -+ (problem_size.k() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("seed", seed); -+ cmd.get_cmd_line_argument("rand_mode", rand_mode); -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "27_ampere_3xtf32_fast_accurate_tensorop_gemm example\n\n" -+ << " This example uses the CUTLASS Library to emulate FP32 with TF32 tensorop GEMM computations.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --rand_mode= gauss / uniform*\n\n" -+ << " --seed= Random number seed (1*)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm --m=1024 --n=512 \\\n" -+ << " --alpha=2 --beta=0.707 \n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product(); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::RowMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 64, 16>; // <- threadblock tile M = 128, N = 128, K = 16 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M = 64, N = 64, K = 16 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ float, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ float, // <- data type of accumulator -+ float>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+// Alignment -+constexpr int Alignment = 4; -+ -+// -+// Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64) -+// -+ -+// Gemm_3xTF32 -+using Gemm_3xTF32 = cutlass::gemm::device::Gemm< -+ float, -+ LayoutInputA, -+ float, -+ LayoutInputB, -+ float, -+ LayoutOutput, -+ float, -+ MMAOp, -+ SmArch, -+ ShapeMMAThreadBlock, -+ ShapeMMAWarp, -+ ShapeMMAOp, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ Alignment, -+ Alignment, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32>; -+ -+// Gemm_1xTF32 -+using Gemm_1xTF32 = cutlass::gemm::device::Gemm< -+ float, -+ LayoutInputA, -+ float, -+ LayoutInputB, -+ float, -+ LayoutOutput, -+ float, -+ MMAOp, -+ SmArch, -+ ShapeMMAThreadBlock, -+ ShapeMMAWarp, -+ ShapeMMAOp, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ Alignment, -+ Alignment, -+ false, -+ cutlass::arch::OpMultiplyAdd>; -+ -+// Gemm_F64 -+using Gemm_F64 = cutlass::reference::device::Gemm< -+ double, -+ LayoutInputA, -+ double, -+ LayoutInputB, -+ double, -+ LayoutOutput, -+ double, -+ double>; -+ -+// Gemm_F32 -+using Gemm_F32 = cutlass::reference::device::Gemm< -+ float, -+ LayoutInputA, -+ float, -+ LayoutInputB, -+ float, -+ LayoutOutput, -+ float, -+ float>; -+ -+bool run(Options &options) { -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size = options.problem_size; -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 1. Initialize F32 Precision input tensors using CUTLASS helper functions -+ //////////////////////////////////////////////////////////////////////////////// -+ cutlass::HostTensor tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ -+ if (options.rand_mode == "uniform") { -+ const float min = -1; -+ const float max = 1; -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix C on host with uniform-distribution random data -+ } else if (options.rand_mode == "gauss") { -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_a_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix A on host with gaussian-distribution random data -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_b_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix B on host with gaussian-distribution random data -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_c_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix C on host with gaussian-distribution random data -+ } -+ cutlass::reference::host::TensorFill( -+ tensor_d_F32.host_view()); // <- fill matrix D on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a_F32.sync_device(); -+ tensor_b_F32.sync_device(); -+ tensor_c_F32.sync_device(); -+ tensor_d_F32.sync_device(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 2. Initialize F64 tensors using the same values used for F32 -+ //////////////////////////////////////////////////////////////////////////////// -+ // Gemm input operands (A, B, C) -+ cutlass::HostTensor tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N -+ -+ // Gemm output (D) for GEMM_F64 -+ cutlass::HostTensor tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ // Gemm output (D) for GEMM_3xTF32 -+ cutlass::HostTensor tensor_d_3xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ // Gemm output (D) for GEMM_1xTF32 -+ cutlass::HostTensor tensor_d_1xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ -+ // Copy values from the DP tensors -+ cutlass::reference::host::TensorCopy(tensor_a_F64.host_view(), tensor_a_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_b_F64.host_view(), tensor_b_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_c_F64.host_view(), tensor_c_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a_F64.sync_device(); -+ tensor_b_F64.sync_device(); -+ tensor_c_F64.sync_device(); -+ tensor_d_F64.sync_device(); -+ tensor_d_3xTF32.sync_device(); -+ tensor_d_1xTF32.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ float alpha = float(options.alpha); -+ float beta = float(options.beta); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 3. Run 3xTF32 kernel within a profiling loop -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm_3xTF32::Arguments arguments_3xtf32{problem_size, // <- problem size of matrix multiplication -+ tensor_a_F32.device_ref(), // <- reference to matrix A on device -+ tensor_b_F32.device_ref(), // <- reference to matrix B on device -+ tensor_c_F32.device_ref(), // <- reference to matrix C on device -+ tensor_d_3xTF32.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_3xtf32 = Gemm_3xTF32::get_workspace_size(arguments_3xtf32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_3xtf32(workspace_size_3xtf32); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm_3xTF32 gemm_op_3xTF32; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_3xtf32 = gemm_op_3xTF32.can_implement(arguments_3xtf32); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_3xtf32 = gemm_op_3xTF32.initialize(arguments_3xtf32, workspace_3xtf32.get()); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ // Result structure -+ Result result; -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMMs -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ // Launch initialized CUTLASS kernel -+ status_3xtf32 = gemm_op_3xTF32(); -+ CUTLASS_CHECK(status_3xtf32); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMMs are complete -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.m = problem_size.m(); -+ result.n = problem_size.n(); -+ result.k = problem_size.k(); -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ tensor_d_3xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 4. Run TF32 kernel without profiling loop -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm_1xTF32::Arguments arguments_1xtf32{problem_size, // <- problem size of matrix multiplication -+ tensor_a_F32.device_ref(), // <- reference to matrix A on device -+ tensor_b_F32.device_ref(), // <- reference to matrix B on device -+ tensor_c_F32.device_ref(), // <- reference to matrix C on device -+ tensor_d_1xTF32.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_1xtf32 = Gemm_1xTF32::get_workspace_size(arguments_1xtf32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_1xtf32(workspace_size_1xtf32); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm_1xTF32 gemm_op_1xtf32; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_1xtf32 = gemm_op_1xtf32.initialize(arguments_1xtf32, workspace_1xtf32.get()); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ // Launch initialized CUTLASS kernel -+ status_1xtf32 = gemm_op_1xtf32(); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ tensor_d_1xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ // Run reference kernel (F64) -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // Create instantiation for device reference gemm kernel -+ Gemm_F64 gemm_f64; -+ -+ // Launch device reference gemm kernel -+ gemm_f64(problem_size, -+ alpha, -+ tensor_a_F64.device_ref(), -+ tensor_b_F64.device_ref(), -+ beta, -+ tensor_c_F64.device_ref(), -+ tensor_d_F64.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_F64.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ // Run reference kernel (F32) -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // Create instantiation for device reference gemm kernel -+ Gemm_F32 gemm_f32; -+ -+ // Launch device reference gemm kernel -+ gemm_f32(problem_size, -+ alpha, -+ tensor_a_F32.device_ref(), -+ tensor_b_F32.device_ref(), -+ beta, -+ tensor_c_F32.device_ref(), -+ tensor_d_F32.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_F32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /////// Compute l2 norms -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // l2 norm 3xTF32 vs F64 -+ cutlass::HostTensor tensor_d_3xTF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view()); -+ -+ result.l2_norm_3xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm 1xTF32 vs F64 -+ cutlass::HostTensor tensor_d_1xTF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32_in_F64.host_view(), tensor_d_1xTF32.host_view()); -+ -+ result.l2_norm_1xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_1xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm F32 vs F64 -+ cutlass::HostTensor tensor_d_F32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_F32_in_F64.host_view(), tensor_d_F32.host_view()); -+ -+ result.l2_norm_fp32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_F32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ results.push_back(result); -+ -+ /////////////////////////////////////////////////////////////////////////////// -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ -+ std::cout << std::fixed; -+ std::cout.precision(4); -+ std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout.precision(2); -+ std::cout << "GFLOPs: " << result.gflops << std::endl; -+ std::cout << "Normalized L2 norm of" << std::endl; -+ std::cout.precision(8); -+ std::cout << std::scientific -+ << " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl -+ << " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl -+ << " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl; -+ -+ return true; -+} -+ -+int main(int argc, const char **argv) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return false; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ bool result = true; -+ -+ if (options.benchmark) { -+ for (int k = 4; k <= 65536; k *= 2) { -+ -+ options.problem_size[2] = k; -+ -+ printf("Gemm problem size: %d x %d x %d\n", \ -+ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ result &= run(options); -+ } -+ } else { -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ result = run(options); -+ } -+ -+ if (!result) return -1; -+ -+ std::cout << std::endl << "CSV results" << std::endl; -+ Result::print_csv_header(); -+ for(auto &r : results) -+ r.print_csv_row(); -+ -+ return 0; -+} -diff --git a/3rdparty/cutlass/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu b/3rdparty/cutlass/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu -new file mode 100644 -index 0000000..27286f9 ---- /dev/null -+++ b/3rdparty/cutlass/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu -@@ -0,0 +1,822 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+This example adopts example 16 to use 3xTF32 to bring FP32 accuracy with 2x performance -+compared with CUDA Cores. See example 27 for the trick of 3xTF32. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = float; // Data type of elements in input tensor -+using ElementInputB = float; // Data type of elements in input tensor -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+// 3xTF32 Fprop -+using Conv2dFpropKernel_3xTF32 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ // Only thing needs to be changed from normal Fprop -+ cutlass::arch::OpMultiplyAddFastF32, -+ IteratorAlgorithm -+>::Kernel; -+ -+// 1xTF32 Fprop -+using Conv2dFpropKernel_1xTF32 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemm_3xTF32 = cutlass::conv::device::ImplicitGemmConvolution; -+using ImplicitGemm_1xTF32 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 4; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "28_ampere_3xtf32_fast_accurate_tensorop_fprop example\n\n" -+ << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" -+ << " forward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/28_ampere_3xtf32_fast_accurate_tensorop_fprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" -+ << "$ ./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/28_ampere_3xtf32_fast_accurate_tensorop_fprop --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ -+ double l2_norm_3xtf32_vs_fp64; -+ double l2_norm_1xtf32_vs_fp64; -+ double l2_norm_fp32_vs_fp64; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ error(cudaSuccess), -+ l2_norm_3xtf32_vs_fp64(0), -+ l2_norm_1xtf32_vs_fp64(0), -+ l2_norm_fp32_vs_fp64(0) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs,3xTF32_vs_FP64,1xTF32_vs_FP64,FP32_vs_FP64"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << runtime_ms << "," -+ << gflops << "," -+ << l2_norm_3xtf32_vs_fp64 << "," -+ << l2_norm_1xtf32_vs_fp64 << "," -+ << l2_norm_fp32_vs_fp64; -+ -+ return out; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 1. Initialize F32 Precision input tensors using CUTLASS helper functions -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a_F32(options.input_size); -+ cutlass::HostTensor tensor_b_F32(options.filter_size); -+ cutlass::HostTensor tensor_c_F32(options.output_size()); -+ cutlass::HostTensor tensor_d_F32(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_F32.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8)); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b_F32.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8)); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c_F32.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8)); -+ -+ // Fill tensor D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_d_F32.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a_F32.sync_device(); -+ tensor_b_F32.sync_device(); -+ tensor_c_F32.sync_device(); -+ tensor_d_F32.sync_device(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 2. Initialize F32 Precision input tensors using CUTLASS helper functions -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a_F64(options.input_size); -+ cutlass::HostTensor tensor_b_F64(options.filter_size); -+ cutlass::HostTensor tensor_c_F64(options.output_size()); -+ -+ cutlass::HostTensor tensor_d_F64(options.output_size()); -+ cutlass::HostTensor tensor_d_3xTF32(options.output_size()); -+ cutlass::HostTensor tensor_d_1xTF32(options.output_size()); -+ -+ // Copy values from the DP tensors -+ cutlass::reference::host::TensorCopy(tensor_a_F64.host_view(), tensor_a_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_b_F64.host_view(), tensor_b_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_c_F64.host_view(), tensor_c_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a_F64.sync_device(); -+ tensor_b_F64.sync_device(); -+ tensor_c_F64.sync_device(); -+ tensor_d_F64.sync_device(); -+ tensor_d_3xTF32.sync_device(); -+ tensor_d_1xTF32.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 3. Run 3xTF32 kernel within a profiling loop -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // Construct ImplicitGemm::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename ImplicitGemm_3xTF32::Arguments arguments_3xTF32{ -+ problem_size, -+ tensor_a_F32.device_ref(), -+ tensor_b_F32.device_ref(), -+ tensor_c_F32.device_ref(), -+ tensor_d_3xTF32.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm_3xTF32 implicit_gemm_op_3xTF32; -+ -+ size_t workspace_size_3xTF32 = implicit_gemm_op_3xTF32.get_workspace_size(arguments_3xTF32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_3xTF32(workspace_size_3xTF32); -+ -+ result.status = implicit_gemm_op_3xTF32.can_implement(arguments_3xTF32); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op_3xTF32.initialize(arguments_3xTF32, workspace_3xTF32.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op_3xTF32(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Performance measurement -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_op_3xTF32(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ tensor_d_3xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 4. Run 1xTF32 kernel within a profiling loop -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // Construct ImplicitGemm::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename ImplicitGemm_1xTF32::Arguments arguments_1xTF32{ -+ problem_size, -+ tensor_a_F32.device_ref(), -+ tensor_b_F32.device_ref(), -+ tensor_c_F32.device_ref(), -+ tensor_d_1xTF32.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm_1xTF32 implicit_gemm_op_1xTF32; -+ -+ size_t workspace_size_1xTF32 = implicit_gemm_op_1xTF32.get_workspace_size(arguments_1xTF32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_1xTF32(workspace_size_1xTF32); -+ -+ result.status = implicit_gemm_op_1xTF32.can_implement(arguments_1xTF32); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op_1xTF32.initialize(arguments_1xTF32, workspace_1xTF32.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op_1xTF32(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ tensor_d_1xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ // Run reference kernel (F64) -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ cutlass::reference::device::Conv2d< -+ double, -+ LayoutInputA, -+ double, -+ LayoutInputB, -+ double, -+ LayoutOutput, -+ double, -+ double -+ >( -+ cutlass::conv::Operator::kFprop, -+ problem_size, -+ tensor_a_F64.device_ref(), -+ tensor_b_F64.device_ref(), -+ tensor_c_F64.device_ref(), -+ tensor_d_F64.device_ref(), -+ options.alpha, -+ options.beta); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_F64.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ // Run reference kernel (F32) -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ cutlass::reference::device::Conv2d< -+ float, -+ LayoutInputA, -+ float, -+ LayoutInputB, -+ float, -+ LayoutOutput, -+ float, -+ float -+ >( -+ cutlass::conv::Operator::kFprop, -+ problem_size, -+ tensor_a_F32.device_ref(), -+ tensor_b_F32.device_ref(), -+ tensor_c_F32.device_ref(), -+ tensor_d_F32.device_ref(), -+ options.alpha, -+ options.beta); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_F32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /////// Compute l2 norms -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // l2 norm 3xTF32 vs F64 -+ cutlass::HostTensor tensor_d_3xTF32_in_F64(options.output_size()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view()); -+ -+ result.l2_norm_3xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm 1xTF32 vs F64 -+ cutlass::HostTensor tensor_d_1xTF32_in_F64(options.output_size()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32_in_F64.host_view(), tensor_d_1xTF32.host_view()); -+ -+ result.l2_norm_1xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_1xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm F32 vs F64 -+ cutlass::HostTensor tensor_d_F32_in_F64(options.output_size()); -+ cutlass::reference::host::TensorCopy(tensor_d_F32_in_F64.host_view(), tensor_d_F32.host_view()); -+ -+ result.l2_norm_fp32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_F32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ /////////////////////////////////////////////////////////////////////////////// -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "28_ampere_3xtf32_fast_accurate_tensorop_fprop_" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a_F32.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b_F32.host_view() << "\n\n"; -+ -+ output_workspace << "TF32x3 = \n" << tensor_d_3xTF32.host_view() << std::endl; -+ output_workspace << "TF32x1 = \n" << tensor_d_1xTF32.host_view() << std::endl; -+ output_workspace << "FP32 = \n" << tensor_d_F32.host_view() << std::endl; -+ output_workspace << "FP64 = \n" << tensor_d_F64.host_view() << "\n\n"; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {1, 32, 64, 128, 256}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1}, -+ {56, 56, 64, 64, 1, 1}, -+ {56, 56, 64, 64, 3, 3}, -+ {56, 56, 256, 64, 1, 1}, -+ {56, 56, 256, 512, 1, 1}, -+ {56, 56, 256, 128, 1, 1}, -+ {28, 28, 128, 128, 3, 3}, -+ {28, 28, 128, 512, 1, 1}, -+ {28, 28, 512, 128, 1, 1}, -+ {28, 28, 512, 1024, 1, 1}, -+ {28, 28, 512, 256, 1, 1}, -+ {14, 14, 256, 256, 3, 3}, -+ {14, 14, 256, 1024, 1, 1}, -+ {14, 14, 1024, 256, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1}, -+ {14, 14, 1024, 512, 1, 1}, -+ {7, 7, 512, 512, 3, 3}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ -+ options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu b/3rdparty/cutlass/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu -new file mode 100644 -index 0000000..fc6f6af ---- /dev/null -+++ b/3rdparty/cutlass/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu -@@ -0,0 +1,692 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ This example is almost the same as example 27 which uses 3xTF32 to run GEMM. The only -+ difference is that this example uses 3xtf32 on complex gemm. -+ -+ To enable this feature, the only change needs to make is to change OpMultiplyAddComplex -+ to OpMultiplyAddComplexFastF32. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_reduce.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ -+ int m, n, k; -+ double l2_norm_3xtf32_vs_fp64; -+ double l2_norm_1xtf32_vs_fp64; -+ double l2_norm_fp32_vs_fp64; -+ -+ // ctor -+ Result( -+ int m, int n, int k, -+ double runtime_ms, double gflops, -+ double l2_norm_3xtf32_vs_fp64, -+ double l2_norm_1xtf32_vs_fp64, -+ double l2_norm_fp32_vs_fp64) : -+ m(m), n(n), k(k), -+ runtime_ms(runtime_ms), gflops(gflops), -+ l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64), -+ l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64), -+ l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {} -+ -+ Result() {} -+ -+ // -+ // Methods -+ // -+ static void print_csv_header() { -+ std::cout << "M,N,K,Runtime(ms),GFLOPS,3xTF32_vs_FP64,1xTF32_vs_FP64,FP32_vs_FP64" << std::endl; -+ } -+ -+ void print_csv_row() { -+ std::cout << m << "," -+ << n << "," -+ << k << "," -+ << runtime_ms << "," -+ << gflops << "," -+ << l2_norm_3xtf32_vs_fp64 << "," -+ << l2_norm_1xtf32_vs_fp64 << "," -+ << l2_norm_fp32_vs_fp64 << std::endl; -+ } -+}; -+ -+std::vector results; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha; -+ float beta; -+ std::string rand_mode; -+ -+ int iterations; -+ int seed; -+ bool benchmark; -+ -+ Options(): -+ help(false), -+ problem_size({3456, 4096, 4096}), -+ iterations(20), -+ seed(1), -+ alpha(1), -+ beta(), -+ rand_mode("uniform"), -+ benchmark(false) { } -+ -+ bool valid() { -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("seed", seed); -+ cmd.get_cmd_line_argument("rand_mode", rand_mode); -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm example\n\n" -+ << " This example uses the CUTLASS Library to emulate FP32 complex GEMM computations with TF32 tensor cores.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --rand_mode= gauss / uniform*\n\n" -+ << " --seed= Random number seed (1*)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_complex_gemm --m=1024 --n=512 \\\n" -+ << " --alpha=2 --beta=0.707 \n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product(); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 128, N = 128, K = 16 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 64, N = 64, K = 16 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ cutlass::complex, // <- data type of output matrix -+ 1, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ cutlass::complex, // <- data type of accumulator -+ cutlass::complex>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+// Transform -+constexpr cutlass::ComplexTransform TransformA = cutlass::ComplexTransform::kNone; -+constexpr cutlass::ComplexTransform TransformB = cutlass::ComplexTransform::kNone; -+ -+// -+// Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64) -+// -+ -+// Gemm_3xTF32 -+using Gemm_3xTF32 = cutlass::gemm::device::GemmComplex< -+ cutlass::complex, -+ LayoutInputA, -+ cutlass::complex, -+ LayoutInputB, -+ cutlass::complex, -+ LayoutOutput, -+ cutlass::complex, -+ MMAOp, -+ SmArch, -+ ShapeMMAThreadBlock, -+ ShapeMMAWarp, -+ ShapeMMAOp, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ TransformA, -+ TransformB, -+ cutlass::arch::OpMultiplyAddComplexFastF32>; -+ -+// Gemm_1xTF32 -+using Gemm_1xTF32 = cutlass::gemm::device::GemmComplex< -+ cutlass::complex, -+ LayoutInputA, -+ cutlass::complex, -+ LayoutInputB, -+ cutlass::complex, -+ LayoutOutput, -+ cutlass::complex, -+ MMAOp, -+ SmArch, -+ ShapeMMAThreadBlock, -+ ShapeMMAWarp, -+ ShapeMMAOp, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ TransformA, -+ TransformB, -+ cutlass::arch::OpMultiplyAddComplex>; -+ -+bool run(Options &options) { -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size = options.problem_size; -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 1. Initialize F32 Precision input tensors using CUTLASS helper functions -+ //////////////////////////////////////////////////////////////////////////////// -+ cutlass::HostTensor, LayoutInputA> tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor, LayoutInputB> tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor, LayoutOutput> tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ -+ if (options.rand_mode == "uniform") { -+ const float min = -1; -+ const float max = 1; -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix C on host with uniform-distribution random data -+ } else if (options.rand_mode == "gauss") { -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_a_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix A on host with gaussian-distribution random data -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_b_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix B on host with gaussian-distribution random data -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_c_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix C on host with gaussian-distribution random data -+ } -+ cutlass::reference::host::TensorFill( -+ tensor_d_F32.host_view()); // <- fill matrix D on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a_F32.sync_device(); -+ tensor_b_F32.sync_device(); -+ tensor_c_F32.sync_device(); -+ tensor_d_F32.sync_device(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 2. Initialize F64 tensors using the same values used for F32 -+ //////////////////////////////////////////////////////////////////////////////// -+ // Gemm input operands (A, B, C) -+ cutlass::HostTensor, LayoutInputA> tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor, LayoutInputB> tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor, LayoutOutput> tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N -+ -+ // Gemm output (D) for GEMM_F64 -+ cutlass::HostTensor, LayoutOutput> tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ // Gemm output (D) for GEMM_3xTF32 -+ cutlass::HostTensor, LayoutOutput> tensor_d_3xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ // Gemm output (D) for GEMM_1xTF32 -+ cutlass::HostTensor, LayoutOutput> tensor_d_1xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ -+ // Copy values from the DP tensors -+ cutlass::reference::host::TensorCopy(tensor_a_F64.host_view(), tensor_a_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_b_F64.host_view(), tensor_b_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_c_F64.host_view(), tensor_c_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a_F64.sync_device(); -+ tensor_b_F64.sync_device(); -+ tensor_c_F64.sync_device(); -+ tensor_d_F64.sync_device(); -+ tensor_d_3xTF32.sync_device(); -+ tensor_d_1xTF32.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ cutlass::complex alpha = cutlass::complex(options.alpha); -+ cutlass::complex beta = cutlass::complex(options.beta); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 3. Run 3xTF32 kernel within a profiling loop -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm_3xTF32::Arguments arguments_3xtf32{problem_size, // <- problem size of matrix multiplication -+ tensor_a_F32.device_ref(), // <- reference to matrix A on device -+ tensor_b_F32.device_ref(), // <- reference to matrix B on device -+ tensor_c_F32.device_ref(), // <- reference to matrix C on device -+ tensor_d_3xTF32.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_3xtf32 = Gemm_3xTF32::get_workspace_size(arguments_3xtf32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_3xtf32(workspace_size_3xtf32); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm_3xTF32 gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_3xtf32 = gemm_op.can_implement(arguments_3xtf32); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_3xtf32 = gemm_op.initialize(arguments_3xtf32, workspace_3xtf32.get()); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ // Result structure -+ Result result; -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMMs -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ // Launch initialized CUTLASS kernel -+ status_3xtf32 = gemm_op(); -+ CUTLASS_CHECK(status_3xtf32); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMMs are complete -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.m = problem_size.m(); -+ result.n = problem_size.n(); -+ result.k = problem_size.k(); -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ tensor_d_3xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 4. Run TF32 kernel without profiling loop -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm_1xTF32::Arguments arguments_1xtf32{problem_size, // <- problem size of matrix multiplication -+ tensor_a_F32.device_ref(), // <- reference to matrix A on device -+ tensor_b_F32.device_ref(), // <- reference to matrix B on device -+ tensor_c_F32.device_ref(), // <- reference to matrix C on device -+ tensor_d_1xTF32.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_1xtf32 = Gemm_1xTF32::get_workspace_size(arguments_1xtf32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_1xtf32(workspace_size_1xtf32); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm_1xTF32 gemm_op_1xtf32; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_1xtf32 = gemm_op_1xtf32.initialize(arguments_1xtf32, workspace_1xtf32.get()); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ // Launch initialized CUTLASS kernel -+ status_1xtf32 = gemm_op_1xtf32(); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ tensor_d_1xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ // Run reference kernel (F64) -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // Launch device reference gemm kernel -+ cutlass::reference::device::GemmComplex( -+ problem_size, -+ alpha, -+ tensor_a_F64.device_ref(), -+ TransformA, -+ tensor_b_F64.device_ref(), -+ TransformB, -+ beta, -+ tensor_c_F64.device_ref(), -+ tensor_d_F64.device_ref(), -+ cutlass::complex(0.f)); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_F64.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ // Run reference kernel (F32) -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // Launch device reference gemm kernel -+ cutlass::reference::device::GemmComplex( -+ problem_size, -+ alpha, -+ tensor_a_F32.device_ref(), -+ TransformA, -+ tensor_b_F32.device_ref(), -+ TransformB, -+ beta, -+ tensor_c_F32.device_ref(), -+ tensor_d_F32.device_ref(), -+ cutlass::complex(0.f)); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_F32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /////// Compute l2 norms -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // l2 norm 3xTF32 vs F64 -+ cutlass::HostTensor, LayoutOutput> tensor_d_3xTF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view()); -+ -+ result.l2_norm_3xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm 1xTF32 vs F64 -+ cutlass::HostTensor, LayoutOutput> tensor_d_1xTF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32_in_F64.host_view(), tensor_d_1xTF32.host_view()); -+ -+ result.l2_norm_1xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_1xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm F32 vs F64 -+ cutlass::HostTensor, LayoutOutput> tensor_d_F32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_F32_in_F64.host_view(), tensor_d_F32.host_view()); -+ -+ result.l2_norm_fp32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_F32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ results.push_back(result); -+ -+ /////////////////////////////////////////////////////////////////////////////// -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ -+ std::cout << std::fixed; -+ std::cout.precision(4); -+ std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout.precision(2); -+ std::cout << "GFLOPs: " << result.gflops << std::endl; -+ std::cout << "Normalized L2 norm of" << std::endl; -+ std::cout.precision(8); -+ std::cout << std::scientific -+ << " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl -+ << " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl -+ << " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl; -+ -+ return true; -+} -+ -+int main(int argc, const char **argv) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return false; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ bool result = true; -+ -+ if (options.benchmark) { -+ for (int k = 4; k <= 65536; k *= 2) { -+ -+ options.problem_size[2] = k; -+ -+ printf("Gemm problem size: %d x %d x %d\n", \ -+ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ result &= run(options); -+ } -+ } else { -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ result = run(options); -+ } -+ -+ if (!result) return -1; -+ -+ std::cout << std::endl << "CSV results" << std::endl; -+ Result::print_csv_header(); -+ for(auto &r : results) -+ r.print_csv_row(); -+ -+ return 0; -+} -diff --git a/3rdparty/cutlass/examples/30_wgrad_split_k/30_wgrad_split_k.cu b/3rdparty/cutlass/examples/30_wgrad_split_k/30_wgrad_split_k.cu -new file mode 100644 -index 0000000..e512242 ---- /dev/null -+++ b/3rdparty/cutlass/examples/30_wgrad_split_k/30_wgrad_split_k.cu -@@ -0,0 +1,791 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+This example shows how to compute conv2d gradient with respect to weight (wgrad). In wgrad, the K dimension of -+impligit GEMM, corresponding to the sequential reduction loop, is very large (N * P * Q). Split-k with parallel -+reduction is highly effective for such cases. Given split_k_slices parameter, it partitions the K loop into -+split_k_slices chunks and computes partial reductions in parallel across different blocks. After that, -+a parallel reduction kernel is launched to accumulate partial reductions. -+In practice, wgrad requires fp32 accumulation to avoid overflow. When the input is fp16, some care is needed -+to correctly instantiate the GEMM template. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+// In Wgrad, fp32 accumulation is necessary in practice. -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementOutput = cutlass::half_t; // Data type of elements in output tensor -+using ElementC = ElementOutput; -+using ElementCompute = ElementComputeEpilogue; -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// We need two epilogue functors - one for GEMM and another for the final reduction. -+// The epilogue for GEMM is not used, but needed to instantiate the CUTLASS kernel template. -+// Note that, when the input is fp16 and accumulation is fp32, the output of GEMM needs to be fp32, -+// the final reduction is done in fp32, and the reduction epilogue converts fp32 outputs to fp16. -+// Therefore, the output type of the GEMM epilogue is ElementCompute, not ElementOutput. -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOpGEMM = cutlass::epilogue::thread::LinearCombination< -+ ElementCompute, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+// The epilogue functor for reduction. This is the one that is actually used. -+using EpilogueOpReduction = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in lin -+ -+using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementAccumulator, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOpGEMM, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+ >::Kernel; -+ -+using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; -+ -+using EpilogueOutputOp = EpilogueOpReduction; -+ -+/// Reduction kernel -+using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ >; -+ -+using ReductionDevice = cutlass::reduction::device::ReduceSplitK; -+using ReductionStrideIndex = typename ReductionDevice::StrideIndex; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ int split_k_slices; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(true), -+ measure_performance(false), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ split_k_slices(8), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size, -+ cutlass::MatrixCoord stride) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ conv_stride = stride; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ cmd.get_cmd_line_argument("split-k-slices", split_k_slices); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "30_wgrad_split_k example\n\n" -+ << " This example shows how to compute conv2d gradient with respect to weight (wgrad).\n" -+ << " In wgrad, the K dimension of impligit GEMM, corresponding to the sequential reduction loop, is very large (N * P * Q).\n" -+ << " Split-k with parallel reduction is highly effective for such cases.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --split-k-slices= Split-k factor \n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/30_wgrad_split_k/30_wgrad_split_k --n=32 --h=224 --w=224 --c=128 --k=256 --r=3 --s=3 --split-k-slices=8\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord(input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << options.conv_stride.row() << "," -+ << options.conv_stride.column() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ // Inputs are the output gradient and the original activation. -+ cutlass::HostTensor tensor_a(options.output_size()); -+ cutlass::HostTensor tensor_b(options.input_size); -+ cutlass::HostTensor tensor_c(options.filter_size); -+ cutlass::HostTensor tensor_d(options.filter_size); -+ cutlass::HostTensor tensor_ref_d(options.filter_size); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C, D on host with zeros -+ cutlass::reference::host::TensorFill(tensor_c.host_view()); -+ -+ cutlass::reference::host::TensorFill(tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill(tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Partition the GEMM K loop into split_k_slices chunks -+ int split_k_slices = options.split_k_slices; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ // Do not forget to pass the last argument. -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ using cutlass::layout::TensorNHWC; -+ -+ cutlass::conv::SplitKMode const split_k_mode = cutlass::conv::SplitKMode::kParallel; -+ -+ // Since the epilogue is not computed after GEMM, there is no need to pass the C tensor and -+ // alpha and beta can be set to 1 and 0 respectively. -+ // Moreover, since the output will be written to the workspace, there is no need to pass -+ // the D tensor as well. -+ // Do not forget to pass the last argument. -+ typename ImplicitGemm::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ {nullptr, TensorNHWC()}, -+ {nullptr, TensorNHWC()}, -+ {ElementCompute(1), ElementCompute(0)}, -+ split_k_mode -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm implicit_gemm; -+ -+ size_t workspace_size = implicit_gemm.get_workspace_size(arguments); -+ -+ // Split-K requires non-zero workspace size. The workspace size grows linearly with split_k_slices. -+ std::cout << "split-k-slices: " << split_k_slices << std::endl; -+ std::cout << "workspace size: " << workspace_size << std::endl; -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ // After the workspace is allocated, we point the GEMM destination pointer to the workspace. -+ TensorNHWC layout_D{TensorNHWC::packed(options.filter_size)}; -+ arguments.ref_D.reset(reinterpret_cast(workspace.get()), layout_D); -+ -+ result.status = implicit_gemm.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ // Do reduction -+ ReductionDevice reduction_op; -+ auto& status = result.status; -+ static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemm::kConvolutionalOperator; -+ typename ReductionDevice::Arguments reduction_args( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), -+ problem_size.split_k_slices, -+ cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), -+ // Reduction input -+ { -+ reinterpret_cast (workspace.get()), -+ ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ // Destination -+ { -+ tensor_d.device_data(), -+ ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ // Source -+ { -+ tensor_c.device_data(), -+ ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ {options.alpha, options.beta} -+ ); -+ -+ status = reduction_op.initialize(reduction_args, nullptr); -+ status = reduction_op(); -+ } -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on device...\n"; -+ -+ // Compute with reference implementation -+ cutlass::reference::device::Conv2dWgrad< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_c.sync_host(); -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "26_ampere_fused_wgrad_batch_normalization_" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {34, 408}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s, stride_h, stride_w; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1, 1, 1}, -+ {56, 56, 64, 64, 1, 1, 1, 1}, -+ {56, 56, 64, 64, 3, 3, 1, 1}, -+ {56, 56, 256, 64, 1, 1, 1, 1}, -+ {56, 56, 256, 512, 1, 1, 2, 2}, -+ {56, 56, 256, 128, 1, 1, 1, 1}, -+ {56, 56, 128, 128, 3, 3, 2, 2}, -+ {28, 28, 128, 512, 1, 1, 1, 1}, -+ {28, 28, 512, 128, 1, 1, 1, 1}, -+ {28, 28, 128, 128, 3, 3, 1, 1}, -+ {28, 28, 512, 1024, 1, 1, 2, 2}, -+ {28, 28, 512, 256, 1, 1, 1, 1}, -+ {28, 28, 256, 256, 3, 3, 2, 2}, -+ {14, 14, 256, 1024, 1, 1, 1, 1}, -+ {14, 14, 1024, 256, 1, 1, 1, 1}, -+ {14, 14, 256, 256, 3, 3, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1, 2, 2}, -+ {14, 14, 1024, 512, 1, 1, 1, 1}, -+ {14, 14, 512, 512, 3, 3, 2, 2}, -+ { 7, 7, 512, 2048, 1, 1, 1, 1}, -+ { 7, 7, 2048, 512, 1, 1, 1, 1}, -+ { 7, 7, 512, 512, 3, 3, 1, 1}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ options.update({N, layer.h, layer.w, layer.c}, -+ {layer.k, layer.r, layer.s, layer.c}, -+ {layer.stride_h, layer.stride_w}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/31_basic_syrk/basic_syrk.cu b/3rdparty/cutlass/examples/31_basic_syrk/basic_syrk.cu -new file mode 100644 -index 0000000..82f4a6a ---- /dev/null -+++ b/3rdparty/cutlass/examples/31_basic_syrk/basic_syrk.cu -@@ -0,0 +1,522 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example demonstrates how to call a CUTLASS SYRK kernel and provides a naive reference -+ matrix multiply kernel to verify its correctness. -+ -+ The CUTLASS Syrk template is instantiated in the function CutlassSsyrkNN. This is kernel computes -+ the symmetric rank-k update (SYRK) using double-precision doubleing-point arithmetic and assumes -+ all matrices have column-major layout. -+ -+ The threadblock tile size is chosen as 16x32x16 which offers good performance for large matrices. -+ See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available -+ in CUTLASS. -+ -+ https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/ -+ -+ Aside from defining and launching the SSYRK kernel, this example does not use any other components -+ or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are -+ prevalent in the CUTLASS unit tests. -+ -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Helper methods to check for errors -+#include "helper.h" -+ -+// -+// CUTLASS includes needed for double-precision SYRK kernel -+// -+ -+// Defines cutlass::gemm::device::Syrk, the generic Syrk computation template class. -+#include "cutlass/gemm/device/rank_k.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// This function defines a CUTLASS SYRK kernel instantiation, constructs its parameters object, -+// and launches it on the CUDA device. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Define a CUTLASS SYRK template and launch a SYRK kernel. -+cudaError_t CutlassSsyrkNN( -+ int N, -+ int K, -+ double alpha, -+ double const *A, -+ int lda, -+ double beta, -+ double *C, -+ int ldc) { -+ -+ // Define type definition for double-precision CUTLASS SYRK with column-major -+ // input matrices and 16x32x16 threadblock tile size (chosen by default). -+ // -+ // To keep the interface manageable, several helpers are defined for plausible compositions -+ // including the following example for double-precision SYRK. Typical values are used as -+ // default template arguments. -+ // -+ // To view the full syrk device API interface, see `cutlass/gemm/device/syrk.h` -+ -+ using ColumnMajor = cutlass::layout::ColumnMajor; -+ -+ using CutlassSyrk = cutlass::gemm::device::RankK< -+ double, -+ ColumnMajor, -+ double, -+ ColumnMajor, -+ cutlass::FillMode::kLower, -+ double, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ double, -+ 1, -+ double, -+ double -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 5, // Stages -+ 1, // AligmentA -+ false, // SplitKSerail -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ // Define a CUTLASS SYRK type -+ CutlassSyrk syrk_operator; -+ -+ // Construct the CUTLASS SYRK arguments object. -+ // -+ // One of CUTLASS's design patterns is to define syrk argument objects that are constructible -+ // in host code and passed to kernels by value. These may include pointers, strides, scalars, -+ // and other arguments needed by Syrk and its components. -+ // -+ // The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible -+ // arguments to kernels and (2.) minimized initialization overhead on kernel entry. -+ // -+ CutlassSyrk::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, -+ {N, N, K}, // Syrk Problem dimensions -+ 1, // batch_count, -+ {alpha, beta}, // Scalars used in the Epilogue -+ reinterpret_cast(A), -+ const_cast(reinterpret_cast(C)), -+ reinterpret_cast(C), // destination matrix D (may be different memory than source C matrix) -+ (int64_t)N*K, // Batch strides -+ (int64_t)N*N, -+ (int64_t)N*N, -+ lda, -+ ldc, -+ ldc); -+ -+ // -+ // Launch the CUTLASS SYRK kernel. -+ // -+ -+ cutlass::Status status = syrk_operator(args); -+ -+ // -+ // Return a cudaError_t if the CUTLASS SYRK operator returned an error code. -+ // -+ -+ if (status != cutlass::Status::kSuccess) { -+ return cudaErrorUnknown; -+ } -+ -+ // Return success, if no errors were encountered. -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// The source code after this point in the file is generic CUDA using the CUDA Runtime API -+// and simple CUDA kernels to initialize matrices and compute the general matrix product. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to initialize a matrix with small integers. -+__global__ void InitializeMatrix_kernel( -+ double *matrix, -+ int ldm, -+ int rows, -+ int columns, -+ int seed = 0) { -+ -+ int i = threadIdx.x + blockIdx.x * blockDim.x; -+ int j = threadIdx.y + blockIdx.y * blockDim.y; -+ -+ if (i < rows && j < columns) { -+ int offset = i + j * ldm; -+ -+ // Generate arbitrary elements. -+ int const k = 16807; -+ int const m = 16; -+ double value = double(((offset + seed) * k % m) - m / 2); -+ -+ matrix[offset] = value; -+ } -+} -+ -+/// Simple function to initialize a matrix to arbitrary small integers. -+cudaError_t InitializeMatrix(double *matrix, int ldm, int rows, int columns, int seed = 0) { -+ -+ dim3 block(16, 16); -+ dim3 grid( -+ (rows + block.x - 1) / block.x, -+ (columns + block.y - 1) / block.y -+ ); -+ -+ InitializeMatrix_kernel<<< grid, block >>>(matrix, ldm, rows, columns, seed); -+ -+ return cudaGetLastError(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocates device memory for a matrix then fills with arbitrary small integers. -+cudaError_t AllocateMatrix(double **matrix, int ldm, int rows, int columns, int seed = 0) { -+ cudaError_t result; -+ -+ size_t sizeof_matrix = sizeof(double) * ldm * columns; -+ -+ // Allocate device memory. -+ result = cudaMalloc(reinterpret_cast(matrix), sizeof_matrix); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to allocate matrix: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ // Clear the allocation. -+ result = cudaMemset(*matrix, 0, sizeof_matrix); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to clear matrix device memory: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ // Initialize matrix elements to arbitrary small integers. -+ result = InitializeMatrix(*matrix, ldm, rows, columns, seed); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to initialize matrix: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ return result; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Naive reference SYRK computation. -+__global__ void ReferenceSyrk_kernel( -+ int N, -+ int K, -+ double alpha, -+ double const *A, -+ int lda, -+ double beta, -+ double *C, -+ int ldc) { -+ -+ int i = threadIdx.x + blockIdx.x * blockDim.x; -+ int j = threadIdx.y + blockIdx.y * blockDim.y; -+ -+ if (i < N && j < N && i >= j ) { // Since C is in Lower Fill Mode -+ double accumulator = 0; -+ -+ for (int k = 0; k < K; ++k) { -+ accumulator += A[i + k * lda] * A[j + k * lda]; -+ } -+ -+ C[i + j * ldc] = alpha * accumulator + beta * C[i + j * ldc]; -+ } -+} -+ -+/// Reference SYRK computation. -+cudaError_t ReferenceSyrk( -+ int N, -+ int K, -+ double alpha, -+ double const *A, -+ int lda, -+ double beta, -+ double *C, -+ int ldc) { -+ -+ dim3 block(16, 16); -+ dim3 grid( -+ (N + block.x - 1) / block.x, -+ (N + block.y - 1) / block.y -+ ); -+ -+ ReferenceSyrk_kernel<<< grid, block >>>(N, K, alpha, A, lda, beta, C, ldc); -+ -+ return cudaGetLastError(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocate several matrices in GPU device memory and call a double-precision -+/// CUTLASS SYRK kernel. -+cudaError_t TestCutlassSyrk(int N, int K, double alpha, double beta) { -+ cudaError_t result; -+ -+ // -+ // Define several matrices to be used as operands to SYRK kernels. -+ // -+ -+ // Compute leading dimensions for each matrix. -+ int lda = N; -+ int ldc = N; -+ -+ // Compute size in bytes of the C matrix. -+ size_t sizeof_C = sizeof(double) * ldc * N; -+ -+ // Define pointers to matrices in GPU device memory. -+ double *A; -+ double *C_cutlass; -+ double *C_reference; -+ -+ // -+ // Allocate matrices in GPU device memory with arbitrary seeds. -+ // -+ -+ result = AllocateMatrix(&A, lda, N, K, 0); -+ -+ if (result != cudaSuccess) { -+ return result; -+ } -+ -+ result = AllocateMatrix(&C_cutlass, ldc, N, N, 101); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ return result; -+ } -+ -+ result = AllocateMatrix(&C_reference, ldc, N, N, 101); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ cudaFree(C_cutlass); -+ return result; -+ } -+ -+ result = cudaMemcpy(C_reference, C_cutlass, sizeof_C, cudaMemcpyDeviceToDevice); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy C_cutlass matrix to C_reference: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Launch CUTLASS SYRK. -+ // -+ -+ result = CutlassSsyrkNN(N, K, alpha, A, lda, beta, C_cutlass, ldc); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "CUTLASS SYRK kernel failed: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Verify. -+ // -+ -+ // Launch reference SYRK -+ result = ReferenceSyrk(N, K, alpha, A, lda, beta, C_reference, ldc); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Reference SYRK kernel failed: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // Copy to host and verify equivalence. -+ std::vector host_cutlass(ldc * N, 0); -+ std::vector host_reference(ldc * N, 0); -+ -+ result = cudaMemcpy(host_cutlass.data(), C_cutlass, sizeof_C, cudaMemcpyDeviceToHost); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy CUTLASS SYRK results: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ result = cudaMemcpy(host_reference.data(), C_reference, sizeof_C, cudaMemcpyDeviceToHost); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy Reference SYRK results: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Free device memory allocations. -+ // -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(A); -+ -+ // -+ // Test for bit equivalence of results. -+ // -+ -+ if (host_cutlass != host_reference) { -+ std::cerr << "CUTLASS results incorrect." << std::endl; -+ -+ return cudaErrorUnknown; -+ } -+ -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point to basic_syrk example. -+// -+// usage: -+// -+// 00_basic_syrk -+// -+int main(int argc, const char *arg[]) { -+ -+ bool notSupported = false; -+ -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "NVIDIA Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ -+ return -1; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ -+ std::cerr << "This example requires compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ // -+ // Parse the command line to obtain SYRK dimensions and scalar values. -+ // -+ -+ // SYRK problem dimensions. -+ int problem[2] = { 128, 128 }; -+ -+ for (int i = 1; i < argc && i < 3; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> problem[i - 1]; -+ } -+ -+ // Scalars used for linear scaling the result of the matrix product. -+ double scalars[2] = { 1, 0 }; -+ -+ for (int i = 3; i < argc && i < 5; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> scalars[i - 3]; -+ } -+ -+ // -+ // Run the CUTLASS SYRK test. -+ // -+ -+ cudaError_t result = TestCutlassSyrk( -+ problem[0], // SYRK N dimension -+ problem[1], // SYRK K dimension -+ scalars[0], // alpha -+ scalars[1] // beta -+ ); -+ -+ if (result == cudaSuccess) { -+ std::cout << "Passed." << std::endl; -+ } -+ -+ // Exit. -+ return result == cudaSuccess ? 0 : -1; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/32_basic_trmm/basic_trmm.cu b/3rdparty/cutlass/examples/32_basic_trmm/basic_trmm.cu -new file mode 100644 -index 0000000..74f5cb9 ---- /dev/null -+++ b/3rdparty/cutlass/examples/32_basic_trmm/basic_trmm.cu -@@ -0,0 +1,550 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example demonstrates how to call a CUTLASS TRMM kernel and provides a naive reference -+ matrix multiply kernel to verify its correctness. -+ -+ The CUTLASS Trmm template is instantiated in the function CutlassStrmmNN. This is kernel computes -+ the triangular matrix product (TRMM) using double-precision doubleing-point arithmetic and assumes -+ all matrices have column-major layout. -+ -+ The threadblock tile size is chosen as 64x64x16 which offers good performance for large matrices. -+ See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available -+ in CUTLASS. -+ -+ https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/ -+ -+ Aside from defining and launching the STRMM kernel, this example does not use any other components -+ or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are -+ prevalent in the CUTLASS unit tests. -+ -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Helper methods to check for errors -+#include "helper.h" -+ -+// -+// CUTLASS includes needed for double-precision TRMM kernel -+// -+ -+// Defines cutlass::gemm::device::Trmm, the generic Trmm computation template class. -+#include "cutlass/gemm/device/trmm.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// This function defines a CUTLASS TRMM kernel instantiation, constructs its parameters object, -+// and launches it on the CUDA device. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Define a CUTLASS TRMM template and launch a TRMM kernel. -+cudaError_t CutlassStrmmNN( -+ int M, -+ int N, -+ double alpha, -+ double const *A, -+ int lda, -+ double const *B, -+ int ldb, -+ double *C, -+ int ldc) { -+ -+ // Define type definition for double-precision CUTLASS TRMM with column-major -+ // input matrices and 64x64x16 threadblock tile size (chosen by default). -+ // -+ // To keep the interface manageable, several helpers are defined for plausible compositions -+ // including the following example for double-precision TRMM. Typical values are used as -+ // default template arguments. -+ // -+ // To view the full trmm device API interface, see `cutlass/gemm/device/trmm.h` -+ -+ using ColumnMajor = cutlass::layout::ColumnMajor; -+ -+ using CutlassTrmm = cutlass::gemm::device::Trmm< -+ double, -+ ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ ColumnMajor, -+ double, -+ ColumnMajor, -+ double, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ double, -+ 1, -+ double, -+ double, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 5, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ // Define a CUTLASS TRMM type -+ CutlassTrmm trmm_operator; -+ -+ // Construct the CUTLASS TRMM arguments object. -+ // -+ // One of CUTLASS's design patterns is to define trmm argument objects that are constructible -+ // in host code and passed to kernels by value. These may include pointers, strides, scalars, -+ // and other arguments needed by Trmm and its components. -+ // -+ // The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible -+ // arguments to kernels and (2.) minimized initialization overhead on kernel entry. -+ // -+ CutlassTrmm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, -+ {M, N, M}, // Trmm Problem dimensions in Left-Side Mode -+ 1, // batch_count, -+ {alpha}, // Scalars used in the Epilogue -+ reinterpret_cast(A), -+ reinterpret_cast(B), -+ reinterpret_cast(C), // destination matrix D (may be different memory than source C matrix) -+ (int64_t)M*M, // Batch strides -+ (int64_t)M*N, -+ (int64_t)M*N, -+ lda, -+ ldb, -+ ldc); -+ -+ // -+ // Launch the CUTLASS TRMM kernel. -+ // -+ -+ cutlass::Status status = trmm_operator(args); -+ -+ // -+ // Return a cudaError_t if the CUTLASS TRMM operator returned an error code. -+ // -+ -+ if (status != cutlass::Status::kSuccess) { -+ return cudaErrorUnknown; -+ } -+ -+ // Return success, if no errors were encountered. -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// The source code after this point in the file is generic CUDA using the CUDA Runtime API -+// and simple CUDA kernels to initialize matrices and compute the general matrix product. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to initialize a matrix with small integers. -+__global__ void InitializeMatrix_kernel( -+ double *matrix, -+ int ldm, -+ int rows, -+ int columns, -+ int seed = 0, -+ cutlass::FillMode fill_mode = cutlass::FillMode::kInvalid) { -+ -+ int i = threadIdx.x + blockIdx.x * blockDim.x; -+ int j = threadIdx.y + blockIdx.y * blockDim.y; -+ -+ if (i < rows && j < columns) { -+ if (fill_mode == cutlass::FillMode::kLower && i < j) return; -+ else if (fill_mode == cutlass::FillMode::kUpper && i > j) return; -+ int offset = i + j * ldm; -+ -+ // Generate arbitrary elements. -+ int const k = 16807; -+ int const m = 16; -+ double value = double(((offset + seed) * k % m) - m / 2); -+ -+ matrix[offset] = value; -+ } -+} -+ -+/// Simple function to initialize a matrix to arbitrary small integers. -+cudaError_t InitializeMatrix(double *matrix, int ldm, int rows, int columns, int seed = 0, -+ cutlass::FillMode fill_mode = cutlass::FillMode::kInvalid) { -+ -+ dim3 block(16, 16); -+ dim3 grid( -+ (rows + block.x - 1) / block.x, -+ (columns + block.y - 1) / block.y -+ ); -+ -+ InitializeMatrix_kernel<<< grid, block >>>(matrix, ldm, rows, columns, seed, fill_mode); -+ -+ return cudaGetLastError(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocates device memory for a matrix then fills with arbitrary small integers. -+cudaError_t AllocateMatrix(double **matrix, int ldm, int rows, int columns, int seed = 0, -+ cutlass::FillMode fill_mode = cutlass::FillMode::kInvalid) { -+ cudaError_t result; -+ -+ size_t sizeof_matrix = sizeof(double) * ldm * columns; -+ -+ // Allocate device memory. -+ result = cudaMalloc(reinterpret_cast(matrix), sizeof_matrix); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to allocate matrix: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ // Clear the allocation. -+ result = cudaMemset(*matrix, 0, sizeof_matrix); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to clear matrix device memory: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ // Initialize matrix elements to arbitrary small integers. -+ result = InitializeMatrix(*matrix, ldm, rows, columns, seed, fill_mode); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to initialize matrix: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ return result; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Naive reference TRMM computation. -+__global__ void ReferenceTrmm_kernel( -+ int M, -+ int N, -+ double alpha, -+ double const *A, -+ int lda, -+ double const *B, -+ int ldb, -+ double *C, -+ int ldc) { -+ -+ int i = threadIdx.x + blockIdx.x * blockDim.x; -+ int j = threadIdx.y + blockIdx.y * blockDim.y; -+ -+ if (i < M && j < N) { -+ double accumulator = 0; -+ -+ for (int k = 0; k < M; ++k) { -+ accumulator += A[i + k * lda] * B[k + j * ldb]; // Since A is in Left-Side Mode -+ } -+ -+ C[i + j * ldc] = alpha * accumulator; -+ } -+} -+ -+/// Reference TRMM computation. -+cudaError_t ReferenceTrmm( -+ int M, -+ int N, -+ double alpha, -+ double const *A, -+ int lda, -+ double const *B, -+ int ldb, -+ double *C, -+ int ldc) { -+ -+ dim3 block(16, 16); -+ dim3 grid( -+ (M + block.x - 1) / block.x, -+ (N + block.y - 1) / block.y -+ ); -+ -+ ReferenceTrmm_kernel<<< grid, block >>>(M, N, alpha, A, lda, B, ldb, C, ldc); -+ -+ return cudaGetLastError(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocate several matrices in GPU device memory and call a double-precision -+/// CUTLASS TRMM kernel. -+cudaError_t TestCutlassTrmm(int M, int N, double alpha) { -+ cudaError_t result; -+ -+ // -+ // Define several matrices to be used as operands to TRMM kernels. -+ // -+ -+ // Compute leading dimensions for each matrix. -+ int lda = M; -+ int ldb = M; -+ int ldc = M; -+ -+ // Compute size in bytes of the C matrix. -+ size_t sizeof_C = sizeof(double) * ldc * N; -+ -+ // Define pointers to matrices in GPU device memory. -+ double *A; -+ double *B; -+ double *C_cutlass; -+ double *C_reference; -+ -+ // -+ // Allocate matrices in GPU device memory with arbitrary seeds. -+ // -+ -+ result = AllocateMatrix(&A, lda, M, M, 0, cutlass::FillMode::kLower); -+ -+ if (result != cudaSuccess) { -+ return result; -+ } -+ -+ result = AllocateMatrix(&B, ldb, M, N, 17); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ return result; -+ } -+ -+ result = AllocateMatrix(&C_cutlass, ldc, M, N, 101); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ cudaFree(B); -+ return result; -+ } -+ -+ result = AllocateMatrix(&C_reference, ldc, M, N, 101); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ cudaFree(B); -+ cudaFree(C_cutlass); -+ return result; -+ } -+ -+ result = cudaMemcpy(C_reference, C_cutlass, sizeof_C, cudaMemcpyDeviceToDevice); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy C_cutlass matrix to C_reference: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Launch CUTLASS TRMM. -+ // -+ -+ result = CutlassStrmmNN(M, N, alpha, A, lda, B, ldb, C_cutlass, ldc); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "CUTLASS TRMM kernel failed: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Verify. -+ // -+ -+ // Launch reference TRMM -+ result = ReferenceTrmm(M, N, alpha, A, lda, B, ldb, C_reference, ldc); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Reference TRMM kernel failed: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // Copy to host and verify equivalence. -+ std::vector host_cutlass(ldc * N, 0); -+ std::vector host_reference(ldc * N, 0); -+ -+ result = cudaMemcpy(host_cutlass.data(), C_cutlass, sizeof_C, cudaMemcpyDeviceToHost); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy CUTLASS TRMM results: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ result = cudaMemcpy(host_reference.data(), C_reference, sizeof_C, cudaMemcpyDeviceToHost); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy Reference TRMM results: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Free device memory allocations. -+ // -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ // -+ // Test for bit equivalence of results. -+ // -+ -+ if (host_cutlass != host_reference) { -+ std::cerr << "CUTLASS results incorrect." << std::endl; -+ -+ return cudaErrorUnknown; -+ } -+ -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point to basic_trmm example. -+// -+// usage: -+// -+// 00_basic_trmm -+// -+int main(int argc, const char *arg[]) { -+ -+ bool notSupported = false; -+ -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "NVIDIA Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ -+ return -1; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ -+ std::cerr << "This example requires compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ // -+ // Parse the command line to obtain TRMM dimensions and scalar values. -+ // -+ -+ // TRMM problem dimensions. -+ int problem[2] = { 128, 128 }; -+ -+ for (int i = 1; i < argc && i < 3; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> problem[i - 1]; -+ } -+ -+ // Scalars used for linear scaling the result of the matrix product. -+ double scalars[1] = { 1 }; -+ -+ for (int i = 3; i < argc && i < 4; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> scalars[i - 3]; -+ } -+ -+ // -+ // Run the CUTLASS TRMM test. -+ // -+ -+ cudaError_t result = TestCutlassTrmm( -+ problem[0], // TRMM M dimension -+ problem[1], // TRMM N dimension -+ scalars[0] // alpha -+ ); -+ -+ if (result == cudaSuccess) { -+ std::cout << "Passed." << std::endl; -+ } -+ -+ // Exit. -+ return result == cudaSuccess ? 0 : -1; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu b/3rdparty/cutlass/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu -new file mode 100644 -index 0000000..c938e23 ---- /dev/null -+++ b/3rdparty/cutlass/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu -@@ -0,0 +1,687 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h) -+data types in tensor cores. One big advantage is that we can load in F32 data and convert them -+implicitly to tf32 inside the SYMM kernel which means no change is needed to accelerate traditional -+F32 data by using NVIDIA Ampere architecture. -+ -+We can use the tf32 mode of tensor core to emulate a fast accurate SYMM kernel which is accelerated -+using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). -+ -+The trick is very simple -+ a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big -+ big = convert_to_tf32(F32) -+ small = convert_to_tf32(F32 - big) -+ -+a_small x b_small is discarded because they are too small. -+ -+This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual F32 -+results (SSYMM from cuBLAS) and against F64 results (DSYMM from CUTLASS) -+ -+To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to -+OpMultiplyAddFastF32. -+ -+Now, we have two different flavors of SSYMM in the profiler for Ampere: -+ -+ s1688symm // Use 3xTF32 to emulate F32. F32 in, converted in TF32-big and TF32-small internally, -+ // accumulated in F32, F32 out. -+ s1688tf32symm // Use 1xTF32. F32 in, converted to one TF32 internally, accumulated in F32, F32 out. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_reduce.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+#if CUTLASS_ENABLE_CUBLAS -+#include -+#endif -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha; -+ float beta; -+ std::string rand_mode; -+ int seed; -+ -+ Options(): -+ help(false), -+ problem_size({4096, 4096, 4096}), -+ seed(1), -+ alpha(1), -+ beta(), -+ rand_mode("uniform") { } -+ -+ bool valid() { -+ // -+ // CUTLASS attempts to load 128b vectors of F32 elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 4 elements. -+ // -+ int const kAlignment = 4; -+ -+ if ((problem_size.m() % kAlignment) || -+ (problem_size.n() % kAlignment) || -+ (problem_size.k() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ // Since the kernels in this example are in Left Side Mode -+ cmd.get_cmd_line_argument("m", problem_size.k()); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("seed", seed); -+ cmd.get_cmd_line_argument("rand_mode", rand_mode); -+ -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "33_ampere_3xtf32_tensorop_symm example\n\n" -+ << " This example uses the CUTLASS Library to execute 3xTF32 tensorop SYMM computations.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= SYMM M dimension\n" -+ << " --n= SYMM N dimension\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --rand_mode= gauss / uniform*\n\n" -+ << " --seed= Random number seed (1*)\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/33_ampere_3xtf32_tensorop_symm/33_ampere_3xtf32_tensorop_symm --m=1024 --n=512 \\\n" -+ << " --alpha=2 --beta=1 \n\n"; -+ -+ return out; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Matrix B and Matrix C (since that's what cuBLAS supports, CUTLASS supports Row Major too) -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::ColumnMajor; -+ -+// Symmetric Matrix A is in Left Side mode -+constexpr cutlass::SideMode SideModeA = cutlass::SideMode::kLeft; -+// Symmetric Matrix A is in Lower Filled mode -+constexpr cutlass::FillMode FillModeA = cutlass::FillMode::kLower; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 64, 16>; // <- threadblock tile M = 128, N = 128, K = 16 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M = 64, N = 64, K = 16 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ float, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ float, // <- data type of accumulator -+ float>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+// Alignment -+constexpr int Alignment = 4; -+ -+// -+// CUTLASS Symm Operators (SSYM: Symm_3xTF32, Symm_1xTF32, DSYMM: Symm_F64) -+// -+ -+// Symm_3xTF32 -+using Symm_3xTF32 = cutlass::gemm::device::Symm< -+ float, -+ LayoutInputA, -+ SideModeA, -+ FillModeA, -+ float, -+ LayoutInputB, -+ float, -+ LayoutOutput, -+ float, -+ MMAOp, -+ SmArch, -+ ShapeMMAThreadBlock, -+ ShapeMMAWarp, -+ ShapeMMAOp, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ 1, // Symmetric matrix is always align 1 -+ Alignment, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32>; -+ -+// Symm_1xTF32 -+using Symm_1xTF32 = cutlass::gemm::device::Symm< -+ float, -+ LayoutInputA, -+ SideModeA, -+ FillModeA, -+ float, -+ LayoutInputB, -+ float, -+ LayoutOutput, -+ float, -+ MMAOp, -+ SmArch, -+ ShapeMMAThreadBlock, -+ ShapeMMAWarp, -+ ShapeMMAOp, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ 1, // Symmetric matrix is always align 1 -+ Alignment, -+ false, -+ cutlass::arch::OpMultiplyAdd>; -+ -+// Symm_F64 -+using Symm_F64 = cutlass::gemm::device::Symm< -+ double, -+ LayoutInputA, -+ SideModeA, -+ FillModeA, -+ double, -+ LayoutInputB, -+ double, -+ LayoutOutput, -+ double, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ double, -+ 1, -+ double, -+ double -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4>; -+ -+bool run(Options &options) { -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size = options.problem_size; -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 1. Initialize F32 Precision input tensors using CUTLASS helper functions -+ //////////////////////////////////////////////////////////////////////////////// -+ cutlass::HostTensor tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ -+ if (options.rand_mode == "uniform") { -+ const float min = -1; -+ const float max = 1; -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix C on host with uniform-distribution random data -+ } else if (options.rand_mode == "gauss") { -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_a_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix A on host with gaussian-distribution random data -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_b_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix B on host with gaussian-distribution random data -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_c_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix C on host with gaussian-distribution random data -+ } -+ cutlass::reference::host::TensorFill( -+ tensor_d_F32.host_view()); // <- fill matrix D on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a_F32.sync_device(); -+ tensor_b_F32.sync_device(); -+ tensor_c_F32.sync_device(); -+ tensor_d_F32.sync_device(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 2. Initialize F64 tensors, Output tensors and setup arguments -+ //////////////////////////////////////////////////////////////////////////////// -+ // Symm F64 input operands (A, B, C) -+ cutlass::HostTensor tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N -+ -+ // Symm output (D) for SYMM_3xTF32 -+ cutlass::HostTensor tensor_d_3xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ // Symm output (D) for SYMM_1xTF32 -+ cutlass::HostTensor tensor_d_1xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ // Symm output (D) for SYMM_F64 -+ cutlass::HostTensor tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N -+#if CUTLASS_ENABLE_CUBLAS -+ // Symm output (D) for SYMM_cublasF32 -+ cutlass::HostTensor tensor_d_cublasF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+#endif -+ -+ // Copy values from the DP tensors -+ cutlass::reference::host::TensorCopy(tensor_a_F64.host_view(), tensor_a_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_b_F64.host_view(), tensor_b_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_c_F64.host_view(), tensor_c_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); -+#if CUTLASS_ENABLE_CUBLAS -+ cutlass::reference::host::TensorCopy(tensor_d_cublasF32.host_view(), tensor_d_F32.host_view()); -+#endif -+ -+ // Copy data from host to GPU -+ tensor_a_F64.sync_device(); -+ tensor_b_F64.sync_device(); -+ tensor_c_F64.sync_device(); -+ tensor_d_F64.sync_device(); -+ tensor_d_3xTF32.sync_device(); -+ tensor_d_1xTF32.sync_device(); -+#if CUTLASS_ENABLE_CUBLAS -+ tensor_d_cublasF32.sync_device(); -+#endif -+ -+ // Initialize alpha and beta for dot product computation -+ float alpha = float(options.alpha); -+ float beta = float(options.beta); -+ -+ // Batch count as 1 -+ int batch_count = 1; -+ -+ // Batch stride for A, when matrix A is in Left Side mode -+ int batch_stride_A = problem_size.m()*problem_size.m(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 3. Run 3xTF32 kernel -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Symm_3xTF32::Arguments arguments_3xtf32{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, // <- problem size of matrix multiplication -+ batch_count, // <- batch count -+ {alpha, beta}, // <- tuple of alpha and beta -+ tensor_a_F32.device_data(), // <- reference to matrix A on device -+ tensor_b_F32.device_data(), // <- reference to matrix B on device -+ tensor_c_F32.device_data(), // <- reference to matrix C on device -+ tensor_d_3xTF32.device_data(), // <- reference to matrix D on device -+ batch_stride_A, // <- batch stride and ld for matrices -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_a_F32.layout().stride(0), -+ tensor_b_F32.layout().stride(0), -+ tensor_c_F32.layout().stride(0), -+ tensor_d_3xTF32.layout().stride(0) -+ }; -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_3xtf32 = Symm_3xTF32::get_workspace_size(arguments_3xtf32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_3xtf32(workspace_size_3xtf32); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Symm_3xTF32 symm_op_3xtf32; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_3xtf32 = symm_op_3xtf32.can_implement(arguments_3xtf32); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_3xtf32 = symm_op_3xtf32.initialize(arguments_3xtf32, workspace_3xtf32.get()); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ // Launch initialized CUTLASS kernel -+ status_3xtf32 = symm_op_3xtf32(); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ tensor_d_3xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 4. Run 1xTF32 kernel -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Symm_1xTF32::Arguments arguments_1xtf32{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, // <- problem size of matrix multiplication -+ batch_count, // <- batch count -+ {alpha, beta}, // <- tuple of alpha and beta -+ tensor_a_F32.device_data(), // <- reference to matrix A on device -+ tensor_b_F32.device_data(), // <- reference to matrix B on device -+ tensor_c_F32.device_data(), // <- reference to matrix C on device -+ tensor_d_1xTF32.device_data(), // <- reference to matrix D on device -+ batch_stride_A, // <- batch stride and ld for matrices -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_a_F32.layout().stride(0), -+ tensor_b_F32.layout().stride(0), -+ tensor_c_F32.layout().stride(0), -+ tensor_d_1xTF32.layout().stride(0) -+ }; -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_1xtf32 = Symm_1xTF32::get_workspace_size(arguments_1xtf32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_1xtf32(workspace_size_1xtf32); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Symm_1xTF32 symm_op_1xtf32; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_1xtf32 = symm_op_1xtf32.can_implement(arguments_1xtf32); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_1xtf32 = symm_op_1xtf32.initialize(arguments_1xtf32, workspace_1xtf32.get()); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ // Launch initialized CUTLASS kernel -+ status_1xtf32 = symm_op_1xtf32(); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ tensor_d_1xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 5. Run F64 kernel -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Symm_F64::Arguments arguments_f64{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, // <- problem size of matrix multiplication -+ batch_count, // <- batch count -+ {double(options.alpha), double(options.alpha)}, // <- tuple of alpha and beta -+ tensor_a_F64.device_data(), // <- reference to matrix A on device -+ tensor_b_F64.device_data(), // <- reference to matrix B on device -+ tensor_c_F64.device_data(), // <- reference to matrix C on device -+ tensor_d_F64.device_data(), // <- reference to matrix D on device -+ batch_stride_A, // <- batch stride and ld for matrices -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_a_F64.layout().stride(0), -+ tensor_b_F64.layout().stride(0), -+ tensor_c_F64.layout().stride(0), -+ tensor_d_F64.layout().stride(0) -+ }; -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_f64 = Symm_F64::get_workspace_size(arguments_f64); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_f64(workspace_size_f64); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Symm_F64 symm_op_f64; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_f64 = symm_op_f64.can_implement(arguments_f64); -+ CUTLASS_CHECK(status_f64); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_f64 = symm_op_f64.initialize(arguments_f64, workspace_f64.get()); -+ CUTLASS_CHECK(status_f64); -+ -+ // Launch initialized CUTLASS kernel -+ status_f64 = symm_op_f64(); -+ CUTLASS_CHECK(status_f64); -+ -+ cudaDeviceSynchronize(); -+ -+ tensor_d_F64.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 6. Run cuBLAS SSYMM kernel -+ //////////////////////////////////////////////////////////////////////////////// -+ -+#if CUTLASS_ENABLE_CUBLAS -+ cublasStatus_t cublas_status; -+ cublasHandle_t handle; -+ -+ cublas_status = cublasCreate(&handle); -+ if (cublas_status != CUBLAS_STATUS_SUCCESS) { -+ std::cerr << "Failed to create cuBLAS handle." << std::endl; -+ return false; -+ } -+ -+ cublas_status = cublasSsymm( -+ handle, -+ CUBLAS_SIDE_LEFT, -+ CUBLAS_FILL_MODE_LOWER, -+ problem_size.m(), -+ problem_size.n(), -+ static_cast(&alpha), -+ static_cast(tensor_a_F32.device_data()), -+ int(tensor_a_F32.layout().stride(0)), -+ static_cast(tensor_b_F32.device_data()), -+ int(tensor_b_F32.layout().stride(0)), -+ static_cast(&beta), -+ static_cast(tensor_d_cublasF32.device_data()), -+ int(tensor_d_cublasF32.layout().stride(0)) -+ ); -+ -+ cudaDeviceSynchronize(); -+ -+ tensor_d_cublasF32.sync_host(); -+#endif -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 7. Compute l2 norms -+ //////////////////////////////////////////////////////////////////////////////// -+ -+#if CUTLASS_ENABLE_CUBLAS -+ // l2 norm cuBLAS F32 vs F64 -+ cutlass::HostTensor tensor_d_cublasF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_cublasF32_in_F64.host_view(), tensor_d_cublasF32.host_view()); -+ -+ double l2_norm_cublasf32_vs_f64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_cublasF32_in_F64.host_view(), tensor_d_F64.host_view()); -+#endif -+ -+ // l2 norm 3xTF32 vs F64 -+ cutlass::HostTensor tensor_d_3xTF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view()); -+ double l2_norm_3xtf32_vs_f64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm 1xTF32 vs F64 -+ cutlass::HostTensor tensor_d_1xTF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32_in_F64.host_view(), tensor_d_1xTF32.host_view()); -+ double l2_norm_1xtf32_vs_f64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_1xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+#if CUTLASS_ENABLE_CUBLAS -+ // l2 norm 3xTF32 vs cuBLAS F32 -+ double l2_norm_3xtf32_vs_cublasf32 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_3xTF32.host_view(), tensor_d_cublasF32.host_view()); -+#endif -+ -+ // l2 norm 3xTF32 vs 1xTF32 -+ double l2_norm_3xtf32_vs_1xtf32 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_3xTF32.host_view(), tensor_d_1xTF32.host_view()); -+ -+ /////////////////////////////////////////////////////////////////////////////// -+ -+ // Print kernel info and L2 norms -+ std::cout << "Problem Size: (" << problem_size.m() << "," << problem_size.n() << "," << problem_size.k() << ") " -+ << "Alpha: " << alpha << "," << " Beta: " << beta << std::endl; -+ std::cout << std::fixed; -+ std::cout << "Normalized L2 norm of" << std::endl; -+ std::cout.precision(8); -+ std::cout << std::scientific -+#if CUTLASS_ENABLE_CUBLAS -+ << " - cuBLAS F32 error with F64 reference : " << l2_norm_cublasf32_vs_f64 << std::endl -+#endif -+ << " - 3xTF32 error with F64 reference : " << l2_norm_3xtf32_vs_f64 << std::endl -+ << " - 1xTF32 error with F64 reference : " << l2_norm_1xtf32_vs_f64 << std::endl -+#if CUTLASS_ENABLE_CUBLAS -+ << " - 3xTF32 error with cuBLAS F32 reference : " << l2_norm_3xtf32_vs_cublasf32 << std::endl -+#endif -+ << " - 3xTF32 error with 1xTF32 reference : " << l2_norm_3xtf32_vs_1xtf32 << std::endl; -+ -+ return true; -+} -+ -+int main(int argc, const char **argv) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return false; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ bool result = true; -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ result = run(options); -+ -+ if (!result) return -1; -+ -+ return 0; -+} -diff --git a/3rdparty/cutlass/examples/34_transposed_conv2d/34_transposed_conv2d.cu b/3rdparty/cutlass/examples/34_transposed_conv2d/34_transposed_conv2d.cu -new file mode 100644 -index 0000000..2e4ce3c ---- /dev/null -+++ b/3rdparty/cutlass/examples/34_transposed_conv2d/34_transposed_conv2d.cu -@@ -0,0 +1,639 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+This example shows how to compute 2d transposed convolution, also known as deconvolution, using CUTLASS -+conv2d Dgrad kernels. Although two operations are computationaly equivalent, some care is needed to correctly -+set up a problem size for CUTLASS. -+In deep learning, transposed convolution is sometimes used for upscaling feature maps. This example -+demonstrates the 2x upscaling case using the strided Dgrad kernel. -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using cutlass::layout::TensorNHWC; -+using cutlass::TensorRef; -+ -+using ElementAccumulator = cutlass::half_t; // Data type of accumulator -+using ElementComputeEpilogue = cutlass::half_t; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementOutput = cutlass::half_t; // Data type of elements in output tensor -+using ElementC = ElementOutput; -+using ElementCompute = ElementComputeEpilogue; -+using LayoutInputA = TensorNHWC; -+using LayoutInputB = TensorNHWC; -+using LayoutOutput = TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementCompute, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementAccumulator, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kStrided // Use the strided Dgrad specialization -+ >::Kernel; -+ -+using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 16), -+ padding(1, 1, 1, 1), -+ conv_stride(2, 2), -+ dilation(1, 1), -+ reference_check(true), -+ measure_performance(false), -+ iterations(20), -+ alpha(1), -+ beta(0) {} -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("skip-ref-check")) { -+ reference_check = false; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ // Filter layout is CRSK -+ cmd.get_cmd_line_argument("k", filter_size.c()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.n() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "34_transposed_conv2d example\n\n" -+ << " This example shows how to compute 2d transposed convolution, also known as\n" -+ << " deconvolution, using CUTLASS conv2d Dgrad kernels. Although two operations are\n" -+ << " computationaly equivalent, some care is needed to correctly set up a problem size.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --skip-ref-check If set (true), skip reference check on the host\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/31_transposed_conv2d/31_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ // Here, out_pad corresponds to "output_padding" of conv2d_transpose op in deep learning frameworks. -+ // See for example https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html -+ int out_pad_h = conv_stride.row() > 1 ? 1 : 0; -+ int out_pad_w = conv_stride.column() > 1 ? 1 : 0; -+ int out_h = (input_size.h() - 1) * conv_stride.row() - 2 * padding.n() + (((filter_size.h() - 1) * dilation.row() + 1)) + out_pad_h; -+ int out_w = (input_size.w() - 1) * conv_stride.column() - 2 * padding.w() + (((filter_size.w() - 1) * dilation.column() + 1)) + out_pad_w; -+ return cutlass::Tensor4DCoord(input_size.n(), out_h, out_w, filter_size.c()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NHWC * KRS -+ // Note that the input with the layout NHWC corresponds to the output from the perspective of dgrad, -+ // and that the filter layout is CRSK. -+ int64_t fmas = input_size.product() * int64_t(filter_size.h() * filter_size.w() * filter_size.n()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.c() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << options.conv_stride.row() << "," -+ << options.conv_stride.column() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+// This is the same as Conv2dDgrad in tools/util/include/cutlass/util/reference/host/convolution.h, -+// only variable names have been adapted for transposed conv2d. -+void Conv2dTransposeReference( -+ cutlass::conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ int H = problem_size.P; -+ int W = problem_size.Q; -+ int P = problem_size.H; -+ int Q = problem_size.W; -+ int K = problem_size.C; -+ int C = problem_size.K; -+ -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int p = 0; p < P; ++p) { -+ for (int q = 0; q < Q; ++q) { -+ for (int k = 0; k < K; ++k) { -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int c = 0; c < C; ++c) { -+ -+ int filter_r = r; -+ int filter_s = s; -+ -+ int h = p + problem_size.pad_h - filter_r * problem_size.dilation_h; -+ int w = q + problem_size.pad_w - filter_s * problem_size.dilation_w; -+ -+ if (h >= 0 && (h % problem_size.stride_h) == 0 && -+ w >= 0 && (w % problem_size.stride_w) == 0) { -+ -+ h = h / problem_size.stride_h; -+ w = w / problem_size.stride_w; -+ -+ if (h < H && w < W) { -+ -+ ElementInputA a = tensor_a.at(cutlass::make_Coord(n, h, w, c)); -+ ElementInputB b = tensor_b.at(cutlass::make_Coord(c, r, s, k)); -+ -+ acc += ElementAccumulator(a) * ElementAccumulator(b); -+ } -+ } -+ -+ } // for (C) -+ } // for (S) -+ } // for (R) -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_c.at(cutlass::make_Coord(n, p, q, k)); -+ } -+ -+ tensor_d.at(cutlass::make_Coord(n, p, q, k)) = alpha * ElementCompute(acc) + beta * ElementCompute(c_ref); -+ -+ } // for (K) -+ } // for (W) -+ } // for (H) -+ } // for (N) -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ std::cout << "Output shape: " << options.output_size() << std::endl; -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_d(options.output_size()); -+ cutlass::HostTensor tensor_ref_d(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C and D on host with zeros -+ cutlass::reference::host::TensorFill(tensor_c.host_view()); -+ -+ cutlass::reference::host::TensorFill(tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill(tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ // The input in transposed conv2d corresponds to the output in the equivalent dgrad. -+ // Similarly for the output. -+ // Although the filter layout is CRSK from the perspective of conv2d transpose, -+ // the filter size does not need to change for setting up the problem size. -+ // There is no need to transpose the filter tensor either. -+ -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.output_size(), -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.input_size, -+ mode -+ ); -+ -+ typename ImplicitGemm::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta} -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm implicit_gemm; -+ -+ size_t workspace_size = implicit_gemm.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm(); -+ CUTLASS_CHECK(result.status); -+ -+ // // Skip reference check since there is no reference code for conv2d transpose in cutlass. -+ if (options.reference_check) { -+ tensor_d.sync_host(); -+ std::cout << "Verification on host...\n"; -+ Conv2dTransposeReference(problem_size, -+ tensor_a.host_ref(), -+ tensor_b.host_ref(), -+ tensor_c.host_ref(), -+ tensor_ref_d.host_ref(), -+ options.alpha, options.beta); -+ -+ bool passed = cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ -+ if (options.measure_performance) { -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/35_gemm_softmax/gemm_softmax.cu b/3rdparty/cutlass/examples/35_gemm_softmax/gemm_softmax.cu -new file mode 100644 -index 0000000..163a634 ---- /dev/null -+++ b/3rdparty/cutlass/examples/35_gemm_softmax/gemm_softmax.cu -@@ -0,0 +1,720 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_reduce.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "gemm_with_softmax.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#define TRACE(x) { std::cout << "gemm_softmax.cu:" << __LINE__ << " " << x << std::endl; } -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+enum class Disposition { -+ kPassed, -+ kIncorrect, -+ kNotVerified -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ int iterations; -+ unsigned seed; -+ float alpha; -+ float beta; -+ bool verification_enabled; -+ float tolerance; -+ -+ Options(): -+ help(false), -+ problem_size({16, 24, 64}), -+ batch_count(16), -+ iterations(20), -+ seed(2022), -+ alpha(1), -+ beta(0), -+ verification_enabled(true), -+ tolerance(1e-5f) -+ { } -+ -+ bool valid() { -+ -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ -+ cmd.get_cmd_line_argument("batch_count", batch_count); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("verify", verification_enabled); -+ cmd.get_cmd_line_argument("seed", seed); -+ cmd.get_cmd_line_argument("tolerance", tolerance); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "35_gemm_softmax example\n\n" -+ << " This example uses the CUTLASS Library to compute GEMM + Softmax for arbitrary problem sizes.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --batch_count= Batch number\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --seed= Random number seed (1*)\n\n" -+ << " --iterations= Number of profiling iterations to perform (0 to disable profiling).\n\n" -+ << " --verify= If true, performs reference calculation.\n\n" -+ << " --tolerance Error tolerance\n" -+ ; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/35_gemm_softmax/35_gemm_softmax --m=1024 --n=512 \\\n" -+ << " --alpha=2 --beta=0.707 \n\n"; -+ -+ return out; -+ } -+ -+ /// Returns true if the environment and Toolkit support this -+ bool supported(bool verbose = true) const { -+ -+ // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ if (verbose) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ } -+ return false; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ if (verbose) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ } -+ return false; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ if (verbose) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ } -+ return false; -+ } -+ -+ return true; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Testbed { -+ -+ // -+ // Type definitions -+ // -+ -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementCompute = float; -+ using ElementD = ElementC; -+ using ElementSoftmax = ElementC; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using OperatorClass = cutlass::arch::OpClassTensorOp; -+ using ArchTag = cutlass::arch::Sm80; -+ -+ // ApplyShape impacts the final Softmax performance a lot. -+ // Set ApplyShape::kColumn to be the next multiple of 32 number that is after -+ // (gemm_N / alignment). -+ // Set ApplyShape::kRow to max(1, 128 / ApplyShape::kColumn). -+ using ApplyShape = cutlass::MatrixShape<1, 1024>; -+ -+ static int const kStages = 3; -+ -+ /// Linear scaling operator -+ using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementCompute, -+ ElementCompute -+ >; -+ -+ using GemmSoftmax = cutlass::GemmSoftmax< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ ElementCompute, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueFunctorOp, -+ kStages, -+ ApplyShape -+ >; -+ -+ using ElementNorm = typename GemmSoftmax::ElementNorm; -+ using ElementSum = typename GemmSoftmax::ElementSum; -+ using LayoutC = typename GemmSoftmax::LayoutC; -+ using LayoutN = typename GemmSoftmax::LayoutN; -+ using LayoutS = typename GemmSoftmax::LayoutS; -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+ // -+ // Data members -+ // -+ -+ Options const &options; -+ -+ -+ cutlass::HostTensor reference_N; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ cutlass::DeviceAllocation block_Ref; -+ cutlass::DeviceAllocation block_Softmax; -+ cutlass::DeviceAllocation block_Norm; -+ cutlass::DeviceAllocation block_Sum; -+ -+ int block_num = (options.problem_size.n() + GemmSoftmax::ThreadblockShape::kN - 1) / GemmSoftmax::ThreadblockShape::kN; -+ -+ cutlass::gemm::GemmCoord problem = options.problem_size; -+ -+ int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0); -+ int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0); -+ int64_t ldc = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ -+ // fixed rowmajor for norm and sum -+ int64_t ldn = problem.m(); -+ int64_t lds = ldn; -+ -+ int64_t total_elements_A_per_batch = problem.m() * problem.k(); -+ int64_t total_elements_B_per_batch = problem.k() * problem.n(); -+ int64_t total_elements_C_per_batch = problem.m() * problem.n(); -+ int64_t total_elements_D_per_batch = problem.m() * problem.n(); -+ int64_t total_elements_partial_norm_per_batch = block_num * problem.m(); -+ -+ int64_t total_elements_A = total_elements_A_per_batch * options.batch_count; -+ int64_t total_elements_B = total_elements_B_per_batch * options.batch_count; -+ int64_t total_elements_C = total_elements_C_per_batch * options.batch_count; -+ int64_t total_elements_D = total_elements_D_per_batch * options.batch_count; -+ int64_t total_elements_partial_norm = total_elements_partial_norm_per_batch * options.batch_count; -+ -+ // -+ // Methods -+ // -+ -+ Testbed( -+ Options const &options_ -+ ): -+ options(options_) -+ { -+ reference_N.reset({options.problem_size.m(), 1}, false); -+ } -+ -+ /// Run -+ Disposition run() { -+ -+ Disposition disposition = Disposition::kNotVerified; -+ -+ // -+ // Initialize the workspace -+ // -+ -+ initialize(); -+ -+ // -+ // Launch device kernel -+ // -+ cutlass::Status status = cutlass::Status::kSuccess; -+ -+ status = execute_device_kernel(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Device execution failed." << std::endl; -+ return disposition; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "Device synchronize failed with error " -+ << cudaGetErrorString(result) << std::endl; -+ return disposition; -+ } -+ -+ // -+ // Verify -+ // -+ -+ if (options.verification_enabled) { -+ -+ bool passed = verify(); -+ -+ if (passed) { -+ disposition = Disposition::kPassed; -+ } -+ else { -+ disposition = Disposition::kIncorrect; -+ } -+ } -+ -+ // -+ // Profiling -+ // -+ if (options.iterations) { -+ profile(); -+ } -+ -+ return disposition; -+ } -+ -+ /// Random initialization -+ void initialize() { -+ -+ block_A.reset(total_elements_A); -+ block_B.reset(total_elements_B); -+ block_C.reset(total_elements_C); -+ block_D.reset(total_elements_D); -+ block_Softmax.reset(total_elements_D); -+ block_Ref.reset(total_elements_D_per_batch); -+ block_Norm.reset(total_elements_partial_norm); -+ block_Sum.reset(total_elements_partial_norm); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block_A.get(), total_elements_A, options.seed, ElementA(5), ElementA(-5), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block_B.get(), total_elements_B, options.seed + 1, ElementB(5), ElementB(-5), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block_C.get(), total_elements_C, options.seed + 2, ElementC(5), ElementC(-5), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block_D.get(), total_elements_D, options.seed + 3, ElementD(5), ElementD(-5), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block_Ref.get(), total_elements_D_per_batch, options.seed + 3, ElementD(5), ElementD(-5), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block_Softmax.get(), total_elements_D, options.seed + 3, ElementSoftmax(5), ElementSoftmax(-5), 0); -+ -+ cutlass::reference::host::TensorFill( -+ reference_N.host_view(), -+ ElementNorm() -+ ); -+ -+ } -+ -+ cutlass::Status execute_device_kernel() { -+ -+ cutlass::Status status = cutlass::Status::kSuccess; -+ -+ // -+ // Setup arguments -+ // -+ -+ GemmSoftmax::Arguments args( -+ options.problem_size, -+ options.batch_count, -+ {block_A.get(), lda}, -+ {block_B.get(), ldb}, -+ {block_C.get(), ldc}, -+ {block_D.get(), ldc}, -+ { -+ ElementCompute(options.alpha), -+ ElementCompute(options.beta) -+ }, -+ {block_Norm.get(), ldn}, -+ {block_Sum.get(), lds}, -+ {block_Softmax.get(), ldc}, -+ total_elements_A_per_batch, -+ total_elements_B_per_batch, -+ total_elements_C_per_batch, -+ total_elements_D_per_batch, -+ total_elements_partial_norm_per_batch, -+ total_elements_partial_norm_per_batch, -+ total_elements_D_per_batch -+ ); -+ -+ // -+ // Launch -+ // -+ -+ GemmSoftmax gemm_softmax; -+ -+ // Initialize -+ status = gemm_softmax.initialize(args); -+ if (status != cutlass::Status::kSuccess) { -+ return status; -+ } -+ -+ // Run -+ status = gemm_softmax(); -+ -+ return status; -+ } -+ -+ template -+ bool verify_tensor(std::vector vector_Input, \ -+ std::vector vector_Input_Ref) { -+ -+ int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size(); -+ float abs_tol = options.tolerance; -+ float rel_tol = options.tolerance; -+ -+ for (int64_t i = 0; i < size; ++i) { -+ float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); -+ float abs_diff = fabs(diff); -+ float abs_ref = fabs((float)vector_Input_Ref.at(i)); -+ float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0; -+ if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > rel_tol && relative_diff > rel_tol)) { -+ printf("diff = %f, {%f, %f}.\n", abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); -+ return false; -+ } -+ -+ } -+ -+ return true; -+ } -+ -+ /// Verifies the reference matches -+ bool verify() { -+ -+ LayoutA layout_A(lda); -+ LayoutB layout_B(ldb); -+ LayoutC layout_C(ldc); -+ LayoutN Layout_N(ldn); -+ LayoutS Layout_S(lds); -+ -+ MatrixCoord extent_A{problem.m(), problem.k()}; -+ MatrixCoord extent_B{problem.k(), problem.n()}; -+ MatrixCoord extent_C{problem.m(), problem.n()}; -+ -+ for (int batch_idx = 0; batch_idx < options.batch_count; batch_idx++) { -+ -+ cutlass::TensorView view_A(block_A.get() + total_elements_A_per_batch * batch_idx, layout_A, extent_A); -+ cutlass::TensorView view_B(block_B.get() + total_elements_B_per_batch * batch_idx, layout_B, extent_B); -+ cutlass::TensorView view_C(block_C.get() + total_elements_C_per_batch * batch_idx, layout_C, extent_C); -+ cutlass::TensorView view_Ref_device(block_Ref.get(), layout_C, extent_C); -+ -+ cutlass::reference::device::GemmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, ElementCompute -+ >( -+ problem, -+ options.alpha, -+ view_A, -+ cutlass::ComplexTransform::kNone, -+ view_B, -+ cutlass::ComplexTransform::kNone, -+ options.beta, -+ view_C, -+ view_Ref_device, -+ ElementCompute(0) -+ ); -+ -+ // Copy reference results to host memory for verification -+ std::vector matrix_D_Ref(layout_C.capacity(extent_C)); -+ cutlass::device_memory::copy_to_host(matrix_D_Ref.data(), block_Ref.get(), matrix_D_Ref.size()); -+ cutlass::TensorView view_Ref(matrix_D_Ref.data(), layout_C, extent_C); -+ -+ std::vector matrix_Softmax_Ref(layout_C.capacity(extent_C)); -+ cutlass::TensorView view_Softmax_Ref(matrix_Softmax_Ref.data(), layout_C, extent_C); -+ -+ // Copy computed results to host memory -+ std::vector matrix_D(layout_C.capacity(extent_C)); -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size()); -+ -+ std::vector matrix_Softmax(layout_C.capacity(extent_C)); -+ cutlass::device_memory::copy_to_host(matrix_Softmax.data(), block_Softmax.get() + total_elements_D_per_batch * batch_idx, matrix_Softmax.size()); -+ -+ // Compute the norm -+ for (int m = 0; m < options.problem_size.m(); ++m) { -+ reference_N.at({m, 0}) = view_Ref.ref().at({m, 0}); -+ for (int n = 1; n < options.problem_size.n(); ++n) { -+ reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(view_Ref.ref().at({m, n}))); -+ } -+ } -+ -+ // Compute softmax -+ for (int m = 0; m < options.problem_size.m(); ++m) { -+ -+ float sum = float(); -+ -+ for (int n = 0; n < options.problem_size.n(); ++n) { -+ sum += std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); -+ } -+ -+ float inv_sum = float(1.0f / sum); -+ -+ for (int n = 0; n < options.problem_size.n(); ++n) { -+ -+ view_Softmax_Ref.ref().at({m, n}) = ElementSoftmax( -+ std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum -+ ); -+ } -+ } -+ -+ // Verification checks - set any of these to 'true' to override the verification checks. -+ bool verified_D = false; -+ bool verified_Softmax = false; -+ -+ // Verify softmax output -+ if (!verified_D) { -+ verified_D = verify_tensor(matrix_D, matrix_D_Ref); -+ } -+ -+ if (!verified_Softmax) { -+ verified_Softmax = verify_tensor(matrix_Softmax, matrix_Softmax_Ref); -+ } -+ -+ if (!verified_D || !verified_Softmax) { -+ -+ std::cerr << "Verification check failed for tensor Softmax at batch " << batch_idx << "\n"; -+ -+ // Summarize which checks failed -+ if (!verified_D) { -+ std::cerr << "Verification of D tensor failed\n"; -+ } -+ -+ if (!verified_Softmax) { -+ std::cerr << "Verification of Softmax tensor failed\n"; -+ } -+ -+ return false; -+ } -+ -+ } -+ -+ return true; -+ } -+ -+ /// Profiles -+ bool profile() { -+ -+ // -+ // Profile -+ // -+ -+ cutlass::Status status = cutlass::Status::kSuccess; -+ cudaError_t result; -+ cudaEvent_t events[2]; -+ int const kIterations = options.iterations; -+ -+ for (cudaEvent_t &evt : events) { -+ result = cudaEventCreate(&evt); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ } -+ -+ result = cudaEventRecord(events[0]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ for (int iter = 0; iter < kIterations; ++iter) { -+ -+ status = execute_device_kernel(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Device execution failed." << std::endl; -+ return false; -+ } -+ } -+ -+ result = cudaEventRecord(events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ float elapsed_ms = 0; -+ result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ for (cudaEvent_t &evt : events) { -+ result = cudaEventDestroy(evt); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ } -+ -+ int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2; -+ int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n(); -+ -+ double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9); -+ double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30); -+ -+ double elapsed_ms_per_iter = double(elapsed_ms) / kIterations; -+ -+ std::cout << " Problem: " -+ << options.problem_size.m() << "-by-" << options.problem_size.n() << "-by-" << options.problem_size.k() -+ << ", batch size: " << options.batch_count -+ << std::endl; -+ -+ std::cout << " Runtime: " << elapsed_ms_per_iter << " ms\n" << std::endl; -+ -+ std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl; -+ std::cout << "Memory bandwidth: " << gbytes_per_second << " GiB/s" << std::endl; -+ -+ return true; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, const char **argv) { -+ -+ // Options parsing -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (!options.supported()) { -+ return 0; -+ } -+ -+ // Run -+ Testbed testbed(options); -+ -+ Disposition disposition = testbed.run(); -+ -+ std::cout << std::endl; -+ -+ switch (disposition) { -+ case Disposition::kPassed: -+ std::cout << "Passed" << std::endl; -+ break; -+ case Disposition::kIncorrect: -+ std::cout << "Incorrect" << std::endl; -+ break; -+ case Disposition::kNotVerified: -+ std::cout << "Not verified" << std::endl; -+ break; -+ } -+ -+ return (disposition == Disposition::kPassed ? 0 : -1); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h b/3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h -new file mode 100644 -index 0000000..586c912 ---- /dev/null -+++ b/3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h -@@ -0,0 +1,536 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief GEMM kernel to support the epilogue visitor model -+ for customized softmax partial reduction epilogue fusion. -+ -+ This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once -+ its usage has been stabilized. For now, it is included in this example to demonstrate -+ some basic output fusion options. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmWithEpilogueVisitor { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueVisitor = typename Epilogue::Visitor; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using TensorRefA = TensorRef; -+ -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using TensorRefB = TensorRef; -+ -+ using ElementC = typename EpilogueVisitor::ElementOutput; -+ using LayoutC = typename Epilogue::Layout; -+ using TensorRefC = TensorRef; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ using ElementNorm = typename EpilogueVisitor::ElementNorm; -+ using ElementSum = typename EpilogueVisitor::ElementSum; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value -+ ); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; -+ -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ TensorRefC ref_C; -+ TensorRefC ref_D; -+ -+ ElementNorm *ptr_Max; -+ ElementSum *ptr_Sum; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ -+ typename EpilogueVisitor::Arguments epilogue_visitor; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1) -+ { } -+ -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode_, -+ GemmCoord problem_size_, -+ int batch_count_, -+ TensorRefA ref_A_, -+ TensorRefB ref_B_, -+ TensorRefC ref_C_, -+ TensorRefC ref_D_, -+ ElementNorm *ptr_Max_, -+ ElementSum *ptr_Sum_, -+ int64_t batch_stride_A_, -+ int64_t batch_stride_B_, -+ typename EpilogueVisitor::Arguments epilogue_visitor_ -+ ): -+ mode(mode_), -+ problem_size(problem_size_), -+ batch_count(batch_count_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ ptr_Max(ptr_Max_), -+ ptr_Sum(ptr_Sum_), -+ batch_stride_A(batch_stride_A_), -+ batch_stride_B(batch_stride_B_), -+ epilogue_visitor(epilogue_visitor_) -+ { -+ -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename EpilogueVisitor::OutputTileIterator::Params params_C; -+ typename EpilogueVisitor::OutputTileIterator::Params params_D; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ ElementC * ptr_C; -+ ElementC * ptr_D; -+ -+ ElementNorm * ptr_Max; -+ ElementSum * ptr_Sum; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ -+ typename EpilogueVisitor::Params epilogue_visitor; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ params_A(0), -+ params_B(0), -+ params_C(0), -+ params_D(0), -+ batch_count(0), -+ gemm_k_size(0), -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ ptr_Max(nullptr), -+ ptr_Sum(nullptr), -+ batch_stride_A(0), -+ batch_stride_B(0) -+ { } -+ -+ -+ Params( -+ Arguments const &args -+ ): -+ problem_size(args.problem_size), -+ swizzle_log_tile(0), -+ params_A(args.ref_A.layout()), -+ params_B(args.ref_B.layout()), -+ params_C(args.ref_C.layout()), -+ params_D(args.ref_D.layout()), -+ mode(args.mode), -+ batch_count(args.batch_count), -+ gemm_k_size(args.problem_size.k()), -+ ptr_A(args.ref_A.data()), -+ ptr_B(args.ref_B.data()), -+ ptr_C(args.ref_C.data()), -+ ptr_D(args.ref_D.data()), -+ ptr_Max(args.ptr_Max), -+ ptr_Sum(args.ptr_Sum), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ epilogue_visitor(args.epilogue_visitor) -+ { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); -+ -+ gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); -+ -+ if (gemm_k_size) { -+ grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); -+ } -+ } -+ -+ swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ -+ typename Mma::SharedStorage main_loop; -+ -+ struct { -+ typename Epilogue::SharedStorage epilogue; -+ typename EpilogueVisitor::SharedStorage visitor; -+ } epilogue; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ GemmWithEpilogueVisitor() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ #define SPLIT_K_ENABLED 1 -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ -+ #if SPLIT_K_ENABLED -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ #endif -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // -+ // Construct the epilogue visitor -+ // -+ -+ EpilogueVisitor epilogue_visitor( -+ params.epilogue_visitor, -+ shared_storage.epilogue.visitor, -+ params.problem_size.mn(), -+ thread_idx, -+ warp_idx, -+ lane_idx, -+ params.params_C, -+ params.params_D, -+ params.ptr_C, -+ params.ptr_D, -+ params.ptr_Max, -+ params.ptr_Sum, -+ threadblock_offset, -+ blockIdx.y *params.problem_size.m() ); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ // Indicate which position in a serial reduction the output operator is currently updating -+ epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { -+ epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); -+ } -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(epilogue_visitor, accumulators); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_softmax.h b/3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_softmax.h -new file mode 100644 -index 0000000..6b2fa99 ---- /dev/null -+++ b/3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_softmax.h -@@ -0,0 +1,651 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+ -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" -+#include "cutlass/reduction/kernel/reduce_softmax_final.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "gemm_with_epilogue_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Kernel computes partial reduction -+// -+// -+// 2. Sum[m, n'] = sum_n(exp(D[m, n] - N[m, 0])) -+// -+template < -+ typename ElementD_, -+ typename ElementNorm_, -+ typename ElementSum_, -+ typename ElementSoft_, -+ typename ElementSoftmaxCompute_, -+ int Alignment, -+ typename ApplyShape_ = MatrixShape<1, 1024> -+> -+class ApplySoftmax { -+public: -+ -+ using ElementD = ElementD_; -+ using ElementNorm = ElementNorm_; -+ using ElementSum = ElementSum_; -+ using ElementSoft = ElementSoft_; -+ using ElementSoftmaxCompute = ElementSoftmaxCompute_; -+ -+ static int const kAlignment = Alignment; -+ using ApplyShape = ApplyShape_; -+ -+ using Layout = cutlass::layout::RowMajor; -+ -+ using TensorRefD = TensorRef; -+ using TensorRefN = TensorRef; -+ using TensorRefSum = TensorRef; -+ using TensorRefSoft = TensorRef; -+ -+ using FragmentSoftmax = Array; -+ -+ // -+ // Arguments -+ // -+ -+ struct Arguments { -+ -+ MatrixCoord extent; ///< Extent of D and Softmax matrices -+ int batch_count; ///< Batch count -+ TensorRefD ref_D; ///< D matrix computed by GEMM+Max (input) -+ TensorRefN ref_N; ///< Norm tensor (input) -+ TensorRefSum ref_S; ///< Sum tensor (input) -+ TensorRefSoft ref_Soft; ///< Softmax tensor (output) -+ int64_t batch_stride_D; ///< Batch stride for D tensor -+ int64_t batch_stride_N; ///< Batch stride for N tensor -+ int64_t batch_stride_S; ///< Batch stride for S tensor -+ int64_t batch_stride_Soft; ///< Batch stride for softmax tensor -+ -+ // -+ // Methods -+ // -+ Arguments(): -+ batch_count(1), -+ batch_stride_D(0), -+ batch_stride_N(0), -+ batch_stride_S(0), -+ batch_stride_Soft(0) -+ { } -+ -+ Arguments( -+ MatrixCoord extent_, ///< Extent of D and Softmax matrices -+ int batch_count_, ///< Batch count -+ TensorRefD ref_D_, ///< D matrix computed by GEMM+PartialReduce -+ TensorRefN ref_N_, ///< Output parameter for N -+ TensorRefSum ref_S_, ///< Output parameter for N -+ TensorRefSoft ref_Soft_, ///< Softmax -+ int64_t batch_stride_D_ = 0, -+ int64_t batch_stride_N_ = 0, -+ int64_t batch_stride_S_ = 0, -+ int64_t batch_stride_Soft_ = 0 -+ ): -+ extent(extent_), -+ batch_count(batch_count_), -+ ref_D(ref_D_), -+ ref_N(ref_N_), -+ ref_S(ref_S_), -+ ref_Soft(ref_Soft_), -+ batch_stride_D(batch_stride_D_), -+ batch_stride_N(batch_stride_N_), -+ batch_stride_S(batch_stride_S_), -+ batch_stride_Soft(batch_stride_Soft_) -+ { -+ -+ } -+ }; -+ -+ // -+ // Params struct -+ // -+ -+ struct Params { -+ Arguments args; -+ -+ // -+ // Methods -+ // -+ Params() { } -+ -+ Params(Arguments const &args_): args(args_) { } -+ }; -+ -+ // -+ // SharedStorage -+ // -+ -+ struct SharedStorage { -+ -+ }; -+ -+private: -+ -+public: -+ -+ CUTLASS_DEVICE -+ ApplySoftmax() { } -+ -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ apply(params, shared_storage); -+ } -+ -+private: -+ -+ -+ /// Compute Softmax -+ CUTLASS_DEVICE -+ void apply(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ using AccessTypeD = AlignedArray; -+ -+ int block_batch = blockIdx.z; -+ int block_m = blockIdx.x * ApplyShape::kRow; -+ int block_n = 0; -+ -+ int thread_m = threadIdx.y; -+ int thread_n = threadIdx.x * kAlignment; -+ -+ int idx_m = block_m + thread_m; -+ int idx_n = block_n + thread_n; -+ -+ int batch_offset_norm = block_batch * params.args.batch_stride_N; -+ int batch_offset_sum = block_batch * params.args.batch_stride_S; -+ -+ // Kill off thread if it is outside the row boundary -+ if (params.args.extent.row() <= idx_m) { -+ return; -+ } -+ -+ // -+ // Setup pointers to load D again -+ // -+ -+ using AccessTypeD = AlignedArray; -+ using AccessTypeSoft = AlignedArray; -+ using FragmentSoft = Array; -+ using ConvertSoftCompute = cutlass::NumericArrayConverter; -+ using ConvertSoftOutput = cutlass::NumericArrayConverter; -+ -+ using Mul = cutlass::multiplies; -+ using Minus = cutlass::minus; -+ using Exp = cutlass::fast_exp_op; -+ -+ ConvertSoftCompute convert_soft_compute; -+ ConvertSoftOutput convert_soft_output; -+ -+ Minus minus; -+ Mul mul; -+ Exp exponential; -+ -+ using ConvertSum = cutlass::NumericConverter; -+ using ConvertNorm = cutlass::NumericConverter; -+ -+ ConvertSum convert_sum; -+ ConvertNorm convert_norm; -+ -+ AccessTypeD *access_d = reinterpret_cast( -+ params.args.ref_D.data() + -+ params.args.batch_stride_D * block_batch + -+ params.args.ref_D.layout()({idx_m, idx_n})); -+ -+ AccessTypeSoft *access_soft = reinterpret_cast( -+ params.args.ref_Soft.data() + -+ params.args.batch_stride_Soft * block_batch + -+ params.args.ref_Soft.layout()({idx_m, idx_n})); -+ -+ ElementSum inv_sum = (params.args.ref_S.data())[idx_m + batch_offset_sum]; -+ ElementNorm norm = (params.args.ref_N.data())[idx_m + batch_offset_norm]; -+ -+ // -+ // Loop -+ // -+ CUTLASS_PRAGMA_UNROLL -+ for ( -+ int idx = 0; -+ idx < params.args.extent.column(); -+ idx += ApplyShape::kColumn * kAlignment) { -+ -+ if (idx_n < params.args.extent.column()) { -+ AccessTypeD fetch; -+ arch::global_load(fetch, access_d, true); -+ -+ FragmentSoftmax result = mul(exponential(minus(convert_soft_compute(fetch), convert_norm(norm))), convert_sum(inv_sum)); -+ FragmentSoft soft = convert_soft_output(result); -+ -+ arch::global_store(soft, access_soft, true); -+ } -+ -+ access_d += ApplyShape::kColumn; -+ access_soft += ApplyShape::kColumn; -+ idx_n += ApplyShape::kColumn * kAlignment; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename LayoutB_, -+ typename ElementC_, -+ typename ElementCompute_, -+ typename OperatorClass_, -+ typename ArchTag_, -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ typename InstructionShape_, -+ typename EpilogueFunctorOp_, -+ int kStages_, -+ typename ApplyShape_ = MatrixShape<1, 1024>, -+ int AlignmentA_ = 128 / cutlass::sizeof_bits::value, -+ int AlignmentB_ = 128 / cutlass::sizeof_bits::value, -+ int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits::value, -+ typename ElementNorm_ = float, -+ typename ElementSum_ = float, -+ typename ElementSoftmax_ = ElementC_ -+> -+class GemmSoftmax { -+public: -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementA = ElementA_; -+ using ElementB = ElementB_; -+ using ElementC = ElementC_; -+ using ElementCompute = ElementCompute_; -+ using ElementSum = ElementSum_; -+ using ElementSoft = ElementSoftmax_; -+ using ElementSoftmaxCompute = float; -+ -+ using LayoutA = LayoutA_; -+ using LayoutB = LayoutB_; -+ -+ using EpilogueFunctorOp = EpilogueFunctorOp_; -+ using ElementNorm = ElementNorm_; -+ -+ using ApplyShape = ApplyShape_; -+ -+ // These are mandatory layouts. -+ using LayoutC = cutlass::layout::RowMajor; -+ using LayoutN = cutlass::layout::RowMajor; -+ using LayoutS = cutlass::layout::RowMajor; -+ using LayoutSoft = cutlass::layout::RowMajor; -+ -+ using TensorRefA = TensorRef; -+ using TensorRefB = TensorRef; -+ using TensorRefC = TensorRef; -+ using TensorRefN = TensorRef; -+ using TensorRefSum = TensorRef; -+ using TensorRefSoft = TensorRef; -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ -+ static int const kStages = kStages_; -+ static int const AlignmentA = AlignmentA_; -+ static int const AlignmentB = AlignmentB_; -+ static int const AlignmentSoftmax = AlignmentSoftmax_; -+ -+ using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ // basic GEMM kernel -+ using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< -+ ElementA, -+ LayoutA, -+ AlignmentA, -+ ElementB, -+ LayoutB, -+ AlignmentB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueFunctorOp, -+ ThreadblockSwizzle, -+ kStages, -+ true, -+ typename cutlass::gemm::device::DefaultGemmConfiguration< -+ OperatorClass, ArchTag, ElementA, ElementB, ElementC, ElementCompute>::Operator, -+ cutlass::gemm::SharedMemoryClearOption::kNone -+ >::GemmKernel; -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ // Epilogue visitor -+ using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax< -+ ThreadblockShape, -+ DefaultGemmKernel::kThreadCount, -+ typename DefaultGemmKernel::Epilogue::OutputTileIterator, -+ ElementCompute, -+ ElementNorm, -+ ElementSum, -+ ElementSoftmaxCompute, -+ EpilogueFunctorOp -+ >; -+ -+ /// Epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< -+ EpilogueVisitor, -+ typename DefaultGemmKernel::Epilogue -+ >::Epilogue; -+ -+ // GEMM -+ using GemmKernel = gemm::kernel::GemmWithEpilogueVisitor< -+ typename DefaultGemmKernel::Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+ -+ // Softmax kernel -+ using SoftmaxApplyKernel = kernel::ApplySoftmax< -+ ElementC, -+ ElementNorm, -+ ElementSum, -+ ElementSoft, -+ ElementSoftmaxCompute, -+ AlignmentSoftmax, -+ ApplyShape -+ >; -+ -+ using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction< -+ ElementNorm, -+ ElementSum, -+ ElementSoftmaxCompute, -+ ThreadblockShape -+ >; -+ -+public: -+ -+ /// Arguments class -+ struct Arguments { -+ -+ typename GemmKernel::Arguments gemm; -+ typename SoftmaxApplyKernel::Arguments softmax; -+ typename ApplyFinalReductionKernel::Arguments reduction; -+ cutlass::gemm::GemmCoord extend; -+ -+ // -+ // Methods -+ // -+ Arguments() { } -+ -+ Arguments( -+ cutlass::gemm::GemmCoord problem_size, -+ int32_t batch_count_, -+ TensorRefA ref_A_, -+ TensorRefB ref_B_, -+ TensorRefC ref_C_, -+ TensorRefC ref_D_, -+ typename EpilogueFunctorOp::Params linear_scaling, -+ TensorRefN ref_N_, -+ TensorRefSum ref_S_, -+ TensorRefSoft ref_Softmax_, -+ int64_t batch_stride_A_ = 0, -+ int64_t batch_stride_B_ = 0, -+ int64_t batch_stride_C_ = 0, -+ int64_t batch_stride_D_ = 0, -+ int64_t batch_stride_Max_ = 0, -+ int64_t batch_stride_Sum_ = 0, -+ int64_t batch_stride_Softmax_ = 0 -+ ): -+ gemm( -+ cutlass::gemm::GemmUniversalMode::kBatched, -+ problem_size, -+ batch_count_, -+ ref_A_, -+ ref_B_, -+ ref_C_, -+ ref_D_, -+ ref_N_.data(), -+ ref_S_.data(), -+ batch_stride_A_, -+ batch_stride_B_, -+ typename EpilogueVisitor::Arguments( -+ linear_scaling, -+ batch_stride_C_, -+ batch_stride_D_, -+ batch_stride_Max_, -+ batch_stride_Sum_ -+ ) -+ ), -+ reduction( -+ problem_size, -+ ref_N_.data(), -+ ref_S_.data(), -+ batch_stride_Max_, -+ batch_stride_Sum_ -+ ), -+ softmax( -+ MatrixCoord(problem_size.m(), problem_size.n()), -+ batch_count_, -+ ref_D_, -+ ref_N_, -+ ref_S_, -+ ref_Softmax_, -+ batch_stride_D_, -+ batch_stride_Max_, -+ batch_stride_Sum_, -+ batch_stride_Softmax_ -+ ), -+ extend(problem_size) -+ { -+ -+ } -+ }; -+ -+ struct Params { -+ -+ typename GemmKernel::Params gemm; -+ typename SoftmaxApplyKernel::Params softmax; -+ typename ApplyFinalReductionKernel::Params reduction; -+ MatrixCoord extend; -+ // -+ // Methods -+ // -+ Params() { } -+ -+ Params(Arguments const &args): -+ gemm(args.gemm), -+ reduction(args.reduction), -+ softmax(args.softmax), -+ extend(MatrixCoord(args.extend.m(), args.extend.n())) -+ { -+ -+ } -+ }; -+ -+public: -+ -+ // Gemm -+ -+ -+ // -+ // Methods -+ // -+ -+private: -+ -+ Params params_; -+ -+public: -+ -+ /// Ctor -+ GemmSoftmax() { -+ -+ } -+ -+ /// Initialize -+ Status initialize(Arguments const &args) { -+ -+ params_ = Params(args); -+ -+ return cutlass::Status::kSuccess; -+ } -+ -+ /// Run -+ Status run(cudaStream_t stream) { -+ -+ // -+ // Launch the GEMM + max kernel -+ // -+ -+ dim3 gemm_grid = ThreadblockSwizzle().get_grid_shape(params_.gemm.grid_tiled_shape); -+ dim3 gemm_block(GemmKernel::kThreadCount, 1, 1); -+ -+ int gemm_smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_.gemm); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ -+ // -+ // Launch the ApplyFinalReductionKernel -+ // -+ -+ int thread_per_block = 128; -+ int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; -+ if (block_per_row < 4) { -+ thread_per_block = 32; -+ block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; -+ } -+ -+ dim3 final_reduction_grid(block_per_row, 1, params_.softmax.args.batch_count); -+ dim3 final_reduction_block(thread_per_block); -+ -+ Kernel<<< -+ final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream -+ >>>(params_.reduction); -+ -+ result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ // -+ // Launch the SoftmaxApplyKernel -+ // -+ -+ dim3 apply_block(SoftmaxApplyKernel::ApplyShape::kColumn, SoftmaxApplyKernel::ApplyShape::kRow); -+ -+ int threadblock_rows = SoftmaxApplyKernel::ApplyShape::kRow; -+ int threadblock_columns = SoftmaxApplyKernel::ApplyShape::kColumn * SoftmaxApplyKernel::kAlignment; -+ -+ dim3 apply_grid( -+ (params_.softmax.args.extent.row() + threadblock_rows - 1) / threadblock_rows, -+ (params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns, -+ params_.softmax.args.batch_count); -+ -+ Kernel<<< -+ apply_grid, apply_block, sizeof(typename SoftmaxApplyKernel::SharedStorage), stream -+ >>>(params_.softmax); -+ -+ result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ return cutlass::Status::kSuccess; -+ } -+ -+ /// Function call operator -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu b/3rdparty/cutlass/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu -new file mode 100644 -index 0000000..3ae92c3 ---- /dev/null -+++ b/3rdparty/cutlass/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu -@@ -0,0 +1,541 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+// This example fuses gather before GEMM and scatter after GEMM into the same -+// GEMM kernel. Gather and scatter operation is controled by an index vector -+// to select rows or columns from A, B, C or D matrices. -+// -+// Suppose, all matrices are column major. The pseudo code of the fused kernel -+// in this example is essentially -+// -+// for (int i = 0; i < problem_size.m(); ++i) { -+// for (int j = 0; j < options.index_size; ++j) { -+// int b_c_d_col = tensor_indices.at({j, 0}); -+// -+// for (int k = 0; k < options.index_size; ++k) { -+// tensor_d_ref.at({i, b_c_d_col}) += -+// alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col}); -+// } -+// } -+// -+// Note that the index vector contains unique random integers with max to be N - 1 -+// -+// The gather/scatter operation works best when we can still keep the biggest -+// alignment. For example, when the matrix is row major, we select rows. When -+// the matrix is column major, we select columns. -+// -+// Not all the combination of gather and scatter are legal. For example, if A is -+// row major and C/D is column major, we cannot gather A and scatter C/D at the -+// same time. -+// -+// Also, we don't check the index value is legal and index array point is valid -+// for the sake of the performance. -+ -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "helper.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int index_size; -+ -+ bool reference_check; -+ int iterations; -+ -+ Options(): -+ help(false), -+ problem_size({248, 1024, 1024}), -+ index_size(240), -+ reference_check(true), -+ iterations(20) { } -+ -+ bool valid() { -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ -+ cmd.get_cmd_line_argument("index_size", index_size); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "36_gather_scatter_fusion example\n\n" -+ << " This example uses the CUTLASS Library to fuse gather/scatter into GEMM\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --index_size= size of N dimension index\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/36_gather_scatter_fusion/36_gather_scatter_fusion --m=1024 --n=512 --k=1024 \\\n" -+ << " --index_size=128\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product(); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = float; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A -+using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B -+using ElementOutput = float; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. -+// Column Major for Matrix A, B and C. -+// -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::ColumnMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 4 -+// 16, 8, 8 -> Turing -+// 16, 8, 16 -> Ampere -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// Define the epilogue operation as LinearCombination. This is approximately equal to -+// -+// d_ij = alpha * sum_k(a_ik * b_kj) + c_ij -+// -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- this is the number of elements per -+ // vectorized memory access. For half -+ // precision, it's 8 elements. This becomes -+ // the vector width of math instructions in -+ // epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 5; -+// Ampere -> 4/5 -+// Turing -> 2 -+ -+using Gemm = cutlass::gemm::device::GemmUniversal; -+ -+int run(Options &options) { -+ -+ // ================================================================================ -+ // Initialization setup -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size = options.problem_size; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size_real(problem_size.m(), -+ options.index_size, -+ problem_size.k()); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d_scattered( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(7), -+ ElementOutput(-8), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFill( -+ tensor_d_scattered.host_view()); // <- fill matrix D on host with zeros -+ -+ cutlass::HostTensor tensor_indices( -+ {options.index_size, 1}); // <- Create scatter indices with dimensions val_len x 1 -+ -+ // <- Fill tensor_b_indices on host with unique random integers -+ std::vector to_fill(problem_size.n()) ; // vector with ints. -+ std::iota (std::begin(to_fill), std::end(to_fill), 0); // Fill with 0, 1, ...., problem_size.n() -+ std::random_shuffle(to_fill.begin(), to_fill.end()); -+ memcpy(tensor_indices.host_data(), to_fill.data(), options.index_size * sizeof(int)); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_indices.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d_scattered.sync_device(); -+ -+ // Initialize alpha/beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(1); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size_real, // <- problem size of matrix multiplication -+ split_k_slices, // <- k-dimension split factor -+ {alpha, beta}, // <- alpha, beta -+ tensor_a.device_data(), // <- reference to matrix A on device -+ tensor_b.device_data(), // <- reference to matrix B on device -+ tensor_c.device_data(), // <- reference to matrix C on device -+ tensor_d_scattered.device_data(), // <- reference to matrix D on device -+ tensor_a.layout().capacity(problem_size.mk()), -+ tensor_b.layout().capacity(cutlass::make_Coord(options.index_size, problem_size.n())), -+ tensor_c.layout().capacity(problem_size.mn()), -+ tensor_d_scattered.layout().capacity(problem_size.mn()), -+ tensor_a.layout().stride(), -+ tensor_b.layout().stride(), -+ tensor_c.layout().stride(), -+ tensor_d_scattered.layout().stride(), -+ nullptr, // <- pointer to index vector to gather A on device -+ tensor_indices.device_data(), // <- pointer to index vector to gather B on device -+ tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // CPU reference calculation -+ cutlass::HostTensor tensor_d_ref(problem_size.mn()); -+ cutlass::reference::host::TensorFill( -+ tensor_d_ref.host_view()); // <- Fill matrix D on host with zeros -+ -+ status = gemm_op(); -+ cudaDeviceSynchronize(); -+ CUTLASS_CHECK(status); -+ -+ if (options.reference_check) { -+ for (int i = 0; i < problem_size.m(); ++i) { -+ for (int j = 0; j < options.index_size; ++j) { -+ int b_c_d_col = tensor_indices.at({j, 0}); -+ -+ for (int k = 0; k < problem_size.k(); ++k) { -+ tensor_d_ref.at({i, b_c_d_col}) += -+ alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col}); -+ } -+ -+ tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col})); -+ } -+ } -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_scattered.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d_scattered.host_view(), -+ tensor_d_ref.host_view()); -+ -+ if (!passed) { -+ std::cout << "Failed!\n"; -+ -+ std::stringstream fname; -+ fname << "error_gather_GEMM_scatter_fusion.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A =\n" << tensor_a.host_view() -+ << "\nB =\n" << tensor_b.host_view() -+ << "\nindices =\n" << tensor_indices.host_view() -+ << "\nC =\n" << tensor_c.host_view() -+ << "\n\nReference =\n" << tensor_d_ref.host_view() -+ << "\nComputed =\n" << tensor_d_scattered.host_view(); -+ return -1; -+ } else { -+ std::cout << "Passed!\n"; -+ } -+ } -+ -+ // Result structure -+ Result result; -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMMs -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMMs are complete -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ std::cout << "Runtime: " << result.runtime_ms << " ms\n"; -+ std::cout << " GFLOPs: " << result.gflops << "\n"; -+ -+ return 0; -+} -+ -+int main(int argc, const char ** argv) { -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << "\n"; -+ return 0; -+ } -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << "\n"; -+ return -1; -+ } -+ -+ return run(options); -+} -diff --git a/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu b/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu -new file mode 100644 -index 0000000..ffe378b ---- /dev/null -+++ b/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu -@@ -0,0 +1,937 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS Layernorm Example. -+ -+ This workload provides a layer normalization example using a one-pass, square-sum-based -+ variance calculation. Specifically, we fuse the reduction operation to find -+ local mean and local square sum mean in the epilogue of 1st GEMM. After a light -+ full reduction kernel, the mean / variance values are readily calculated for element-wise -+ operations which are fused into the 2nd GEMM. -+ -+ As stated in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data, -+ the square-sum based one-pass implementation may raise concerns on numerical stability issues. -+ That being said, though this fully fused layernorm example almost perfectly hides all the memory cost to -+ access the intermediate matrix for layernorm computation, the numerical issue might hinder a persuasive -+ usage in real-world scenarios. If that is the case, a user may turn to the stand-alone CUTLASS layernorm -+ example in tools/util/include/cutlass/util/device_layernorm.h -+ -+ Examples: -+ -+ # Run a CUTLASS layernorm example with default setup , -+ # using the language of the transformer model as an example, -+ (Column Major output matrix, hidden dimension = 768, valid word number = 4096, intermediate_scale = 4) -+ $ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion -+ -+ # Run an attention example with hidden dimension = 512 -+ $ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion --hidden_dim=512 -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_reduce.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/fast_math.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "gemm_with_layernorm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+enum class Disposition { -+ kPassed, -+ kIncorrect, -+ kNotVerified -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+template -+struct Options { -+ -+ using LayoutOutput = LayoutOutput_; -+ -+ static bool const kIsColumnMajorOutput = cutlass::platform::is_same::value; -+ -+ bool help; -+ cutlass::gemm::GemmCoord problem_size0; -+ cutlass::gemm::GemmCoord problem_size1; -+ int hidden_dim; -+ int valid_word_num; -+ int intermediate_scale; -+ int iterations; -+ unsigned seed; -+ float alpha; -+ float beta; -+ bool verification_enabled; -+ double tolerance; -+ -+ Options(): -+ help(false), -+ iterations(20), -+ seed(2022), -+ hidden_dim(768), -+ valid_word_num(4096), -+ intermediate_scale(4), -+ alpha(1), -+ beta(0), -+ verification_enabled(true), -+ tolerance(0.01), -+ problem_size1(problem_size0.m() * 4, problem_size0.n(), problem_size0.m()) -+ { } -+ -+ bool valid() { -+ -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("hidden_dim", hidden_dim, 768); -+ cmd.get_cmd_line_argument("valid_word_num", valid_word_num, 4096); -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("verify", verification_enabled); -+ cmd.get_cmd_line_argument("seed", seed); -+ cmd.get_cmd_line_argument("tolerance", tolerance); -+ -+ if (kIsColumnMajorOutput) { -+ // column major output setup -+ problem_size0.m() = hidden_dim; -+ problem_size0.n() = valid_word_num; -+ problem_size0.k() = hidden_dim; -+ -+ problem_size1.m() = hidden_dim * intermediate_scale; -+ problem_size1.n() = valid_word_num; -+ problem_size1.k() = hidden_dim; -+ }else{ -+ // row major output setup -+ problem_size0.m() = valid_word_num; -+ problem_size0.n() = hidden_dim; -+ problem_size0.k() = hidden_dim; -+ -+ problem_size1.m() = valid_word_num; -+ problem_size1.n() = hidden_dim * intermediate_scale; -+ problem_size1.k() = hidden_dim; -+ } -+ -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "37_gemm_layernorm_gemm_fusion example\n\n" -+ << " This example uses the CUTLASS Library to compute GEMM + Layernorm for arbitrary problem sizes.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --hidden_dim= Hidden dimension\n" -+ << " --valid_word_num= Valid word number\n" -+ << " --seed= Random number seed (1*)\n\n" -+ << " --iterations= Number of profiling iterations to perform (0 to disable profiling).\n\n" -+ << " --verify= If true, performs reference calculation.\n\n" -+ << " --tolerance Error tolerance\n" -+ ; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion \\\n" -+ << " --hidden_dim=768 --valid_word_num=1024 \n\n"; -+ -+ return out; -+ } -+ -+ /// Returns true if the environment and Toolkit support this -+ bool supported(bool verbose = true) const { -+ -+ // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ if (verbose) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ } -+ return false; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ if (verbose) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ } -+ return false; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ if (verbose) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ } -+ return false; -+ } -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((problem_size0.m() % kAlignment) || -+ (problem_size0.n() % kAlignment) || -+ (problem_size0.k() % kAlignment)) { -+ if (verbose) { -+ std::cerr << "Misaligned input in 1st GEMM." << std::endl; -+ } -+ // misaligned tensors for Gemm1 -+ return false; -+ } -+ -+ if ((problem_size1.m() % kAlignment) || -+ (problem_size1.n() % kAlignment) || -+ (problem_size1.k() % kAlignment)) { -+ if (verbose) { -+ std::cerr << "Misaligned input in 2nd GEMM." << std::endl; -+ } -+ // misaligned tensors for Gemm2 -+ return false; -+ } -+ -+ return true; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template< -+ typename LayoutOutput_> -+struct Testbed { -+ -+ // -+ // Type definitions -+ // -+ -+ // User-defined data types -+ using ElementInputA0 = cutlass::half_t; -+ using ElementInputB0 = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using LayoutInputA0 = cutlass::layout::RowMajor; -+ using LayoutInputB0 = cutlass::layout::ColumnMajor; -+ using LayoutOutput = LayoutOutput_; -+ -+ static bool const kIsColumnMajorOutput = cutlass::platform::is_same::value; -+ // turn of shifted K by default -+ static bool const kIsShiftedVariance = false; -+ -+ /// Linear scaling operator -+ using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementCompute, -+ ElementCompute -+ >; -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kStages0 = 3; -+ static int const kStages1 = 4; -+ -+ using GemmLayernorm = cutlass::GemmLayernorm< -+ ElementInputA0, -+ LayoutInputA0, -+ ElementInputB0, -+ LayoutInputB0, -+ ElementOutput, -+ LayoutOutput, -+ ElementCompute, -+ EpilogueFunctorOp, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ kStages0, -+ kStages1, -+ kIsShiftedVariance -+ >; -+ -+ using ElementInputA1 = typename GemmLayernorm::ElementInputA1; -+ using ElementOutputC1 = typename GemmLayernorm::ElementOutputC1; -+ using ElementInputScaleBias = typename GemmLayernorm::ElementInputScaleBias; -+ using ElementLayernormCompute = typename GemmLayernorm::ElementLayernormCompute; -+ -+ using LayoutInputA1 = typename GemmLayernorm::LayoutInputA1; -+ using LayoutOutputC0 = typename GemmLayernorm::LayoutOutputC0; -+ using LayoutOutputC1 = typename GemmLayernorm::LayoutOutputC1; -+ using LayoutInputScaleBias = typename GemmLayernorm::LayoutInputScaleBias; -+ -+ // -+ // Data members -+ // -+ -+ Options const &options; -+ -+ cutlass::HostTensor tensor_A0; -+ cutlass::HostTensor tensor_B0; -+ cutlass::HostTensor tensor_C0; -+ cutlass::HostTensor tensor_A1; -+ cutlass::HostTensor tensor_C1; -+ -+ cutlass::HostTensor reference_C0; -+ cutlass::HostTensor reference_C1; -+ -+ cutlass::HostTensor tensor_Variance; -+ cutlass::HostTensor tensor_Mean; -+ cutlass::HostTensor tensor_Beta; -+ cutlass::HostTensor tensor_Gamma; -+ -+ cutlass::HostTensor reference_Mean; -+ cutlass::HostTensor reference_Variance; -+ -+ // shifted K tensor to better ensure the numerical stability -+ // According to https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -+ // the closer shifted K to the actual mean, the better numerical stability we'll observe -+ cutlass::HostTensor tensor_Shifted_K; -+ -+ // -+ // Methods -+ // -+ -+ Testbed( -+ Options const &options_ -+ ): -+ options(options_) -+ { -+ -+ tensor_A0.reset({options.problem_size0.m(), options.problem_size0.k()}); -+ tensor_B0.reset({options.problem_size0.k(), options.problem_size0.n()}); -+ -+ tensor_C0.reset({options.problem_size0.m(), options.problem_size0.n()}); -+ -+ tensor_A1.reset({options.problem_size1.m(), options.problem_size1.k()}); -+ tensor_C1.reset({options.problem_size1.m(), options.problem_size1.n()}); -+ -+ reference_C0.reset({options.problem_size0.m(), options.problem_size0.n()}); -+ reference_C1.reset({options.problem_size1.m(), options.problem_size1.n()}); -+ -+ int leading_dim_0 = kIsColumnMajorOutput ? options.problem_size0.n() : options.problem_size0.m(); -+ int leading_dim_1 = kIsColumnMajorOutput ? options.problem_size0.m() : options.problem_size0.n(); -+ -+ int block_num = (leading_dim_1 + GemmLayernorm::ThreadblockShape::kM - 1) / GemmLayernorm::ThreadblockShape::kM; -+ -+ tensor_Variance.reset({block_num, leading_dim_0}); -+ tensor_Mean.reset({block_num, leading_dim_0}); -+ tensor_Shifted_K.reset({1, leading_dim_0}); -+ -+ tensor_Beta.reset({1, leading_dim_1}); -+ tensor_Gamma.reset({1, leading_dim_1}); -+ -+ reference_Mean.reset({1, leading_dim_0}, false); -+ reference_Variance.reset({1, leading_dim_0}, false); -+ -+ } -+ -+ /// Run -+ Disposition run() { -+ -+ Disposition disposition = Disposition::kNotVerified; -+ -+ // -+ // Initialize the workspace -+ // -+ -+ initialize(); -+ -+ // -+ // Launch device kernel -+ // -+ cutlass::Status status = cutlass::Status::kSuccess; -+ -+ status = execute_device_kernel(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Device execution failed." << std::endl; -+ return disposition; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "Device synchronize failed with error " -+ << cudaGetErrorString(result) << std::endl; -+ return disposition; -+ } -+ -+ // -+ // Compute the reference -+ // -+ compute_reference(); -+ -+ // -+ // Verify -+ // -+ -+ if (options.verification_enabled) { -+ -+ bool passed = verify(); -+ -+ if (passed) { -+ disposition = Disposition::kPassed; -+ } -+ else { -+ disposition = Disposition::kIncorrect; -+ } -+ } -+ -+ // -+ // Profiling -+ // -+ if (options.iterations) { -+ profile(); -+ } -+ -+ return disposition; -+ } -+ -+ /// Random initialization -+ void initialize() { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_A0.host_view(), -+ options.seed, -+ ElementInputA0(5), -+ ElementInputA0(-5), -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_B0.host_view(), -+ options.seed + 1, -+ ElementInputB0(5), -+ ElementInputB0(-5), -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_A1.host_view(), -+ options.seed + 2, -+ ElementInputA1(5), -+ ElementInputA1(-5), -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_Beta.host_view(), -+ options.seed + 3, -+ ElementInputScaleBias(5), -+ ElementInputScaleBias(-5), -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_Gamma.host_view(), -+ options.seed + 4, -+ ElementInputScaleBias(5), -+ ElementInputScaleBias(-5), -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_Shifted_K.host_view(), -+ options.seed + 5, -+ ElementOutput(5), -+ ElementOutput(-6), -+ 0 -+ ); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_A1.sync_device(); -+ tensor_Beta.sync_device(); -+ tensor_Gamma.sync_device(); -+ -+ } -+ -+ -+ -+ cutlass::Status execute_device_kernel() { -+ -+ cutlass::Status status = cutlass::Status::kSuccess; -+ -+ // -+ // Setup arguments -+ // -+ -+ typename GemmLayernorm::Arguments args( -+ options.problem_size0, -+ options.problem_size1, -+ tensor_A0.device_ref().data(), -+ tensor_B0.device_ref().data(), -+ tensor_C0.device_ref().data(), -+ tensor_C0.device_ref().data(), -+ tensor_A1.device_ref().data(), -+ tensor_C1.device_ref().data(), -+ tensor_A0.device_ref().stride(0), -+ tensor_B0.device_ref().stride(0), -+ tensor_C0.device_ref().stride(0), -+ tensor_C0.device_ref().stride(0), -+ tensor_A1.device_ref().stride(0), -+ tensor_C1.device_ref().stride(0), -+ { -+ ElementCompute(options.alpha), -+ ElementCompute(options.beta) -+ }, -+ tensor_Variance.device_ref(), -+ tensor_Mean.device_ref(), -+ tensor_Gamma.device_ref(), -+ tensor_Beta.device_ref(), -+ tensor_Shifted_K.device_ref().data() -+ ); -+ -+ // -+ // Launch -+ // -+ -+ GemmLayernorm gemm_layernorm; -+ -+ // Initialize -+ status = gemm_layernorm.initialize(args); -+ if (status != cutlass::Status::kSuccess) { -+ return status; -+ } -+ -+ // Run -+ status = gemm_layernorm(); -+ -+ return status; -+ } -+ -+ /// Reference calculation -+ void compute_reference() { -+ -+ cutlass::reference::device::Gemm< -+ ElementInputA0, -+ LayoutInputA0, -+ ElementInputB0, -+ LayoutInputB0, -+ ElementOutput, -+ LayoutOutputC0, -+ ElementCompute, -+ ElementCompute -+ > gemm_device0; -+ -+ cutlass::reference::device::Gemm< -+ ElementInputA1, -+ LayoutInputA1, -+ ElementOutput, -+ LayoutOutputC0, -+ ElementOutputC1, -+ LayoutOutputC1, -+ ElementCompute, -+ ElementCompute -+ > gemm_device1; -+ -+ // Compute 1st GEMM -+ gemm_device0( -+ options.problem_size0, -+ ElementCompute(options.alpha), -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ ElementCompute(options.beta), -+ tensor_C0.device_ref(), -+ reference_C0.device_ref() -+ ); -+ -+ reference_C0.sync_host(); -+ -+ tensor_Mean.sync_host(); -+ tensor_Variance.sync_host(); -+ tensor_Gamma.sync_host(); -+ tensor_Beta.sync_host(); -+ tensor_Shifted_K.sync_host(); -+ -+ // Compute the sum and square sum for verification purpose -+ if (kIsColumnMajorOutput) { -+ for (int n = 0; n < options.problem_size0.n(); ++n) { -+ -+ ElementLayernormCompute sum = ElementLayernormCompute(0); -+ ElementLayernormCompute square_sum = ElementLayernormCompute(0); -+ for (int m = 0; m < options.problem_size0.m(); ++m) { -+ sum += ElementLayernormCompute(reference_C0.at({m, n})); -+ square_sum += ElementLayernormCompute(reference_C0.at({m, n})) * ElementLayernormCompute(reference_C0.at({m, n})); -+ } -+ -+ ElementLayernormCompute mean = sum / ElementLayernormCompute(options.problem_size0.m()); -+ ElementLayernormCompute square_mean = square_sum / ElementLayernormCompute(options.problem_size0.m()); -+ ElementLayernormCompute variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6) ) ; -+ -+ mean = -mean * variance; -+ -+ reference_Mean.at({0, n}) = ElementInputScaleBias(mean); -+ reference_Variance.at({0, n}) = ElementInputScaleBias(variance); -+ } -+ }else{ -+ for (int m = 0; m < options.problem_size0.m(); ++m) { -+ -+ ElementLayernormCompute sum = ElementLayernormCompute(0); -+ ElementLayernormCompute square_sum = ElementLayernormCompute(0); -+ for (int n = 0; n < options.problem_size0.n(); ++n) { -+ sum += ElementLayernormCompute(reference_C0.at({m, n})) ; -+ square_sum += ElementLayernormCompute(reference_C0.at({m, n})) * ElementLayernormCompute(reference_C0.at({m, n})) ; -+ } -+ -+ ElementLayernormCompute mean = sum / ElementLayernormCompute(options.problem_size0.n()); -+ ElementLayernormCompute square_mean = square_sum / ElementLayernormCompute(options.problem_size0.n()); -+ ElementLayernormCompute variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6)) ; -+ -+ mean = -mean * variance; -+ -+ reference_Mean.at({0, m}) = ElementInputScaleBias(mean); -+ reference_Variance.at({0, m}) = ElementInputScaleBias(variance); -+ } -+ } -+ -+ // Element-wise transform for OutputC0 using 1-pass layernorm algo -+ if (kIsColumnMajorOutput) { -+ for (int n = 0; n < options.problem_size0.n(); ++n) { -+ -+ ElementLayernormCompute sum = ElementLayernormCompute(0); -+ for (int m = 0; m < options.problem_size0.m(); ++m) { -+ sum += ElementLayernormCompute(reference_C0.at({m, n})) ; -+ } -+ -+ ElementInputScaleBias mean = ElementInputScaleBias(sum / ElementLayernormCompute(options.problem_size0.m())); -+ sum = ElementLayernormCompute(0); -+ for (int m = 0; m < options.problem_size0.m(); ++m) { -+ sum += ElementLayernormCompute(reference_C0.at({m, n}) - ElementLayernormCompute(mean)) * ElementLayernormCompute(reference_C0.at({m, n}) - ElementLayernormCompute(mean)) ; -+ } -+ -+ ElementLayernormCompute square_mean = sum / ElementLayernormCompute(options.problem_size0.m()); -+ ElementInputScaleBias variance = ElementInputScaleBias(cutlass::constants::one() -+ / cutlass::fast_sqrt(square_mean + ElementLayernormCompute(1e-6))) ; -+ -+ for (int m = 0; m < options.problem_size0.m(); ++m) { -+ reference_C0.at({m, n}) = -+ ElementOutput( ( (ElementInputScaleBias(reference_C0.at({m, n})) - mean) * variance ) -+ * tensor_Gamma.at({0, m}) + tensor_Beta.at({0, m})); -+ -+ } -+ -+ } -+ }else{ -+ -+ for (int m = 0; m < options.problem_size0.m(); ++m) { -+ -+ float sum = float(0); -+ for (int n = 0; n < options.problem_size0.n(); ++n) { -+ sum += float(reference_C0.at({m, n})) ; -+ } -+ -+ float mean = sum / float(options.problem_size0.n()); -+ sum = float(0); -+ for (int n = 0; n < options.problem_size0.n(); ++n) { -+ sum += float(reference_C0.at({m, n}) - mean) * float(reference_C0.at({m, n}) - mean) ; -+ } -+ -+ float square_mean = sum / float(options.problem_size0.n()); -+ float variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean + ElementLayernormCompute(1e-6)) ; -+ -+ for (int n = 0; n < options.problem_size0.n(); ++n) { -+ reference_C0.at({m, n}) = -+ ElementOutput( ( (float(reference_C0.at({m, n})) - mean) * variance ) -+ * float(tensor_Gamma.at({0, n})) + float(tensor_Beta.at({0, n}))); -+ -+ } -+ -+ } -+ -+ } -+ -+ -+ // Sync host data with device after element-wise transform -+ reference_C0.sync_device(); -+ -+ // Compute 2nd GEMM -+ gemm_device1( -+ options.problem_size1, -+ ElementCompute(options.alpha), -+ kIsColumnMajorOutput ? tensor_A1.device_ref() : reference_C0.device_ref(), -+ kIsColumnMajorOutput ? reference_C0.device_ref() :tensor_A1.device_ref(), -+ ElementCompute(options.beta), -+ reference_C1.device_ref(), -+ reference_C1.device_ref() -+ ); -+ -+ } -+ -+ /// Emits all tensor values -+ void emit_results() { -+ std::cout << "tensor_C1 = \n" << tensor_C1.host_view() << "\n\n"; -+ std::cout << "Reference C1 = \n" << reference_C1.host_view() << "\n\n"; -+ std::cout << "Mean = \n" << tensor_Mean.host_view() << "\n\n"; -+ std::cout << "rsqrt(Variance) = \n" << tensor_Variance.host_view() << "\n\n"; -+ std::cout << "Reference Mean = \n" << reference_Mean.host_view() << "\n\n"; -+ std::cout << "Reference rsqrt(Variance) = \n" << reference_Variance.host_view() << "\n\n"; -+ } -+ -+ template -+ bool verify_tensor(cutlass::HostTensor tensor, \ -+ cutlass::HostTensor reference, -+ int leading_dim0, int leading_dim1, bool is_print = false) { -+ float const kThreshold = float(options.tolerance); -+ float const kAbsThreshold = 0.5f; -+ float const kRelativeThreshold = 0.1f; -+ // Adds a constant bias to avoid being divided by '0' -+ float const kBias = 1e-5f; -+ int counter = 0; -+ for (int m = 0; m < leading_dim0; m++) { -+ for (int n = 0; n < leading_dim1; ++n) { -+ float diff = (float)(tensor.at({m, n}) - reference.at({m, n})); -+ float rel_diff = fabs(diff) / fabs(reference.at({m, n}) + kBias); -+ if (fabs(diff) > kAbsThreshold && rel_diff > kRelativeThreshold) { -+ counter++; -+ } -+ } -+ } -+ -+ float err_rate = float(counter) / (float(leading_dim0) * float(leading_dim1)); -+ return (err_rate < kThreshold); -+ } -+ -+ /// Verifies the reference matches -+ bool verify() { -+ -+ tensor_Variance.sync_host(); -+ tensor_Mean.sync_host(); -+ tensor_C1.sync_host(); -+ reference_C1.sync_host(); -+ -+ // Verification checks - set any of these to 'true' to override the verification checks. -+ bool verified_C1 = false; -+ bool verified_Mean = false; -+ bool verified_Variance = false; -+ -+ // Verify layernorm output -+ if (!verified_C1) { -+ verified_C1 = verify_tensor(tensor_C1, reference_C1, options.problem_size1.m(), options.problem_size1.n()); -+ } -+ -+ if (!verified_Variance) { -+ verified_Variance = verify_tensor(tensor_Variance, reference_Variance, 1, options.problem_size0.n()); -+ } -+ -+ if (!verified_Mean) { -+ verified_Mean = verify_tensor(tensor_Mean, reference_Mean, 1, options.problem_size0.n()); -+ } -+ -+ if (!verified_C1 || !verified_Mean || !verified_Variance) { -+ -+ // emit_results(); -+ -+ std::cerr << "Verification check failed for tensor Layernorm" << std::endl; -+ -+ // Summarize which checks failed -+ if (!verified_C1) { -+ std::cerr << "Verification of O tensor failed\n"; -+ } -+ -+ if (!verified_Mean) { -+ std::cerr << "Verification of Mean tensor failed\n"; -+ } -+ -+ if (!verified_Variance) { -+ std::cerr << "Verification of Variance tensor failed\n"; -+ } -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Profiles -+ bool profile() { -+ -+ // -+ // Profile -+ // -+ -+ cutlass::Status status = cutlass::Status::kSuccess; -+ cudaError_t result; -+ cudaEvent_t events[2]; -+ int const kIterations = options.iterations; -+ -+ for (cudaEvent_t &evt : events) { -+ result = cudaEventCreate(&evt); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ } -+ -+ result = cudaEventRecord(events[0]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ for (int iter = 0; iter < kIterations; ++iter) { -+ -+ status = execute_device_kernel(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Device execution failed." << std::endl; -+ return false; -+ } -+ } -+ -+ result = cudaEventRecord(events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ float elapsed_ms = 0; -+ result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); -+ -+ float elapsed_ms_per_iter = elapsed_ms / float(kIterations); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ for (cudaEvent_t &evt : events) { -+ result = cudaEventDestroy(evt); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ } -+ -+ int64_t flops = int64_t(options.problem_size0.m()) * options.problem_size0.n() * options.problem_size0.k() * 2 \ -+ + int64_t(options.problem_size1.m()) * options.problem_size1.n() * options.problem_size1.k() * 2; -+ -+ double gflops_per_second = double(flops) * kIterations / double(elapsed_ms / 1000.0f) / double(1.0e9); -+ -+ std::cout << " 1st GEMM: " -+ << options.problem_size0.m() << "-by-" << options.problem_size0.n() << "-by-" << options.problem_size0.k() << "\n" -+ << " 2nd GEMM: " -+ << options.problem_size1.m() << "-by-" << options.problem_size1.n() << "-by-" << options.problem_size1.k() -+ << std::endl; -+ -+ std::cout << " Runtime / iteration: " << elapsed_ms_per_iter << " ms\n" << std::endl; -+ std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl; -+ -+ return true; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, const char **argv) { -+ -+ // Define final layout -+ using LayoutOutput = cutlass::layout::ColumnMajor; -+ -+ // Options parsing -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (!options.supported()) { -+ return 0; -+ } -+ -+ // Run -+ Testbed testbed(options); -+ -+ Disposition disposition = testbed.run(); -+ -+ std::cout << std::endl; -+ -+ switch (disposition) { -+ case Disposition::kPassed: -+ std::cout << "Passed" << std::endl; -+ break; -+ case Disposition::kIncorrect: -+ std::cout << "Incorrect" << std::endl; -+ break; -+ case Disposition::kNotVerified: -+ std::cout << "Not verified" << std::endl; -+ break; -+ } -+ -+ return (disposition == Disposition::kPassed ? 0 : -1); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h b/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h -new file mode 100644 -index 0000000..143bca3 ---- /dev/null -+++ b/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h -@@ -0,0 +1,444 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief GEMM kernel to support the epilogue visitor model -+ for customized layernorm partial reduction epilogue fusion. -+ -+ This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once -+ its usage has been stabilized. For now, it is included in this example to demonstrate -+ some basic output fusion options. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmWithEpilogueVisitor { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueVisitor = typename Epilogue::Visitor; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using TensorRefA = TensorRef; -+ -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using TensorRefB = TensorRef; -+ -+ using ElementC = typename EpilogueVisitor::ElementOutput; -+ using LayoutC = typename Epilogue::Layout; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value -+ ); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ -+ typename EpilogueVisitor::Arguments epilogue_visitor; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ mode(GemmUniversalMode::kGemm) -+ { } -+ -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode_, -+ GemmCoord problem_size_, -+ TensorRefA ref_A_, -+ TensorRefB ref_B_, -+ typename EpilogueVisitor::Arguments epilogue_visitor_ -+ ): -+ mode(mode_), -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ epilogue_visitor(epilogue_visitor_) -+ { -+ -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ -+ GemmUniversalMode mode; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ -+ typename EpilogueVisitor::Params epilogue_visitor; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ params_A(0), -+ params_B(0), -+ gemm_k_size(0), -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr) -+ { } -+ -+ -+ Params( -+ Arguments const &args -+ ): -+ problem_size(args.problem_size), -+ swizzle_log_tile(0), -+ params_A(args.ref_A.layout()), -+ params_B(args.ref_B.layout()), -+ mode(args.mode), -+ gemm_k_size(args.problem_size.k()), -+ ptr_A(args.ref_A.data()), -+ ptr_B(args.ref_B.data()), -+ epilogue_visitor(args.epilogue_visitor) -+ { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 1); -+ -+ if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); -+ -+ gemm_k_size = round_up(args.problem_size.k(), kAlignK); -+ -+ if (gemm_k_size) { -+ grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); -+ } -+ } -+ -+ swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ -+ typename Mma::SharedStorage main_loop; -+ -+ struct { -+ typename Epilogue::SharedStorage epilogue; -+ typename EpilogueVisitor::SharedStorage visitor; -+ } epilogue; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ GemmWithEpilogueVisitor() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // -+ // Construct the epilogue visitor -+ // -+ -+ EpilogueVisitor epilogue_visitor( -+ params.epilogue_visitor, -+ shared_storage.epilogue.visitor, -+ params.problem_size.mn(), -+ thread_idx, -+ warp_idx, -+ lane_idx, -+ threadblock_offset); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ // Indicate which position in a serial reduction the output operator is currently updating -+ epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { -+ epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); -+ } -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(epilogue_visitor, accumulators); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h b/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h -new file mode 100644 -index 0000000..dde3c07 ---- /dev/null -+++ b/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h -@@ -0,0 +1,1066 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief A file contains all functioning classes needed by GemmLayernorm. -+ -+ GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm) -+ + lightweight full reduction kernel (ApplyFinalReduction) -+ + GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion) -+ -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "gemm_with_epilogue_visitor.h" -+#include "helper.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementVariance_, -+ typename ElementMean_, -+ typename ElementLayernormCompute_, -+ typename ElementOutput, -+ typename ThreadblockShape_, -+ bool IsShiftedVariance_ = false -+> -+class ApplyFinalReduction { -+public: -+ -+ using ElementVariance = ElementVariance_; -+ using ElementMean = ElementMean_; -+ using ElementLayernormCompute = ElementLayernormCompute_; -+ using ThreadblockShape = ThreadblockShape_; -+ -+ // Pre-processing has ensured the layout equivelent to RowMajor -+ using Layout = cutlass::layout::RowMajor; -+ -+ using TensorVariance = TensorRef; -+ using TensorMean = TensorRef; -+ -+ static bool const kIsShiftedVariance = IsShiftedVariance_; -+ -+ // -+ // Arguments -+ // -+ -+ struct Arguments { -+ -+ MatrixCoord extent; ///< Extent of D and Layernorm matrices -+ TensorVariance ref_Variance; ///< Sum Square or Variance tensor (input / output) -+ TensorMean ref_Mean; ///< Sum or Mean tensor (input / output) -+ ElementOutput *ptr_Shifted_K; ///< Shifted K tensor pointer -+ -+ // -+ // Methods -+ // -+ Arguments(){ } -+ -+ Arguments( -+ MatrixCoord extent_, -+ TensorVariance ref_Variance_, -+ TensorMean ref_Mean_, -+ ElementOutput *ptr_Shifted_K_ -+ ): -+ extent(extent_), -+ ref_Variance(ref_Variance_), -+ ref_Mean(ref_Mean_), -+ ptr_Shifted_K(ptr_Shifted_K_) -+ { -+ -+ } -+ }; -+ -+ struct SharedStorage { -+ -+ -+ }; -+ -+ // -+ // Params struct -+ // -+ -+ struct Params { -+ Arguments args; -+ -+ // -+ // Methods -+ // -+ Params() { } -+ -+ Params(Arguments const &args_): args(args_) { } -+ }; -+ -+private: -+ -+public: -+ -+ CUTLASS_DEVICE -+ ApplyFinalReduction() { } -+ -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ apply(params, shared_storage); -+ } -+ -+private: -+ -+ /// Partial reduction -+ CUTLASS_DEVICE -+ void apply(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ int threadblock_num = (params.args.extent.column() + ThreadblockShape::kM - 1) / ThreadblockShape::kM; -+ -+ int block_n = blockIdx.x * blockDim.x; -+ -+ int thread_n = threadIdx.x; -+ -+ int idx_n = block_n + thread_n; -+ -+ if (idx_n >= params.args.extent.row()) { -+ return; -+ } -+ -+ using ConvertVarianceOutput = cutlass::NumericConverter; -+ using ConvertMeanOutput = cutlass::NumericConverter; -+ -+ using ConvertVariance = cutlass::NumericConverter; -+ using ConvertMean = cutlass::NumericConverter; -+ -+ using ConvertShiftK = cutlass::NumericConverter; -+ -+ ConvertVariance convert_variance; -+ ConvertMean convert_mean; -+ -+ ConvertVarianceOutput convert_variance_output; -+ ConvertMeanOutput convert_mean_output; -+ -+ ElementVariance *access_square = params.args.ref_Variance.data() + idx_n; -+ ElementMean *access_mean = params.args.ref_Mean.data() + idx_n; -+ -+ ElementVariance *access_square_bak = access_square; -+ ElementMean *access_mean_bak = access_mean; -+ -+ ElementLayernormCompute frag_square_sum = ElementLayernormCompute(0); -+ ElementLayernormCompute frag_element_sum = ElementLayernormCompute(0); -+ ElementVariance fetch_square; -+ ElementMean fetch_mean; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx_m = 0; idx_m < threadblock_num; idx_m++) { -+ arch::global_load(fetch_square, access_square, true); -+ arch::global_load(fetch_mean, access_mean, true); -+ frag_element_sum += convert_mean(fetch_mean); -+ frag_square_sum += convert_variance(fetch_square); -+ access_square += params.args.extent.row(); -+ access_mean += params.args.extent.row(); -+ } -+ -+ ElementLayernormCompute mean = frag_element_sum; -+ ElementLayernormCompute square_mean = frag_square_sum; -+ -+ ElementLayernormCompute variance; -+ -+ if (kIsShiftedVariance && params.args.ptr_Shifted_K != nullptr) { -+ ElementOutput *access_shift_k = params.args.ptr_Shifted_K + idx_n; -+ ElementOutput fetch_shift_k; -+ ConvertShiftK convert_shift_k; -+ arch::global_load(fetch_shift_k, access_shift_k, true); -+ ElementLayernormCompute shifted_mean = mean - convert_shift_k(fetch_shift_k); -+ variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - shifted_mean * shifted_mean + ElementLayernormCompute(1e-6)); -+ }else{ -+ variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6)); -+ } -+ -+ mean = -mean * variance; -+ -+ access_square = access_square_bak; -+ access_mean = access_mean_bak; -+ -+ access_square[0] = convert_variance_output(variance); -+ access_mean[0] = convert_mean_output(mean); -+ -+ } -+ -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ThreadblockShape_, -+ int ThreadCount, -+ typename OutputTileIterator_, -+ typename AccumulatorTile_, -+ typename ElementAccumulator_, -+ typename ElementVariance_, -+ typename ElementMean_, -+ typename ElementLayernormCompute_, -+ typename ElementwiseFunctor_, -+ bool IsShiftedVariance_ = false -+> -+class EpilogueVisitorLayerNorm { -+public: -+ -+ using ElementVariance = ElementVariance_; -+ using ElementMean = ElementMean_; -+ using ElementLayernormCompute = ElementLayernormCompute_; -+ -+ using AccumulatorTile = AccumulatorTile_; -+ -+ using ThreadblockShape = ThreadblockShape_; -+ static int const kThreadCount = ThreadCount; -+ -+ using OutputTileIterator = OutputTileIterator_; -+ using ElementwiseFunctor = ElementwiseFunctor_; -+ -+ static int const kIterations = OutputTileIterator::kIterations; -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow; -+ -+ static int const kThreads = OutputTileIterator::ThreadMap::kThreads; -+ -+ static bool const kIsShiftedVariance = IsShiftedVariance_; -+ -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow; -+ -+ /// Array type used in Shift-K Layernorm -+ static int const kRowAccessCount = kIterations * kRowIterations; -+ -+ using ConvertedShiftFragment = Array; -+ -+ // Conducts manual transpose externally (already supported) for column major -+ using LayoutOutput = cutlass::layout::RowMajor; -+ -+ using ElementAccumulator = ElementAccumulator_; -+ -+ using AccumulatorFragment = Array; -+ using LayernormFragment = Array; -+ using OutputVector = Array; -+ using TensorRefD = TensorRef; -+ -+ static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::RowArrangement::Detail::kShapeWidth; -+ static int const kThreadsInColumn = kThreads / kThreadsPerRow; -+ static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1); -+ -+ /// Argument structure -+ struct Arguments { -+ -+ typename ElementwiseFunctor::Params elementwise; -+ TensorRefD ref_C; -+ TensorRefD ref_D; -+ ElementVariance *ptr_Variance; -+ ElementMean *ptr_Mean; -+ ElementOutput *ptr_Shifted_K; -+ -+ // -+ // Methods -+ // -+ Arguments(): -+ ptr_Variance(nullptr), -+ ptr_Mean(nullptr), -+ ptr_Shifted_K(nullptr) -+ { -+ -+ } -+ -+ Arguments( -+ typename ElementwiseFunctor::Params elementwise_, -+ TensorRefD ref_C_, -+ TensorRefD ref_D_, -+ ElementVariance *ptr_Variance, -+ ElementMean *ptr_Mean_, -+ ElementOutput *ptr_Shifted_K_ = nullptr -+ ): -+ elementwise(elementwise_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ ptr_Variance(ptr_Variance), -+ ptr_Mean(ptr_Mean_), -+ ptr_Shifted_K(ptr_Shifted_K_) -+ { -+ -+ } -+ }; -+ -+ struct Params { -+ -+ typename ElementwiseFunctor::Params elementwise; -+ typename OutputTileIterator::Params params_C; -+ typename OutputTileIterator::Params params_D; -+ typename OutputTileIterator::Element *ptr_C; -+ typename OutputTileIterator::Element *ptr_D; -+ ElementVariance *ptr_Variance; -+ ElementMean *ptr_Mean; -+ ElementOutput *ptr_Shifted_K; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params(): -+ ptr_D(nullptr), -+ ptr_Variance(nullptr), -+ ptr_Mean(nullptr) -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ elementwise(args.elementwise), -+ params_C(args.ref_C.layout()), -+ params_D(args.ref_D.layout()), -+ ptr_C(args.ref_C.data()), -+ ptr_D(args.ref_D.data()), -+ ptr_Variance(args.ptr_Variance), -+ ptr_Mean(args.ptr_Mean), -+ ptr_Shifted_K(args.ptr_Shifted_K) -+ { -+ -+ } -+ }; -+ -+ /// Shared storage -+ struct SharedStorage { -+ -+ }; -+ -+private: -+ -+ Params const & params_; -+ SharedStorage & shared_storage_; -+ MatrixCoord extent_; -+ ElementwiseFunctor elementwise_; -+ -+ OutputTileIterator iterator_C_; -+ OutputTileIterator iterator_D_; -+ typename OutputTileIterator::Fragment fragment_C_; -+ typename OutputTileIterator::Fragment fragment_D_; -+ -+ ElementAccumulator alpha_; -+ ElementAccumulator beta_; -+ ConvertedShiftFragment shift_k_frag_; -+ -+ ElementLayernormCompute accum_sum_square_; -+ ElementLayernormCompute accum_sum_element_; -+ -+ MatrixCoord thread_offset_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ EpilogueVisitorLayerNorm( -+ Params const ¶ms, ///< Parameters routed to the epilogue -+ SharedStorage &shared_storage, ///< Shared storage needed by the functors here -+ MatrixCoord const &problem_size0, ///< Problem size of the output -+ int thread_idx, ///< Thread index within the threadblock -+ int warp_idx, ///< Warp index within the threadblock -+ int lane_idx, ///< Lane index within the warp -+ MatrixCoord const &threadblock_offset = MatrixCoord(0, 0) -+ ): -+ params_(params), -+ shared_storage_(shared_storage), -+ extent_(problem_size0), -+ elementwise_(params.elementwise), -+ iterator_C_(params.params_C, params.ptr_C, problem_size0, thread_idx, threadblock_offset), -+ iterator_D_(params.params_D, params.ptr_D, problem_size0, thread_idx, threadblock_offset) -+ { -+ alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); -+ beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); -+ -+ if (beta_ == ElementAccumulator()) { -+ iterator_C_.clear_mask(); -+ } -+ } -+ -+ /// Helper to indicate split-K behavior -+ CUTLASS_DEVICE -+ void set_k_partition( -+ int split_k_index, ///< Index of this threadblock within split-K partitioned scheme -+ int split_k_slices) { ///< Total number of split-K slices -+ -+ } -+ -+ /// Called to set the batch index -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ -+ } -+ -+ /// Called at the start of the epilogue just before iterating over accumulator slices -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ -+ // If shift-K feature is enabled, we load shift-k fragment -+ // at the very beginning of an epilogue -+ if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) { -+ shift_k_frag_.clear(); -+ int thread_offset_row_base = iterator_D_.thread_start_row(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) { -+ int step_offset = iter_idx * OutputTileIterator::Shape::kRow; -+ CUTLASS_PRAGMA_UNROLL -+ for (int rid = 0; rid < kRowIterations; ++rid) { -+ int row_step_offset = rid * kDeltaRow; -+ int row_offset = thread_offset_row_base + step_offset + row_step_offset; -+ bool is_load = (row_offset < extent_.row()); -+ shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load); -+ } -+ -+ } -+ -+ } -+ -+ } -+ -+ /// Called at the start of one step before starting accumulator exchange -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ fragment_D_.clear(); -+ -+ if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { -+ fragment_C_.clear(); -+ iterator_C_.load(fragment_C_); -+ ++iterator_C_; -+ } -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ -+ } -+ -+ /// Called after accumulators have been exchanged for each accumulator vector -+ CUTLASS_DEVICE -+ void visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorFragment const &accum) { -+ -+ using Mul = cutlass::multiplies; -+ using Minus = cutlass::minus; -+ using Exp = cutlass::fast_exp_op; -+ -+ [[maybe_unused]] Minus minus; -+ [[maybe_unused]] Mul mul; -+ [[maybe_unused]] Exp exponential; -+ -+ LayernormFragment result; -+ -+ thread_offset_ = -+ iterator_D_.thread_start() + -+ OutputTileIterator::ThreadMap::iteration_offset(frag_idx); -+ -+ NumericArrayConverter source_converter; -+ OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; -+ -+ bool column_guard = (thread_offset_.column() < extent_.column()); -+ -+ if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { -+ result = source_converter(elementwise_(accum)); -+ }else{ -+ result = source_converter(elementwise_(accum, source_vector)); -+ } -+ -+ -+ ElementLayernormCompute inv_scalar = cutlass::constants::one() / ElementLayernormCompute(extent_.column()); -+ -+ // Fragment is cleared for non-reachable columns so no need to check against column guard -+ accum_sum_element_ = element_sum_accumulator_(result); -+ -+ // Square sum is different. Non-reachable columns should've been computed for shift-k -+ // Otherwise we will incorrectly have some extra k^2 added into square sum. -+ if (column_guard) { -+ accum_sum_square_ = (kIsShiftedVariance) ? \ -+ square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \ -+ square_sum_accumulator_(result); -+ } -+ else { -+ accum_sum_square_ = ElementLayernormCompute(0); -+ } -+ -+ accum_sum_element_ *= inv_scalar; -+ accum_sum_square_ *= inv_scalar; -+ -+ // After performing the in-thread reduction, we then perform cross-thread / in-warp reduction -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) { -+ accum_sum_element_ += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_, i); -+ accum_sum_square_ += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_, i); -+ } -+ -+ // Convert to the output -+ NumericArrayConverter output_converter; -+ OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; -+ output = output_converter(result); -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ -+ using ConvertVarianceOutput = cutlass::NumericConverter; -+ using ConvertMeanOutput = cutlass::NumericConverter; -+ -+ ConvertVarianceOutput convert_variance_output; -+ ConvertMeanOutput convert_mean_output; -+ -+ bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0); -+ int row_offset = thread_offset_.row() + blockIdx.y * extent_.row(); -+ -+ ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset; -+ ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset; -+ -+ arch::global_store( -+ convert_variance_output(accum_sum_square_), -+ (void *)curr_ptr_sum_square, -+ is_write_thread); -+ -+ arch::global_store( -+ convert_mean_output(accum_sum_element_), -+ (void *)curr_ptr_element_sum, -+ is_write_thread); -+ -+ } -+ -+ /// Called after all accumulator elements have been visited -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ -+ iterator_D_.store(fragment_D_); -+ ++iterator_D_; -+ } -+ -+ /// Called after all steps have been completed -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) { -+ using ConvertShiftK = cutlass::NumericConverter; -+ ConvertShiftK convert_shift_k; -+ ElementOutput shift_k_val; -+ -+ // Computes the address to load shift_k element -+ ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset; -+ // Conditionally loads from global memory -+ arch::global_load(shift_k_val, (void *)curr_ptr_shift_k, is_load); -+ // Converts data type to return -+ ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val); -+ -+ return converted_shift_k_val; -+ } -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) { -+ ElementLayernormCompute sum_ = ElementLayernormCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < LayernormFragment::kElements; ++i) { -+ auto accum_ = accum[i]; -+ sum_ += accum_ * accum_; -+ } -+ -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) { -+ ElementLayernormCompute sum_ = ElementLayernormCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < LayernormFragment::kElements; ++i) { -+ auto accum_ = accum[i] - shift_k_val; -+ sum_ += accum_ * accum_; -+ } -+ -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) { -+ ElementLayernormCompute sum_ = ElementLayernormCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < LayernormFragment::kElements; ++i) { -+ sum_ += accum[i]; -+ } -+ -+ return sum_; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename ElementInputA0_, -+ typename LayoutInputA0_, -+ typename ElementInputB0_, -+ typename LayoutInputB0_, -+ typename ElementOutput_, -+ typename LayoutOutput_, -+ typename ElementCompute_, -+ typename EpilogueFunctorOp_, -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ typename InstructionShape_, -+ int Stages0, -+ int Stages1, -+ bool IsShiftedVariance_ = false -+> -+class GemmLayernorm { -+public: -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ // -+ // Type definitions -+ // -+ -+ static bool const kInternalTranspose = cutlass::platform::is_same::value; -+ static bool const kIsShiftedVariance = IsShiftedVariance_; -+ -+ // These is mandatory layout. -+ using LayoutInputScaleBias = cutlass::layout::RowMajor; -+ -+ // These are mandatory data types. -+ using ElementLayernormCompute = float; -+ using ElementInputScaleBias = cutlass::half_t; -+ -+ // These are mandatory params required by mainloop fusion -+ using OperatorClass = cutlass::arch::OpClassTensorOp; -+ using ArchTag = cutlass::arch::Sm80; -+ -+ // These are mandatory layouts and data types -+ // that are inheritated from pre-defined params -+ -+ using LayoutSumSqr = LayoutInputScaleBias; -+ using LayoutSum = LayoutInputScaleBias; -+ -+ using ElementMean = ElementInputScaleBias; -+ using ElementVariance = ElementInputScaleBias; -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ using LayoutInputA0 = LayoutInputA0_; -+ using LayoutInputB0 = LayoutInputB0_; -+ using LayoutInputA1 = LayoutOutput_; -+ using LayoutInputB1 = LayoutOutput_; -+ using LayoutOutputC0 = LayoutOutput_; -+ using LayoutOutputC1 = LayoutOutput_; -+ -+ using ElementInputA0 = ElementInputA0_; -+ using ElementInputB0 = ElementInputB0_; -+ using ElementOutputC0 = ElementOutput_; -+ using ElementCompute = ElementCompute_; -+ using ElementInputB1 = ElementInputB0_; -+ -+ using ElementInputA1 = ElementOutputC0; -+ using ElementOutputC1 = ElementOutputC0; -+ -+ using EpilogueFunctorOp = EpilogueFunctorOp_; -+ -+ using TensorRefA = TensorRef; -+ using TensorRefB = TensorRef; -+ using TensorRefC = TensorRef; -+ using TensorVariance = TensorRef; -+ using TensorMean = TensorRef; -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ -+ static int const kStages0 = Stages0; -+ static int const kStages1 = Stages1; -+ -+ using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ using MapArguments = cutlass::gemm::kernel::detail::MapArguments< -+ ElementInputA0, -+ LayoutInputA0, -+ cutlass::ComplexTransform::kNone, -+ 128 / cutlass::sizeof_bits::value, -+ ElementInputB0, -+ LayoutInputB0, -+ cutlass::ComplexTransform::kNone, -+ 128 / cutlass::sizeof_bits::value, -+ LayoutOutputC0, -+ kInternalTranspose -+ >; -+ -+ using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< -+ typename MapArguments::ElementA, -+ typename MapArguments::LayoutA, -+ MapArguments::kAlignmentA, -+ typename MapArguments::ElementB, -+ typename MapArguments::LayoutB, -+ MapArguments::kAlignmentB, -+ ElementOutputC0, -+ typename MapArguments::LayoutC, -+ ElementCompute, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueFunctorOp, -+ SwizzleThreadBlock, -+ kStages0, -+ true, -+ typename cutlass::gemm::device::DefaultGemmConfiguration< -+ OperatorClass, ArchTag, ElementInputA0, ElementInputB0, ElementOutputC0, ElementCompute>::Operator, -+ cutlass::gemm::SharedMemoryClearOption::kNone -+ >::GemmKernel; -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ // Epilogue visitor -+ using EpilogueVisitor = kernel::EpilogueVisitorLayerNorm< -+ ThreadblockShape, -+ DefaultGemmKernel::kThreadCount, -+ typename DefaultGemmKernel::Epilogue::OutputTileIterator, -+ typename DefaultGemmKernel::Epilogue::AccumulatorFragmentIterator::AccumulatorTile, -+ ElementCompute, -+ ElementVariance, -+ ElementMean, -+ ElementLayernormCompute, -+ EpilogueFunctorOp, -+ kIsShiftedVariance -+ >; -+ -+ /// Epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< -+ EpilogueVisitor, -+ typename DefaultGemmKernel::Epilogue -+ >::Epilogue; -+ -+ // GEMM -+ using GemmEpilogueFusion = gemm::kernel::GemmWithEpilogueVisitor< -+ typename DefaultGemmKernel::Mma, -+ Epilogue, -+ SwizzleThreadBlock -+ >; -+ -+ using ApplyFinalReductionKernel = kernel::ApplyFinalReduction< -+ ElementVariance, -+ ElementMean, -+ ElementLayernormCompute, -+ ElementOutputC0, -+ ThreadblockShape, -+ kIsShiftedVariance -+ >; -+ -+using GemmMainloopFusion = typename cutlass::gemm::device::GemmLayernormMainloopFusion< -+ ElementInputA1, LayoutInputA1, -+ ElementInputB1, LayoutInputB1, -+ ElementInputScaleBias, LayoutInputScaleBias, -+ ElementOutputC1, LayoutOutputC1, -+ ElementCompute, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueFunctorOp, -+ SwizzleThreadBlock, -+ kStages1 -+>; -+ -+public: -+ -+ /// Arguments class -+ struct Arguments { -+ -+ typename GemmEpilogueFusion::Arguments gemm0; -+ typename GemmMainloopFusion::Arguments gemm1; -+ typename ApplyFinalReductionKernel::Arguments reduction; -+ cutlass::gemm::GemmCoord extend; -+ -+ // -+ // Methods -+ // -+ Arguments() { } -+ -+ Arguments( -+ cutlass::gemm::GemmCoord problem_size0, -+ cutlass::gemm::GemmCoord problem_size1, -+ ElementInputA0 * ptr_A, -+ ElementInputB0 * ptr_B, -+ ElementOutputC0 * ptr_C, -+ ElementOutputC0 * ptr_D, -+ ElementOutputC0 * ptr_E, -+ ElementOutputC0 * ptr_O, -+ int64_t ldm_A, -+ int64_t ldm_B, -+ int64_t ldm_C, -+ int64_t ldm_D, -+ int64_t ldm_E, -+ int64_t ldm_O, -+ typename EpilogueFunctorOp::Params linear_scaling, -+ TensorVariance ref_Variance_, -+ TensorMean ref_Mean_, -+ TensorVariance ref_Gamma_, -+ TensorMean ref_Beta_, -+ ElementOutputC0 *ptr_Shifted_K = nullptr -+ ): -+ gemm0( -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ {kInternalTranspose ? problem_size0.n() : problem_size0.m(),\ -+ kInternalTranspose ? problem_size0.m() : problem_size0.n(),\ -+ problem_size0.k()}, -+ {kInternalTranspose ? ptr_B : ptr_A, \ -+ kInternalTranspose ? ldm_B : ldm_A}, -+ {kInternalTranspose ? ptr_A : ptr_B, \ -+ kInternalTranspose ? ldm_A : ldm_B}, -+ typename EpilogueVisitor::Arguments( -+ linear_scaling, -+ {ptr_C, ldm_C}, -+ {ptr_D, ldm_D}, -+ ref_Variance_.data(), -+ ref_Mean_.data(), -+ ptr_Shifted_K -+ ) -+ ), -+ reduction( -+ MatrixCoord(kInternalTranspose ? problem_size0.n() : problem_size0.m(),\ -+ kInternalTranspose ? problem_size0.m() : problem_size0.n()), -+ ref_Variance_, -+ ref_Mean_, -+ ptr_Shifted_K -+ ), -+ gemm1( -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size1, -+ 1, -+ linear_scaling, -+ kInternalTranspose ? ptr_E : ptr_D, -+ kInternalTranspose ? ptr_D : ptr_E, -+ ref_Variance_.data(), -+ ref_Mean_.data(), -+ ref_Gamma_.data(), -+ ref_Beta_.data(), -+ ptr_O, -+ ptr_O, -+ problem_size1.m() * problem_size1.k(), -+ problem_size1.n() * problem_size1.k(), -+ problem_size1.n(), -+ problem_size1.n(), -+ problem_size1.k(), -+ problem_size1.k(), -+ problem_size1.m() * problem_size1.n(), -+ problem_size1.m() * problem_size1.n(), -+ kInternalTranspose ? ldm_E : ldm_D, -+ kInternalTranspose ? ldm_D : ldm_D, -+ ref_Variance_.layout().stride(0), -+ ref_Mean_.layout().stride(0), -+ ref_Gamma_.layout().stride(0), -+ ref_Beta_.layout().stride(0), -+ ldm_O, -+ ldm_O -+ ), -+ extend(problem_size0) -+ { -+ -+ } -+ }; -+ -+ struct Params { -+ -+ typename GemmEpilogueFusion::Params gemm0; -+ typename ApplyFinalReductionKernel::Params reduction; -+ MatrixCoord extend; -+ // -+ // Methods -+ // -+ Params() { } -+ -+ Params(Arguments const &args): -+ gemm0(args.gemm0), -+ reduction(args.reduction), -+ extend(MatrixCoord(args.extend.m(), args.extend.n())) -+ { -+ -+ } -+ }; -+ -+public: -+ -+ // Gemm -+ -+ -+ // -+ // Methods -+ // -+ -+private: -+ -+ Params params_; -+ GemmMainloopFusion gemm_fusion_op; -+ -+public: -+ -+ /// Ctor -+ GemmLayernorm() { -+ -+ } -+ -+ /// Initialize -+ Status initialize(Arguments const &args) { -+ -+ params_ = Params(args); -+ cutlass::Status status; -+ size_t workspace_size = gemm_fusion_op.get_workspace_size(args.gemm1); -+ cutlass::device_memory::allocation workspace(workspace_size); -+ status = gemm_fusion_op.can_implement(args.gemm1); -+ CUTLASS_CHECK(status); -+ -+ status = gemm_fusion_op.initialize(args.gemm1, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ return cutlass::Status::kSuccess; -+ } -+ -+ /// Run -+ Status run(cudaStream_t stream) { -+ -+ // -+ // Launch the GEMM + layernorm kernel -+ // -+ -+ dim3 gemm_grid = SwizzleThreadBlock().get_grid_shape(params_.gemm0.grid_tiled_shape); -+ dim3 gemm_block(GemmEpilogueFusion::kThreadCount, 1, 1); -+ -+ int gemm_smem_size = int(sizeof(typename GemmEpilogueFusion::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_.gemm0); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ // -+ // Launch the ApplyFinalReductionKernel -+ // -+ -+ // always performs reduction from leading dimension -+ int leading_dim_0 = kInternalTranspose ? params_.extend.row() : params_.extend.column(); -+ int leading_dim_1 = kInternalTranspose ? params_.extend.column() : params_.extend.row(); -+ -+ int thread_per_block = 128; -+ int block_per_row = (leading_dim_1 + thread_per_block - 1) / thread_per_block; -+ if (block_per_row < 4) { -+ thread_per_block = 32; -+ block_per_row = (leading_dim_1 + thread_per_block - 1) / thread_per_block; -+ } -+ -+ dim3 final_reduction_block(thread_per_block); -+ dim3 final_reduction_grid(block_per_row); -+ -+ Kernel<<< -+ final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream -+ >>>(params_.reduction); -+ -+ result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ // -+ // Launch the GEMM + mainloop fusion kernel -+ // -+ -+ cutlass::Status status = gemm_fusion_op(); -+ CUTLASS_CHECK(status); -+ -+ return cutlass::Status::kSuccess; -+ } -+ -+ /// Function call operator -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/38_syr2k_grouped/syr2k_grouped.cu b/3rdparty/cutlass/examples/38_syr2k_grouped/syr2k_grouped.cu -new file mode 100644 -index 0000000..d8adb9c ---- /dev/null -+++ b/3rdparty/cutlass/examples/38_syr2k_grouped/syr2k_grouped.cu -@@ -0,0 +1,1466 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief SYR2K Grouped Example. -+ -+ This workload computes a batch of SYR2K operations with distinct problem sizes. This example closely -+ follows 24_gemm_grouped. -+ -+ Examples: -+ -+ # Runs a grouped SYR2K with 100 random problem sizes -+ $ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 -+ -+ # Runs a grouped SYR2K with 100 random problem sizes (with SYR2K-K dimension equal to 1024) -+ $ ./examples/38_syr2k_grouped/24_gemm_grouped --groups=100 --k=1024 --verbose=true -+ -+ # Runs a grouped SYR2K that is equivalent to a batched SYR2K -+ $ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 --n=1024 --k=1024 --verbose=true -+ -+ # Execute grouped SYR2K and profile with NSight -+ $ nv-nsight-cu-cli ./examples/38_syr2k_grouped/38_syr2k_grouped --n=256 --k=256 --verbose=true \ -+ --iterations=1 --reference-check=false -+ -+*/ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/blas3.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/device_kernel.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double initialization_time_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double initialization_time_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), initialization_time_ms(initialization_time_ms), gflops(gflops), -+ status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool error; -+ bool reference_check; -+ bool profile_initialization; -+ bool sort_problems; -+ -+ std::vector problem_sizes; -+ -+ int alignment; -+ int problem_count; -+ int iterations; -+ int cuda_streams; -+ bool verbose; -+ float alpha; -+ float beta; -+ std::string benchmark_path; -+ -+ std::string output_tag; -+ std::ofstream output_file; -+ -+ using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; -+ std::vector scheduler_modes; -+ -+ std::unordered_map -+ str_to_scheduler_mode = { -+ {"kDeviceOnly", GroupScheduleMode::kDeviceOnly}, -+ {"kHostPrecompute", GroupScheduleMode::kHostPrecompute} -+ }; -+ -+ struct GroupScheduleModeHash { -+ size_t operator()(GroupScheduleMode m) const { -+ return static_cast(m); -+ } -+ }; -+ -+ std::unordered_map -+ scheduler_mode_to_str = { -+ {GroupScheduleMode::kDeviceOnly, "kDeviceOnly"}, -+ {GroupScheduleMode::kHostPrecompute, "kHostPrecompute"} -+ }; -+ -+ std::vector all_scheduler_modes = {GroupScheduleMode::kDeviceOnly, GroupScheduleMode::kHostPrecompute}; -+ -+ // -+ // Methods -+ // -+ -+ Options(): -+ help(false), -+ error(false), -+ alignment(8), -+ reference_check(true), -+ profile_initialization(false), -+ sort_problems(false), -+ problem_count(5), -+ iterations(20), -+ cuda_streams(0), -+ verbose(false), -+ alpha(1), -+ beta(), -+ scheduler_modes({GroupScheduleMode::kDeviceOnly}) -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("alignment", alignment, 8); -+ cmd.get_cmd_line_argument("groups", problem_count, 5); -+ cmd.get_cmd_line_argument("alpha", alpha, 1.0f); -+ cmd.get_cmd_line_argument("beta", beta, 0.0f); -+ cmd.get_cmd_line_argument("iterations", iterations, 20); -+ cmd.get_cmd_line_argument("streams", cuda_streams, 0); -+ cmd.get_cmd_line_argument("verbose", verbose, false); -+ cmd.get_cmd_line_argument("reference-check", reference_check, true); -+ cmd.get_cmd_line_argument("profile-initialization", profile_initialization, false); -+ cmd.get_cmd_line_argument("sort-problems", sort_problems, false); -+ cmd.get_cmd_line_argument("benchmark", benchmark_path); -+ -+ std::vector scheduler_mode_strs; -+ cmd.get_cmd_line_arguments("scheduler-modes", scheduler_mode_strs); -+ -+ if (!scheduler_mode_strs.empty()) { -+ scheduler_modes.clear(); -+ if (scheduler_mode_strs.size() == 1 && scheduler_mode_strs[0] == "all") { -+ scheduler_modes = all_scheduler_modes; -+ } else { -+ for (std::string precomp_str : scheduler_mode_strs) { -+ auto it = str_to_scheduler_mode.find(precomp_str); -+ if (it != str_to_scheduler_mode.end()) { -+ scheduler_modes.push_back(it->second); -+ } else if (precomp_str == "all") { -+ std::cerr << "Flag --scheduler-modes=all must not contain other scheduler modes in list." << std::endl; -+ error = true; -+ return; -+ } else { -+ std::cerr << "Unrecognized scheduler mode '" << precomp_str << "'" << std::endl; -+ error = true; -+ return; -+ } -+ } -+ } -+ } -+ -+ std::string output_path; -+ cmd.get_cmd_line_argument("tag", output_tag); -+ cmd.get_cmd_line_argument("output_file", output_path); -+ -+ if (!output_path.empty()) { -+ -+ std::ios_base::openmode open_mode = std::ios_base::out; -+ -+ std::ifstream input_file(output_path.c_str()); -+ -+ if (input_file.good()) { -+ open_mode = std::ios_base::app; -+ input_file.close(); -+ } -+ -+ output_file.open(output_path.c_str(), open_mode); -+ -+ if (output_file.good() && open_mode != std::ios_base::app) { -+ output_file << "Tag,Provider,Kind,Groups,Runtime,GFLOPs\n"; -+ } -+ } -+ -+ // Decide how to initialize the problems -+ if (!benchmark_path.empty()) { -+ if (!benchmark_problems()) { -+ error = true; -+ problem_sizes.clear(); -+ return; -+ } -+ } -+ else { -+ randomize_problems(cmd); -+ } -+ } -+ -+ void randomize_problems(cutlass::CommandLine &cmd) { -+ -+ // -+ // For now, randomly choose the problem sizes. -+ // -+ -+ int cmd_line_m = -1; -+ int cmd_line_n = -1; -+ int cmd_line_k = -1; -+ -+ cmd.get_cmd_line_argument("m", cmd_line_m); -+ cmd.get_cmd_line_argument("n", cmd_line_n); -+ cmd.get_cmd_line_argument("k", cmd_line_k); -+ -+ // SYR2K is defined via only N and K. -+ if (cmd_line_m != -1) { -+ std::cerr << "Parameter M is ignored for SYR2K\n"; -+ error = true; -+ return; -+ } -+ -+ problem_sizes.reserve(problem_count); -+ -+ for (int i = 0; i < problem_count; ++i) { -+ int n = cmd_line_n; -+ int k = cmd_line_k; -+ -+ if (n < 1) { -+ n = alignment * ((rand() % 256) + 1); -+ } -+ -+ if (k < 1) { -+ k = alignment * ((rand() % 256) + 1); -+ } -+ -+ // SYR2K is defined only in terms of N and K. Replicate N into -+ // the SYR2K-N dimension. -+ cutlass::gemm::GemmCoord problem(n, n, k); -+ -+ problem_sizes.push_back(problem); -+ } -+ } -+ -+ /// Load a benchmark -+ bool benchmark_problems() { -+ std::ifstream file(benchmark_path); -+ if (!file.good()) { -+ return false; -+ } -+ -+ while (file.good()) { -+ -+ int idx = -1; -+ std::string extent_str; -+ -+ file >> idx >> extent_str; -+ -+ if (idx < 0 || extent_str.empty()) { -+ break; -+ } -+ -+ cutlass::gemm::GemmCoord extent; -+ std::vector tokens; -+ -+ cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); -+ -+ for (int i = 0; i < int(tokens.size()); ++i) { -+ int x = std::atoi(tokens.at(i).c_str()); -+ -+ // round up -+ if (x % alignment) { -+ x += (alignment - (x % alignment)); -+ } -+ -+ extent.at(i) = x; -+ } -+ -+ if (extent.product()) { -+ problem_sizes.push_back(extent); -+ } -+ } -+ -+ problem_count = int(problem_sizes.size()); -+ return true; -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "38_syr2k_grouped\n\n" -+ << " This example profiles the performance of a 'grouped' SYR2K kernel. This example closely follows 24_gemm_grouped\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --benchmark= Executes a benchmark problem size.\n" -+ << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" -+ << " --tag= String tag to prepend to the CSV file.\n" -+ << " --groups= Number of individual SYR2K problems (default: --groups=15)\n" -+ << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" -+ << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" -+ << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n" -+ << " --scheduler-modes= List of scheduler modes to be profile for grouped GEMM scheduler (default: --scheduler_modes=kDeviceOnly)\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --reference-check= If true, performs reference check.\n" -+ << " --verbose= If true, prints problem sizes and batching structure.\n" -+ << " --profile-initialization= If true, profiles the device-level kernel's initialization.\n" -+ << " --sort-problems= If true, sorts problem sizes in descending order of SYR2K-K dimension.\n"; -+ -+ out << "\n\nExamples:\n\n" -+ -+ << "# Runs a grouped SYR2K with 100 random problem sizes\n" -+ << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100\n\n" -+ -+ << "# Runs a grouped SYR2K with 100 random problem sizes (with K dimension equal to 1024)\n" -+ << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 --k=1024 --verbose=true\n\n" -+ -+ << "# Runs a grouped SYR2K that is equivalent to a batched SYR2K\n" -+ << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 --n=1024 --k=1024 --verbose=true\n\n" -+ -+ << "# Runs a grouped SYR2K with each different scheduler mode\n" -+ << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --scheduler-modes=all\n\n" -+ -+ << "# Runs a grouped SYR2K with each different scheduler mode and profiles host-side initialization time\n" -+ << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --scheduler-modes=all --profile-initialization=true\n\n" -+ -+ << "# Runs a grouped SYR2K problem given an externally supplied benchmark file. This is a text file in which\n" -+ << "# Each line contains a unique group index and an MxNxK triple indicating problemsize. NOTE that the\n" -+ << "# GEMM-M and GEMM-N dimensions must match.\n" -+ << "#\n" -+ << "# For example, assume the following are the contents of 'problems.txt'\n" -+ << "#\n" -+ << "# 0 256x256x520\n" -+ << "# 1 264x264x1024\n" -+ << "# 2 48x48x1024\n" -+ << "#\n" -+ << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --benchmark=problems.txt\n\n" -+ -+ << "# Execute Grouped SYR2K and profile with NSight\n" -+ << "$ nv-nsight-cu-cli ./examples/24_gemm_grouped/24_gemm_grouped --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = int64_t(); -+ -+ for (auto const & problem : problem_sizes) { -+ fmas += problem.product(); -+ } -+ -+ // SYR2K is defined as (A x BT) + (B x AT), so the number of FMAs is twice that in a GEMM -+ fmas *= 2; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class BaseTestbed { -+public: -+ // -+ // Type definitions -+ // -+ -+ using ElementA = typename Rank2K::ElementA; -+ using ElementB = typename Rank2K::ElementB; -+ using ElementC = typename Rank2K::ElementC; -+ using ElementAccumulator = typename Rank2K::ElementAccumulator; -+ -+ using EpilogueOutputOp = typename Rank2K::Rank2Kkernel::Epilogue::OutputOp; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using LayoutA = typename Rank2K::LayoutA; -+ using LayoutB = typename Rank2K::LayoutB; -+ using LayoutC = typename Rank2K::LayoutC; -+ -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+ // -+ // Data members -+ // -+ -+ Options & options; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint32_t seed; -+ -+ cutlass::DeviceAllocation problem_sizes_device; -+ -+ std::vector offset_A; -+ std::vector offset_B; -+ std::vector offset_C; -+ std::vector offset_D; -+ -+ std::vector lda_host; -+ std::vector ldb_host; -+ std::vector ldc_host; -+ std::vector ldd_host; -+ -+ cutlass::DeviceAllocation lda; -+ cutlass::DeviceAllocation ldb; -+ cutlass::DeviceAllocation ldc; -+ cutlass::DeviceAllocation ldd; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ -+ cutlass::DeviceAllocation ptr_A; -+ cutlass::DeviceAllocation ptr_B; -+ cutlass::DeviceAllocation ptr_C; -+ cutlass::DeviceAllocation ptr_D; -+ -+ BaseTestbed( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ int problem_count() const { -+ return options.problem_count; -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ Element *ptr, -+ size_t capacity, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ ptr, capacity, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::device::BlockFillRandomGaussian( -+ ptr, capacity, seed, Element(), Element(0.5f)); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ // Fill with increasing elements -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(1), Element()); -+ } -+ else { -+ -+ // Fill with all 1s -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(), Element(1)); -+ } -+ } -+ -+ /// Allocates device-side data -+ void allocate() { -+ int64_t total_elements_A = 0; -+ int64_t total_elements_B = 0; -+ int64_t total_elements_C = 0; -+ int64_t total_elements_D = 0; -+ -+ lda_host.resize(problem_count()); -+ ldb_host.resize(problem_count()); -+ ldc_host.resize(problem_count()); -+ ldd_host.resize(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ -+ auto problem = options.problem_sizes.at(i); -+ -+ lda_host.at(i) = LayoutA::packed({problem.n(), problem.k()}).stride(0); -+ ldb_host.at(i) = LayoutB::packed({problem.n(), problem.k()}).stride(0); -+ ldc_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); -+ ldd_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); -+ -+ offset_A.push_back(total_elements_A); -+ offset_B.push_back(total_elements_B); -+ offset_C.push_back(total_elements_C); -+ offset_D.push_back(total_elements_D); -+ -+ int64_t elements_A = problem.n() * problem.k(); -+ int64_t elements_B = problem.n() * problem.k(); -+ int64_t elements_C = problem.n() * problem.n(); -+ int64_t elements_D = problem.n() * problem.n(); -+ -+ total_elements_A += elements_A; -+ total_elements_B += elements_B; -+ total_elements_C += elements_C; -+ total_elements_D += elements_D; -+ } -+ -+ lda.reset(problem_count()); -+ ldb.reset(problem_count()); -+ ldc.reset(problem_count()); -+ ldd.reset(problem_count()); -+ -+ block_A.reset(total_elements_A); -+ block_B.reset(total_elements_B); -+ block_C.reset(total_elements_C); -+ block_D.reset(total_elements_D); -+ } -+ -+ /// Initializes device-side data -+ void initialize() { -+ problem_sizes_device.reset(problem_count()); -+ problem_sizes_device.copy_from_host(options.problem_sizes.data()); -+ -+ lda.copy_from_host(lda_host.data()); -+ ldb.copy_from_host(ldb_host.data()); -+ ldc.copy_from_host(ldc_host.data()); -+ ldd.copy_from_host(ldd_host.data()); -+ -+ // -+ // Assign pointers -+ // -+ -+ std::vector ptr_A_host(problem_count()); -+ std::vector ptr_B_host(problem_count()); -+ std::vector ptr_C_host(problem_count()); -+ std::vector ptr_D_host(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ ptr_A_host.at(i) = block_A.get() + offset_A.at(i); -+ ptr_B_host.at(i) = block_B.get() + offset_B.at(i); -+ ptr_C_host.at(i) = block_C.get() + offset_C.at(i); -+ ptr_D_host.at(i) = block_D.get() + offset_D.at(i); -+ } -+ -+ ptr_A.reset(problem_count()); -+ ptr_A.copy_from_host(ptr_A_host.data()); -+ -+ ptr_B.reset(problem_count()); -+ ptr_B.copy_from_host(ptr_B_host.data()); -+ -+ ptr_C.reset(problem_count()); -+ ptr_C.copy_from_host(ptr_C_host.data()); -+ -+ ptr_D.reset(problem_count()); -+ ptr_D.copy_from_host(ptr_D_host.data()); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ initialize_tensor(block_A.get(), block_A.size(), init_A, seed * 2021); -+ initialize_tensor(block_B.get(), block_B.size(), init_B, seed * 2022); -+ initialize_tensor(block_C.get(), block_C.size(), init_C, seed * 2023); -+ -+ cutlass::reference::device::BlockFillSequential( -+ block_D.get(), block_D.size(), ElementC(), ElementC()); -+ } -+ -+ /// Verifies the result is a SYR2K -+ bool verify() { -+ -+ bool passed = true; -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ cutlass::gemm::GemmCoord problem = options.problem_sizes.at(i); -+ -+ LayoutA layout_A(lda_host.at(i)); -+ LayoutB layout_B(ldb_host.at(i)); -+ LayoutC layout_C(ldc_host.at(i)); -+ LayoutC layout_D(ldd_host.at(i)); -+ -+ cutlass::HostTensor host_A( -+ typename LayoutA::TensorCoord(problem.n(), problem.k()), /*device_backed=*/false); -+ cutlass::HostTensor host_B( -+ typename LayoutB::TensorCoord(problem.n(), problem.k()), /*device_backed=*/false); -+ cutlass::HostTensor host_C( -+ typename LayoutC::TensorCoord(problem.n(), problem.n()), /*device_backed=*/false); -+ cutlass::HostTensor host_D( -+ typename LayoutC::TensorCoord(problem.n(), problem.n()), /*device_backed=*/false); -+ -+ cutlass::device_memory::copy_to_host(host_A.host_data(), block_A.get() + offset_A.at(i), problem.n() * problem.k()); -+ cutlass::device_memory::copy_to_host(host_B.host_data(), block_B.get() + offset_B.at(i), problem.n() * problem.k()); -+ cutlass::device_memory::copy_to_host(host_C.host_data(), block_C.get() + offset_C.at(i), problem.n() * problem.n()); -+ cutlass::reference::host::BlockFillSequential( -+ host_D.host_data(), problem.n() * problem.n(), ElementC(), ElementC()); -+ -+ MatrixCoord extent_C{problem.n(), problem.n()}; -+ -+ // Reference Rank2K -+ cutlass::reference::host::Rank2KComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementC, ElementAccumulator -+ >( -+ problem, -+ (double)options.alpha, -+ host_A.host_view(), -+ Rank2K::kTransformA, -+ host_B.host_view(), -+ Rank2K::kTransformB, -+ (double)options.beta, -+ host_C.host_view(), -+ host_D.host_view(), -+ ElementAccumulator(0), -+ Rank2K::kFillModeC, -+ Rank2K::kBlasMode -+ ); -+ -+ // Copy to host memory -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); -+ -+ cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); -+ cutlass::TensorView view_Ref = host_D.host_view(); -+ -+ // Reference check -+ passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; -+ return passed; -+ } -+ } -+ -+ return passed; -+ } -+}; -+ -+template -+class TestbedConventional : BaseTestbed { -+public: -+ TestbedConventional( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} -+ -+ /// Verbose printing of problem sizes -+ void print_problem_sizes() { -+ -+ // Print groups -+ std::cout << this->problem_count() << " groups:\n"; -+ -+ int32_t idx = 0; -+ int64_t total_tiles = 0; -+ -+ for (auto const & problem : this->options.problem_sizes) { -+ int tiles = -+ ((problem.m() + Rank2K::ThreadblockShape::kM - 1) / Rank2K::ThreadblockShape::kM) * -+ ((problem.n() + Rank2K::ThreadblockShape::kN - 1) / Rank2K::ThreadblockShape::kN); -+ -+ total_tiles += tiles; -+ -+ std::cout << " [" << idx << "]: " -+ << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() -+ << " (" << tiles << " threadblock tiles)" << "\n"; -+ -+ ++idx; -+ } -+ std::cout << std::endl; -+ } -+ -+ /// Executes a conventional SYR2K kernel. -+ Result profile() { -+ std::cout << "Conventional Rank2K:\n" -+ << "====================================================" << std::endl; -+ -+ Result result; -+ result.passed = false; -+ -+ // Initialize the problem -+ this->allocate(); -+ this->initialize(); -+ -+ if (this->options.verbose) { -+ print_problem_sizes(); -+ } -+ -+ // -+ // Create CUDA streams to maximize concurrency of SYR2K kernels -+ // -+ int32_t effective_streams = (this->options.cuda_streams ? this->options.cuda_streams : 1); -+ std::vector cuda_streams; -+ char const *provider = "CUTLASS"; -+ -+ // -+ // Warmup run -+ // -+ -+ if (this->options.cuda_streams) { -+ for (int i = 0; i < this->options.cuda_streams; ++i) { -+ cudaStream_t stream; -+ -+ result.error = cudaStreamCreate(&stream); -+ if (result.error != cudaSuccess) { -+ std::cerr << "Failed to create CUDA stream." << std::endl; -+ return result; -+ } -+ cuda_streams.push_back(stream); -+ } -+ } -+ else { -+ cuda_streams.push_back(nullptr); -+ } -+ -+ // Use 'D' for the in/out workspace -+ this->block_D.copy_from_device(this->block_C.get()); -+ -+ for (int i = 0; i < this->options.problem_sizes.size(); ++i) { -+ cutlass::gemm::GemmCoord const & problem = this->options.problem_sizes[i]; -+ int32_t batch_count = 1; -+ int64_t lda = this->lda_host.at(i); -+ int64_t ldb = this->ldb_host.at(i); -+ int64_t ldc = this->ldc_host.at(i); -+ typename Rank2K::ElementA* ptrA = this->block_A.get() + this->offset_A.at(i); -+ typename Rank2K::ElementB* ptrB = this->block_B.get() + this->offset_B.at(i); -+ typename Rank2K::ElementC* ptrC = this->block_C.get() + this->offset_C.at(i); -+ typename Rank2K::ElementC* ptrD = this->block_D.get() + this->offset_D.at(i); -+ -+ // -+ // Initialize the CUTLASS SYR2K operator -+ // -+ -+ // Configure the SYR2K arguments -+ typename Rank2K::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); -+ -+ typename Rank2K::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem, -+ batch_count, -+ epilogue_op, -+ (void const *)ptrA, -+ (void const *)ptrB, -+ (void const *)ptrC, -+ (void *)ptrD, -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(lda), -+ int64_t(ldb), -+ int64_t(ldc), -+ int64_t(ldc) -+ }; -+ -+ Rank2K rank2k_op; -+ -+ cutlass::Status status = rank2k_op.initialize(arguments); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ -+ status = rank2k_op(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // Record an event at the start of a series of SYR2K operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ int last_stream_idx = 0; -+ -+ for (int iter = 0; iter < this->options.iterations; ++iter) { -+ for (int i = 0; i < this->options.problem_sizes.size(); ++i) { -+ cutlass::gemm::GemmCoord const & problem = this->options.problem_sizes[i]; -+ int32_t batch_count = 1; -+ int64_t lda = this->lda_host.at(i); -+ int64_t ldb = this->ldb_host.at(i); -+ int64_t ldc = this->ldc_host.at(i); -+ typename Rank2K::ElementA* ptrA = this->block_A.get() + this->offset_A.at(i); -+ typename Rank2K::ElementB* ptrB = this->block_B.get() + this->offset_B.at(i); -+ typename Rank2K::ElementC* ptrC = this->block_C.get() + this->offset_C.at(i); -+ typename Rank2K::ElementC* ptrD = this->block_D.get() + this->offset_D.at(i); -+ -+ last_stream_idx = (i % effective_streams); -+ -+ // -+ // Initialize the CUTLASS SYR2K operator -+ // -+ -+ // Configure the SYR2K arguments -+ typename Rank2K::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); -+ -+ typename Rank2K::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem, -+ batch_count, -+ epilogue_op, -+ (void const *)ptrA, -+ (void const *)ptrB, -+ (void const *)ptrC, -+ (void *)ptrD, -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(lda), -+ int64_t(ldb), -+ int64_t(ldc), -+ int64_t(ldc) -+ }; -+ -+ Rank2K rank2k_op; -+ -+ cutlass::Status status = rank2k_op.initialize(arguments); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ -+ status = rank2k_op(cuda_streams[last_stream_idx]); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ } -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the SYR2K operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Wait for work to be completed -+ // -+ -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(this->options.iterations); -+ result.gflops = this->options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ for (auto stream : cuda_streams) { -+ if (stream) { -+ (void)cudaStreamDestroy(stream); -+ } -+ } -+ -+ std::cout << " " << this->options.problem_sizes.size() << " conventional Rank2Ks launched" << std::endl; -+ std::cout << std::endl; -+ std::cout << " " << "Conventional Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "Conventional GFLOPS: " << result.gflops << std::endl; -+ -+ if (this->options.output_file.good()) { -+ this->options.output_file << this->options.output_tag << "," << provider << ",conventional," -+ << this->problem_count() << "," << result.runtime_ms << "," << result.gflops << std::endl; -+ } -+ -+ result.passed = true; -+ return result; -+ } -+}; -+ -+template -+class TestbedGrouped : BaseTestbed { -+public: -+ TestbedGrouped( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ) : BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} -+ -+ // Redefine Rank2K with different GroupScheduleMode_ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ typename Rank2K_::ElementA, typename Rank2K_::LayoutA, Rank2K_::kTransformA, Rank2K_::kAlignmentA, -+ typename Rank2K_::ElementB, typename Rank2K_::LayoutB, Rank2K_::kTransformB, Rank2K_::kAlignmentB, -+ typename Rank2K_::ElementC, typename Rank2K_::LayoutC, Rank2K_::kFillModeC, -+ typename Rank2K_::ElementAccumulator, -+ typename Rank2K_::OperatorClass, -+ typename Rank2K_::ArchTag, -+ typename Rank2K_::ThreadblockShape, -+ typename Rank2K_::WarpShape, -+ typename Rank2K_::InstructionShape, -+ typename Rank2K_::EpilogueOutputOp, -+ typename Rank2K_::ThreadblockSwizzle, -+ Rank2K_::kStages, -+ typename Rank2K_::Operator::ArchMmaOperator::Operator, -+ Rank2K_::kBlasMode, -+ GroupScheduleMode_>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ /// Verbose printing of problem sizes -+ void print_problem_sizes() { -+ -+ // Print groups -+ std::cout << this->problem_count() << " groups:\n"; -+ -+ int32_t idx = 0; -+ int64_t total_tiles = 0; -+ -+ for (auto const & problem : this->options.problem_sizes) { -+ int tiles = Rank2K::problem_tile_count(problem); -+ total_tiles += tiles; -+ -+ std::cout << " [" << idx << "]: " -+ << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() -+ << " (" << tiles << " threadblock tiles)" << "\n"; -+ -+ ++idx; -+ } -+ std::cout << std::endl; -+ } -+ -+ /// Sort problems in descending order of problem-K dimension -+ void sort_problems() { -+ Rank2K::sort_problems(this->options.problem_count, -+ this->options.problem_sizes.data(), -+ this->lda_host.data(), -+ this->ldb_host.data(), -+ this->ldc_host.data(), -+ this->ldd_host.data(), -+ this->offset_A.data(), -+ this->offset_B.data(), -+ this->offset_C.data(), -+ this->offset_D.data()); -+ } -+ -+ /// Executes a grouped kernel and measures runtime. -+ Result profile() { -+ std::string sched_mode = this->options.scheduler_mode_to_str.find(GroupScheduleMode_)->second; -+ std::cout << std::endl; -+ std::cout << "Grouped Rank2K (CUTLASS) with mode " << sched_mode << ":\n" -+ << "====================================================" << std::endl; -+ -+ Result result; -+ -+ int threadblock_count = Rank2K::sufficient(this->options.problem_sizes.data(), this->options.problem_count); -+ -+ // Early exit -+ if (!threadblock_count) { -+ std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped SYR2K kernel." << std::endl; -+ return result; -+ } -+ -+ result.passed = false; -+ -+ // Initialize the problem -+ this->allocate(); -+ if (this->options.sort_problems) { -+ sort_problems(); -+ } -+ this->initialize(); -+ -+ if (this->options.verbose) { -+ print_problem_sizes(); -+ } -+ -+ // Configure the Rank2K arguments -+ typename Rank2K::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); -+ -+ // Configure Rank2K arguments -+ typename Rank2K::Arguments args( -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ this->problem_sizes_device.get(), -+ this->problem_count(), -+ threadblock_count, -+ epilogue_op, -+ this->ptr_A.get(), -+ this->ptr_B.get(), -+ this->ptr_C.get(), -+ this->ptr_D.get(), -+ this->lda.get(), -+ this->ldb.get(), -+ this->ldc.get(), -+ this->ldd.get(), -+ this->options.problem_sizes.data() -+ ); -+ -+ // Initialize the Rank2K object -+ Rank2K rank2k; -+ size_t workspace_size = rank2k.get_workspace_size(args); -+ cutlass::DeviceAllocation workspace(workspace_size); -+ -+ result.status = rank2k.initialize(args, workspace.get()); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize CUTLASS Grouped Rank2K kernel." << std::endl; -+ return result; -+ } -+ -+ // Run the grouped Rank2K object -+ result.status = rank2k.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Grouped Rank2K kernel." << std::endl; -+ return result; -+ } -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ if (this->options.reference_check) { -+ result.passed = this->verify(); -+ } -+ -+ // -+ // Warm-up run of the grouped Rank2K object -+ // -+ result.status = rank2k.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Grouped Rank2K kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of SYR2K operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < this->options.iterations; ++iter) { -+ rank2k(); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the Rank2K operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(this->options.iterations); -+ result.gflops = this->options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ // Optionally profile initialization -+ if (this->options.profile_initialization) { -+ // Warm up -+ rank2k.initialize(args, workspace.get()); -+ -+ auto start_time = std::chrono::high_resolution_clock::now(); -+ for (int32_t i = 0; i < this->options.iterations; ++i) { -+ rank2k.initialize(args, workspace.get()); -+ } -+ auto end_time = std::chrono::high_resolution_clock::now(); -+ -+ std::chrono::duration duration = end_time - start_time; -+ duration /= double(this->options.iterations); -+ result.initialization_time_ms = duration.count(); -+ } -+ -+ int64_t total_tiles = Rank2K::group_tile_count(args); -+ std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; -+ -+ std::cout << std::endl; -+ std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; -+ if (this->options.profile_initialization) { -+ std::cout << " " << "Init Runtime: " << result.initialization_time_ms << " ms" << std::endl; -+ } -+ -+ if (this->options.output_file.good()) { -+ this->options.output_file << this->options.output_tag << ",CUTLASS,grouped-" << sched_mode << "," -+ << this->problem_count() << "," << result.runtime_ms << "," << result.gflops << std::endl; -+ } -+ -+ std::cout << "\nPassed\n"; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { -+ -+ // -+ // This example requires an NVIDIA Ampere-architecture GPU. -+ // -+ -+ std::cout -+ << "CUTLASS's Grouped Rank2K example requires a GPU of NVIDIA's Ampere Architecture or " -+ << "later (compute capability 80 or greater).\n"; -+ -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.error) { -+ std::cerr << "Aborting execution." << std::endl; -+ return -1; -+ } -+ -+ // -+ // Define the Grouped and Conventional Rank2K types -+ // -+ -+ using ElementA = double; -+ using ElementB = double; -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ const cutlass::FillMode kFillModeC = cutlass::FillMode::kLower; -+ const int kAlignmentA = 1; -+ const int kAlignmentB = 1; -+ const cutlass::ComplexTransform kTransformA = cutlass::ComplexTransform::kNone; -+ const cutlass::ComplexTransform kTransformB = cutlass::ComplexTransform::kNone; -+ -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using OperatorClass = cutlass::arch::OpClassTensorOp; -+ using ArchTag = cutlass::arch::Sm80; -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>; -+ -+ // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. -+ // This parameter is passed in at present to match the APIs of other kernels. The parameter -+ // is unused within the kernel. -+ using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+ const int kStages = 4; -+ const bool kSplitKSerial = false; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+ const cutlass::BlasMode kBlasMode = cutlass::BlasMode::kSymmetric; -+ -+ // Define a grouped Rank2K kernel with all template parameters set except -+ // for scheduling mode. This will be used as the template for all scheduling -+ // modes executed. -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, kTransformA, kAlignmentA, -+ ElementB, LayoutB, kTransformB, kAlignmentB, -+ ElementOutput, LayoutC, kFillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ Operator, -+ kBlasMode>::Rank2Kkernel; -+ -+ using Rank2KGrouped = cutlass::gemm::device::Rank2KGrouped; -+ -+ // Rank2k operator -+ using Rank2KConventional = cutlass::gemm::device::Rank2K< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementOutput, LayoutC, kFillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kAlignmentA, -+ kAlignmentB, -+ kSplitKSerial, -+ Operator, -+ kTransformA, -+ kTransformB, -+ kBlasMode -+ >; -+ -+ // -+ // Profile it -+ // -+ -+ TestbedConventional testbed(options); -+ -+ Result result = testbed.profile(); -+ if (!result.passed) { -+ std::cout << "Profiling CUTLASS conventional Rank2K has failed.\n"; -+ std::cout << "\nFailed\n"; -+ return -1; -+ } -+ -+ using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; -+ for (GroupScheduleMode mode : options.scheduler_modes) { -+ Result result; -+ switch (mode) { -+ case GroupScheduleMode::kDeviceOnly: -+ { -+ TestbedGrouped runner(options); -+ result = runner.profile(); -+ break; -+ } -+ case GroupScheduleMode::kHostPrecompute: -+ { -+ TestbedGrouped runner(options); -+ result = runner.profile(); -+ break; -+ } -+ } -+ -+ if (result.error != cudaSuccess) { -+ return 1; -+ } -+ -+ // Override verbose flag to avoid printing duplicate information for each scheduling mode -+ options.verbose = false; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/39_gemm_permute/gemm_permute.cu b/3rdparty/cutlass/examples/39_gemm_permute/gemm_permute.cu -new file mode 100644 -index 0000000..ed3e399 ---- /dev/null -+++ b/3rdparty/cutlass/examples/39_gemm_permute/gemm_permute.cu -@@ -0,0 +1,1126 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief GEMM Permute Example. -+ -+ This example computes batched GEMM operations with output results permuted as reshaped tensors. -+ -+ We provide layout plugin as a flexible tool for users to add any customized output tensor permute operation, -+ or any other generalized global memory writeout address computation. To add a customized layout, add new class -+ in include/cutlass/layout/permute.h -+ -+ In this example, we used Tensor4DPermuteBMM0213 layout to perform Batched GEMM with permute([0, 2, 1, 3]) on BMM -+ whole output tensor, and used Tensor5DPermute20314 layout to perform Normal GEMM with permute([2, 0, 3, 1, 4]) on -+ output matrix. The address computations are performed in compute(col_init, row_init, stride_init, -+ BMM_batch_idx) with {col_permute, row_permute and stride_permute} as new addresses after permute op. -+ (check include/cutlass/layout/permute.h) -+ -+ Tips: -+ -+ 1) Make sure to set batch_stride_D to zero for BMM permute; Also the BMM GEMM should be in mode -+ cutlass::gemm::GemmUniversalMode::kBatched instead of kArray -+ -+ 2) When the last dimension is touched in permute op (for example permute([0, 2, 3, 1])), AlignmentC should -+ be set to 1. If the last dimension is untouched, one can set AlignmentC to be larger like 8 in our example. -+ As a result, permute op without touching the last dimension is recommended to obtain the best performance gain. -+ -+ Examples: -+ -+ # Runs a batched GEMM with 96 batches -+ $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 -+ -+ # Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024) -+ $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true -+ -+ # Execute batched GEMM and profile with NSight -+ $ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false -+ -+*/ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+ -+#include "cutlass/layout/permute.h" -+ -+/// Tensor4DPermuteBMM0213 ---> -+/// Permute layout function for 4-D permuted tensors for BMM with BMM output tensor (dimension as [B, M, N]) reshaped -+/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM output tensor. -+const int D1 = 12; -+ -+/// Tensor5DPermute20314 ---> -+/// Permute layout function for 5-D permuted tensors with output matrix (dimension as [M, N]) reshaped -+/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding output tensor. -+const int T1 = 16; -+const int T2 = 3; -+const int T3 = 8; -+ -+// Alignment C -+const int AlignmentC = 8; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool error; -+ bool reference_check; -+ -+ cutlass::gemm::GemmCoord problem_each; -+ -+ int batch_count; -+ int iterations; -+ int cuda_streams; -+ bool verbose; -+ float alpha; -+ float beta; -+ -+ // -+ // Methods -+ // -+ -+ Options(): -+ help(false), -+ error(false), -+ reference_check(true), -+ batch_count(-1), -+ iterations(20), -+ cuda_streams(0), -+ verbose(false), -+ alpha(1), -+ beta() -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("alpha", alpha, 1.0f); -+ cmd.get_cmd_line_argument("beta", beta, 0.0f); -+ cmd.get_cmd_line_argument("iterations", iterations, 20); -+ cmd.get_cmd_line_argument("streams", cuda_streams, 0); -+ cmd.get_cmd_line_argument("verbose", verbose, false); -+ cmd.get_cmd_line_argument("reference-check", reference_check, true); -+ -+ int m, n, k; -+ -+ cmd.get_cmd_line_argument("m", m, 128); -+ cmd.get_cmd_line_argument("n", n, 192); -+ cmd.get_cmd_line_argument("k", k, 128); -+ cmd.get_cmd_line_argument("batch-count", batch_count, 768); -+ -+ cutlass::gemm::GemmCoord problem(m, n, k); -+ problem_each = problem; -+ -+ if (batch_count % D1 != 0){ -+ std::cerr << "\nProblem count error (problem-count = " << batch_count << "). " -+ << "problem-count needs to be divided with no remain by " << D1 << " (D1)." -+ << " (Required by the Batched GEMM permute Tensor4DPermuteBMM0213)\n\n"; -+ error = true; -+ } -+ -+ if (m % (AlignmentC * T1) != 0){ -+ std::cerr << "\nProblem m size error (m = " << m << "). " -+ << "m needs to be divided with no remain by " << (AlignmentC * T1) << " (AlignmentC * T1)." -+ << " (Required by the normal GEMM permute Tensor5DPermute20314)\n\n"; -+ error = true; -+ } -+ -+ if (n % (AlignmentC * (T2 * T3)) != 0){ -+ std::cerr << "\nProblem n size error (n = " << n << "). " -+ << "n needs to be divided with no remain by " << (AlignmentC * (T2 * T3)) << " (AlignmentC * T2 * T3)." -+ << " (Required by the normal GEMM permute Tensor5DPermute20314)\n\n"; -+ error = true; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "39_gemm_permute\n\n" -+ << " 1) This example firstly profiles the performance of a batched GEMM kernel with BMM whole output" -+ << " (including output matrices for each batch) as permuted 4D Tensor." -+ << " The BMM tensor output in shape of [B, M, N] is reshaped as [B/D1, D1, M, N] and then permuted with" -+ << " permute([0, 2, 1, 3]) to be in shape of [B/D1, M, D1, N].\n\n" -+ << " 2) This example also profiles the performance of a normal GEMM kernel with output as permuted 5D Tensor." -+ << " The GEMM matrix output in shape of [M, N] is reshaped as [M/T1, T1, T2, T3, N/T2/T3] and then permuted" -+ << " with permute([2, 0, 3, 1, 4]) to be in shape of [T2, M/T1, T3, T1, N/T2/T3].\n\n" -+ << " Note: D1, T1, T2, T3 are compile-time constants defined in gemm_permute.cu\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --batch-count= Sets the number of batches in batched GEMM (batch number for BMM). (default: --batch-count=768)\n" -+ << " --m= Sets the M dimension for both batched GEMM and normal GEMM problems. (default: --m=128)\n" -+ << " --n= Sets the N dimension for both batched GEMM and normal GEMM problems. (default: --n=192)\n" -+ << " --k= Sets the K dimension for both batched GEMM and normal GEMM problems. (default: --k=128)\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --reference-check= If true, performs reference check.\n" -+ << " --verbose= If true, prints problem sizes and batching structure.\n"; -+ -+ out << "\n\nExamples:\n\n" -+ -+ << "# Runs a batched GEMM with 96 batches\n" -+ << "$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96\n\n" -+ -+ << "# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)\n" -+ << "$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true\n\n" -+ -+ << "# Execute batched GEMM and profile with NSight\n" -+ << "$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = int64_t(); -+ -+ fmas += problem_each.product() * batch_count; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Testbed { -+public: -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementA = typename GemmBatched::ElementA; -+ using ElementB = typename GemmBatched::ElementB; -+ using ElementC = typename GemmBatched::ElementC; -+ using ElementAccumulator = typename GemmBatched::ElementAccumulator; -+ -+ using EpilogueOutputOp = typename GemmBatched::GemmKernel::Epilogue::OutputOp; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using LayoutA = typename GemmBatched::LayoutA; -+ using LayoutB = typename GemmBatched::LayoutB; -+ using LayoutC = typename GemmBatched::LayoutC; -+ -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Options & options; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint32_t seed; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ Testbed( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3090 -+ ): -+ options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Verbose BMM info -+ void print_BMM_info_() { -+ -+ // Print batched GEMM -+ std::cout << "Batched GEMM with permute([0, 2, 1, 3]) on BMM whole output tensor:\n"; -+ -+ auto problem = options.problem_each; -+ std::cout -+ << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() -+ << ", batch count: " << options.batch_count << "\n"; -+ -+ std::cout << "output tensor shape: [" << options.batch_count << ", " << problem.m() << ", " -+ << problem.n() <<"]\n"; -+ std::cout << "reshaped as: [" << options.batch_count / D1 << ", " << D1 << ", " -+ << problem.m() << ", " << problem.n() <<"]\n"; -+ std::cout << "finally permuted as: [" << options.batch_count / D1 << ", " << problem.m() << ", " -+ << D1 << ", " << problem.n() <<"]\n"; -+ -+ std::cout << "----------------------------------------------------\n"; -+ -+ } -+ -+ /// Verbose normal GEMM info -+ void print_GEMM_info_() { -+ -+ // Print batched GEMM -+ std::cout << "Normal GEMM with permute([2, 0, 3, 1, 4]):\n"; -+ -+ auto problem = options.problem_each; -+ std::cout -+ << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() << "\n"; -+ -+ std::cout << "output tensor shape: [" << problem.m() << ", " << problem.n() <<"]" << std::endl; -+ std::cout << "reshaped as: [" << problem.m() / T1 << ", " << T1 << ", " -+ << T2 << ", " << T3 << ", " << problem.n() / T2 / T3 <<"]" << std::endl; -+ std::cout << "finally permuted as: [" << T2 << ", " << problem.m() / T1 << ", " -+ << T3 << ", " << T1 << ", " << problem.n() / T2 / T3 <<"]" << std::endl; -+ -+ std::cout << "----------------------------------------------------\n"; -+ -+ } -+ -+private: -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor_( -+ Element *ptr, -+ size_t capacity, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ ptr, capacity, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::device::BlockFillRandomGaussian( -+ ptr, capacity, seed, Element(), Element(0.5f)); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ // Fill with increasing elements -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(1), Element()); -+ } -+ else { -+ -+ // Fill with all 1s -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(), Element(1)); -+ } -+ } -+ -+ /// Initializes data structures -+ void initialize_(int batch_count) { -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ // construct a few problems of random sizes -+ srand(seed); -+ -+ int64_t total_elements_A = options.problem_each.m() * options.problem_each.k() * batch_count; -+ int64_t total_elements_B = options.problem_each.n() * options.problem_each.k() * batch_count; -+ int64_t total_elements_C = options.problem_each.m() * options.problem_each.n() * batch_count; -+ int64_t total_elements_D = options.problem_each.m() * options.problem_each.n() * batch_count; -+ -+ // -+ // Assign space -+ // -+ -+ block_A.reset(total_elements_A); -+ block_B.reset(total_elements_B); -+ block_C.reset(total_elements_C); -+ block_D.reset(total_elements_D); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ initialize_tensor_(block_A.get(), total_elements_A, init_A, seed * 2021); -+ initialize_tensor_(block_B.get(), total_elements_B, init_B, seed * 2022); -+ initialize_tensor_(block_C.get(), total_elements_C, init_C, seed * 2023); -+ -+ cutlass::reference::device::BlockFillSequential( -+ block_D.get(), total_elements_D, ElementC(), ElementC()); -+ } -+ -+ /// Verifies the BMM GEMM result -+ bool verify_BMM_() { -+ -+ bool passed = true; -+ -+ cutlass::gemm::GemmCoord problem = options.problem_each; -+ -+ LayoutA layout_A(LayoutA::packed({problem.m(), problem.k()}).stride(0)); -+ LayoutB layout_B(LayoutB::packed({problem.k(), problem.n()}).stride(0)); -+ LayoutC layout_C(LayoutC::packed({problem.m(), problem.n()}).stride(0)); -+ LayoutC layout_D(LayoutC::packed({problem.m(), problem.n()}).stride(0)); -+ -+ MatrixCoord extent_A{problem.m(), problem.k()}; -+ MatrixCoord extent_B{problem.k(), problem.n()}; -+ MatrixCoord extent_C{problem.m(), problem.n()}; -+ -+ cutlass::TensorView view_A(block_A.get(), layout_A, extent_A); -+ cutlass::TensorView view_B(block_B.get(), layout_B, extent_B); -+ cutlass::TensorView view_C(block_C.get(), layout_C, extent_C); -+ -+ cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C) * options.batch_count); -+ cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem, -+ options.alpha, -+ view_A, -+ GemmBatched::kTransformA, -+ view_B, -+ GemmBatched::kTransformB, -+ options.beta, -+ view_C, -+ view_Ref_device, -+ ElementAccumulator(0), -+ options.batch_count, -+ options.problem_each.m() * options.problem_each.k(), -+ options.problem_each.n() * options.problem_each.k(), -+ options.problem_each.m() * options.problem_each.n(), -+ options.problem_each.m() * options.problem_each.n() -+ ); -+ -+ // Copy to host memory -+ std::vector matrix_D(layout_D.capacity(extent_C) * options.batch_count); -+ std::vector matrix_Ref(layout_D.capacity(extent_C) * options.batch_count); -+ -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get(), matrix_D.size()); -+ cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); -+ -+ // Print out the results and reference in 4D Tensor -+ // [options.batch_count, options.problem_each.m() * options.problem_each.n()] -> [D0, D1, D2, D3]. -+ // After permute Op, -> [D0, D2, D1, D3]. -+ int D0 = options.batch_count / D1; -+ int D2 = options.problem_each.m(); -+ int D3 = options.problem_each.n(); -+ -+ cutlass::TensorView view_D_Tensor(matrix_D.data(), // if LayoutC = cutlass::layout::ColumnMajor, view_D_Tensor should be constructed differently -+ cutlass::layout::TensorNHWC().packed(cutlass::Tensor4DCoord({D0, D2, D1, D3})), cutlass::Tensor4DCoord({D0, D2, D1, D3})); -+ -+ cutlass::TensorView view_Ref_Tensor(matrix_Ref.data(), -+ cutlass::layout::TensorNHWC().packed(cutlass::Tensor4DCoord({D0, D1, D2, D3})), cutlass::Tensor4DCoord({D0, D1, D2, D3})); -+ -+ // Tensor Permute Op on reference tensor -+ cutlass::HostTensor view_Ref_Permute_Tensor(cutlass::Tensor4DCoord({D0, D2, D1, D3})); -+ for (int n = 0; n < D0; ++n) { -+ for (int h = 0; h < D1; ++h) { -+ for (int w = 0; w < D2; ++w) { -+ for (int c = 0; c < D3; ++c) { -+ view_Ref_Permute_Tensor.at({n, w, h, c}) = view_Ref_Tensor.at({n, h, w, c}); -+ } -+ } -+ } -+ } -+ -+ // Reference check -+ passed = cutlass::reference::host::TensorEquals(view_Ref_Permute_Tensor.host_view(), view_D_Tensor); -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; -+ return passed; -+ } -+ -+ std::cout << "Passed verification" << std::endl; -+ return passed; -+ } -+ -+ bool verify_GEMM_normal_() { -+ -+ bool passed = true; -+ -+ cutlass::gemm::GemmCoord problem = options.problem_each; -+ -+ LayoutA layout_A(LayoutA::packed({problem.m(), problem.k()}).stride(0)); -+ LayoutB layout_B(LayoutB::packed({problem.k(), problem.n()}).stride(0)); -+ LayoutC layout_C(LayoutC::packed({problem.m(), problem.n()}).stride(0)); -+ LayoutC layout_D(LayoutC::packed({problem.m(), problem.n()}).stride(0)); -+ -+ MatrixCoord extent_A{problem.m(), problem.k()}; -+ MatrixCoord extent_B{problem.k(), problem.n()}; -+ MatrixCoord extent_C{problem.m(), problem.n()}; -+ -+ cutlass::TensorView view_A(block_A.get(), layout_A, extent_A); -+ cutlass::TensorView view_B(block_B.get(), layout_B, extent_B); -+ cutlass::TensorView view_C(block_C.get(), layout_C, extent_C); -+ -+ cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C)); -+ cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem, -+ options.alpha, -+ view_A, -+ GemmBatched::kTransformA, -+ view_B, -+ GemmBatched::kTransformB, -+ options.beta, -+ view_C, -+ view_Ref_device, -+ ElementAccumulator(0) -+ ); -+ -+ // Copy to host memory -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ std::vector matrix_Ref(layout_D.capacity(extent_C)); -+ -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get(), matrix_D.size()); -+ cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); -+ -+ // Print out the results and reference in 5D Tensor -+ // [options.problem_each.m(), options.problem_each.n()] -> [T0, T1, T2, T3, T4]. -+ // options.problem_each.m() == T0 * T1 -+ // options.problem_each.n() == T2 * T3 * T4 -+ // After permute Op, -> [T2, T0, T3, T1, T4]. -+ int T0 = options.problem_each.m() / T1; -+ int T4 = options.problem_each.n() / T2 / T3; -+ -+ cutlass::TensorView view_D_Tensor(matrix_D.data(), // if LayoutC = cutlass::layout::ColumnMajor, view_D_Tensor should be constructed differently -+ cutlass::layout::TensorNDHWC().packed(cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})), cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})); -+ cutlass::TensorView view_Ref_Tensor(matrix_Ref.data(), -+ cutlass::layout::TensorNDHWC().packed(cutlass::Tensor5DCoord({T0, T1, T2, T3, T4})), cutlass::Tensor5DCoord({T0, T1, T2, T3, T4})); -+ -+ // Tensor Permute Op on reference tensor -+ cutlass::HostTensor view_Ref_Permute_Tensor(cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})); -+ for (int n = 0; n < T0; ++n) { -+ for (int d = 0; d < T1; ++d) { -+ for (int h = 0; h < T2; ++h) { -+ for (int w = 0; w < T3; ++w) { -+ for (int c = 0; c < T4; ++c) { -+ view_Ref_Permute_Tensor.at({h, n, w, d, c}) = view_Ref_Tensor.at({n, d, h, w, c}); // permute([2,0,3,1,4]) -+ } -+ } -+ } -+ } -+ } -+ -+ // Reference check -+ passed = cutlass::reference::host::TensorEquals(view_Ref_Permute_Tensor.host_view(), view_D_Tensor); -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; -+ return passed; -+ } -+ -+ std::cout << "Passed verification" << std::endl; -+ return passed; -+} -+ -+public: -+ /// Executes a conventional batched GEMM kernel. -+ Result profile_batched_kBatched() { -+ -+ std::cout << "\n====================================================" << std::endl; -+ std::cout << "Batched GEMM (CUTLASS):\n" -+ << "====================================================" << std::endl; -+ -+ if (options.verbose) { -+ print_BMM_info_(); -+ } -+ -+ Result result; -+ -+ result.passed = false; -+ -+ // Initialize the problem -+ initialize_(options.batch_count); -+ -+ // Configure the GEMM arguments -+ typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); -+ -+ // Please make sure all problem_sizes are the same for kBatched mode -+ auto problem = options.problem_each; -+ -+ // For regular BMM -+ int64_t batch_stride_C = problem.m() * problem.n(); -+ // For BMM permute output ---> make sure to set batch_stride_D to zero for BMM permute op -+ int64_t batch_stride_D = 0; -+ -+ // Configure GEMM arguments -+ typename GemmBatched::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kBatched, -+ options.problem_each, -+ options.batch_count, -+ epilogue_op, -+ (void*)block_A.get(), -+ (void*)block_B.get(), -+ (void*)block_C.get(), -+ (void*)block_D.get(), -+ problem.m() * problem.k(), -+ problem.n() * problem.k(), -+ batch_stride_C, -+ batch_stride_D, -+ problem.k(), -+ problem.n(), -+ problem.n(), -+ problem.n() -+ }; -+ -+ // Initialize the GEMM object -+ GemmBatched gemm; -+ -+ result.status = gemm.initialize(arguments, nullptr); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Run the batched GEMM object -+ result.status = gemm.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ -+ if (options.reference_check) { -+ result.passed = verify_BMM_(); -+ } -+ -+ // -+ // Warm-up run of the batched GEMM object -+ // -+ result.status = gemm.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ gemm(); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ std::cout << " " << 1 << " batched GEMMs launched\n"; -+ -+ std::cout << std::endl; -+ std::cout << " " << "Batched Runtime: " << result.runtime_ms << " ms\n"; -+ std::cout << " " << "Batched GFLOPs: " << result.gflops << "\n"; -+ -+ return result; -+ } -+ -+ Result profile_GEMM_permute() { -+ -+ std::cout << "\n====================================================" << std::endl; -+ std::cout << "Normal GEMM (CUTLASS):\n" -+ << "====================================================" << std::endl; -+ -+ if (options.verbose) { -+ print_GEMM_info_(); -+ } -+ -+ Result result; -+ -+ result.passed = false; -+ -+ // Initialize the problem -+ initialize_(1); -+ -+ // Configure the GEMM arguments -+ typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); -+ -+ // Please make sure all problem_sizes are the same for kBatched mode -+ auto problem = options.problem_each; -+ -+ // Configure GEMM arguments -+ typename GemmPermute::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ options.problem_each, -+ 1, -+ epilogue_op, -+ (void*)block_A.get(), -+ (void*)block_B.get(), -+ (void*)block_C.get(), -+ (void*)block_D.get(), -+ 0, -+ 0, -+ 0, -+ 0, -+ problem.k(), -+ problem.n(), -+ problem.n(), -+ problem.n() -+ }; -+ -+ // Initialize the GEMM object -+ GemmPermute gemm_normal; -+ -+ result.status = gemm_normal.initialize(arguments, nullptr); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Run the normal GEMM object -+ result.status = gemm_normal.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ -+ if (options.reference_check) { -+ result.passed = verify_GEMM_normal_(); -+ } -+ -+ // -+ // Warm-up run of the normal GEMM object -+ // -+ result.status = gemm_normal.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ gemm_normal(); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ std::cout << std::endl; -+ std::cout << " " << "Normal Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "Normal GFLOPs: " << result.gflops << "\n"; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // -+ // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. -+ // -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { -+ -+ // -+ // This example requires an NVIDIA Ampere-architecture GPU. -+ // -+ -+ std::cout -+ << "CUTLASS's Grouped GEMM example requires a GPU of NVIDIA's Ampere Architecture or " -+ << "later (compute capability 80 or greater).\n"; -+ -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.error) { -+ std::cerr << "Aborting execution." << std::endl; -+ return -1; -+ } -+ -+ // -+ // Define the GEMM types -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ // -+ // Define a conventional batched GEMM type -+ // -+ -+ // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 -+ using GemmBatched = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, LayoutA, -+ cutlass::half_t, LayoutB, -+ ElementOutput, LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ AlignmentC, //128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 4, -+ 8, /*alignmentA*/ -+ 8, /*alignmengB*/ -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ false, /*GatherA*/ -+ false, /*GatherB*/ -+ false, /*ScatterD*/ -+ cutlass::layout::Tensor4DPermuteBMM0213 /*PermuteDLayout*/ -+ >; -+ -+ // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 -+ using GemmPermute = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, LayoutA, -+ cutlass::half_t, LayoutB, -+ ElementOutput, LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ AlignmentC, //128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 4, -+ 8, /*alignmentA*/ -+ 8, /*alignmengB*/ -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ false, /*GatherA*/ -+ false, /*GatherB*/ -+ false, /*ScatterD*/ -+ cutlass::layout::Tensor5DPermute20314 /*PermuteDLayout*/ -+ >; -+ -+ // -+ // Profile it -+ // -+ -+ Testbed testbed(options); -+ -+ Result result; -+ result = testbed.profile_batched_kBatched(); -+ if (!result.passed) { -+ std::cout << "Profiling batched GEMM has failed.\n"; -+ std::cout << "\nFailed\n"; -+ } else { -+ std::cout << "\nPassed CUTLASS batched GEMM\n"; -+ } -+ -+ result = testbed.profile_GEMM_permute(); -+ if (!result.passed) { -+ std::cout << "Profiling normal GEMM has failed.\n"; -+ std::cout << "\nFailed\n"; -+ } else { -+ std::cout << "\nPassed CUTLASS normal GEMM\n"; -+ } -+ -+ std::cout << "\n====================================================" << std::endl; -+ std::cout << "Finished\n"; -+ std::cout << "====================================================" << std::endl; -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/attention_scaling_coefs_updater.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/attention_scaling_coefs_updater.h -new file mode 100644 -index 0000000..4de04ef ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/attention_scaling_coefs_updater.h -@@ -0,0 +1,513 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/functional.h" -+#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+#include "cutlass/matrix_shape.h" -+#include "gemm_kernel_utils.h" -+ -+namespace { -+ -+static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { -+ // source: https://stackoverflow.com/a/51549250 -+ return (value >= 0) -+ ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) -+ : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); -+} -+} // namespace -+ -+/* Iterates on the accumulator and corresponding position on result matrix -+ -+(1) Update `mi[r]` to the max value of the row `r` -+(2) In a second iteration do the following: -+ (a) accum <- exp(accum - mi) -+ (b) m_prime <- exp(m_prime - mi) -+ (c) s_prime <- s_prime * m_prime + sum(accum) -+ -+All of this is done on registers, before we store all of this -+on shared memory for the next matmul with Value. -+ -+We have multiple implementations, because each configuration has a different way -+of iterating in the accumulators. -+*/ -+ -+template -+struct RegisterOps { -+ template < -+ int kQueriesPerBlock, -+ bool kFullColumns, -+ bool kIsFirst, -+ bool kKeepOutputInRF> -+ CUTLASS_DEVICE static void update( -+ typename T::Fragment& frag_o, // output so far -+ typename T::Fragment& frag, -+ cutlass::Array& mi, -+ cutlass::Array& m_prime, -+ cutlass::Array& s_prime, -+ int8_t lane_id, -+ int8_t thread_id, -+ int8_t warp_id, -+ int16_t max_col, -+ typename T::TensorCoord const& tile_offset, -+ float scaling) { -+ // Convert to `accum_t` (rather than double) -+ constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E -+ if (!kIsFirst) { -+ if (thread_id < kQueriesPerBlock) { -+ m_prime[thread_id] = mi[thread_id]; -+ } -+ __syncthreads(); -+ } -+ -+ auto lane_offset = BASE::get_lane_offset(lane_id, warp_id, tile_offset); -+ -+ // First update `mi` to the max per-row -+ { -+ accum_t max; -+ BASE::iterateRows( -+ lane_offset, -+ [&](int accum_m) { -+ max = -cutlass::platform::numeric_limits::infinity(); -+ }, -+ [&](int accum_m, int accum_n, int idx) { -+ if (kFullColumns || accum_n < max_col) { -+ max = cutlass::fast_max(max, frag[idx]); -+ } -+ }, -+ [&](int accum_m) { -+ // Having 4x atomicMax seems faster than reduce within warp -+ // first... -+ atomicMaxFloat(&mi[accum_m], max * scaling); -+ }); -+ } -+ frag = cutlass::multiplies()(scaling * kLog2e, frag); -+ -+ // Make sure we all share the update values for `mi` -+ __syncthreads(); -+ -+ if (thread_id < kQueriesPerBlock) { -+ auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); -+ m_prime[thread_id] = m_prime_exp; -+ s_prime[thread_id] *= m_prime_exp; -+ } -+ __syncthreads(); // Update output fragments -+ if (kKeepOutputInRF && !kIsFirst) { -+ accum_t mp; -+ BASE::iterateRows( -+ lane_offset, -+ [&](int accum_m) { mp = m_prime[accum_m]; }, -+ [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, -+ [&](int accum_m) {}); -+ __syncthreads(); -+ } -+ // Update accum_m, accum_n, ... -+ { -+ accum_t mi_row, total_row; -+ BASE::iterateRows( -+ lane_offset, -+ [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, -+ [&](int accum_m, int accum_n, int idx) { -+ frag[idx] = (kFullColumns || accum_n < max_col) -+ ? exp2f(frag[idx] - mi_row) -+ : accum_t(0.0); -+ }, -+ [&](int accum_m) {}); -+ BASE::iterateRows( -+ lane_offset, -+ [&](int accum_m) { total_row = 0.0; }, -+ [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, -+ [&](int accum_m) { -+ if (BASE::reduceSameRow( -+ lane_id, total_row, [](accum_t a, accum_t b) { -+ return a + b; -+ })) { -+ atomicAdd(&s_prime[accum_m], total_row); -+ } -+ }); -+ } -+ } -+}; -+ -+template -+struct AttentionScalingCoefsUpdaterSm80 -+ : RegisterOps< -+ AttentionScalingCoefsUpdaterSm80, -+ T, -+ accum_t, -+ kWarpSize> { -+ static_assert( -+ cutlass::platform:: -+ is_same::value, -+ "only RowMajor is supported"); -+ -+ using Policy = typename T::Policy; -+ using InstructionShape = typename T::InstructionShape; -+ using OpDelta = typename T::OpDelta; -+ using Shape = typename T::Shape; -+ static int const kElementsPerAccess = InstructionShape::kN / 4; -+ static int const kRowsPerTile = 8; -+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; -+ -+ static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( -+ int8_t lane_id, -+ int8_t warp_id, -+ typename T::TensorCoord const& tile_offset) { -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ return cutlass::MatrixCoord( -+ quad + tile_offset.row() * Shape::kRow, -+ lane_in_quad * kElementsPerAccess + -+ tile_offset.column() * Shape::kColumn); -+ } -+ -+ template -+ CUTLASS_DEVICE static void iterateRows( -+ cutlass::MatrixCoord& lane_offset, -+ FA beginRow, -+ FB op, -+ FC endRow) { -+ // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile + lane_offset.row(); -+ beginRow(accum_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + -+ col + lane_offset.column(); -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ op(accum_m, accum_n, idx); -+ } -+ } -+ -+ endRow(accum_m); -+ } -+ } -+ } -+ -+ template -+ CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { -+ // In each warp, 4 threads will work on the same row -+ // - the ones with the same `quad` -+ auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); -+ myValue = fn(myValue, otherV); -+ otherV = __shfl_xor_sync(0xffffffff, myValue, 2); -+ myValue = fn(myValue, otherV); -+ int lane_in_quad = (lane_id & 3); -+ return lane_in_quad == 0; -+ } -+}; -+ -+template -+struct AttentionScalingCoefsUpdaterVolta -+ : RegisterOps< -+ AttentionScalingCoefsUpdaterVolta, -+ T, -+ accum_t, -+ kWarpSize> { -+ static_assert( -+ cutlass::platform:: -+ is_same::value, -+ "only RowMajor is supported"); -+ -+ using Policy = typename T::Policy; -+ using InstructionShape = typename T::InstructionShape; -+ using OpDelta = typename T::OpDelta; -+ using Shape = typename T::Shape; -+ using Element = accum_t; -+ -+ static int const kElementsPerPartial = 4; -+ using EleShapePerPatial = typename cutlass::platform::conditional< -+ cutlass::platform::is_same::value, -+ cutlass::MatrixShape<2, 2>, -+ cutlass::MatrixShape<1, 4>>::type; -+ static int const kElementsPerMma = 8; -+ static int const kAccumulatorPatials = 2; -+ using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; -+ -+ static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( -+ int8_t lane_id, -+ int8_t warp_id, -+ typename T::TensorCoord const& tile_offset) { -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ int accum_m, accum_n; -+ -+ if (cutlass::platform::is_same::value) { -+ // (quad[2],quad[0])+lane_in_quad[0] -+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); -+ // (quad[1])+lane_in_quad[1] -+ accum_n = -+ ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + -+ (lane_in_quad & 2); -+ } else { -+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + -+ lane_in_quad; // (quad[2],quad[0]) -+ accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; -+ } -+ return cutlass::MatrixCoord( -+ accum_m + tile_offset.row() * Shape::kRow, -+ accum_n + tile_offset.column() * Shape::kColumn); -+ } -+ -+ template -+ CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { -+ static_assert( -+ cutlass::platform::is_same::value, -+ "update to support non-float accum"); -+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 -+ // T0 & T2 share same line within a quad -+ auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); -+ myValue = fn(myValue, otherV); -+ // quad 0 and quad 2 are on the same lines -+ otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); -+ myValue = fn(myValue, otherV); -+ return (lane_id & ((1 << 1) | (1 << 3))) == 0; -+ } -+ -+ template -+ CUTLASS_DEVICE static void iterateRows( -+ cutlass::MatrixCoord& lane_offset, -+ FA beginRow, -+ FB op, -+ FC endRow) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < EleShapePerPatial::kRow; ++m) { -+ int accum_m = tile_m * Policy::InterleavedTile::kRow + -+ mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row(); -+ beginRow(accum_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; -+ ++tile_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; -+ ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < kAccumulatorPatials; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { -+ int mma_accum_start = -+ (((tile_n * Policy::TileIterations::kRow + tile_m) * -+ Policy::MmaIterations::kColumn + -+ mma_n) * -+ Policy::MmaIterations::kRow + -+ mma_m) * -+ kElementsPerMma; -+ int accum_n = tile_n * Policy::InterleavedTile::kColumn + -+ mma_n * QuadShapePerPatialMma::kColumn + -+ p * Policy::InterleavedTile::kColumn / 2 + n + -+ lane_offset.column(); -+ int idx = mma_accum_start + p * kElementsPerPartial + -+ m * EleShapePerPatial::kColumn + n; -+ op(accum_m, accum_n, idx); -+ } -+ } -+ } -+ } -+ endRow(accum_m); -+ } -+ } -+ } -+ } -+}; -+ -+template -+struct AttentionScalingCoefsUpdaterSimt -+ : RegisterOps< -+ AttentionScalingCoefsUpdaterSimt, -+ T, -+ accum_t, -+ kWarpSize> { -+ using Policy = typename T::Policy; -+ using Iterations = typename T::Iterations; -+ using Element = typename T::Element; -+ using Delta = typename T::Delta; -+ using Shape = typename T::Shape; -+ static_assert( -+ cutlass::platform:: -+ is_same::value, -+ "only RowMajor is supported"); -+ -+ template -+ CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { -+ auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); -+ myValue = fn(myValue, otherV); -+ } -+ return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; -+ } -+ -+ template -+ CUTLASS_DEVICE static void iterateRows( -+ cutlass::MatrixCoord& lane_offset, -+ FA beginRow, -+ FB op, -+ FC endRow) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { -+ int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); -+ beginRow(accum_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { -+ int accum_n = -+ mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + -+ lane_offset.column(); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { -+ int idx = n + -+ Policy::LaneMmaShape::kN * -+ (mma_n + -+ Iterations::kColumn * -+ (m + mma_m * Policy::LaneMmaShape::kM)); -+ op(accum_m, accum_n + n, idx); -+ } -+ } -+ endRow(accum_m); -+ } -+ } -+ } -+ -+ static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( -+ int8_t lane_id, -+ int8_t warp_id, -+ typename T::TensorCoord const& tile_offset) { -+ static_assert( -+ cutlass::platform::is_same< -+ typename Policy::LaneLayout, -+ cutlass::layout::RowMajorInterleaved<1>>::value, -+ ""); -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ cutlass::MatrixCoord(Policy::LaneMmaShape::kM, -+ Policy::LaneMmaShape::kN); -+ return lane_offset + -+ tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); -+ } -+}; -+ -+template -+struct DefaultAttentionScalingCoefsUpdater; -+ -+// Simt -+template -+struct DefaultAttentionScalingCoefsUpdater< -+ cutlass::gemm::warp::MmaSimtTileIterator< -+ S, -+ cutlass::gemm::Operand::kC, -+ accum_t, -+ cutlass::layout::RowMajor, -+ P, -+ 1, -+ 1>, -+ accum_t, -+ kWarpSize> { -+ using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator< -+ S, -+ cutlass::gemm::Operand::kC, -+ accum_t, -+ cutlass::layout::RowMajor, -+ P, -+ 1, -+ 1>; -+ using Updater = -+ AttentionScalingCoefsUpdaterSimt; -+}; -+ -+// TensorOp - Volta -+template -+struct DefaultAttentionScalingCoefsUpdater< -+ cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< -+ S1, -+ accum_t, -+ cutlass::layout::RowMajor, -+ S2, -+ cutlass::MatrixShape<1, 1>>, -+ accum_t, -+ kWarpSize> { -+ using Iterator = -+ typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< -+ S1, -+ accum_t, -+ cutlass::layout::RowMajor, -+ S2, -+ cutlass::MatrixShape<1, 1>>; -+ using Updater = -+ AttentionScalingCoefsUpdaterVolta; -+}; -+ -+// TensorOp - Sm75+ -+template < -+ typename S1, -+ typename S2, -+ typename S3, -+ typename accum_t, -+ int kWarpSize> -+struct DefaultAttentionScalingCoefsUpdater< -+ cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< -+ S1, -+ accum_t, -+ cutlass::layout::RowMajor, -+ S2, -+ S3>, -+ accum_t, -+ kWarpSize> { -+ using Iterator = -+ typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< -+ S1, -+ accum_t, -+ cutlass::layout::RowMajor, -+ S2, -+ S3>; -+ using Updater = -+ AttentionScalingCoefsUpdaterSm80; -+}; -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/debug_utils.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/debug_utils.h -new file mode 100644 -index 0000000..73a258e ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/debug_utils.h -@@ -0,0 +1,159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+#include -+#include -+#include -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Debugging functions -+//////////////////////////////////////////////////////////////////////////////// -+// Nans & inf detection -+#define NANCHECK(frag) \ -+ { \ -+ for (int _i = 0; _i < frag.size(); ++_i) { \ -+ assert(std::isfinite(float(frag[_i]))); \ -+ assert(!std::isnan(float(frag[_i]))); \ -+ } \ -+ } -+ -+// Print on the first thread of the first block -+#if 0 -+#define PRINT_WARP_ID 0 -+#define PRINT_LANE_ID 0 -+#define PRINT_T0_L0(msg, ...) \ -+ if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ -+ threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ -+ threadIdx.z == 0) { \ -+ printf(msg "\n", __VA_ARGS__); \ -+ } -+struct __string_view { -+ char const* data; -+ std::size_t size; -+}; -+template -+constexpr __string_view __get_type_name() { -+ char const* p = __PRETTY_FUNCTION__; -+ while (*p++ != '=') -+ ; -+ for (; *p == ' '; ++p) -+ ; -+ char const* p2 = p; -+ int count = 1; -+ for (;; ++p2) { -+ switch (*p2) { -+ case '[': -+ ++count; -+ break; -+ case ']': -+ --count; -+ if (!count) -+ return {p, std::size_t(p2 - p)}; -+ } -+ } -+ return {}; -+} -+#else -+#define PRINT_T0_L0 -+#endif -+ -+// Print a given array -+#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \ -+ PRINT_T0_L0( \ -+ "%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \ -+ name, \ -+ int(start), \ -+ int(start + 8), \ -+ float(accum[start + 0]), \ -+ float(accum[start + 1]), \ -+ float(accum[start + 2]), \ -+ float(accum[start + 3]), \ -+ float(accum[start + 4]), \ -+ float(accum[start + 5]), \ -+ float(accum[start + 6]), \ -+ float(accum[start + 7])); -+#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0) -+#define PRINT_FRAG_T0_L0(name, frag) \ -+ { \ -+ auto typeStr = __get_type_name(); \ -+ PRINT_T0_L0("printing %s (%s)", name, typeStr.data); \ -+ for (int _start = 0; _start < frag.size(); _start += 8) { \ -+ PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \ -+ } \ -+ /*__syncthreads(); \ -+ NANCHECK(frag); */ \ -+ } -+#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \ -+ { \ -+ PRINT_T0_L0("printing %s (len=%d)", name, int(length)); \ -+ for (int _start = 0; _start < length; _start += incr) { \ -+ PRINT_ACCUM8_T0_L0_START(" ", array, _start); \ -+ } \ -+ } -+#define PRINT_ARRAY_T0_L0(name, array, length) \ -+ PRINT_ARRAY_T0_L0_INCR(name, array, length, 8) -+ -+// Print a 4x4 matrix -+#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \ -+ PRINT_T0_L0( \ -+ "%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \ -+ name, \ -+ int(start_x), \ -+ int(start_x + 4), \ -+ int(start_y), \ -+ int(start_y + 4), \ -+ float(ref.at({start_x + 0, start_y + 0})), \ -+ float(ref.at({start_x + 0, start_y + 1})), \ -+ float(ref.at({start_x + 0, start_y + 2})), \ -+ float(ref.at({start_x + 0, start_y + 3})), \ -+ float(ref.at({start_x + 1, start_y + 0})), \ -+ float(ref.at({start_x + 1, start_y + 1})), \ -+ float(ref.at({start_x + 1, start_y + 2})), \ -+ float(ref.at({start_x + 1, start_y + 3})), \ -+ float(ref.at({start_x + 2, start_y + 0})), \ -+ float(ref.at({start_x + 2, start_y + 1})), \ -+ float(ref.at({start_x + 2, start_y + 2})), \ -+ float(ref.at({start_x + 2, start_y + 3})), \ -+ float(ref.at({start_x + 3, start_y + 0})), \ -+ float(ref.at({start_x + 3, start_y + 1})), \ -+ float(ref.at({start_x + 3, start_y + 2})), \ -+ float(ref.at({start_x + 3, start_y + 3}))); -+#define PRINT_TENSOR4x4_T0_L0(name, ref) \ -+ PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0) -+ -+#define PRINT_PROBLEM_SIZE(name, ps) \ -+ PRINT_T0_L0( \ -+ "%s.problem_size: {.m=%d, .n=%d, .k=%d}", \ -+ name, \ -+ int(ps.m()), \ -+ int(ps.n()), \ -+ int(ps.k())) -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/default_fmha_grouped.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/default_fmha_grouped.h -new file mode 100644 -index 0000000..5a1ed5c ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/default_fmha_grouped.h -@@ -0,0 +1,284 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "fmha_grouped.h" -+#include "gemm_kernel_utils.h" -+#include "find_default_mma.h" -+#include "attention_scaling_coefs_updater.h" -+#include "mma_from_smem.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ // The datatype of Q/K/V -+ typename scalar_t_, -+ // Architecture we are targeting (eg `cutlass::arch::Sm80`) -+ typename ArchTag_, -+ // If Q/K/V are correctly aligned in memory and we can run a fast kernel -+ bool isAligned_, -+ int kQueriesPerBlock, -+ int kKeysPerBlock, -+ bool kSingleValueIteration, -+ GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly -+ > -+struct DefaultFMHAGrouped { -+ using scalar_t = scalar_t_; -+ using accum_t = float; -+ using output_t = scalar_t; -+ -+ // Accumulator between 2 iterations -+ // Using `accum_t` improves perf on f16 at the cost of -+ // numerical errors -+ using output_accum_t = accum_t; -+ -+ using ArchTag = ArchTag_; -+ static bool const kIsAligned = isAligned_; -+ static int const kWarpSize = 32; -+ static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize); -+ -+ struct MM0 { -+ /* -+ In this first matmul, we compute a block of `Q @ K.T`. -+ While the calculation result is still hot in registers, we update -+ `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value -+ into a shared-memory ("AccumulatorSharedStorage") that is used later as -+ operand A for the second matmul (see MM1) -+ */ -+ -+ using GemmType = gemm_kernel_utils::DefaultGemmType; -+ using OpClass = typename GemmType::OpClass; -+ -+ using ElementA = scalar_t; -+ using ElementB = scalar_t; -+ using ElementC = scalar_t; -+ using ElementAccumulator = accum_t; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using DefaultConfig = -+ typename cutlass::gemm::device::DefaultGemmConfiguration< -+ OpClass, -+ ArchTag, -+ ElementA, -+ ElementB, -+ ElementC, -+ ElementAccumulator -+ >; -+ -+ static int const kAlignmentA = -+ kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; -+ static int const kAlignmentB = -+ kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; -+ using InstructionShape = typename GemmType::InstructionShape; -+ -+ static int const kStages = DefaultConfig::kStages; -+ using Operator = typename GemmType::Operator; -+ -+ using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ OpClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ kStages, -+ Operator -+ >::DefaultMma; -+ -+ using MmaCore = typename DefaultMma::MmaCore; -+ using IteratorA = typename DefaultMma::IteratorA; -+ using IteratorB = typename DefaultMma::IteratorB; -+ using Mma = typename DefaultMma::ThreadblockMma; -+ using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< -+ typename Mma::Operator::IteratorC, -+ ElementAccumulator, -+ kWarpSize>::Updater; -+ -+ static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, ""); -+ -+ // Epilogue to store to shared-memory in a format that we can use later for -+ // the second matmul -+ using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< -+ typename Mma::Operator::IteratorC, -+ typename Mma::Operator, -+ scalar_t, -+ WarpShape, -+ ThreadblockShape>; -+ using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; -+ }; -+ -+ struct MM1 { -+ /* -+ Second matmul: perform `attn @ V` where `attn` is the attention (not -+ normalized) and stored in shared memory -+ */ -+ -+ using GemmType = typename MM0::GemmType; -+ using OpClass = typename GemmType::OpClass; -+ -+ using ElementA = scalar_t; -+ using ElementB = scalar_t; -+ using ElementC = output_accum_t; -+ using ElementAccumulator = accum_t; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using DefaultConfig = -+ typename cutlass::gemm::device::DefaultGemmConfiguration< -+ OpClass, -+ ArchTag, -+ ElementA, -+ ElementB, -+ ElementC, -+ ElementAccumulator -+ >; -+ -+ static int const kAlignmentA = DefaultConfig::kAlignmentA; -+ static int const kAlignmentB = -+ kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; -+ -+ using ThreadblockShape = typename MM0::ThreadblockShape; -+ using WarpShape = typename MM0::WarpShape; -+ using InstructionShape = typename MM0::InstructionShape; -+ -+ using EpilogueOutputOp = typename DefaultConfig::EpilogueOutputOp; -+ -+ static int const kStages = DefaultConfig::kStages; -+ using Operator = typename GemmType::Operator; -+ -+ using ThreadblockSwizzle = void; // Swizzling is unused -+ static bool const kSplitKSerial = false; -+ -+ using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OpClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator>; -+ -+ using DefaultMmaFromSmem = -+ typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< -+ typename DefaultGemm::Mma, -+ typename MM0::AccumulatorSharedStorage>; -+ -+ using Mma = typename DefaultMmaFromSmem::Mma; -+ using IteratorB = typename Mma::IteratorB; -+ using WarpCount = typename Mma::WarpCount; -+ static_assert(WarpCount::kCount == kNumWarpsPerBlock, ""); -+ -+ using DefaultEpilogue = typename DefaultGemm::Epilogue; -+ using OutputTileIterator = -+ typename cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename DefaultEpilogue::OutputTileIterator::ThreadMap, -+ output_t>; -+ using OutputTileIteratorAccum = -+ typename cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename DefaultEpilogue::OutputTileIterator::ThreadMap, -+ output_accum_t>; -+ -+ struct SharedStorageMM1 { -+ typename Mma::SharedStorage mm; -+ }; -+ }; -+ -+/// Define the kernel in terms of the default kernel -+ using FMHAKernel = kernel::FMHAGrouped< -+ MM0, -+ MM1, -+ scalar_t, -+ accum_t, -+ output_t, -+ output_accum_t, -+ kSingleValueIteration, -+ GroupScheduleMode_ -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_pipelined.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_pipelined.h -new file mode 100644 -index 0000000..2a574e7 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_pipelined.h -@@ -0,0 +1,632 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ File copied from "cutlass/epilogue/threadblock/epilogue.h" -+ then modified to: -+ (1) load 2 source fragments at the same time (pipelining) -+ (2) support reading from a different dtype -+ (3) pass the row id to the OutputOp if it takes it -+ (see MemoryEfficientAttentionNormalize) -+ Note that in general the fragment passed to the OutputOp could -+ span multiple rows but it does not happen with the configurations we have -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+template -+struct ApplyEpilogueOp { -+ static CUTLASS_DEVICE typename Op::FragmentOutput apply( -+ Op const& output_op, -+ int row_id, -+ typename Op::FragmentAccumulator const& accum, -+ typename Op::FragmentOutput const& source) { -+ return output_op(accum, source); -+ } -+ static CUTLASS_DEVICE typename Op::FragmentOutput apply( -+ Op const& output_op, -+ int row_id, -+ typename Op::FragmentAccumulator const& accum) { -+ return output_op(accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: -+ ///< gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator writing output tensors -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting -+ ///< accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing -+ ///< accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading -+ ///< from SMEM -+ typename OutputOp_, ///< Output operator -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank -+ ///< conflicts (concept: MatrixShape) -+ int FragmentsPerPartition = -+ 1, ///< Used to coarsten the epilogue granularity -+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is -+ ///< large -+ (!IsEpilogueFunctorHeavy::value), -+ typename OutputTileSourceIterator_ = -+ OutputTileIterator_ ///< Tile iterator reading tensors -+ > -+class EpiloguePipelined : public EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition> { -+ public: -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>; -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using OutputTileSourceIterator = OutputTileSourceIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ using ElementSource = typename OutputTileSourceIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = -+ typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, -+ OutputTileIterator::kElementsPerAccess>; -+ using SourceAccessType = Array< -+ typename OutputTileSourceIterator::Element, -+ OutputTileSourceIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array< -+ typename WarpTileIterator::Element, -+ OutputTileIterator::kElementsPerAccess>; -+ -+ /// Number of warps -+ using WarpCount = typename Base::WarpCount; -+ -+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 -+ ? Base::kFragmentsPerIteration -+ : kPartitionsK; -+ static int constexpr kSmemPointerOffset = -+ Base::SharedStorage::StorageShape::kCount / kSmemTiles; -+ -+ public: -+ static_assert( -+ OutputTileSourceIterator::Fragment::kElements == -+ OutputTileIterator::Fragment::kElements, -+ "Mismatch between input tile and output tile iterator (kElements)"); -+ static_assert( -+ OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, -+ "Mismatch between input tile and output tile iterator (kIterations)"); -+ static_assert( -+ SharedLoadIterator::Fragment::kElements == -+ OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert( -+ OutputTileIterator::kElementsPerAccess, -+ "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert( -+ !(OutputTileIterator::Fragment::kElements % -+ OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+ private: -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ public: -+ /// Constructor -+ CUTLASS_DEVICE -+ EpiloguePipelined( -+ typename Base::SharedStorage& shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ) -+ : Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ shared_load_iterator_(shared_storage.reference(), thread_idx) {} -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const& output_op, ///< Output operator -+ OutputTileIterator -+ destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const& -+ accumulators, ///< Complete warp-level accumulator tile -+ OutputTileSourceIterator -+ source_iterator) { ///< Threadblock tile coordinate in GEMM (in units -+ ///< of threadblock tiles) -+ -+ if (!output_op.is_source_needed()) { -+ compute_source_not_needed_(output_op, destination_iterator, accumulators); -+ } else { -+ compute_source_needed_( -+ output_op, destination_iterator, accumulators, source_iterator); -+ } -+ } -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const& output_op, ///< Output operator -+ OutputTileIterator -+ destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const& -+ accumulators) { ///< Complete warp-level accumulator tile -+ compute_source_not_needed_(output_op, destination_iterator, accumulators); -+ } -+ -+ private: -+ template -+ struct acc2smem_source_not_needed; -+ -+ template -+ struct acc2smem_source_not_needed> { -+ template -+ CUTLASS_DEVICE static void helper( -+ AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator& warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ warp_tile_iterator.store(accum_fragment); -+ if (p < Base::kFragmentsPerIteration - 1) { -+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); -+ } -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ warp_tile_iterator.add_pointer_offset( -+ kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ static void push( -+ size_t pos, -+ AccumulatorFragmentIterator const& iterator_begin, -+ WarpTileIterator& warp_tile_iterator) { -+ int dummy[] = { -+ (pos == (Seq * Base::kFragmentsPerIteration)) && -+ (helper( -+ iterator_begin, warp_tile_iterator), -+ 0)...}; -+ -+ CUTLASS_UNUSED(dummy[0]); -+ } -+ }; -+ -+ static_assert( -+ kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, -+ "One of these must be exactly 1."); -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_not_needed_( -+ OutputOp const& output_op, ///< Output operator -+ OutputTileIterator -+ destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const& -+ accumulators ///< Complete warp-level accumulator tile -+ ) { -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+#pragma unroll( \ -+ IterationsUnroll \ -+ ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ -+ : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; -+ iter += Base::kFragmentsPerIteration) { -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ acc2smem_source_not_needed>:: -+ push(iter, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { -+ typename SharedLoadIterator::Fragment -+ aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ if (p < Base::kFragmentsPerIteration - 1) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ } else if (kPartitionsK > 1) { -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments( -+ aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_pointer_offset( -+ (1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Compute the output result -+ // -+ -+ typename OutputTileIterator::Fragment output_fragment; -+ -+ apply_output_operator_source_not_needed_( -+ destination_iterator.thread_start_row(), -+ output_fragment, -+ output_op, -+ aligned_accum_fragment[0]); -+ -+ // -+ // Store the final result -+ // -+ -+ destination_iterator.store(output_fragment); -+ ++destination_iterator; -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ shared_load_iterator_.add_pointer_offset( -+ kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ } -+ -+ template -+ struct acc2smem_source_needed; -+ -+ template -+ struct acc2smem_source_needed> { -+ template -+ CUTLASS_DEVICE static void helper( -+ AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator& warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ warp_tile_iterator.store(accum_fragment); -+ } -+ -+ CUTLASS_DEVICE -+ static void push( -+ size_t pos, -+ AccumulatorFragmentIterator const& iterator_begin, -+ WarpTileIterator& warp_tile_iterator) { -+ int dummy[] = { -+ (pos == Seq) && -+ (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ } -+ }; -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const& output_op, ///< Output operator -+ OutputTileIterator -+ destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const& -+ accumulators, ///< Complete warp-level accumulator tile -+ OutputTileSourceIterator -+ source_iterator ///< Threadblock tile coordinate in GEMM (in units of -+ ///< threadblock tiles) -+ ) { -+ typename OutputTileSourceIterator::Fragment source_fragment[2]; -+ -+ source_fragment[0].clear(); -+ source_iterator.load(source_fragment[0]); -+ ++source_iterator; -+ source_fragment[1].clear(); -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ if (iter > 0) { -+ __syncthreads(); -+ } -+ // -+ // Load the source for next iteration (pipelining) -+ // -+ -+ if (iter + 1 < OutputTileIterator::kIterations) { -+ source_iterator.load(source_fragment[(iter + 1) % 2]); -+ } -+ ++source_iterator; -+ acc2smem_source_needed< -+ cutlass::make_index_sequence>:: -+ push(iter, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment -+ aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the -+ // k-slices -+ if (kPartitionsK > 1) { -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments( -+ aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_pointer_offset( -+ (1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Compute the output result -+ // -+ -+ typename OutputTileIterator::Fragment output_fragment; -+ -+ apply_output_operator_( -+ destination_iterator.thread_start_row(), -+ output_fragment, -+ output_op, -+ aligned_accum_fragment[0], -+ source_fragment[iter % 2]); -+ -+ // -+ // Store the final result -+ // -+ -+ destination_iterator.store(output_fragment); -+ ++destination_iterator; -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ int begin_row, -+ typename OutputTileIterator::Fragment& output_fragment, -+ OutputOp const& output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const& aligned_accum_fragment, -+ typename OutputTileSourceIterator::Fragment const& source_fragment) { -+ OutputAccessType* output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const* compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ SourceAccessType const* source_frag_ptr = -+ reinterpret_cast(&source_fragment); -+ -+ int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / -+ OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ // Call the output operator -+ output_frag_ptr[i] = ApplyEpilogueOp::apply( -+ output_op, -+ begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), -+ compute_frag_ptr[i], -+ source_frag_ptr[i]); -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_source_not_needed_( -+ int begin_row, -+ typename OutputTileIterator::Fragment& output_fragment, -+ OutputOp const& output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const& aligned_accum_fragment) { -+ OutputAccessType* output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const* compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / -+ OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ // Call the output operator -+ output_frag_ptr[i] = ApplyEpilogueOp::apply( -+ output_op, -+ begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), -+ compute_frag_ptr[i]); -+ } -+ } -+ -+ // This should be constexpr, but it's only supported on c++14 -+ static int CUTLASS_HOST_DEVICE getRowOffset(int i) { -+ using ThreadMap = typename OutputTileIterator::ThreadMap; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; -+ ++cluster) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ int row_offset = row * ThreadMap::Delta::kRow + -+ group * ThreadMap::Delta::kGroup + -+ cluster * ThreadMap::Delta::kCluster; -+ int frag_row_idx = -+ (row + -+ ThreadMap::Iterations::kRow * -+ (group + ThreadMap::Iterations::kGroup * cluster)); -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; -+ ++column) { -+ int frag_idx = ThreadMap::kElementsPerAccess * -+ (frag_row_idx * ThreadMap::Iterations::kColumn + column); -+ if (i < frag_idx + ThreadMap::kElementsPerAccess) { -+ return row_offset; -+ } -+ } -+ } -+ } -+ } -+ return -1; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_rescale_output.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_rescale_output.h -new file mode 100644 -index 0000000..a5d8f8d ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_rescale_output.h -@@ -0,0 +1,262 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory -+ to match canonical tensor layouts in global memory. Epilogues support -+ conversion and reduction operations. -+ -+ This is a copy of cutlass/epilogue/threadblock/epilogue.h that can -+ handle "row_id" as a first argument, as uses it to get the corresponding -+ `m_prime` / `s_prime` to rescale the output. -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "epilogue_pipelined.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+// output <- alpha * accumulator + beta * source -+// with: -+// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) -+// beta = alpha / m_prime (renormalize the output when the max changes) -+// source is the current output -+template < -+ typename ElementOutput_, ///< Data type used to store tensors -+ typename ElementSource_, //< Data type for source (usually matches -+ //`ElementOutput`) -+ int Count, ///< Number of elements computed per operation. -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data -+ ///< to store -+ typename ElementAccumulator_, ///< Accumulator data type -+ typename ElementCompute_, ///< Data type used to compute linear combination -+ bool isFirst, -+ bool isLast, -+ typename FragmentAlphaBeta_, -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> -+class MemoryEfficientAttentionNormalize { -+ public: -+ using ElementOutput = ElementOutput_; -+ using ElementSource = ElementSource_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ -+ using FragmentOutput = Array; -+ using FragmentSource = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ using FragmentAlphaBeta = FragmentAlphaBeta_; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ FragmentAlphaBeta const& s_prime_; -+ FragmentAlphaBeta const& m_prime_; -+ -+ public: -+ /// Constructs the function object, possibly loading from pointers in host -+ /// memory -+ CUTLASS_HOST_DEVICE -+ MemoryEfficientAttentionNormalize( -+ FragmentAlphaBeta const& s_prime, -+ FragmentAlphaBeta const& m_prime) -+ : s_prime_(s_prime), m_prime_(m_prime) {} -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return !isFirst; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) {} -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ int row, -+ FragmentAccumulator const& accumulator, -+ FragmentSource const& source) const { -+ assert(!isFirst); -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter -+ source_converter; -+ NumericArrayConverter -+ accumulator_converter; -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ ComputeFragment converted_source = source_converter(source); -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; -+ ElementCompute beta = alpha * m_prime_[row]; -+ -+ intermediate = mul_add_source(beta, converted_source); // X = beta * C -+ -+ intermediate = mul_add_accumulator( -+ alpha, converted_accumulator, intermediate); // D = alpha * Accum + X -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) -+ const { -+ assert(isFirst); -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter -+ accumulator_converter; -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ ComputeFragment intermediate; -+ multiplies mul_accumulator; -+ -+ ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; -+ -+ intermediate = mul_accumulator( -+ alpha, converted_accumulator); // X = alpha * C + uniform -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+} // namespace thread -+ -+namespace threadblock { -+template < -+ typename EO, -+ typename ES, -+ int Count, -+ typename EA, -+ typename EC, -+ bool F, -+ bool L, -+ typename FAB, -+ FloatRoundStyle R> -+struct ApplyEpilogueOp> { -+ using Op = thread:: -+ MemoryEfficientAttentionNormalize; -+ static CUTLASS_DEVICE typename Op::FragmentOutput apply( -+ Op const& output_op, -+ int row_id, -+ typename Op::FragmentAccumulator const& accum, -+ typename Op::FragmentSource const& source) { -+ return output_op(row_id, accum, source); -+ } -+ static CUTLASS_DEVICE typename Op::FragmentOutput apply( -+ Op const& output_op, -+ int row_id, -+ typename Op::FragmentAccumulator const& accum) { -+ return output_op(row_id, accum); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_thread_apply_logsumexp.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_thread_apply_logsumexp.h -new file mode 100644 -index 0000000..2e286d3 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_thread_apply_logsumexp.h -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct ArrayExponential { -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const& input) const { -+ Array result; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ElementsPerAccess; ++i) { -+ result[i] = expf(input[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct ArrayExponential { -+ CUTLASS_DEVICE -+ Array operator()( -+ Array const& input) const { -+ Array result; -+ -+ int const kVectorCount = ElementsPerAccess / 2; -+ -+ __half2 const* input_ptr = -+ reinterpret_cast<__half2 const*>(input.raw_data()); -+ __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kVectorCount; ++i) { -+ res_ptr[i] = h2exp(input_ptr[i]); -+ } -+ -+ return result; -+ } -+}; -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies: -+/// output <- (input - lse).exp() -+template < -+ typename ElementOutput_, // output -+ typename ElementLSE_, // accumulator from LSE -+ typename ElementAccumulator_, // accumulator from matmul -+ typename ElementCompute_, // intermediate compute (and exp calculation) -+ int ElementsPerAccess> -+class ApplyLogSumExp { -+ public: -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ using ElementLSE = ElementLSE_; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kCount = kElementsPerAccess; -+ static const ScaleType::Kind kScale = -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentLSE = Array; -+ using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h -+ -+ public: -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ ApplyLogSumExp() {} -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return true; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) {} -+ -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const& AB, -+ FragmentLSE const& scale_unused, -+ // bias used as LSE -+ FragmentLSE const& bias) const { -+ FragmentCompute frag_AB = NumericArrayConverter< -+ ElementCompute, -+ ElementAccumulator, -+ kElementsPerAccess>()(AB); -+ FragmentCompute frag_lse_compute = -+ NumericArrayConverter()( -+ bias); -+ FragmentCompute frag_compute; -+ -+ minus minus_lse; -+ detail::ArrayExponential apply_exp; -+ frag_compute = minus_lse(frag_AB, frag_lse_compute); -+ frag_compute = apply_exp(frag_compute); -+ -+ return NumericArrayConverter< -+ ElementOutput, -+ ElementCompute, -+ kElementsPerAccess>()(frag_compute); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/find_default_mma.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/find_default_mma.h -new file mode 100644 -index 0000000..9c62c8c ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/find_default_mma.h -@@ -0,0 +1,189 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Cutlass provides helper template functions to figure out the right -+ datastructures to instanciate to run a GEMM with various parameters (see -+ `cutlass/gemm/threadblock/default_mma.h`). However, due to template -+ instantiation priority rules, it will only create an MmaMultiStage with -+ kStages=3 (otherwise creates an MmePipelined - which is not compatible with -+ FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, -+ so we just copy-pasted some code from `default_mma.h` and -+ `default_mma_core.h` files and wrapped this template to allow our usecase. -+ -+ This is really only for the FastF32 case - aka using TensorCores with fp32. -+*/ -+ -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ typename Enable_ = void> -+struct FindDefaultMma { -+ static constexpr bool AccumulatorsInRowMajor = false; -+ static constexpr SharedMemoryClearOption SharedMemoryClear = -+ SharedMemoryClearOption::kNone; -+ using DefaultMma = cutlass::gemm::threadblock::DefaultMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ Stages, -+ Operator, -+ AccumulatorsInRowMajor, -+ SharedMemoryClear>; -+}; -+ -+/// Specialization for sm80 / FastF32 / multistage with kStages=2 -+template < -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ int kStages, -+ typename Operator> -+struct FindDefaultMma< -+ ElementA_, -+ LayoutA_, -+ kAlignmentA, -+ ElementB_, -+ LayoutB_, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ kStages, -+ Operator, -+ typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> { -+ using LayoutC = layout::RowMajor; -+ using OperatorClass = arch::OpClassTensorOp; -+ using ArchTag = arch::Sm80; -+ -+ using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma< -+ ElementA_, -+ LayoutA_, -+ kAlignmentA, -+ ElementB_, -+ LayoutB_, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ 3, -+ Operator>; -+ struct DefaultMma : DefaultMma_ { -+ using MmaCore_ = typename DefaultMma_::MmaCore; -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore_::Shape, -+ typename DefaultMma_::IteratorA, -+ typename MmaCore_::SmemIteratorA, -+ MmaCore_::kCacheOpA, -+ typename DefaultMma_::IteratorB, -+ typename MmaCore_::SmemIteratorB, -+ MmaCore_::kCacheOpB, -+ ElementAccumulator, -+ LayoutC, -+ typename MmaCore_::MmaPolicy, -+ kStages>; -+ }; -+}; -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/fmha_grouped.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fmha_grouped.h -new file mode 100644 -index 0000000..7201599 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fmha_grouped.h -@@ -0,0 +1,839 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Grouped FMHA kernel -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/trace.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+ -+#include "fmha_grouped_problem_visitor.h" -+#include "gemm_kernel_utils.h" -+#include "epilogue_rescale_output.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename MM0_, ///! Structure for computing P = Q @ K -+ typename MM1_, ///! Structure for computing O = P @ V -+ typename scalar_t_, -+ typename accum_t_, -+ typename output_t_, -+ typename output_accum_t_, -+ bool kKeepOutputInRF, ///! Whether the intermediate output from MM0_ should be kept in the register file -+ GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform -+> -+struct FMHAGrouped { -+public: -+ using MM0 = MM0_; -+ using MM1 = MM1_; -+ -+ using scalar_t = scalar_t_; -+ using accum_t = accum_t_; -+ using output_t = output_t_; -+ using output_accum_t = output_accum_t_; -+ -+ static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; -+ -+ static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && -+ !cutlass::platform::is_same::value; -+ -+ // Parameters to satisfy BaseGrouped -+ using ElementA = scalar_t; -+ using ElementB = scalar_t; -+ using ElementC = accum_t; -+ using LayoutA = typename MM0::LayoutA; -+ using LayoutB = typename MM0::ElementB; -+ using LayoutC = typename MM1::ElementC; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ static int const kAlignmentA = MM0::kAlignmentA; -+ static int const kAlignmentB = MM0::kAlignmentB; -+ static int const kAlignmentC = 1; -+ using Mma = typename MM1::Mma; -+ using EpilogueOutputOp = typename MM1::EpilogueOutputOp; -+ using ThreadblockSwizzle = void; -+ using Operator = typename MM1::Operator; -+ using WarpShape = typename MM1::WarpShape; -+ using InstructionShape = typename MM1::InstructionShape; -+ -+ using ElementQ = scalar_t; -+ using ElementK = scalar_t; -+ using ElementP = accum_t; -+ using ElementV = scalar_t; -+ using ElementO = output_t; -+ using ElementOAccum = output_accum_t; -+ using ElementAccumulator = accum_t; -+ -+ using LayoutQ = typename MM0::LayoutA; -+ using LayoutK = typename MM0::LayoutB; -+ using LayoutP = typename MM0::LayoutC; -+ using LayoutV = typename MM1::LayoutB; -+ using LayoutO = typename MM1::LayoutC; -+ -+ static bool const kPreloadV = (MM1::Mma::ArchTag::kMinComputeCapability >= 80 && -+ cutlass::sizeof_bits::value == 16); -+ -+ static int const kAlignmentQ = MM0::kAlignmentA; -+ static int const kAlignmentK = MM0::kAlignmentB; -+ static int const kAlignmentV = 1; -+ -+ using ThreadblockShape = typename MM0::ThreadblockShape; -+ -+ static int const kQueriesPerBlock = ThreadblockShape::kM; -+ static int const kKeysPerBlock = ThreadblockShape::kN; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename MM1::WarpCount; -+ static int const kThreadsPerWarp = 32; -+ static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount; -+ -+ using ProblemVisitor = FMHAGroupedProblemVisitor< -+ ThreadblockShape, -+ kGroupScheduleMode, -+ kThreadCount, -+ kThreadCount>; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord *problem_sizes0; -+ GemmCoord *problem_sizes1; -+ -+ int problem_count; -+ int threadblock_count; -+ -+ ElementQ ** ptr_Q; -+ ElementK ** ptr_K; -+ ElementP ** ptr_P; -+ ElementV ** ptr_V; -+ ElementO ** ptr_O; -+ ElementOAccum ** ptr_O_accum; -+ -+ typename LayoutQ::Stride::LongIndex *ldq; -+ typename LayoutK::Stride::LongIndex *ldk; -+ typename LayoutP::Stride::LongIndex *ldv; -+ typename LayoutO::Stride::LongIndex *ldo; -+ -+ // Whether causal masking is to be performed -+ bool causal; -+ -+ // Only used by device-level operator -+ GemmCoord *host_problem_sizes; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ problem_count(0), -+ threadblock_count(0), -+ ptr_Q(nullptr), -+ ptr_K(nullptr), -+ ptr_P(nullptr), -+ ptr_V(nullptr), -+ ptr_O(nullptr), -+ ptr_O_accum(nullptr), -+ ldq(nullptr), -+ ldk(nullptr), -+ ldv(nullptr), -+ ldo(nullptr), -+ causal(false), -+ host_problem_sizes(nullptr) -+ { -+ -+ } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord *problem_sizes0, -+ GemmCoord *problem_sizes1, -+ int problem_count, -+ int threadblock_count, -+ ElementQ ** ptr_Q, -+ ElementK ** ptr_K, -+ ElementP ** ptr_P, -+ ElementV ** ptr_V, -+ ElementO ** ptr_O, -+ ElementOAccum ** ptr_O_accum, -+ typename LayoutQ::Stride::LongIndex *ldq, -+ typename LayoutK::Stride::LongIndex *ldk, -+ typename LayoutP::Stride::LongIndex *ldp, -+ typename LayoutV::Stride::LongIndex *ldv, -+ typename LayoutO::Stride::LongIndex *ldo, -+ bool causal, -+ GemmCoord *host_problem_sizes=nullptr -+ ): -+ problem_sizes0(problem_sizes0), -+ problem_sizes1(problem_sizes1), -+ problem_count(problem_count), -+ threadblock_count(threadblock_count), -+ ptr_Q(ptr_Q), -+ ptr_K(ptr_K), -+ ptr_P(ptr_P), -+ ptr_V(ptr_V), -+ ptr_O(ptr_O), -+ ptr_O_accum(kNeedsOutputAccumulatorBuffer ? ptr_O_accum : (accum_t**)ptr_O), -+ ldq(ldq), -+ ldk(ldk), -+ ldv(ldv), -+ ldo(ldo), -+ causal(causal), -+ host_problem_sizes(host_problem_sizes) -+ { -+ -+ } -+ -+ bool __host__ check_supported() { -+ CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ); -+ CHECK_ALIGNED_PTR(ptr_K, kAlignmentK); -+ CHECK_ALIGNED_PTR(ptr_V, kAlignmentV); -+ XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned"); -+ XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned"); -+ XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned"); -+ return true; -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ typename ProblemVisitor::Params problem_visitor; -+ int threadblock_count; -+ -+ ElementQ ** ptr_Q; -+ ElementK ** ptr_K; -+ ElementP ** ptr_P; -+ ElementV ** ptr_V; -+ ElementO ** ptr_O; -+ ElementOAccum ** ptr_O_accum; -+ -+ typename LayoutQ::Stride::LongIndex *ldq; -+ typename LayoutK::Stride::LongIndex *ldk; -+ typename LayoutP::Stride::LongIndex *ldv; -+ typename LayoutO::Stride::LongIndex *ldo; -+ -+ bool causal; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ ptr_Q(nullptr), -+ ptr_K(nullptr), -+ ptr_P(nullptr), -+ ptr_V(nullptr), -+ ptr_O(nullptr), -+ ptr_O_accum(nullptr), -+ ldq(nullptr), -+ ldk(nullptr), -+ ldv(nullptr), -+ ldo(nullptr), -+ causal(false) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0): -+ problem_visitor(args.problem_sizes0, args.problem_sizes1, args.problem_count, workspace, tile_count), -+ threadblock_count(args.threadblock_count), -+ ptr_Q(args.ptr_Q), -+ ptr_K(args.ptr_K), -+ ptr_P(args.ptr_P), -+ ptr_V(args.ptr_V), -+ ptr_O(args.ptr_O), -+ ptr_O_accum(kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O), -+ ldq(args.ldq), -+ ldk(args.ldk), -+ ldv(args.ldv), -+ ldo(args.ldo), -+ causal(args.causal) -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0) { -+ -+ problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0, -+ args.problem_sizes1, -+ args.problem_count, -+ workspace, tile_count); -+ threadblock_count = args.threadblock_count; -+ ptr_Q = args.ptr_Q; -+ ptr_K = args.ptr_K; -+ ptr_P = args.ptr_P; -+ ptr_V = args.ptr_V; -+ ptr_O = args.ptr_O; -+ ptr_O_accum = kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O; -+ ldq = args.ldq; -+ ldk = args.ldk; -+ ldv = args.ldv; -+ ldo = args.ldo; -+ causal = args.causal; -+ } -+ }; -+ -+ // Shared storage - depends on kernel params -+ struct ScalingCoefs { -+ cutlass::Array m_prime; -+ cutlass::Array s_prime; -+ cutlass::Array mi; -+ }; -+ -+ struct SharedStorageEpilogueAtEnd : ScalingCoefs { -+ struct SharedStorageAfterMM0 { -+ // Everything here might be overwritten during MM0 -+ typename MM0::AccumulatorSharedStorage si; -+ typename MM1::SharedStorageMM1 mm1; -+ }; -+ -+ union { -+ typename MM0::Mma::SharedStorage mm0; -+ SharedStorageAfterMM0 after_mm0; -+ typename MM1::DefaultEpilogue::SharedStorage epilogue; -+ }; -+ -+ CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& -+ epilogue_shared_storage() { -+ return epilogue; -+ } -+ -+ // ProblemVisitor shared storage can't be overlapped with others -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+ struct SharedStorageEpilogueInLoop : ScalingCoefs { -+ struct SharedStorageAfterMM0 { -+ // Everything here might be overwritten during MM0 -+ typename MM0::AccumulatorSharedStorage si; -+ typename MM1::SharedStorageMM1 mm1; -+ typename MM1::DefaultEpilogue::SharedStorage epilogue; -+ }; -+ -+ union { -+ typename MM0::Mma::SharedStorage mm0; -+ SharedStorageAfterMM0 after_mm0; -+ }; -+ -+ CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& -+ epilogue_shared_storage() { -+ return after_mm0.epilogue; -+ } -+ -+ // ProblemVisitor shared storage can't be overlapped with others -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+ using SharedStorage = typename cutlass::platform::conditional< -+ kKeepOutputInRF, -+ SharedStorageEpilogueAtEnd, -+ SharedStorageEpilogueInLoop>::type; -+ -+private: -+ -+ // Parameters to be used by an individual tile -+ struct TileParams { -+ -+ CUTLASS_HOST_DEVICE -+ static int query_start(int threadblock_idx) { -+ return threadblock_idx * kQueriesPerBlock; -+ } -+ -+ // Returns whether this threadblock computes within the number of queries, -+ // which is determined by the M dimension of problem 0 -+ CUTLASS_HOST_DEVICE -+ static bool can_compute(int threadblock_idx, const GemmCoord& problem_size0) { -+ return query_start(threadblock_idx) < problem_size0.m(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int num_queries(int threadblock_idx, const GemmCoord& problem_size0) { -+ return problem_size0.m() - query_start(threadblock_idx); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int num_keys(int threadblock_idx, const GemmCoord& problem_size0, bool causal) { -+ int nk = problem_size0.n(); -+ if (causal) { -+ nk = cutlass::fast_min(int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk); -+ } -+ return nk; -+ } -+ -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ FMHAGrouped() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return Status::kSuccess; -+ } -+ -+ static CUTLASS_DEVICE int16_t thread_id() { -+ return threadIdx.x; -+ } -+ -+ static CUTLASS_DEVICE int8_t warp_id() { -+ return threadIdx.x / kThreadsPerWarp; -+ } -+ -+ static CUTLASS_DEVICE int8_t lane_id() { -+ return threadIdx.x % kThreadsPerWarp; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ auto& m_prime = shared_storage.m_prime; -+ auto& s_prime = shared_storage.s_prime; -+ [[maybe_unused]] auto& si = shared_storage.after_mm0.si; -+ auto& mi = shared_storage.mi; -+ -+ ProblemVisitor problem_visitor( -+ params.problem_visitor, -+ shared_storage.problem_visitor, -+ blockIdx.x); -+ -+ // Outer 'persistent' loop to iterate over tiles -+ while (problem_visitor.next_tile()) { -+ -+ GemmCoord problem_size0 = problem_visitor.problem_size0(); -+ GemmCoord problem_size1 = problem_visitor.problem_size1(); -+ const int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); -+ -+ if (!TileParams::can_compute(threadblock_idx, problem_size0)) { -+ problem_visitor.advance(gridDim.x); -+ continue; -+ } -+ -+ const int32_t problem_idx = problem_visitor.problem_index(); -+ -+ if (thread_id() < kQueriesPerBlock) { -+ s_prime[thread_id()] = ElementAccumulator(0); -+ m_prime[thread_id()] = -+ -cutlass::platform::numeric_limits::infinity(); -+ mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); -+ } -+ -+ ElementO *ptr_O = params.ptr_O[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; -+ ElementOAccum *ptr_O_accum = params.ptr_O_accum[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; -+ const int num_queries = TileParams::num_queries(threadblock_idx, problem_size0); -+ -+ auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { -+ using OutputTileIterator = typename MM1::OutputTileIterator; -+ return OutputTileIterator( -+ typename OutputTileIterator::Params{(int32_t)params.ldo[problem_idx]}, -+ ptr_O, -+ typename OutputTileIterator::TensorCoord{ -+ num_queries, problem_size1.n()}, -+ thread_id(), -+ {0, col}); -+ }; -+ -+ auto createOutputAccumIter = [&](int col) -> -+ typename MM1::OutputTileIteratorAccum { -+ using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; -+ return OutputTileIteratorAccum( -+ typename OutputTileIteratorAccum::Params{(int32_t)params.ldo[problem_idx]}, -+ ptr_O_accum, -+ typename OutputTileIteratorAccum::TensorCoord{ -+ num_queries, problem_size1.n()}, -+ thread_id(), -+ {0, col}); -+ }; -+ -+ typename MM1::Mma::FragmentC accum_o; -+ accum_o.clear(); -+ -+ const int num_keys = TileParams::num_keys(threadblock_idx, problem_size0, params.causal); -+ -+ for (int32_t iter_key_start = 0; iter_key_start < num_keys; -+ iter_key_start += kKeysPerBlock) { -+ int32_t problem_size_0_m = -+ cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries); -+ int32_t problem_size_0_n = cutlass::fast_min( -+ (int32_t)kKeysPerBlock, num_keys - iter_key_start); -+ int32_t const& problem_size_0_k = problem_size0.k(); -+ int32_t const& problem_size_1_n = problem_size1.n(); -+ int32_t const& problem_size_1_k = problem_size_0_n; -+ -+ auto prologueV = [&](int blockN) { -+ typename MM1::Mma::IteratorB iterator_V( -+ typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])}, -+ params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], -+ {problem_size_1_k, problem_size_1_n}, -+ thread_id(), -+ cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); -+ -+ MM1::Mma::prologue( -+ shared_storage.after_mm0.mm1.mm, -+ iterator_V, -+ thread_id(), -+ problem_size_1_k); -+ }; -+ -+ __syncthreads(); // Need to have shared memory initialized, and `m_prime` -+ // updated from end of prev iter -+ -+ // -+ // MATMUL: Q.K_t -+ // -+ // Computes the block-matrix product of: -+ // (a) query[query_start:query_end, :] -+ // with -+ // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] -+ // and stores that into `shared_storage.si` -+ // -+ -+ ElementQ *ptr_Q = params.ptr_Q[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldq[problem_idx]; -+ -+ // Construct iterators to A and B operands -+ typename MM0::IteratorA iterator_A( -+ typename MM0::IteratorA::Params( -+ typename MM0::MmaCore::LayoutA(params.ldq[problem_idx])), -+ ptr_Q, -+ {problem_size_0_m, problem_size_0_k}, -+ thread_id(), -+ {0, 0}); -+ -+ typename MM0::IteratorB iterator_B( -+ typename MM0::IteratorB::Params( -+ typename MM0::MmaCore::LayoutB(params.ldk[problem_idx])), -+ params.ptr_K[problem_idx] + iter_key_start * params.ldk[problem_idx], -+ {problem_size_0_k, problem_size_0_n}, -+ thread_id(), -+ {0, 0}); -+ -+ // Construct thread-scoped matrix multiply -+ typename MM0::Mma mma( -+ shared_storage.mm0, thread_id(), warp_id(), lane_id()); -+ -+ typename MM0::Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ auto gemm_k_iterations = -+ (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); -+ __syncthreads(); -+ -+ if (kPreloadV) { -+ prologueV(0); -+ } -+ -+ typename MM0::Mma::Operator::IteratorC::TensorCoord -+ iteratorC_tile_offset = { -+ (warp_id() % MM0::Mma::WarpCount::kM), -+ (warp_id() / MM0::Mma::WarpCount::kM) -+ }; -+ -+ // Mask out last if causal -+ if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) { -+ auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( -+ lane_id(), warp_id(), iteratorC_tile_offset); -+ int32_t last_col; -+ MM0::ScalingCoefsUpdater::iterateRows( -+ lane_offset, -+ [&](int accum_m) { -+ last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start; -+ }, -+ [&](int accum_m, int accum_n, int idx) { -+ if (accum_n > last_col) { -+ accum[idx] = -+ -cutlass::platform::numeric_limits::infinity(); -+ } -+ }, -+ [&](int accum_m) {}); -+ } -+ DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { -+ DISPATCH_BOOL( -+ num_keys - iter_key_start >= kKeysPerBlock, -+ kFullColumns, -+ ([&] { -+ // Update `mi` from accum stored in registers -+ // Also updates `accum` with accum[i] <- -+ // exp(accum[i] * scale -+ // - mi) -+ MM0::ScalingCoefsUpdater::update< -+ kQueriesPerBlock, -+ kFullColumns, -+ kIsFirst, -+ kKeepOutputInRF>( -+ accum_o, -+ accum, -+ mi, -+ m_prime, -+ s_prime, -+ lane_id(), -+ thread_id(), -+ warp_id(), -+ num_keys - iter_key_start, -+ iteratorC_tile_offset, -+ 1.0f / cutlass::fast_sqrt(float(problem_size0.k()))); -+ })); -+ })); -+ -+ // Output results to shared-memory -+ int warp_idx_mn_0 = warp_id() % -+ (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); -+ auto output_tile_coords = cutlass::MatrixCoord{ -+ warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, -+ warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; -+ -+ MM0::B2bGemm::accumToSmem( -+ shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords); -+ -+ __syncthreads(); -+ -+ // -+ // MATMUL: Attn . V -+ // Run the matmul `attn @ V` for a block of attn and V. -+ // `attn` is read from shared memory (in `shared_storage_si`) -+ // `V` is read from global memory (with iterator_B) -+ // -+ -+ const int64_t nBlockN = kKeepOutputInRF ? 1 -+ : ceil_div( -+ (int64_t)problem_size_1_n, -+ int64_t(MM1::ThreadblockShape::kN)); -+ -+ // Iterate over the N dimension of GEMM1 -+ for (int blockN = 0; blockN < nBlockN; ++blockN) { -+ int gemm_k_iterations = -+ (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add and store it in accum -+ // (in registers) -+ if (!kPreloadV) { -+ __syncthreads(); // we share shmem between mma and epilogue -+ } -+ -+ typename MM1::Mma::IteratorB iterator_V( -+ typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])}, -+ params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], -+ {problem_size_1_k, problem_size_1_n}, -+ thread_id(), -+ cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); -+ -+ typename MM1::Mma mma_pv( -+ shared_storage.after_mm0.mm1.mm, -+ shared_storage.after_mm0.si, -+ (int)thread_id(), -+ (int)warp_id(), -+ (int)lane_id(), -+ (int)problem_size_1_k); -+ -+ mma_pv.set_prologue_done(kPreloadV); -+ if (!kKeepOutputInRF) { -+ accum_o.clear(); -+ } -+ -+ mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); -+ __syncthreads(); -+ -+ if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) { -+ prologueV(blockN + 1); -+ } -+ -+ if (!kKeepOutputInRF) { -+ DISPATCH_BOOL( -+ iter_key_start == 0, kIsFirst, ([&] { -+ DISPATCH_BOOL( -+ (iter_key_start + kKeysPerBlock) >= num_keys, -+ kIsLast, -+ ([&] { -+ using DefaultEpilogue = typename MM1::DefaultEpilogue; -+ using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; -+ using ElementCompute = typename DefaultOp::ElementCompute; -+ using EpilogueOutputOp = typename cutlass::epilogue:: -+ thread::MemoryEfficientAttentionNormalize< -+ typename cutlass::platform::conditional< -+ kIsLast, -+ output_t, -+ output_accum_t>::type, -+ output_accum_t, -+ DefaultOp::kCount, -+ typename DefaultOp::ElementAccumulator, -+ output_accum_t, -+ kIsFirst, -+ kIsLast, -+ cutlass::Array>; -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ EpiloguePipelined< -+ typename DefaultEpilogue::Shape, -+ typename MM1::Mma::Operator, -+ DefaultEpilogue::kPartitionsK, -+ typename cutlass::platform::conditional< -+ kIsLast, -+ typename MM1::OutputTileIterator, -+ typename MM1::OutputTileIteratorAccum>::type, -+ typename DefaultEpilogue:: -+ AccumulatorFragmentIterator, -+ typename DefaultEpilogue::WarpTileIterator, -+ typename DefaultEpilogue::SharedLoadIterator, -+ EpilogueOutputOp, -+ typename DefaultEpilogue::Padding, -+ DefaultEpilogue::kFragmentsPerIteration, -+ true, // IterationsUnroll -+ typename MM1::OutputTileIteratorAccum // Read -+ // iterator -+ >; -+ -+ int col = blockN * MM1::Mma::Shape::kN; -+ auto source_iter = createOutputAccumIter(col); -+ auto dest_iter = gemm_kernel_utils::call_conditional< -+ kIsLast, -+ decltype(createOutputIter), -+ decltype(createOutputAccumIter)>:: -+ apply(createOutputIter, createOutputAccumIter, col); -+ EpilogueOutputOp rescale(s_prime, m_prime); -+ Epilogue epilogue( -+ shared_storage.epilogue_shared_storage(), -+ thread_id(), -+ warp_id(), -+ lane_id()); -+ epilogue(rescale, dest_iter, accum_o, source_iter); -+ })); -+ })); -+ if (!kKeepOutputInRF) { -+ __syncthreads(); -+ } -+ } -+ } -+ __syncthreads(); // we modify `m_prime` after -+ } -+ -+ if (kKeepOutputInRF) { -+ const bool kIsFirst = true; -+ const bool kIsLast = true; -+ using DefaultEpilogue = typename MM1::DefaultEpilogue; -+ using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; -+ using ElementCompute = typename DefaultOp::ElementCompute; -+ using EpilogueOutputOp = -+ typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< -+ output_t, // output -+ output_accum_t, // source -+ DefaultOp::kCount, -+ typename DefaultOp::ElementAccumulator, // accum -+ output_accum_t, // compute -+ kIsFirst, -+ kIsLast, -+ cutlass::Array>; -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::EpiloguePipelined< -+ typename DefaultEpilogue::Shape, -+ typename MM1::Mma::Operator, -+ DefaultEpilogue::kPartitionsK, -+ typename MM1::OutputTileIterator, // destination -+ typename DefaultEpilogue::AccumulatorFragmentIterator, -+ typename DefaultEpilogue::WarpTileIterator, -+ typename DefaultEpilogue::SharedLoadIterator, -+ EpilogueOutputOp, -+ typename DefaultEpilogue::Padding, -+ DefaultEpilogue::kFragmentsPerIteration, -+ true, // IterationsUnroll -+ typename MM1::OutputTileIteratorAccum // source tile -+ >; -+ auto dest_iter = createOutputIter(0); -+ EpilogueOutputOp rescale(s_prime, m_prime); -+ Epilogue epilogue( -+ shared_storage.epilogue_shared_storage(), -+ thread_id(), -+ warp_id(), -+ lane_id()); -+ epilogue(rescale, dest_iter, accum_o); -+ } -+ -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h -new file mode 100644 -index 0000000..2b31319 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h -@@ -0,0 +1,178 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Scheduler for grouped FMHA -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/gemm/kernel/grouped_problem_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+// Helper for correctly representing problem sizes in grouped kernels -+template -+struct FMHAGroupedProblemSizeHelper { -+ -+ CUTLASS_HOST_DEVICE -+ static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { -+ // FMHA only partitions tiles across the M dimension. -+ return cutlass::gemm::GemmCoord( -+ ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), 1, 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {} -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { -+ return grid.m() * grid.n(); -+ } -+}; -+ -+} // namespace detail -+ -+/// Visitor class to abstract away the algorithm for iterating over tiles -+template -+struct FMHAGroupedProblemVisitor : public GroupedProblemVisitor< -+ detail::FMHAGroupedProblemSizeHelper, -+ ThreadblockShape, -+ GroupScheduleMode_, -+ PrefetchTileCount, -+ ThreadCount> { -+ -+ using ProblemSizeHelper = detail::FMHAGroupedProblemSizeHelper; -+ using Base = GroupedProblemVisitor; -+ using BaseParams = typename Base::Params; -+ using SharedStorage = typename Base::SharedStorage; -+ -+ cutlass::gemm::GemmCoord const *problem_sizes0; -+ cutlass::gemm::GemmCoord const *problem_sizes1; -+ -+ struct Params { -+ cutlass::gemm::GemmCoord const *problem_sizes0; -+ cutlass::gemm::GemmCoord const *problem_sizes1; -+ int32_t problem_count; -+ void const *workspace; -+ int32_t tile_count; -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Params(): problem_sizes0(nullptr), problem_sizes1(nullptr), -+ problem_count(0), workspace(nullptr), tile_count(0) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const *problem_sizes0, -+ cutlass::gemm::GemmCoord const *problem_sizes1, -+ int32_t problem_count, -+ void const *workspace = nullptr, -+ int32_t tile_count = 0 -+ ): -+ problem_sizes0(problem_sizes0), -+ problem_sizes1(problem_sizes1), -+ problem_count(problem_count), -+ workspace(workspace), -+ tile_count(tile_count) -+ {} -+ -+ /// Convert the FMHA-specific parameters to those used by the base class -+ CUTLASS_HOST_DEVICE -+ BaseParams to_base() const { -+ return BaseParams(// Set problem_sizes as problem_sizes1 because these determine -+ // shape of the final output of FMHA -+ problem_sizes1, -+ problem_count, -+ workspace, -+ tile_count); -+ } -+ -+ }; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ FMHAGroupedProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base ( -+ params_.to_base(), -+ shared_storage_, block_idx), -+ problem_sizes0(params_.problem_sizes0), -+ problem_sizes1(params_.problem_sizes1) -+ {} -+ -+ /// Returns the problem size 0 for the current problem -+ CUTLASS_HOST_DEVICE -+ cutlass::gemm::GemmCoord problem_size0() const { -+ GemmCoord problem = problem_sizes0[this->problem_idx]; -+ ProblemSizeHelper::possibly_transpose_problem(problem); -+ return problem; -+ } -+ -+ /// Returns the problem size 1 for the current problem -+ CUTLASS_HOST_DEVICE -+ cutlass::gemm::GemmCoord problem_size1() const { -+ GemmCoord problem = problem_sizes1[this->problem_idx]; -+ ProblemSizeHelper::possibly_transpose_problem(problem); -+ return problem; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu -new file mode 100644 -index 0000000..53af4ac ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu -@@ -0,0 +1,1087 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS Attention Example. -+ -+ This workload computes a fused multi head attention. -+ Because it keeps the attention matrix in shared memory, it's both faster and -+ uses less global memory. -+ -+ This is based on `"Self-Attention Does Not Need O(n^2) Memory" `_, -+ and very similar to `"FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" `_. -+ -+ Algorithm: -+ In short, we can compute the output incrementally in blocks of size B, -+ we just need to divide the final result by the sum of all coefficients in -+ the softmax (which we compute incrementally) with the following pseudo-code: -+ -+ ``` -+ s_prime = torch.zeros([num_queries, B]) -+ O = torch.zeros([num_queries, head_size_v]) -+ for i in range(0, K.shape[0], B): -+ si = exp((Q . K[i * B:(i+1) * B].t) * scale) -+ sum_coefs += attn_unscaled.sum(-1) -+ O += si . V[i * B:(i+1) * B] -+ O = O / s_prime -+ ``` -+ -+ In practice, and for numerical stability reasons, -+ we also substract the maximum so far (`mi`) before doing -+ the exponential. When we encounter new keys, the maximum -+ used to compute O so far (`m_prime`) can differ from the -+ current maximum, so we update O before accumulating with -+ -+ ``` -+ O = O * exp(m_prime - mi) -+ m_prime = mi -+ ``` -+ -+ Implementation details: -+ - `si` is stored in shared memory between the 2 back to back gemms -+ - we keep and accumulate the output -+ directly in registers if we can (`head_size_v <= 128`). -+ Otherwise, we store it & accumulate in global memory (slower) -+ - blocks are parallelized across the batch dimension, the number -+ of heads, and the query sequence size -+ -+ -+ Examples: -+ -+ # Run an attention example with default setup -+ $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen -+ -+ # Run an attention example with custom setup -+ $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true -+ -+ Acknowledgement: Fixed-sequence-length FMHA code was upstreamed by Meta xFormers (https://github.com/facebookresearch/xformers). -+*/ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" -+#include "cutlass/fast_math.h" -+#include "kernel_forward.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool error; -+ bool reference_check; -+ bool use_mask; -+ bool causal; -+ -+ std::vector problem_sizes0; -+ std::vector problem_sizes1; -+ -+ std::vector problem_sizes0_real; -+ std::vector problem_sizes1_real; -+ -+ int alignment; -+ int head_number; -+ int batch_size; -+ int head_size; -+ int head_size_v; -+ int seq_length; -+ int seq_length_kv; -+ int iterations; -+ -+ // alpha0, alpha1 and beta are fixed -+ // in this multi-head attention example -+ float alpha0; -+ float alpha1; -+ float beta; -+ -+ // -+ // Methods -+ // -+ -+ Options(): -+ help(false), -+ error(false), -+ alignment(1), -+ reference_check(true), -+ head_number(12), -+ batch_size(16), -+ head_size(64), -+ head_size_v(64), -+ seq_length(1024), -+ seq_length_kv(1024), -+ use_mask(false), -+ iterations(20), -+ causal(false) -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("alignment", alignment, 1); -+ cmd.get_cmd_line_argument("head_number", head_number, 12); -+ cmd.get_cmd_line_argument("batch_size", batch_size, 16); -+ cmd.get_cmd_line_argument("head_size", head_size, 64); -+ cmd.get_cmd_line_argument("head_size_v", head_size_v, head_size); -+ cmd.get_cmd_line_argument("seq_length", seq_length, 1024); -+ cmd.get_cmd_line_argument("seq_length_kv", seq_length_kv, seq_length); -+ cmd.get_cmd_line_argument("use_mask", use_mask, false); -+ cmd.get_cmd_line_argument("iterations", iterations, 20); -+ cmd.get_cmd_line_argument("reference-check", reference_check, true); -+ cmd.get_cmd_line_argument("causal", causal, true); -+ -+ randomize_problems(); -+ -+ } -+ -+ void randomize_problems() { -+ -+ int problem_count = head_number * batch_size; -+ -+ problem_sizes0.reserve(problem_count); -+ problem_sizes1.reserve(problem_count); -+ -+ // When using mask, the original inputs are not padded -+ // and we need to save these info. -+ if (use_mask) { -+ problem_sizes0_real.reserve(problem_count); -+ problem_sizes1_real.reserve(problem_count); -+ } -+ -+ for (int i = 0; i < batch_size; ++i) { -+ // problems belonging to the same batch share the same seq len -+ int m_real = seq_length; -+ int mkv_real = seq_length_kv; -+ int m = (m_real + alignment - 1) / alignment * alignment; -+ int mkv = (mkv_real + alignment - 1) / alignment * alignment; -+ int k0 = head_size; -+ int k1 = head_size_v; -+ -+ for (int j = 0; j < head_number; ++j) { -+ cutlass::gemm::GemmCoord problem0(m, mkv, k0); -+ cutlass::gemm::GemmCoord problem1(m, k1, mkv); -+ problem_sizes0.push_back(problem0); -+ problem_sizes1.push_back(problem1); -+ -+ if (use_mask) { -+ cutlass::gemm::GemmCoord problem0_real(m_real, mkv_real, k0); -+ cutlass::gemm::GemmCoord problem1_real(m_real, k1, mkv_real); -+ problem_sizes0_real.push_back(problem0_real); -+ problem_sizes1_real.push_back(problem1_real); -+ } -+ } -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "41_fused_multi_head_attention_fixed_seqlen\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --head_number= Head number in multi-head attention (default: --head_number=12)\n" -+ << " --batch_size= Batch size in multi-head attention (default: --batch_size=16)\n" -+ << " --head_size= Head size in multi-head attention (default: --head_size=64)\n" -+ << " --head_size_v= Head size in multi-head attention for V (default: --head_size_v=head_size)\n" -+ << " --seq_length= Sequence length in multi-head attention for Q (default: --seq_length=1024)\n" -+ << " --seq_length_kv= Sequence length in multi-head attention for K/V (default: --seq_length_kv=seq_length)\n" -+ << " --use_mask= If true, performs padding-like masking in softmax.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --reference-check= If true, performs reference check.\n" -+ << " --causal= If true, uses causal masking.\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fops = int64_t(); -+ -+ for (int i = 0; i < problem_sizes0.size(); ++i) { -+ auto const& problem0 = problem_sizes0[i]; -+ auto const& problem1 = problem_sizes1[i]; -+ for (int row = 0; row < problem0.m(); ++row) { -+ int num_cols0 = problem0.n(); -+ if (causal) { -+ num_cols0 = std::min(row + 1, num_cols0); -+ } -+ // P <- Q . K_t -+ fops += 2 * num_cols0 * problem0.k(); -+ // P <- exp(P - max(P)) -+ fops += 2 * num_cols0; -+ // S <- sum(P) -+ fops += num_cols0 - 1; -+ // O <- P . V -+ fops += 2 * num_cols0 * problem1.n(); -+ // O <- O / S -+ fops += num_cols0 * problem1.n(); -+ } -+ } -+ -+ return double(fops) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TestbedAttention { -+public: -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementQ = typename Attention::scalar_t; -+ using ElementK = typename Attention::scalar_t; -+ using ElementP = typename Attention::accum_t; -+ using ElementAccumulator = typename Attention::accum_t; -+ using ElementV = typename Attention::scalar_t; -+ using ElementO = typename Attention::output_t; -+ -+ using ElementCompute = typename Attention::accum_t; -+ -+ using ElementNorm = typename Attention::accum_t; -+ using ElementSum = typename Attention::accum_t; -+ using ElementSoftmaxCompute = typename Attention::accum_t; -+ -+ using LayoutQ = cutlass::layout::RowMajor; -+ using LayoutK = cutlass::layout::ColumnMajor; -+ using LayoutP = cutlass::layout::RowMajor; -+ using LayoutV = cutlass::layout::RowMajor; -+ using LayoutO = cutlass::layout::RowMajor; -+ -+ using MatrixCoord = typename LayoutP::TensorCoord; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Options & options; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_Q; -+ cutlass::Distribution::Kind init_K; -+ cutlass::Distribution::Kind init_P; -+ cutlass::Distribution::Kind init_V; -+ cutlass::Distribution::Kind init_O; -+ uint32_t seed; -+ -+ cutlass::DeviceAllocation problem_sizes_device0; -+ cutlass::DeviceAllocation problem_sizes_device1; -+ cutlass::DeviceAllocation problem_sizes_device0_real; -+ -+ std::vector offset_Q; -+ std::vector offset_K; -+ std::vector offset_P; -+ std::vector offset_V; -+ std::vector offset_O; -+ -+ std::vector ldq_host; -+ std::vector ldk_host; -+ std::vector ldp_host; -+ std::vector ldv_host; -+ std::vector ldo_host; -+ std::vector seqlen_host; -+ -+ cutlass::DeviceAllocation ldq; -+ cutlass::DeviceAllocation ldk; -+ cutlass::DeviceAllocation ldp; -+ cutlass::DeviceAllocation ldv; -+ cutlass::DeviceAllocation ldo; -+ cutlass::DeviceAllocation seqlen; -+ -+ cutlass::DeviceAllocation block_Q; -+ cutlass::DeviceAllocation block_K; -+ cutlass::DeviceAllocation block_P; -+ cutlass::DeviceAllocation block_V; -+ cutlass::DeviceAllocation block_O; -+ cutlass::DeviceAllocation block_Norm; -+ cutlass::DeviceAllocation block_Sum; -+ -+ cutlass::DeviceAllocation offset_P_Device; -+ -+ cutlass::DeviceAllocation ptr_Q; -+ cutlass::DeviceAllocation ptr_K; -+ cutlass::DeviceAllocation ptr_P; -+ cutlass::DeviceAllocation ptr_V; -+ cutlass::DeviceAllocation ptr_O; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ TestbedAttention( -+ Options &options_, -+ cutlass::Distribution::Kind init_Q_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_K_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_P_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_V_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_O_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ options(options_), init_Q(init_Q_), init_K(init_K_), init_P(init_P_), init_V(init_V_), init_O(init_O_), seed(seed_) { } -+ -+ int problem_count() const { -+ return (options.head_number * options.batch_size); -+ } -+ -+private: -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor_( -+ Element *ptr, -+ size_t capacity, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 8; -+ scope_min = -8; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ ptr, capacity, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::device::BlockFillRandomGaussian( -+ ptr, capacity, seed, Element(), Element(0.5f)); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ // Fill with increasing elements -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(1), Element()); -+ } -+ else { -+ -+ // Fill with all 1s -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(), Element(1)); -+ } -+ } -+ -+ /// Initializes data structures -+ void initialize_() { -+ -+ // -+ // Set scalors for the mha example -+ // -+ -+ options.alpha0 = 1.0f / sqrt(float(options.head_size)); -+ options.alpha1 = 1.0f; -+ options.beta = 0; -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ // construct a few problems of random sizes -+ srand(seed); -+ -+ int64_t total_elements_Q = 0; -+ int64_t total_elements_K = 0; -+ int64_t total_elements_P = 0; -+ int64_t total_elements_V = 0; -+ int64_t total_elements_O = 0; -+ -+ ldq_host.resize(problem_count()); -+ ldk_host.resize(problem_count()); -+ ldp_host.resize(problem_count()); -+ ldv_host.resize(problem_count()); -+ ldo_host.resize(problem_count()); -+ seqlen_host.resize(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ -+ auto problem0 = options.problem_sizes0.at(i); -+ auto problem1 = options.problem_sizes1.at(i); -+ -+ ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0); -+ ldk_host.at(i) = LayoutK::packed({problem0.k(), problem0.n()}).stride(0); -+ ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0); -+ ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0); -+ ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0); -+ -+ // m = n for attention problems. -+ seqlen_host.at(i) = problem0.m(); -+ -+ offset_Q.push_back(total_elements_Q); -+ offset_K.push_back(total_elements_K); -+ offset_P.push_back(total_elements_P); -+ offset_V.push_back(total_elements_V); -+ offset_O.push_back(total_elements_O); -+ -+ int64_t elements_Q = problem0.m() * problem0.k(); -+ int64_t elements_K = problem0.k() * problem0.n(); -+ int64_t elements_P = problem0.m() * problem0.n(); -+ int64_t elements_V = problem1.k() * problem1.n(); -+ int64_t elements_O = problem1.m() * problem1.n(); -+ -+ total_elements_Q += elements_Q; -+ total_elements_K += elements_K; -+ total_elements_P += elements_P; -+ total_elements_V += elements_V; -+ total_elements_O += elements_O; -+ } -+ -+ problem_sizes_device0.reset(problem_count()); -+ problem_sizes_device1.reset(problem_count()); -+ problem_sizes_device0.copy_from_host(options.problem_sizes0.data()); -+ problem_sizes_device1.copy_from_host(options.problem_sizes1.data()); -+ -+ if (options.use_mask) { -+ problem_sizes_device0_real.reset(problem_count()); -+ problem_sizes_device0_real.copy_from_host(options.problem_sizes0_real.data()); -+ } -+ -+ ldq.reset(problem_count()); -+ ldk.reset(problem_count()); -+ ldp.reset(problem_count()); -+ ldv.reset(problem_count()); -+ ldo.reset(problem_count()); -+ seqlen.reset(problem_count()); -+ -+ ldq.copy_from_host(ldq_host.data()); -+ ldk.copy_from_host(ldk_host.data()); -+ ldp.copy_from_host(ldp_host.data()); -+ ldv.copy_from_host(ldv_host.data()); -+ ldo.copy_from_host(ldo_host.data()); -+ seqlen.copy_from_host(seqlen_host.data()); -+ -+ // -+ // Assign pointers -+ // -+ -+ block_Q.reset(total_elements_Q); -+ block_K.reset(total_elements_K); -+ block_P.reset(total_elements_P); -+ block_V.reset(total_elements_V); -+ block_O.reset(total_elements_O); -+ -+ offset_P_Device.reset(problem_count()); -+ -+ // sync offset with device -+ cutlass::device_memory::copy_to_device(offset_P_Device.get(), offset_P.data(), offset_P.size()); -+ -+ std::vector ptr_Q_host(problem_count()); -+ std::vector ptr_K_host(problem_count()); -+ std::vector ptr_P_host(problem_count()); -+ std::vector ptr_V_host(problem_count()); -+ std::vector ptr_O_host(problem_count()); -+ std::vector ptr_norm_host(problem_count()); -+ std::vector ptr_sum_host(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ ptr_Q_host.at(i) = block_Q.get() + offset_Q.at(i); -+ ptr_K_host.at(i) = block_K.get() + offset_K.at(i); -+ ptr_P_host.at(i) = block_P.get() + offset_P.at(i); -+ ptr_V_host.at(i) = block_V.get() + offset_V.at(i); -+ ptr_O_host.at(i) = block_O.get() + offset_O.at(i); -+ } -+ -+ ptr_Q.reset(problem_count()); -+ ptr_Q.copy_from_host(ptr_Q_host.data()); -+ -+ ptr_K.reset(problem_count()); -+ ptr_K.copy_from_host(ptr_K_host.data()); -+ -+ ptr_P.reset(problem_count()); -+ ptr_P.copy_from_host(ptr_P_host.data()); -+ -+ ptr_V.reset(problem_count()); -+ ptr_V.copy_from_host(ptr_V_host.data()); -+ -+ ptr_O.reset(problem_count()); -+ ptr_O.copy_from_host(ptr_O_host.data()); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ initialize_tensor_(block_Q.get(), total_elements_Q, init_Q, seed + 1); -+ initialize_tensor_(block_K.get(), total_elements_K, init_K, seed + 2); -+ initialize_tensor_(block_V.get(), total_elements_V, init_V, seed + 3); -+ -+ } -+ -+ template -+ bool verify_tensor_(std::vector vector_Input, \ -+ std::vector vector_Input_Ref, -+ int64_t verify_length = -1) { -+ -+ int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size(); -+ size = (verify_length == -1) ? size : verify_length; -+ -+ // 0.05 for absolute error -+ float abs_tol = 5e-2f; -+ // 10% for relative error -+ float rel_tol = 1e-1f; -+ for (int64_t i = 0; i < size; ++i) { -+ float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); -+ float abs_diff = fabs(diff); -+ float abs_ref = fabs((float)vector_Input_Ref.at(i) + 1e-5f); -+ float relative_diff = abs_diff / abs_ref; -+ if ( (isnan(vector_Input_Ref.at(i)) || isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { -+ printf("[%d/%d] diff = %f, rel_diff = %f, {computed=%f, ref=%f}.\n", int(i), int(size), abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); -+ return false; -+ } -+ -+ } -+ -+ return true; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify_() { -+ -+ bool passed = true; -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(i); -+ cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i); -+ -+ LayoutQ layout_Q(ldq_host.at(i)); -+ LayoutK layout_K(ldk_host.at(i)); -+ LayoutP layout_P(ldp_host.at(i)); -+ LayoutV layout_V(ldv_host.at(i)); -+ LayoutO layout_O(ldo_host.at(i)); -+ -+ MatrixCoord extent_Q{problem0.m(), problem0.k()}; -+ MatrixCoord extent_K{problem0.k(), problem0.n()}; -+ MatrixCoord extent_P{problem0.m(), problem0.n()}; -+ MatrixCoord extent_V{problem1.k(), problem1.n()}; -+ MatrixCoord extent_O{problem1.m(), problem1.n()}; -+ -+ cutlass::TensorView view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q); -+ cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); -+ cutlass::TensorView view_P(block_P.get() + offset_P.at(i), layout_P, extent_P); -+ cutlass::TensorView view_V(block_V.get() + offset_V.at(i), layout_V, extent_V); -+ -+ cutlass::DeviceAllocation block_Ref(layout_P.capacity(extent_P)); -+ cutlass::TensorView view_Ref_device(block_Ref.get(), layout_P, extent_P); -+ -+ cutlass::DeviceAllocation block_Ref_O(layout_O.capacity(extent_O)); -+ cutlass::TensorView view_Ref_O_device(block_Ref_O.get(), layout_O, extent_O); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementQ, LayoutQ, -+ ElementK, LayoutK, -+ ElementP, LayoutP, -+ ElementCompute, ElementAccumulator -+ >( -+ problem0, -+ ElementAccumulator(options.alpha0), -+ view_Q, -+ Attention::MM0::Mma::kTransformA, -+ view_K, -+ Attention::MM0::Mma::kTransformB, -+ ElementAccumulator(options.beta), -+ view_P, -+ view_Ref_device, -+ ElementAccumulator(0) -+ ); -+ -+ // Compute softmax for P. We need to explicitly compute softmax -+ // over P because softmax is fused to the second GEMM in the -+ // profiled implementation. -+ std::vector matrix_Ref(layout_P.capacity(extent_P)); -+ cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_Ref.size()); -+ cutlass::TensorView view_Ref_host(matrix_Ref.data(), layout_P, extent_P); -+ std::vector vector_Norm_Ref(problem0.m()); -+ std::vector vector_Sum_Ref(problem0.m()); -+ -+ int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n(); -+ -+ // Compute softmax for referece matrix -+ for (int m = 0; m < problem0.m(); m++) { -+ int n_dim_row = n_dim; -+ if (options.causal) { -+ n_dim_row = std::min(m + 1, n_dim); -+ } -+ ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0})); -+ for (int n = 1; n < n_dim_row; n++) { -+ max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n}))); -+ } -+ -+ vector_Norm_Ref.at(m) = ElementNorm(max); -+ -+ ElementSoftmaxCompute sum = ElementSoftmaxCompute(); -+ for (int n = 0; n < n_dim_row; n++) { -+ sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ); -+ } -+ ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum); -+ -+ vector_Sum_Ref.at(m) = ElementSum(inv_sum); -+ -+ for (int n = 0; n < n_dim_row; n++) { -+ view_Ref_host.ref().at({m, n}) = ElementP( -+ std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum -+ ); -+ } -+ // Mask out the rest of the attention matrix -+ for (int n = n_dim_row; n < n_dim; ++n) { -+ view_Ref_host.ref().at({m, n}) = ElementP(0); -+ } -+ } -+ -+ // when not using mask, problem_real and problem share the same sizes -+ if (options.use_mask) { -+ for (int m = 0; m < problem0.m(); m++) { -+ for (int n = n_dim; n < problem0.n(); n++) { -+ view_Ref_host.ref().at({m, n}) = ElementP(0); -+ } -+ } -+ } -+ -+ cutlass::device_memory::copy_to_device(block_P.get() + offset_P.at(i), matrix_Ref.data(), matrix_Ref.size()); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementP, LayoutP, -+ ElementV, LayoutV, -+ ElementO, LayoutO, -+ ElementCompute, ElementAccumulator -+ >( -+ problem1, -+ ElementAccumulator(options.alpha1), -+ view_P, -+ Attention::MM0::Mma::kTransformA, -+ view_V, -+ Attention::MM0::Mma::kTransformB, -+ ElementAccumulator(options.beta), -+ view_Ref_O_device, -+ view_Ref_O_device, -+ ElementAccumulator(0) -+ ); -+ -+ // Copy to host memory -+ cutlass::TensorView view_Ref(matrix_Ref.data(), layout_P, extent_P); -+ -+ std::vector matrix_O(layout_O.capacity(extent_O)); -+ cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size()); -+ std::vector matrix_Ref_O(layout_O.capacity(extent_O)); -+ cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size()); -+ -+ // printf("Pb %d: \n Q=(offset=%d, ldq=%d)\n K=(offset=%d, ldk=%d)\n O=(offset=%d, ldo=%d)\n", -+ // int(i), int(offset_Q[i]), int(ldq_host[i]), int(offset_K[i]), int(ldk_host[i]), int(offset_O[i]), int(ldo_host[i])); -+ -+ bool verified_O = false; -+ -+ if (!verified_O) { -+ verified_O = verify_tensor_(matrix_O, matrix_Ref_O); -+ } -+ -+ passed = passed && verified_O; -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; -+ -+ if (!verified_O) { -+ std::cout << "Final matrix output is incorrect" << std::endl; -+ } -+ -+ return passed; -+ } -+ } -+ -+ return passed; -+ } -+ -+public: -+ -+ -+ /// Executes a CUTLASS Attention kernel and measures runtime. -+ Result profile() { -+ -+ Result result; -+ result.passed = false; -+ -+ // Initialize the problem -+ initialize_(); -+ -+ typename Attention::Params p; -+ { // set parameters -+ p.query_ptr = block_Q.get(); -+ p.key_ptr = block_K.get(); -+ p.value_ptr = block_V.get(); -+ p.logsumexp_ptr = nullptr; // Only needed for bw -+ p.output_accum_ptr = nullptr; -+ if (Attention::kNeedsOutputAccumulatorBuffer) { -+ cudaMalloc(&p.output_accum_ptr, block_O.size() * sizeof(typename Attention::output_accum_t)); -+ } -+ p.output_ptr = block_O.get(); -+ -+ // TODO: support arbitrary seq lengths -+ // if (cu_seqlens_q.has_value()) { -+ // p.cu_seqlens_q_ptr = (int32_t*)cu_seqlens_q->data_ptr(); -+ // p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr(); -+ // } -+ -+ p.num_heads = options.head_number; -+ p.num_batches = options.batch_size; -+ p.head_dim = options.head_size; -+ p.head_dim_value = options.head_size_v; -+ p.num_queries = options.seq_length; -+ p.num_keys = options.seq_length_kv; -+ p.causal = options.causal; -+ -+ // TODO: This might overflow for big tensors -+ p.q_strideM = int32_t(ldq_host[0]); -+ p.k_strideM = int32_t(ldk_host[0]); -+ p.v_strideM = int32_t(ldv_host[0]); -+ p.q_strideH = p.q_strideM * options.seq_length; -+ p.k_strideH = p.k_strideM * options.seq_length_kv; -+ p.v_strideH = p.v_strideM * options.seq_length_kv; -+ p.o_strideH = options.head_size_v * options.seq_length; -+ p.q_strideB = p.q_strideH * options.head_number; -+ p.k_strideB = p.k_strideH * options.head_number; -+ p.v_strideB = p.v_strideH * options.head_number; -+ p.o_strideB = options.head_size_v * options.seq_length * options.head_number; -+ } -+ -+ // launch kernel :) -+ constexpr auto kernel_fn = attention_kernel_batched_impl; -+ int smem_bytes = sizeof(typename Attention::SharedStorage); -+ if (smem_bytes > 0xc000) { -+ cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); -+ } -+ if (!Attention::check_supported(p)) { -+ std::cerr << "Kernel does not support these inputs" << std::endl; -+ return result; -+ } -+ kernel_fn<<>>(p); -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ -+ if (options.reference_check) { -+ result.passed = verify_(); -+ } -+ -+ // -+ // Warm-up run -+ // -+ -+ kernel_fn<<>>(p); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Attention kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ kernel_fn<<>>(p); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ std::cout << std::endl; -+ std::cout << "CUTLASS Attention:\n" -+ << "====================================================" << std::endl; -+ std::cout << " " << " {seq length Q, seq length KV, head size, head size V, head number, batch size} = {" << options.seq_length \ -+ << ", " << options.seq_length_kv << ", " << options.head_size << ", " << options.head_size_v << ", " << options.head_number\ -+ << ", " << options.batch_size << "}." << std::endl; -+ std::cout << std::endl; -+ std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "GFLOPs: " << result.gflops << std::endl; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int kQueriesPerBlock, -+ int kKeysPerBlock, -+ bool kSingleValueIteration -+> -+int run_attention(Options& options) { -+ using Attention = AttentionKernel< -+ cutlass::half_t, // scalar_t -+ cutlass::arch::Sm80, // ArchTag -+ true, // Memory is aligned -+ kQueriesPerBlock, -+ kKeysPerBlock, -+ kSingleValueIteration -+ >; -+ -+ // -+ // Test and profile -+ // -+ -+ TestbedAttention testbed(options); -+ -+ Result result = testbed.profile(); -+ if (!result.passed) { -+ std::cout << "Profiling CUTLASS attention has failed.\n"; -+ std::cout << "\nFailed\n"; -+ return -1; -+ } -+ -+ std::cout << "\nPassed\n"; -+ return 0; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // -+ // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. -+ // -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { -+ -+ // -+ // This example requires an NVIDIA Ampere-architecture GPU. -+ // -+ -+ std::cout -+ << "CUTLASS's CUTLASS Attention example requires a GPU of NVIDIA's Ampere Architecture or " -+ << "later (compute capability 80 or greater).\n"; -+ -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.error) { -+ std::cerr << "Aborting execution." << std::endl; -+ return -1; -+ } -+ -+ if (options.use_mask) { -+ std::cerr << "--use_mask is not supported at the moment\n"; -+ return -2; -+ } -+ if (options.alignment != 1) { -+ std::cerr << "--alignment=1 is the only supported value\n"; -+ return -2; -+ } -+ -+ // Determine kernel configuration based on head size. -+ // If head size is less than or equal to 64, each block operates over 64 queries and -+ // 64 keys, and parital results can be stored in the register file. -+ // If head size is greater than 64, each block operates over 32 queries and 128 keys, -+ // and partial results are stored in shared memory. -+ if (options.head_size_v > 64) { -+ static int const kQueriesPerBlock = 32; -+ static int const kKeysPerBlock = 128; -+ if (options.head_size_v <= kKeysPerBlock) { -+ return run_attention(options); -+ } else { -+ return run_attention(options); -+ } -+ } else { -+ static int const kQueriesPerBlock = 64; -+ static int const kKeysPerBlock = 64; -+ return run_attention(options); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu -new file mode 100644 -index 0000000..35b5c32 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu -@@ -0,0 +1,1193 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS Attention Example. -+ -+ This workload computes a fused multi head attention that supports variable sequence lengths. -+ Because it keeps the attention matrix in shared memory, it's both faster and -+ uses less global memory. -+ -+ This is based on `"Self-Attention Does Not Need O(n^2) Memory" `_, -+ and very similar to `"FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" `_. -+ -+ Algorithm: -+ In short, we can compute the output incrementally in blocks of size B, -+ we just need to divide the final result by the sum of all coefficients in -+ the softmax (which we compute incrementally) with the following pseudo-code: -+ -+ ``` -+ s_prime = torch.zeros([num_queries, B]) -+ O = torch.zeros([num_queries, head_size_v]) -+ for i in range(0, K.shape[0], B): -+ si = exp((Q . K[i * B:(i+1) * B].t) * scale) -+ sum_coefs += attn_unscaled.sum(-1) -+ O += si . V[i * B:(i+1) * B] -+ O = O / s_prime -+ ``` -+ -+ In practice, and for numerical stability reasons, -+ we also substract the maximum so far (`mi`) before doing -+ the exponential. When we encounter new keys, the maximum -+ used to compute O so far (`m_prime`) can differ from the -+ current maximum, so we update O before accumulating with -+ -+ ``` -+ O = O * exp(m_prime - mi) -+ m_prime = mi -+ ``` -+ -+ Implementation details: -+ - `si` is stored in shared memory between the 2 back to back gemms -+ - we keep and accumulate the output -+ directly in registers if we can (`head_size_v <= 128`). -+ Otherwise, we store it & accumulate in global memory (slower) -+ - blocks are parallelized across the batch dimension, the number -+ of heads, and the query sequence size -+ -+ -+ Examples: -+ -+ # Run an attention example with default setup -+ $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_variable_seqlen -+ -+ # Run an attention example with custom setup -+ $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_variable_seqlen --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true -+ -+ Acknowledgement: Fixed-sequence-length FMHA code was upstreamed by Meta xFormers (https://github.com/facebookresearch/xformers). -+ Using grouped GEMM to handle variable sequence lengths is inspired by an idea originally prototyped by ByteDance Inc. -+*/ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/fast_math.h" -+ -+#include "default_fmha_grouped.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool error; -+ bool reference_check; -+ bool use_mask; -+ bool causal; -+ bool fixed_seq_length; -+ -+ std::vector problem_sizes0; -+ std::vector problem_sizes1; -+ -+ std::vector problem_sizes0_real; -+ std::vector problem_sizes1_real; -+ -+ int alignment; -+ int head_number; -+ int batch_size; -+ int head_size; -+ int head_size_v; -+ int seq_length; -+ int seq_length_kv; -+ int iterations; -+ int problem_count; -+ -+ // alpha0, alpha1 and beta are fixed -+ // in this multi-head attention example -+ float alpha0; -+ float alpha1; -+ float beta; -+ -+ cutlass::gemm::kernel::GroupScheduleMode scheduler_mode; -+ -+ // -+ // Methods -+ // -+ -+ Options(): -+ help(false), -+ error(false), -+ alignment(1), -+ reference_check(true), -+ head_number(12), -+ batch_size(16), -+ head_size(64), -+ head_size_v(64), -+ seq_length(1024), -+ seq_length_kv(1024), -+ use_mask(false), -+ iterations(20), -+ causal(false), -+ fixed_seq_length(false), -+ problem_count(batch_size * head_number), -+ scheduler_mode(cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("alignment", alignment, 1); -+ cmd.get_cmd_line_argument("head_number", head_number, 12); -+ cmd.get_cmd_line_argument("batch_size", batch_size, 16); -+ cmd.get_cmd_line_argument("head_size", head_size, 64); -+ cmd.get_cmd_line_argument("head_size_v", head_size_v, head_size); -+ cmd.get_cmd_line_argument("seq_length", seq_length, 1024); -+ cmd.get_cmd_line_argument("seq_length_kv", seq_length_kv, seq_length); -+ cmd.get_cmd_line_argument("use_mask", use_mask, false); -+ cmd.get_cmd_line_argument("iterations", iterations, 20); -+ cmd.get_cmd_line_argument("reference-check", reference_check, true); -+ cmd.get_cmd_line_argument("causal", causal, true); -+ cmd.get_cmd_line_argument("fixed_seq_length", fixed_seq_length, false); -+ -+ std::vector scheduler_mode_strs; -+ cmd.get_cmd_line_arguments("scheduler-mode", scheduler_mode_strs); -+ -+ if (!scheduler_mode_strs.empty()) { -+ if (scheduler_mode_strs.size() > 1) { -+ std::cerr << "Only one scheduler mode may be passed in" << std::endl; -+ error = true; -+ return; -+ } -+ std::string scheduler_mode_str = scheduler_mode_strs[0]; -+ if (scheduler_mode_str == "kDeviceOnly") { -+ scheduler_mode = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly; -+ } else if (scheduler_mode_str == "kHostPrecompute") { -+ scheduler_mode = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute; -+ } else { -+ std::cerr << "Unrecognized scheduler mode '" << scheduler_mode_str << "'" << std::endl; -+ error = true; -+ return; -+ } -+ } -+ -+ if (fixed_seq_length) { -+ std::cout << "NOTE: Better performance is expected for fixed-sized sequence length from 41_fused_multi_head_attention_fixed_seqlen." << std::endl; -+ } -+ -+ randomize_problems(); -+ } -+ -+ void randomize_problems() { -+ -+ problem_count = head_number * batch_size; -+ -+ problem_sizes0.reserve(problem_count); -+ problem_sizes1.reserve(problem_count); -+ -+ // When using mask, the original inputs are not padded -+ // and we need to save these info. -+ if (use_mask) { -+ problem_sizes0_real.reserve(problem_count); -+ problem_sizes1_real.reserve(problem_count); -+ } -+ -+ for (int i = 0; i < batch_size; ++i) { -+ // problems belonging to the same batch share the same seq len -+ -+ int m_real, mkv_real; -+ if (fixed_seq_length) { -+ m_real = seq_length; -+ mkv_real = seq_length_kv; -+ } else { -+ m_real = (rand() % seq_length) + 1; -+ -+ // Only randomize seq_length_kv if it was set to a different value than -+ // seq_length originally. -+ if (seq_length != seq_length_kv) { -+ mkv_real = (rand() % seq_length_kv) + 1; -+ } else { -+ mkv_real = m_real; -+ } -+ } -+ -+ int m = (m_real + alignment - 1) / alignment * alignment; -+ int mkv = (mkv_real + alignment - 1) / alignment * alignment; -+ int k0 = head_size; -+ int k1 = head_size_v; -+ -+ for (int j = 0; j < head_number; ++j) { -+ cutlass::gemm::GemmCoord problem0(m, mkv, k0); -+ cutlass::gemm::GemmCoord problem1(m, k1, mkv); -+ -+ problem_sizes0.push_back(problem0); -+ problem_sizes1.push_back(problem1); -+ -+ if (use_mask) { -+ cutlass::gemm::GemmCoord problem0_real(m_real, mkv_real, k0); -+ cutlass::gemm::GemmCoord problem1_real(m_real, k1, mkv_real); -+ problem_sizes0_real.push_back(problem0_real); -+ problem_sizes1_real.push_back(problem1_real); -+ } -+ -+ } -+ } -+ } -+ -+ void print_problems() { -+ std::cout << " Running " << batch_size << " batches, each with " << head_number << " heads of size " << head_size << ":" << std::endl; -+ for (int i = 0; i < batch_size; ++i) { -+ int idx = i * head_number; -+ std::cout << " [" << i << "] seq_length = " << problem_sizes0[idx].m() << " seq_length_kv = " << problem_sizes0[idx].n() << std::endl; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "41_fused_multi_head_attention_variable_seqlen\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --head_number= Head number in multi-head attention (default: --head_number=12)\n" -+ << " --batch_size= Batch size in multi-head attention (default: --batch_size=16)\n" -+ << " --head_size= Head size in multi-head attention (default: --head_size=64)\n" -+ << " --head_size_v= Head size in multi-head attention for V (default: --head_size_v=head_size)\n" -+ << " --seq_length= Sequence length in multi-head attention for Q (default: --seq_length=1024)\n" -+ << " --seq_length_kv= Sequence length in multi-head attention for K/V (default: --seq_length_kv=seq_length)\n" -+ << " --use_mask= If true, performs padding-like masking in softmax.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --reference-check= If true, performs reference check.\n" -+ << " --causal= If true, uses causal masking.\n" -+ << " --fixed_seq_length= If true, uses the same sequence length for each item in the batch.\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fops = int64_t(); -+ -+ for (int i = 0; i < problem_sizes0.size(); ++i) { -+ auto const& problem0 = problem_sizes0[i]; -+ auto const& problem1 = problem_sizes1[i]; -+ -+ for (int row = 0; row < problem0.m(); ++row) { -+ int num_cols0 = problem0.n(); -+ if (causal) { -+ num_cols0 = std::min(row + 1, num_cols0); -+ } -+ // P <- Q . K_t -+ fops += 2 * num_cols0 * problem0.k(); -+ // P <- exp(P - max(P)) -+ fops += 2 * num_cols0; -+ // S <- sum(P) -+ fops += num_cols0 - 1; -+ // O <- P . V -+ fops += 2 * num_cols0 * problem1.n(); -+ // O <- O / S -+ fops += num_cols0 * problem1.n(); -+ } -+ } -+ -+ return double(fops) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TestbedAttention { -+public: -+ -+ // -+ // Type definitions -+ // -+ -+ using scalar_t = typename Attention::GemmKernel::scalar_t; -+ using accum_t = typename Attention::GemmKernel::accum_t; -+ using output_t = typename Attention::GemmKernel::output_t; -+ using output_accum_t = typename Attention::GemmKernel::output_accum_t; -+ -+ using ElementQ = scalar_t; -+ using ElementK = scalar_t; -+ using ElementP = accum_t; -+ using ElementAccumulator = accum_t; -+ using ElementV = scalar_t; -+ using ElementO = output_t; -+ using ElementOAccum = output_accum_t; -+ -+ using ElementCompute = accum_t; -+ -+ using ElementNorm = accum_t; -+ using ElementSum = accum_t; -+ using ElementSoftmaxCompute = accum_t; -+ -+ using LayoutQ = cutlass::layout::RowMajor; -+ using LayoutK = cutlass::layout::ColumnMajor; -+ using LayoutP = cutlass::layout::RowMajor; -+ using LayoutV = cutlass::layout::RowMajor; -+ using LayoutO = cutlass::layout::RowMajor; -+ -+ using MatrixCoord = typename LayoutP::TensorCoord; -+ -+ static bool const kNeedsOutputAccumulatorBuffer = Attention::GemmKernel::kNeedsOutputAccumulatorBuffer; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Options & options; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_Q; -+ cutlass::Distribution::Kind init_K; -+ cutlass::Distribution::Kind init_P; -+ cutlass::Distribution::Kind init_V; -+ cutlass::Distribution::Kind init_O; -+ uint32_t seed; -+ -+ cutlass::DeviceAllocation problem_sizes_device0; -+ cutlass::DeviceAllocation problem_sizes_device1; -+ cutlass::DeviceAllocation problem_sizes_device0_real; -+ -+ std::vector offset_Q; -+ std::vector offset_K; -+ std::vector offset_P; -+ std::vector offset_V; -+ std::vector offset_O; -+ -+ std::vector ldq_host; -+ std::vector ldk_host; -+ std::vector ldp_host; -+ std::vector ldv_host; -+ std::vector ldo_host; -+ std::vector seqlen_host; -+ -+ cutlass::DeviceAllocation ldq; -+ cutlass::DeviceAllocation ldk; -+ cutlass::DeviceAllocation ldp; -+ cutlass::DeviceAllocation ldv; -+ cutlass::DeviceAllocation ldo; -+ cutlass::DeviceAllocation seqlen; -+ -+ cutlass::DeviceAllocation block_Q; -+ cutlass::DeviceAllocation block_K; -+ cutlass::DeviceAllocation block_P; -+ cutlass::DeviceAllocation block_V; -+ cutlass::DeviceAllocation block_O; -+ cutlass::DeviceAllocation block_O_accumulate; -+ cutlass::DeviceAllocation block_Norm; -+ cutlass::DeviceAllocation block_Sum; -+ -+ cutlass::DeviceAllocation offset_P_Device; -+ -+ cutlass::DeviceAllocation ptr_Q; -+ cutlass::DeviceAllocation ptr_K; -+ cutlass::DeviceAllocation ptr_P; -+ cutlass::DeviceAllocation ptr_V; -+ cutlass::DeviceAllocation ptr_O; -+ cutlass::DeviceAllocation ptr_O_accumulate; -+ -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ TestbedAttention( -+ Options &options_, -+ cutlass::Distribution::Kind init_Q_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_K_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_P_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_V_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_O_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ options(options_), init_Q(init_Q_), init_K(init_K_), init_P(init_P_), init_V(init_V_), init_O(init_O_), seed(seed_) { } -+ -+ int problem_count() const { -+ return (options.head_number * options.batch_size); -+ } -+ -+private: -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor_( -+ Element *ptr, -+ size_t capacity, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 8; -+ scope_min = -8; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ ptr, capacity, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::device::BlockFillRandomGaussian( -+ ptr, capacity, seed, Element(), Element(0.5f)); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ // Fill with increasing elements -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(1), Element()); -+ } -+ else { -+ -+ // Fill with all 1s -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(), Element(1)); -+ } -+ } -+ -+ /// Initializes data structures -+ void initialize_() { -+ -+ // -+ // Set scalors for the mha example -+ // -+ -+ options.alpha0 = 1.0f / sqrt(float(options.head_size)); -+ options.alpha1 = 1.0f; -+ options.beta = 0; -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ // construct a few problems of random sizes -+ srand(seed); -+ -+ int64_t total_elements_Q = 0; -+ int64_t total_elements_K = 0; -+ int64_t total_elements_P = 0; -+ int64_t total_elements_V = 0; -+ int64_t total_elements_O = 0; -+ -+ ldq_host.resize(problem_count()); -+ ldk_host.resize(problem_count()); -+ ldp_host.resize(problem_count()); -+ ldv_host.resize(problem_count()); -+ ldo_host.resize(problem_count()); -+ seqlen_host.resize(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ -+ auto problem0 = options.problem_sizes0.at(i); -+ auto problem1 = options.problem_sizes1.at(i); -+ -+ ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0); -+ ldk_host.at(i) = LayoutK::packed({problem0.k(), problem0.n()}).stride(0); -+ ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0); -+ ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0); -+ ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0); -+ -+ // m = n for attention problems. -+ seqlen_host.at(i) = problem0.m(); -+ -+ offset_Q.push_back(total_elements_Q); -+ offset_K.push_back(total_elements_K); -+ offset_P.push_back(total_elements_P); -+ offset_V.push_back(total_elements_V); -+ offset_O.push_back(total_elements_O); -+ -+ int64_t elements_Q = problem0.m() * problem0.k(); -+ int64_t elements_K = problem0.k() * problem0.n(); -+ int64_t elements_P = problem0.m() * problem0.n(); -+ int64_t elements_V = problem1.k() * problem1.n(); -+ int64_t elements_O = problem1.m() * problem1.n(); -+ -+ total_elements_Q += elements_Q; -+ total_elements_K += elements_K; -+ total_elements_P += elements_P; -+ total_elements_V += elements_V; -+ total_elements_O += elements_O; -+ -+ } -+ -+ problem_sizes_device0.reset(problem_count()); -+ problem_sizes_device1.reset(problem_count()); -+ problem_sizes_device0.copy_from_host(options.problem_sizes0.data()); -+ problem_sizes_device1.copy_from_host(options.problem_sizes1.data()); -+ -+ if (options.use_mask) { -+ problem_sizes_device0_real.reset(problem_count()); -+ problem_sizes_device0_real.copy_from_host(options.problem_sizes0_real.data()); -+ } -+ -+ ldq.reset(problem_count()); -+ ldk.reset(problem_count()); -+ ldp.reset(problem_count()); -+ ldv.reset(problem_count()); -+ ldo.reset(problem_count()); -+ seqlen.reset(problem_count()); -+ -+ ldq.copy_from_host(ldq_host.data()); -+ ldk.copy_from_host(ldk_host.data()); -+ ldp.copy_from_host(ldp_host.data()); -+ ldv.copy_from_host(ldv_host.data()); -+ ldo.copy_from_host(ldo_host.data()); -+ seqlen.copy_from_host(seqlen_host.data()); -+ -+ // -+ // Assign pointers -+ // -+ -+ block_Q.reset(total_elements_Q); -+ block_K.reset(total_elements_K); -+ block_P.reset(total_elements_P); -+ block_V.reset(total_elements_V); -+ block_O.reset(total_elements_O); -+ -+ if (kNeedsOutputAccumulatorBuffer) { -+ block_O_accumulate.reset(total_elements_O); -+ } -+ -+ offset_P_Device.reset(problem_count()); -+ -+ // sync offset with device -+ cutlass::device_memory::copy_to_device(offset_P_Device.get(), offset_P.data(), offset_P.size()); -+ -+ std::vector ptr_Q_host(problem_count()); -+ std::vector ptr_K_host(problem_count()); -+ std::vector ptr_P_host(problem_count()); -+ std::vector ptr_V_host(problem_count()); -+ std::vector ptr_O_host(problem_count()); -+ std::vector ptr_O_accumulate_host(problem_count()); -+ std::vector ptr_norm_host(problem_count()); -+ std::vector ptr_sum_host(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ ptr_Q_host.at(i) = block_Q.get() + offset_Q.at(i); -+ ptr_K_host.at(i) = block_K.get() + offset_K.at(i); -+ ptr_P_host.at(i) = block_P.get() + offset_P.at(i); -+ ptr_V_host.at(i) = block_V.get() + offset_V.at(i); -+ ptr_O_host.at(i) = block_O.get() + offset_O.at(i); -+ -+ if (kNeedsOutputAccumulatorBuffer) { -+ ptr_O_accumulate_host.at(i) = block_O_accumulate.get() + offset_O.at(i); -+ } -+ } -+ -+ ptr_Q.reset(problem_count()); -+ ptr_Q.copy_from_host(ptr_Q_host.data()); -+ -+ ptr_K.reset(problem_count()); -+ ptr_K.copy_from_host(ptr_K_host.data()); -+ -+ ptr_P.reset(problem_count()); -+ ptr_P.copy_from_host(ptr_P_host.data()); -+ -+ ptr_V.reset(problem_count()); -+ ptr_V.copy_from_host(ptr_V_host.data()); -+ -+ ptr_O.reset(problem_count()); -+ ptr_O.copy_from_host(ptr_O_host.data()); -+ -+ if (kNeedsOutputAccumulatorBuffer) { -+ ptr_O_accumulate.reset(problem_count()); -+ ptr_O_accumulate.copy_from_host(ptr_O_accumulate_host.data()); -+ } -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ initialize_tensor_(block_Q.get(), total_elements_Q, init_Q, seed + 1); -+ initialize_tensor_(block_K.get(), total_elements_K, init_K, seed + 2); -+ initialize_tensor_(block_V.get(), total_elements_V, init_V, seed + 3); -+ -+ } -+ -+ template -+ bool verify_tensor_(std::vector vector_Input, \ -+ std::vector vector_Input_Ref, -+ int64_t verify_length = -1) { -+ -+ int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size(); -+ size = (verify_length == -1) ? size : verify_length; -+ -+ // 0.05 for absolute error -+ float abs_tol = 5e-2f; -+ // 10% for relative error -+ float rel_tol = 1e-1f; -+ for (int64_t i = 0; i < size; ++i) { -+ float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); -+ float abs_diff = fabs(diff); -+ float abs_ref = fabs((float)vector_Input_Ref.at(i) + 1e-5f); -+ float relative_diff = abs_diff / abs_ref; -+ if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { -+ printf("[%d/%d] diff = %f, rel_diff = %f, {computed=%f, ref=%f}.\n", int(i), int(size), abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); -+ return false; -+ } -+ -+ } -+ -+ return true; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify_() { -+ -+ bool passed = true; -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(i); -+ cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i); -+ -+ LayoutQ layout_Q(ldq_host.at(i)); -+ LayoutK layout_K(ldk_host.at(i)); -+ LayoutP layout_P(ldp_host.at(i)); -+ LayoutV layout_V(ldv_host.at(i)); -+ LayoutO layout_O(ldo_host.at(i)); -+ -+ MatrixCoord extent_Q{problem0.m(), problem0.k()}; -+ MatrixCoord extent_K{problem0.k(), problem0.n()}; -+ MatrixCoord extent_P{problem0.m(), problem0.n()}; -+ MatrixCoord extent_V{problem1.k(), problem1.n()}; -+ MatrixCoord extent_O{problem1.m(), problem1.n()}; -+ -+ cutlass::TensorView view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q); -+ cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); -+ cutlass::TensorView view_P(block_P.get() + offset_P.at(i), layout_P, extent_P); -+ cutlass::TensorView view_V(block_V.get() + offset_V.at(i), layout_V, extent_V); -+ -+ cutlass::DeviceAllocation block_Ref(layout_P.capacity(extent_P)); -+ cutlass::TensorView view_Ref_device(block_Ref.get(), layout_P, extent_P); -+ -+ cutlass::DeviceAllocation block_Ref_O(layout_O.capacity(extent_O)); -+ cutlass::TensorView view_Ref_O_device(block_Ref_O.get(), layout_O, extent_O); -+ cutlass::reference::device::TensorFill(view_Ref_O_device, ElementO(0)); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementQ, LayoutQ, -+ ElementK, LayoutK, -+ ElementP, LayoutP, -+ ElementCompute, ElementAccumulator -+ >( -+ problem0, -+ ElementAccumulator(options.alpha0), -+ view_Q, -+ Attention::GemmKernel::MM0::Mma::kTransformA, -+ view_K, -+ Attention::GemmKernel::MM0::Mma::kTransformB, -+ ElementAccumulator(options.beta), -+ view_P, -+ view_Ref_device, -+ ElementAccumulator(0) -+ ); -+ -+ // Compute softmax for P. We need to explicitly compute softmax -+ // over P because softmax is fused to the second GEMM in the -+ // profiled implementation. -+ std::vector matrix_Ref(layout_P.capacity(extent_P)); -+ cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_Ref.size()); -+ cutlass::TensorView view_Ref_host(matrix_Ref.data(), layout_P, extent_P); -+ std::vector vector_Norm_Ref(problem0.m()); -+ std::vector vector_Sum_Ref(problem0.m()); -+ -+ int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n(); -+ -+ // Compute softmax for reference matrix -+ for (int m = 0; m < problem0.m(); m++) { -+ int n_dim_row = n_dim; -+ if (options.causal) { -+ n_dim_row = std::min(m + 1, n_dim); -+ } -+ ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0})); -+ for (int n = 1; n < n_dim_row; n++) { -+ max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n}))); -+ } -+ -+ vector_Norm_Ref.at(m) = ElementNorm(max); -+ -+ ElementSoftmaxCompute sum = ElementSoftmaxCompute(); -+ for (int n = 0; n < n_dim_row; n++) { -+ sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ); -+ } -+ ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum); -+ -+ vector_Sum_Ref.at(m) = ElementSum(inv_sum); -+ -+ for (int n = 0; n < n_dim_row; n++) { -+ view_Ref_host.ref().at({m, n}) = ElementP( -+ std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum -+ ); -+ } -+ // Mask out the rest of the attention matrix -+ for (int n = n_dim_row; n < n_dim; ++n) { -+ view_Ref_host.ref().at({m, n}) = ElementP(0); -+ } -+ -+ } -+ -+ // when not using mask, problem_real and problem share the same sizes -+ if (options.use_mask) { -+ for (int m = 0; m < problem0.m(); m++) { -+ for (int n = n_dim; n < problem0.n(); n++) { -+ view_Ref_host.ref().at({m, n}) = ElementP(0); -+ } -+ } -+ } -+ -+ cutlass::device_memory::copy_to_device(block_P.get() + offset_P.at(i), matrix_Ref.data(), matrix_Ref.size()); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementP, LayoutP, -+ ElementV, LayoutV, -+ ElementO, LayoutO, -+ ElementCompute, ElementAccumulator -+ >( -+ problem1, -+ ElementAccumulator(options.alpha1), -+ view_P, -+ Attention::GemmKernel::MM0::Mma::kTransformA, -+ view_V, -+ Attention::GemmKernel::MM0::Mma::kTransformB, -+ ElementAccumulator(options.beta), -+ view_Ref_O_device, -+ view_Ref_O_device, -+ ElementAccumulator(0) -+ ); -+ -+ // Copy to host memory -+ cutlass::TensorView view_Ref(matrix_Ref.data(), layout_P, extent_P); -+ -+ std::vector matrix_O(layout_O.capacity(extent_O)); -+ cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size()); -+ std::vector matrix_Ref_O(layout_O.capacity(extent_O)); -+ cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size()); -+ -+ -+ bool verified_O = false; -+ if (!verified_O) { -+ verified_O = verify_tensor_(matrix_O, matrix_Ref_O); -+ } -+ -+ passed = passed && verified_O; -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; -+ -+ if (!verified_O) { -+ std::cout << "Final matrix output is incorrect" << std::endl; -+ } -+ -+ return passed; -+ } -+ -+ } -+ -+ return passed; -+ } -+ -+public: -+ -+ -+ /// Executes a CUTLASS Attention kernel and measures runtime. -+ Result profile() { -+ -+ Result result; -+ result.passed = false; -+ -+ int threadblock_count = Attention::sufficient(options.problem_sizes1.data(), options.problem_count); -+ -+ // Early exit -+ if (!threadblock_count) { -+ std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped FMHA kernel." << std::endl; -+ return result; -+ } -+ -+ result.passed = false; -+ -+ // Initialize the problem -+ initialize_(); -+ -+ typename Attention::Arguments args( -+ problem_sizes_device0.get(), -+ problem_sizes_device1.get(), -+ options.problem_count, -+ threadblock_count, -+ ptr_Q.get(), -+ ptr_K.get(), -+ ptr_P.get(), -+ ptr_V.get(), -+ ptr_O.get(), -+ ptr_O_accumulate.get(), -+ ldq.get(), -+ ldk.get(), -+ ldp.get(), -+ ldv.get(), -+ ldo.get(), -+ options.causal, -+ options.problem_sizes1.data() -+ ); -+ -+ Attention fmha; -+ -+ size_t workspace_size = fmha.get_workspace_size(args); -+ cutlass::DeviceAllocation workspace(workspace_size); -+ -+ result.status = fmha.initialize(args, workspace.get()); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize CUTLASS Grouped FMHA kernel." << std::endl; -+ return result; -+ } -+ -+ // Run the grouped FMHA object -+ result.status = fmha.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Grouped FMHA kernel." << std::endl; -+ return result; -+ } -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ -+ if (options.reference_check) { -+ result.passed = verify_(); -+ } -+ -+ // -+ // Warm-up run of the grouped FMHA object -+ // -+ result.status = fmha.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Grouped FMHA kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of FMHA operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < this->options.iterations; ++iter) { -+ fmha(); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(this->options.iterations); -+ result.gflops = this->options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ std::cout << std::endl; -+ std::cout << "CUTLASS Attention:\n" -+ << "====================================================" << std::endl; -+ std::cout << " " << " {seq length Q, seq length KV, head size, head size V, head number, batch size} = {" << options.seq_length \ -+ << ", " << options.seq_length_kv << ", " << options.head_size << ", " << options.head_size_v << ", " << options.head_number\ -+ << ", " << options.batch_size << "}." << std::endl; -+ options.print_problems(); -+ std::cout << std::endl; -+ std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "GFLOPs: " << result.gflops << std::endl; -+ -+ return result; -+ } -+ -+ -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int kQueriesPerBlock, -+ int kKeysPerBlock, -+ bool kSingleValueIteration, -+ cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ -+> -+int run_grouped(Options& options) { -+ using AttentionKernel = typename cutlass::gemm::kernel::DefaultFMHAGrouped< -+ cutlass::half_t, // scalar_t -+ cutlass::arch::Sm80, // ArchTag -+ true, // Memory is aligned -+ kQueriesPerBlock, -+ kKeysPerBlock, -+ kSingleValueIteration, -+ GroupScheduleMode_ -+ >::FMHAKernel; -+ -+ using FMHA = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test and profile -+ // -+ -+ TestbedAttention testbed(options); -+ -+ Result result = testbed.profile(); -+ if (!result.passed) { -+ std::cout << "Profiling CUTLASS attention has failed.\n"; -+ std::cout << "\nFailed\n"; -+ return -1; -+ } -+ -+ std::cout << "\nPassed\n"; -+ return 0; -+} -+ -+ -+template < -+ int kQueriesPerBlock, -+ int kKeysPerBlock, -+ bool kSingleValueIteration -+> -+int run_attention(Options& options) { -+ if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) { -+ return run_grouped(options); -+ } else { -+ return run_grouped(options); -+ } -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // -+ // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. -+ // -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { -+ -+ // -+ // This example requires an NVIDIA Ampere-architecture GPU. -+ // -+ -+ std::cout -+ << "CUTLASS's CUTLASS Attention example requires a GPU of NVIDIA's Ampere Architecture or " -+ << "later (compute capability 80 or greater).\n"; -+ -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.error) { -+ std::cerr << "Aborting execution." << std::endl; -+ return -1; -+ } -+ -+ if (options.use_mask) { -+ std::cerr << "--use_mask is not supported at the moment\n"; -+ return -2; -+ } -+ if (options.alignment != 1) { -+ std::cerr << "--alignment=1 is the only supported value\n"; -+ return -2; -+ } -+ -+ // Determine kernel configuration based on head size. -+ // If head size is less than or equal to 64, each block operates over 64 queries and -+ // 64 keys, and parital results can be stored in the register file. -+ // If head size is greater than 64, each block operates over 32 queries and 128 keys, -+ // and partial results are stored in shared memory. -+ if (options.head_size_v > 64) { -+ static int const kQueriesPerBlock = 32; -+ static int const kKeysPerBlock = 128; -+ if (options.head_size_v <= kKeysPerBlock) { -+ return run_attention(options); -+ } else { -+ return run_attention(options); -+ } -+ } else { -+ static int const kQueriesPerBlock = 64; -+ static int const kKeysPerBlock = 64; -+ return run_attention(options); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma.h -new file mode 100644 -index 0000000..7326bad ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma.h -@@ -0,0 +1,124 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "custom_mma_multistage.h" -+#include "custom_mma_pipelined.h" -+#include "cutlass/gemm/threadblock/mma_multistage.h" -+#include "cutlass/gemm/threadblock/mma_pipelined.h" -+ -+template -+struct MakeCustomMma; -+ -+template < -+ typename Shape, -+ typename IteratorA, -+ typename SmemIteratorA, -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ typename IteratorB, -+ typename SmemIteratorB, -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ typename ElementC, -+ typename LayoutC, -+ typename Policy, -+ int Stages, -+ cutlass::gemm::SharedMemoryClearOption SharedMemoryClear, -+ int kMaxK> -+struct MakeCustomMma< -+ cutlass::gemm::threadblock::MmaMultistage< -+ Shape, -+ IteratorA, -+ SmemIteratorA, -+ CacheOpA, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ ElementC, -+ LayoutC, -+ Policy, -+ Stages, -+ SharedMemoryClear>, -+ kMaxK> { -+ // Reduce the number of stages if we don't need that many -+ static int constexpr kStages = -+ kMaxK == cutlass::platform::numeric_limits::max() -+ ? Stages -+ : cutlass::const_min( -+ Stages, -+ (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); -+ using Mma = cutlass::gemm::threadblock::CustomMmaMultistage< -+ Shape, -+ IteratorA, -+ SmemIteratorA, -+ CacheOpA, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ ElementC, -+ LayoutC, -+ Policy, -+ kStages, -+ SharedMemoryClear, -+ kMaxK>; -+}; -+ -+template < -+ typename Shape, -+ typename IteratorA, -+ typename SmemIteratorA, -+ typename IteratorB, -+ typename SmemIteratorB, -+ typename ElementC, -+ typename LayoutC, -+ typename Policy, -+ int kMaxK> -+struct MakeCustomMma< -+ cutlass::gemm::threadblock::MmaPipelined< -+ Shape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ Policy>, -+ kMaxK> { -+ using Mma = cutlass::gemm::threadblock::CustomMmaPipelined< -+ Shape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ Policy>; -+}; -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h -new file mode 100644 -index 0000000..6c6d078 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h -@@ -0,0 +1,183 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class CustomMmaBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = GemmShape< -+ Shape::kM / WarpGemm::kM, -+ Shape::kN / WarpGemm::kN, -+ Shape::kK / WarpGemm::kK>; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ template -+ struct OperandSharedStorage { -+ AlignedBuffer buffer; -+ using TensorRef = TensorRef; -+ -+ CUTLASS_DEVICE -+ static OperandLayout Layout() { -+ return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the operand -+ CUTLASS_HOST_DEVICE -+ TensorRef ref() { -+ return TensorRef{buffer.data(), Layout()}; -+ } -+ }; -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape< -+ Shape::kM + Policy::SmemPaddingA::kRow, -+ Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = MatrixShape< -+ Shape::kK * kStages + Policy::SmemPaddingB::kRow, -+ Shape::kN + Policy::SmemPaddingB::kColumn>; -+ -+ using SharedStorageA = OperandSharedStorage< -+ typename Operator::ElementA, -+ ShapeA, -+ typename Operator::LayoutA>; -+ using SharedStorageB = OperandSharedStorage< -+ typename Operator::ElementB, -+ ShapeB, -+ typename Operator::LayoutB>; -+ using TensorRefA = typename SharedStorageA::TensorRef; -+ using TensorRefB = typename SharedStorageB::TensorRef; -+ -+ struct SharedStorage { -+ /// Buffer for A operand -+ SharedStorageA operand_A; -+ -+ /// Buffer for B operand -+ SharedStorageB operand_B; -+ }; -+ -+ protected: -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ CustomMmaBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorageA& shared_storageA, -+ SharedStorageB& shared_storageB, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h -new file mode 100644 -index 0000000..e5cdc88 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h -@@ -0,0 +1,767 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "custom_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Upper boundon the K dimension -+ int kMaxK = cutlass::platform::numeric_limits::max(), -+ /// Used for partial specialization -+ typename Enable = bool> -+class CustomMmaMultistage : public CustomMmaBase { -+ public: -+ ///< Base class -+ using Base = CustomMmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ static_assert( -+ Base::kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / -+ Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / -+ Base::kWarpGemmIterations; -+ }; -+ -+ static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages; -+ static constexpr int kNumStagesConcurrentLoad = -+ kSmemContainsEntireMat ? Stages : Stages - 1; -+ -+ private: -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ bool prologue_done_; -+ -+ // Set to `True` to ensure the accumulator will be zero outside the GEMM -+ // footprint -+ bool zero_outside_bounds_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ CustomMmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorageA& shared_storageA, -+ typename Base::SharedStorageB& shared_storageB, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storageA.ref(), thread_idx), -+ smem_iterator_B_(shared_storageB.ref(), thread_idx), -+ prologue_done_(false), -+ zero_outside_bounds_(false) { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ CUTLASS_DEVICE -+ CustomMmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage& st, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : CustomMmaMultistage( -+ st.operand_A, -+ st.operand_B, -+ thread_idx, -+ warp_idx, -+ lane_idx) {} -+ -+ CUTLASS_DEVICE -+ bool set_prologue_done(bool value) { -+ prologue_done_ = value; -+ } -+ -+ CUTLASS_DEVICE -+ bool set_zero_outside_bounds(bool value) { -+ zero_outside_bounds_ = value; -+ } -+ -+ template -+ CUTLASS_DEVICE static void prologue( -+ typename Base::SharedStorage& shared_storage, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ int thread_idx, -+ int problem_size_k) { -+ prologue( -+ shared_storage.operand_A, -+ shared_storage.operand_B, -+ iterator_A, -+ iterator_B, -+ thread_idx, -+ problem_size_k); -+ } -+ -+ template -+ CUTLASS_DEVICE static void prologue( -+ typename Base::SharedStorageA& shared_storageA, -+ typename Base::SharedStorageB& shared_storageB, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ int thread_idx, -+ int problem_size_k) { -+ SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx); -+ SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx); -+ int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK; -+ _prologue( -+ iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance( -+ IteratorA& iterator_A, -+ IteratorB& iterator_B, -+ int group_start_A = 0, -+ int group_start_B = 0) { -+ iterator_A.set_iteration_index( -+ group_start_A * IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType* dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ if (zero_outside_bounds_ || -+ SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index( -+ group_start_B * IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType* dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ -+ if (zero_outside_bounds_ || -+ SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ template -+ CUTLASS_DEVICE static void _prologue( -+ IteratorA& iterator_A, -+ IteratorB& iterator_B, -+ int32_t& gemm_k_iterations, -+ SmemIteratorA& smem_iterator_A_, -+ SmemIteratorB& smem_iterator_B_) { -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < kNumStagesConcurrentLoad; -+ ++stage, --gemm_k_iterations) { -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType* dst_ptr = -+ reinterpret_cast( -+ smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); -+ -+ if (kLoadA) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType* dst_ptr = -+ reinterpret_cast( -+ smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ if (kLoadB) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ } -+ -+ ++iterator_B; -+ } -+ -+ ++smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ smem_iterator_A_.add_tile_offset({0, 1}); -+ smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC& accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< initial value of accumulator -+ FragmentC const& src_accum) { -+ // -+ // Prologue -+ // -+ -+ if (!prologue_done_) { -+ _prologue( -+ iterator_A, -+ iterator_B, -+ gemm_k_iterations, -+ smem_iterator_A_, -+ smem_iterator_B_); -+ } else if (!kSmemContainsEntireMat) { -+ _prologue( -+ iterator_A, -+ iterator_B, -+ gemm_k_iterations, -+ smem_iterator_A_, -+ smem_iterator_B_); -+ } else { -+ gemm_k_iterations -= kNumStagesConcurrentLoad; -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // -+ // Clear the remaining tiles of SMEM. This is a functional requirement for -+ // some kernels so that all accumulator elements outside the GEMM footprint -+ // are zero. -+ // -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { -+ /// Iterator to write threadblock-scoped tile of A operand to shared -+ /// memory -+ SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); -+ -+ typename IteratorA::AccessType zero_A; -+ zero_A.clear(); -+ -+ last_smem_iterator_A.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType* dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_A.get()); -+ -+ *dst_ptr = zero_A; -+ -+ ++last_smem_iterator_A; -+ } -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared -+ /// memory -+ SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); -+ typename IteratorB::AccessType zero_B; -+ -+ zero_B.clear(); -+ last_smem_iterator_B.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType* dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_B.get()); -+ -+ *dst_ptr = zero_B; -+ -+ ++last_smem_iterator_B; -+ } -+ } -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform( -+ warp_transformed_frag_A[0], -+ warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], -+ warp_loaded_frag_B[0]); -+ -+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary -+ // accumulator and this temporary accumulator is added to the final -+ // accumulator once in every mainloop iteration. -+ plus plus_accum; -+ -+ FragmentC tmp_accum; -+ -+ if (platform::is_same< -+ typename Operator::MathOperator, -+ arch::OpMultiplyAddFastF32>::value || -+ platform::is_same< -+ typename Operator::MathOperator, -+ arch::OpMultiplyAddComplexFastF32>::value) { -+ tmp_accum.clear(); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ // In case of a non-circular buffer ("kSmemContainsEntireMat") -+ // make sure we don't load out of bounds data. -+ if (!kSmemContainsEntireMat || -+ gemm_k_iterations > (-kNumStagesConcurrentLoad) || -+ warp_mma_k < Base::kWarpGemmIterations - 1) { -+ this->warp_tile_iterator_A_.load( -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load( -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ } -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform( -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ if (platform::is_same< -+ typename Operator::MathOperator, -+ arch::OpMultiplyAddFastF32>::value || -+ platform::is_same< -+ typename Operator::MathOperator, -+ arch::OpMultiplyAddComplexFastF32>::value) { -+ warp_mma( -+ tmp_accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ tmp_accum); -+ -+ if (warp_mma_k == 0) { -+ accum = plus_accum(accum, tmp_accum); -+ tmp_accum.clear(); -+ } -+ } else { -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum); -+ } -+ -+ // Issue global->shared copies for the this stage -+ if (!kSmemContainsEntireMat && -+ warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance( -+ iterator_A, -+ iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ if (!kSmemContainsEntireMat) { -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance( -+ iterator_A, -+ iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (!kSmemContainsEntireMat && -+ smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -+ -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) -+ warp_mma.transform( -+ warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ } -+ } -+ -+ if (platform::is_same< -+ typename Operator::MathOperator, -+ arch::OpMultiplyAddFastF32>::value || -+ platform::is_same< -+ typename Operator::MathOperator, -+ arch::OpMultiplyAddComplexFastF32>::value) { -+ accum = plus_accum(accum, tmp_accum); -+ } -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ // commit and drain all pending and predicated cp.async pnz from the GEMM -+ // mainloop -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h -new file mode 100644 -index 0000000..73112e9 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h -@@ -0,0 +1,401 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "custom_mma_base.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to A operand -+ typename TransformA_ = NumericArrayConverter< -+ typename SmemIteratorA_::Element, -+ typename IteratorA_::Element, -+ IteratorA_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B operand -+ typename TransformB_ = NumericArrayConverter< -+ typename SmemIteratorB_::Element, -+ typename IteratorB_::Element, -+ IteratorB_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool> -+class CustomMmaPipelined : public CustomMmaBase { -+ public: -+ ///< Base class -+ using Base = CustomMmaBase; -+ -+ using Shape = -+ Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA = -+ IteratorA_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB = -+ IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert( -+ (Base::kStages == 2), -+ "MmaPipelined requires kStages set to value 2"); -+ -+ static bool const kSmemContainsEntireMat = false; -+ -+ private: -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+ protected: -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ CustomMmaPipelined( -+ typename Base::SharedStorageA& shared_storageA, -+ typename Base::SharedStorageB& shared_storageB, -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ) -+ : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storageA.ref(), thread_idx), -+ smem_iterator_B_(shared_storageB.ref(), thread_idx) { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ CUTLASS_DEVICE -+ CustomMmaPipelined( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage& st, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : CustomMmaPipelined( -+ st.operand_A, -+ st.operand_B, -+ thread_idx, -+ warp_idx, -+ lane_idx) {} -+ -+ CUTLASS_DEVICE -+ bool set_prologue_done(bool value) { -+ // NOT IMPLEMENTED FOR PIPELINED -+ } -+ -+ CUTLASS_DEVICE -+ bool set_zero_outside_bounds(bool value) { -+ // NOT NEEDED FOR PIPELINED -+ // shared memory will always be zero-filled -+ } -+ -+ template -+ CUTLASS_DEVICE static void prologue( -+ typename Base::SharedStorage& shared_storage, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ int thread_idx, -+ int problem_size_k) { -+ prologue( -+ shared_storage.operand_A, -+ shared_storage.operand_B, -+ iterator_A, -+ iterator_B, -+ thread_idx, -+ problem_size_k); -+ } -+ -+ template -+ CUTLASS_DEVICE static void prologue( -+ typename Base::SharedStorageA& shared_storageA, -+ typename Base::SharedStorageB& shared_storageB, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ int thread_idx, -+ int problem_size_k) { -+ // NOT IMPLEMENTED FOR PIPELINED -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC& accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const& src_accum, ///< source accumulator tile -+ TransformA transform_A = -+ TransformA(), ///< transformation applied to A fragment -+ TransformB transform_B = -+ TransformB()) { ///< transformation applied to B fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentA tb_frag_A; -+ FragmentB tb_frag_B; -+ -+ tb_frag_A.clear(); -+ tb_frag_B.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpFragmentA warp_frag_A[2]; -+ WarpFragmentB warp_frag_B[2]; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Avoid reading out of bounds -+ iterator_A.clear_mask(gemm_k_iterations <= 1); -+ iterator_B.clear_mask(gemm_k_iterations <= 1); -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* -+ // issuing shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } else { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -+ -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A.clear_mask(gemm_k_iterations <= 2); -+ iterator_B.clear_mask(gemm_k_iterations <= 2); -+ } -+ -+ warp_mma( -+ accum, -+ warp_frag_A[warp_mma_k % 2], -+ warp_frag_B[warp_mma_k % 2], -+ accum); -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h -new file mode 100644 -index 0000000..1930717 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h -@@ -0,0 +1,295 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/arch/mma.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Some helper functions -+//////////////////////////////////////////////////////////////////////////////// -+#define DISPATCH_TYPES(tensor, func) \ -+ { \ -+ if (query.scalar_type() == at::ScalarType::Float) { \ -+ using scalar_t = float; \ -+ func(); \ -+ } else if (query.scalar_type() == at::ScalarType::Half) { \ -+ using scalar_t = cutlass::half_t; \ -+ func(); \ -+ } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ -+ using scalar_t = cutlass::bfloat16_t; \ -+ func(); \ -+ } else { \ -+ TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ -+ } \ -+ } -+ -+#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ -+ { \ -+ if (BOOL_V) { \ -+ constexpr bool BOOL_NAME = true; \ -+ F(); \ -+ } else { \ -+ constexpr bool BOOL_NAME = false; \ -+ F(); \ -+ } \ -+ } -+#define DISPATCH_ARCHTAG(CC, func) \ -+ { \ -+ if (CC >= 80) { \ -+ using ArchTag = cutlass::arch::Sm80; \ -+ func(); \ -+ } else if (CC >= 75) { \ -+ using ArchTag = cutlass::arch::Sm75; \ -+ func(); \ -+ } else if (CC >= 70) { \ -+ using ArchTag = cutlass::arch::Sm70; \ -+ func(); \ -+ } else if (CC >= 50) { \ -+ using ArchTag = cutlass::arch::Sm50; \ -+ func(); \ -+ } else { \ -+ TORCH_CHECK( \ -+ false, \ -+ "Your device is too old. We require compute capability >= 50"); \ -+ } \ -+ } -+ -+#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ -+ TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ -+ TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ -+ TORCH_CHECK(TENSOR.is_contiguous()); -+ -+#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ -+ TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ -+ TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ -+ TORCH_CHECK( \ -+ TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); -+ -+#ifdef HAS_PYTORCH -+#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ -+ TORCH_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") -+#define XFORMERS_CHECK TORCH_CHECK -+#elif defined(__CUDACC_RTC__) -+#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ -+ if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ -+ return false; \ -+ } -+#define XFORMERS_CHECK(COND, ERR) \ -+ if (!(COND)) { \ -+ return false; \ -+ } -+#else -+#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ -+ if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ -+ std::cerr << #PTR " is not correctly aligned\n"; \ -+ return false; \ -+ } -+#define XFORMERS_CHECK(COND, ERR) \ -+ if (!(COND)) { \ -+ std::cerr << #COND " failed\n"; \ -+ return false; \ -+ } -+#endif -+ -+#define ASSIGN_CHECK_OVERFLOW(A, B) \ -+ { \ -+ A = B; \ -+ TORCH_CHECK( \ -+ B < cutlass::platform::numeric_limits::max(), \ -+ #B " overflows"); \ -+ } -+ -+namespace gemm_kernel_utils { -+ -+#ifdef HAS_PYTORCH -+template -+struct TypeTraits; -+ -+template <> -+struct TypeTraits { -+ using scalar_t = cutlass::half_t; -+ -+ static constexpr __host__ at::ScalarType atScalarType() { -+ return at::ScalarType::Half; -+ } -+ template -+ static __host__ at::PackedTensorAccessor32 packed_accessor( -+ at::Tensor const& tensor) { -+ return at::PackedTensorAccessor32( -+ (scalar_t*)(tensor.data_ptr()), -+ tensor.sizes().data(), -+ tensor.strides().data()); -+ } -+}; -+ -+template <> -+struct TypeTraits { -+ using scalar_t = cutlass::bfloat16_t; -+ -+ static constexpr __host__ at::ScalarType atScalarType() { -+ return at::ScalarType::BFloat16; -+ } -+ template -+ static __host__ at::PackedTensorAccessor32 packed_accessor( -+ at::Tensor const& tensor) { -+ return at::PackedTensorAccessor32( -+ (scalar_t*)(tensor.data_ptr()), -+ tensor.sizes().data(), -+ tensor.strides().data()); -+ } -+}; -+ -+template <> -+struct TypeTraits { -+ using scalar_t = float; -+ -+ static constexpr __host__ at::ScalarType atScalarType() { -+ return at::ScalarType::Float; -+ } -+ template -+ static __host__ at::PackedTensorAccessor32 packed_accessor( -+ at::Tensor const& tensor) { -+ return tensor.packed_accessor32(); -+ } -+}; -+#endif -+ -+template -+constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { -+ return (n + m - 1) / m; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Determine the type of GEMM we do (TensorCores or not, Shapes ...) -+// TODO: Maybe we could rely on Cutlass's DefaultGemm templates -+//////////////////////////////////////////////////////////////////////////////// -+ -+// Fallback to Simt (FMA on cuda cores) if not in a special case below -+template -+struct DefaultGemmType { -+ static constexpr int ThreadK = 8; -+ static constexpr int WarpK = 8; -+ static constexpr int kMinimumAlignment = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using OpClass = cutlass::arch::OpClassSimt; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+}; -+ -+// Specialization for tensorcores with f32 -+template -+struct DefaultGemmType< -+ ArchTag, -+ float, -+ typename cutlass::platform::enable_if< -+ ArchTag::kMinComputeCapability >= 80>::type> { -+ static constexpr int ThreadK = 32; -+ static constexpr int WarpK = 32; -+ static constexpr int kMinimumAlignment = 4; -+ using OpClass = cutlass::arch::OpClassTensorOp; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Operator = cutlass::arch::OpMultiplyAdd; // FastF32; -+}; -+ -+// Specialization for tensorcores with f16/bf16 - Sm75+ -+template -+struct DefaultGemmType< -+ ArchTag, -+ scalar_t, -+ typename cutlass::platform::enable_if< -+ ArchTag::kMinComputeCapability >= 75 && -+ cutlass::sizeof_bits::value == 16>::type> { -+ static constexpr int ThreadK = 32; -+ static constexpr int WarpK = 32; -+ static constexpr int kMinimumAlignment = 4; -+ using OpClass = cutlass::arch::OpClassTensorOp; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+}; -+ -+// Specialization for tensorcores with f16 - Volta -+template <> -+struct DefaultGemmType { -+ static constexpr int ThreadK = 32; -+ static constexpr int WarpK = 32; -+ static constexpr int kMinimumAlignment = 2; -+ using OpClass = cutlass::arch::OpClassTensorOp; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+}; -+ -+// Enables to do -+// `auto x = kCondition ? fa(arg) : fb(arg)` -+// when `fa` and `fb` have different types -+template -+struct call_conditional; -+ -+template -+struct call_conditional { -+ template -+ static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -+ -> decltype(ta(arg)) { -+ return ta(arg); -+ } -+}; -+ -+template -+struct call_conditional { -+ template -+ static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -+ -> decltype(tb(arg)) { -+ return tb(arg); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Mark a variable as warp-uniform - enables some compiler optimizations -+// The cheapest way to do it is just to broadcast it from lane 0 -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_DEVICE int32_t warp_uniform(int32_t value) { -+ return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0); -+} -+ -+template -+CUTLASS_DEVICE T* warp_uniform(T* ptr) { -+ struct { -+ union { -+ T* ptr; -+ uint32_t asInt[2]; -+ }; -+ } p; -+ p.ptr = ptr; -+ p.asInt[0] = warp_uniform(p.asInt[0]); -+ p.asInt[1] = warp_uniform(p.asInt[1]); -+ return p.ptr; -+} -+} // namespace gemm_kernel_utils -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h -new file mode 100644 -index 0000000..298876e ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h -@@ -0,0 +1,752 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue iterator that supports prefetching -+ -+ Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load and store output tile from global memory in -+/// epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | -+/// ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ bool ScatterD = false, ///< Scatter D operand or not -+ bool UseCUDAStore = false> -+class PredicatedTileIteratorPrefetch { -+ public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Count::kTile; -+ -+ static_assert( -+ ThreadMap::Iterations::kRow > 0, -+ "ThreadMap::Iterations::kRow must be > 0"); -+ static_assert( -+ ThreadMap::Iterations::kGroup > 0, -+ "ThreadMap::Iterations::kGroup must be > 0"); -+ static_assert( -+ ThreadMap::Iterations::kCluster > 0, -+ "ThreadMap::Iterations::kCluster must be > 0"); -+ static_assert( -+ ThreadMap::Iterations::kColumn > 0, -+ "ThreadMap::Iterations::kColumn must be > 0"); -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * -+ ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorParams { -+ using Base = PredicatedTileIteratorParams; -+ -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : PredicatedTileIteratorParams( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ make_OutputTileThreadMapDesc()) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const& base) : Base(base) {} -+ }; -+ -+ /// Mask object -+ struct Mask { -+ static int const kCount = ThreadMap::Iterations::kColumn; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ PredicatedTileIteratorParams params_; -+ -+ /// Byte-level pointer -+ uint8_t* byte_pointer_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_column_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have -+ /// been computed) -+ Index thread_start_row_; -+ -+ /// A thread's starting column -+ Index thread_start_column_; -+ -+ /// Internal state counter -+ int state_[3]; -+ -+ /// Scatter indices -+ int const* indices_; -+ -+ // -+ // Static asserts about internal strides -+ // -+ -+ static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); -+ static_assert( -+ sizeof(PredicatedTileIteratorParams::stride) == 8, -+ "Expected 64b strides"); -+ -+ private: -+ // -+ // Methods -+ // -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIteratorPrefetch( -+ PredicatedTileIteratorParams const& params, -+ Element* pointer, -+ TensorCoord extent, -+ int thread_idx, -+ TensorCoord threadblock_offset = TensorCoord(), -+ int const* indices = nullptr) -+ : params_(params), indices_(indices) { -+ TensorCoord thread_offset = -+ ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ extent_row_ = extent.row(); -+ extent_column_ = extent.column(); -+ -+ thread_start_row_ = thread_offset.row(); -+ thread_start_column_ = thread_offset.column(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { -+ mask_.predicates[c] = -+ ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < -+ extent.column()); -+ } -+ -+ // Null pointer performs no accesses -+ if (!pointer) { -+ mask_.clear(); -+ } -+ -+ if (ScatterD && !indices) { -+ mask_.clear(); -+ } -+ -+ // Initialize pointer -+ byte_pointer_ = reinterpret_cast(pointer) + -+ LongIndex(thread_offset.row()) * LongIndex(params_.stride) + -+ LongIndex(thread_offset.column()) * sizeof(AccessType) / -+ kElementsPerAccess; -+ -+ if (ScatterD) { -+ byte_pointer_ = reinterpret_cast(pointer) + -+ LongIndex(thread_offset.column()) * sizeof(AccessType) / -+ kElementsPerAccess; -+ } -+ -+ // Initialize internal state counter -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_DEVICE -+ void prefetch_all() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < kIterations; ++iter) { -+ prefetch(); -+ ++(*this); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void prefetch() { -+ uint8_t* byte_pointer = byte_pointer_; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; -+ ++cluster) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ int row_offset = row * ThreadMap::Delta::kRow + -+ group * ThreadMap::Delta::kGroup + -+ cluster * ThreadMap::Delta::kCluster; -+ -+ AccessType* memory_pointer = -+ reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; -+ ++column) { -+ // on windows using unsigned long here gives the error -+ // error: asm operand type size(4) does not match -+ // type/size implied by constraint 'l' -+ uint64_t addr = (uint64_t)( -+ (void*)&memory_pointer -+ [column * ThreadMap::Delta::kColumn / kElementsPerAccess]); -+ asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ if (!ScatterD) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { -+ uint8_t* byte_pointer = byte_pointer_; -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; -+ ++cluster) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ int frag_row_idx = -+ (row + -+ ThreadMap::Iterations::kRow * -+ (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow + -+ group * ThreadMap::Delta::kGroup + -+ cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType* memory_pointer = -+ reinterpret_cast(byte_pointer + byte_offset); -+ -+ if (ScatterD && row_guard) { -+ assert(indices_); -+ -+ memory_pointer = reinterpret_cast( -+ byte_pointer + byte_offset + -+ LongIndex(indices_[row_offset + thread_start_row_]) * -+ LongIndex(params_.stride)); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; -+ ++column) { -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load( -+ frag_ptr -+ [frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void*)&memory_pointer -+ [column * ThreadMap::Delta::kColumn / kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ if (!ScatterD) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) const { -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const { -+ uint8_t* byte_pointer = byte_pointer_; -+ AccessType const* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; -+ ++cluster) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ int frag_row_idx = -+ (row + -+ ThreadMap::Iterations::kRow * -+ (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow + -+ group * ThreadMap::Delta::kGroup + -+ cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType* memory_pointer = -+ reinterpret_cast(byte_pointer + byte_offset); -+ -+ if (ScatterD && row_guard) { -+ assert(indices_); -+ -+ memory_pointer = reinterpret_cast( -+ byte_pointer + byte_offset + -+ LongIndex(indices_[row_offset + thread_start_row_]) * -+ LongIndex(params_.stride)); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; -+ ++column) { -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ if (UseCUDAStore) { -+ if (guard) { -+ memory_pointer -+ [column * ThreadMap::Delta::kColumn / kElementsPerAccess] = -+ frag_ptr -+ [frag_row_idx * ThreadMap::Iterations::kColumn + -+ column]; -+ } -+ } else { -+ cutlass::arch::global_store( -+ frag_ptr -+ [frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void*)&memory_pointer -+ [column * ThreadMap::Delta::kColumn / kElementsPerAccess], -+ guard); -+ } -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ if (!ScatterD) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) const { -+ store_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void downsample_load_with_byte_offset( -+ Fragment& frag, -+ int64_t byte_offset, -+ int convolution_P, -+ int convolution_Q, -+ int add_P, -+ int add_Q, -+ int problem_N) const { -+ uint8_t* byte_pointer = byte_pointer_; -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; -+ ++cluster) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ int frag_row_idx = -+ (row + -+ ThreadMap::Iterations::kRow * -+ (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow + -+ group * ThreadMap::Delta::kGroup + -+ cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ int output_row = row_offset + thread_start_row_; -+ int output_N = output_row / (convolution_P * convolution_Q); -+ int output_PQ = output_row % (convolution_P * convolution_Q); -+ int output_P = output_PQ / convolution_Q; -+ int output_Q = output_PQ % convolution_Q; -+ -+ int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + -+ (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; -+ -+ int64_t byte_offset = -+ (input_row - output_row) * problem_N * sizeof(float); -+ -+ AccessType* memory_pointer = -+ reinterpret_cast(byte_pointer + byte_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; -+ ++column) { -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load( -+ frag_ptr -+ [frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void*)&memory_pointer -+ [column * ThreadMap::Delta::kColumn / kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void upsample_load_with_byte_offset( -+ Fragment& frag, -+ int64_t byte_offset, -+ int convolution_P, -+ int convolution_Q, -+ int add_P, -+ int add_Q, -+ int problem_N) const { -+ uint8_t* byte_pointer = byte_pointer_; -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; -+ ++cluster) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ int frag_row_idx = -+ (row + -+ ThreadMap::Iterations::kRow * -+ (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow + -+ group * ThreadMap::Delta::kGroup + -+ cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ int output_row = row_offset + thread_start_row_; -+ int output_N = output_row / (convolution_P * convolution_Q); -+ int output_PQ = output_row % (convolution_P * convolution_Q); -+ int output_P = output_PQ / convolution_Q; -+ int output_Q = output_PQ % convolution_Q; -+ int row_add_P = add_P; -+ int row_add_Q = add_Q; -+ if (output_P > convolution_P - 2) -+ row_add_P = 0; -+ if (output_Q > convolution_Q - 2) -+ row_add_Q = 0; -+ -+ int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + -+ ((output_P + row_add_P) / 2) * (convolution_Q / 2) + -+ (output_Q + row_add_Q) / 2; -+ -+ int64_t byte_offset = -+ (input_row - output_row) * problem_N * sizeof(float); -+ -+ AccessType* memory_pointer = -+ reinterpret_cast(byte_pointer + byte_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; -+ ++column) { -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load( -+ frag_ptr -+ [frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void*)&memory_pointer -+ [column * ThreadMap::Delta::kColumn / kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ MatrixCoord thread_start() const { -+ return MatrixCoord(thread_start_row_, thread_start_column_); -+ } -+ -+ /// Need to get the thread start row from the tile iterator -+ CUTLASS_DEVICE -+ int32_t thread_start_row() const { -+ return thread_start_row_; -+ } -+ -+ /// Need to get the thread start row from the tile iterator -+ CUTLASS_DEVICE -+ int32_t thread_start_column() const { -+ return thread_start_column_; -+ } -+ -+ /// Extent of the matrix in rows -+ CUTLASS_DEVICE -+ Index extent_row() const { -+ return extent_row_; -+ } -+ -+ /// Extent of the matrix in columns -+ CUTLASS_DEVICE -+ Index extent_column() const { -+ return extent_column_; -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorPrefetch& operator++() { -+ ++state_[0]; -+ -+ if (!ScatterD) { -+ byte_pointer_ += params_.advance_row; -+ } -+ -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ -+ if (state_[0] == ThreadMap::Count::kRow) { -+ state_[0] = 0; -+ ++state_[1]; -+ byte_pointer_ += params_.advance_group; -+ -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ state_[1] = 0; -+ ++state_[2]; -+ byte_pointer_ += params_.advance_cluster; -+ -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * -+ ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ byte_pointer_ += params_.advance_tile; -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask& mask) const { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const& mask) { -+ mask_ = mask; -+ } -+}; -+ -+template -+struct MakePrefetchableIterator { -+ using Iterator = PredicatedTileIteratorPrefetch< -+ typename IT::ThreadMap, -+ typename IT::Element>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/make_residual_last.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/make_residual_last.h -new file mode 100644 -index 0000000..e6b5d58 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/make_residual_last.h -@@ -0,0 +1,97 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "predicated_tile_access_iterator_residual_last.h" -+#include "predicated_tile_iterator_residual_last.h" -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+template -+struct MakeIteratorResidualLast; -+ -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ int AccessSize, -+ bool Gather> -+struct MakeIteratorResidualLast> { -+ using Iterator = PredicatedTileIteratorResidualLast< -+ Shape, -+ Element, -+ Layout, -+ AdvanceRank, -+ ThreadMap, -+ AccessSize, -+ Gather>; -+}; -+ -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ typename AccessType, -+ bool Gather> -+struct MakeIteratorResidualLast> { -+ using Iterator = PredicatedTileAccessIteratorResidualLast< -+ Shape, -+ Element, -+ Layout, -+ AdvanceRank, -+ ThreadMap, -+ AccessType, -+ Gather>; -+}; -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h -new file mode 100644 -index 0000000..b9c38cc ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h -@@ -0,0 +1,2115 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of tiles -+ from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. The first tile -+ this iterator visits maybe partial, then the remaining tiles are complete. -+ So, we only need to compute the predicates twice, once before the first tile -+ and once for the remaining full tiles which can share the same predicates. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileAccessIteratorResidualLast -+/// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ typename AccessType, -+ bool Gather = false> -+class PredicatedTileAccessIteratorResidualLast; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear -+/// data. -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_, -+ bool Gather> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::PitchLinear, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ Gather> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< -+ Shape, -+ Element, -+ Layout, -+ AdvanceRank, -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = -+ ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert( -+ !(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ using Mask = typename UnderlyingPredicates::Mask; -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileAccessIteratorParams { -+ using Base = PredicatedTileAccessIteratorParams; -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : Base( -+ layout.stride(0), -+ MakePredicatedTileAccessIteratorDesc< -+ Shape, -+ Element, -+ Layout, -+ kAdvanceRank, -+ ThreadMap>()()) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const& base) : Base(base) {} -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char*; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ UnderlyingPredicates the_predicates; -+ Mask residual_tile_mask; -+ -+ /// Parameters object with precomputed internal state -+ Params const& params_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ /// Below is used when Gather is turned on. We need to record strided_offset -+ /// and contiguous_offset seperated to compute the offset by using -+ /// -+ /// offset = contiguous_offset + indices[strided_offset] -+ /// -+ -+ /// Gather indices -+ int const* indices_; -+ -+ Index gather_offset_strided; -+ -+ private: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent, -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ the_predicates.compute_predicates_(extent, is_steady_state); -+ } -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ /// Gather indices -+ int const* indices = nullptr) -+ : params_(params), -+ pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ the_predicates(extent), -+ indices_(indices) { -+ the_predicates.set_predicates(thread_id, threadblock_offset); -+ the_predicates.get_mask(residual_tile_mask); -+ -+ // Working around a weird compiler bug happening on P100 for the backward. -+ // I've seen together: the_predicates.predicates_[0] = 14 (instead of 15) -+ // residual_tile_mask[0] = 15 (correct) -+ // -+ // Adding prints when the value is calculated (in `compute_predicates_`) -+ // sometimes removes the bug. The consequence is that we skip some -+ // element of a tensor, leading to wrong results -+ // Setting `compute_predicates_`'s second argument (`is_steady_state`) to -+ // true also seems to get rid of the bug - at the cost of twice as many -+ // comparisons. -+#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700) -+ constexpr bool kWorkAroundCompilerBug = false; -+#else -+ constexpr bool kWorkAroundCompilerBug = true; -+#endif -+ the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ -+ if (!Gather) { -+ add_pointer_offset(layout(the_predicates.thread_offset_)); -+ } else { -+ gather_offset_strided = the_predicates.thread_offset_.strided(); -+ add_pointer_offset( -+ layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); -+ } -+ } -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ the_predicates.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool is_residual_tile) { -+ if (is_residual_tile) { -+ the_predicates.set_mask(residual_tile_mask); -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += sizeof_bits::value * pointer_offset / 8; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ if (!Gather) { -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } else { -+ add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); -+ gather_offset_strided += Shape::kStrided * tile_offset.strided(); -+ } -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ if (Gather) { -+ assert(indices_); -+ -+ if (!valid()) { -+ return nullptr; -+ } -+ -+ LongIndex contiguous_offset = the_predicates.iteration_contiguous_ * -+ (ThreadMap::Delta::kContiguous * sizeof_bits::value / -+ 8) + -+ the_predicates.iteration_vector_; -+ int strided_index = gather_offset_strided + -+ the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; -+ -+ LongIndex strided_offset = indices_[strided_index] * -+ LongIndex(params_.stride_) * sizeof_bits::value / 8; -+ -+ return reinterpret_cast( -+ pointer_ + contiguous_offset + strided_offset); -+ } -+ -+ return reinterpret_cast( -+ pointer_ + -+ the_predicates.iteration_contiguous_ * -+ (ThreadMap::Delta::kContiguous * -+ sizeof_bits::value) / -+ 8) + -+ the_predicates.iteration_vector_; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ the_predicates.operator++(); -+ -+ ++the_predicates.iteration_vector_; -+ if (the_predicates.iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ -+ the_predicates.iteration_vector_ = 0; -+ ++the_predicates.iteration_contiguous_; -+ -+ if (the_predicates.iteration_contiguous_ < -+ ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ the_predicates.iteration_contiguous_ = 0; -+ ++the_predicates.iteration_strided_; -+ -+ if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ if (!Gather) { -+ pointer_ += params_.inc_strided_; -+ } -+ -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ the_predicates.iteration_strided_ = 0; -+ -+ if (!Gather) { -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, -+ // this subtraction as well as the subsequent integer addition are both -+ // elided by the compiler. -+ pointer_ -= params_.inc_advance_; -+ } -+ -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ the_predicates.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ the_predicates.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ the_predicates.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ the_predicates.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ return the_predicates.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major -+/// data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_, -+ bool Gather> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::ColumnMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ Gather> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessType, -+ Gather>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))){}; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ ///< Precomputed parameters object -+ Params const& params, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row(), -+ threadblock_offset.column()), -+ indices) {} -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iterator_.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major -+/// data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_, -+ bool Gather> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::RowMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ Gather> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessType, -+ Gather>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))){}; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ ///< Precomputed parameters object -+ Params const& params, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ /// Gather indices -+ int const* indices = nullptr) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column(), -+ threadblock_offset.row()), -+ indices) {} -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iterator_.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 -+/// data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::AffineRankN<2>, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRankN<2>; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< -+ Shape, -+ Element, -+ layout::PitchLinear, -+ AdvanceRank, -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = -+ ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert( -+ !(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingPredicates::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ private: -+ /// stride of pitch-linear layout (units of Element) -+ Coord stride_; -+ /// amount (in byte) to increment pointer to move to next access along -+ /// contiguous dimension -+ LongIndex inc_contiguous_; -+ /// amount (in byte) to increment pointer from first access of current -+ /// contiguous dimension to first access of next one. -+ LongIndex inc_strided_; -+ /// amount (in byte) to increment pointer from last access of current -+ /// contiguous dimension to first access of next one. -+ LongIndex inc_next_strided_; -+ /// amount (in byte) to increment pointer from last access to first access -+ /// of next tile -+ LongIndex inc_next_; -+ /// amount (in byte) to increment pointer from first access of current tile -+ /// to first access of next tile -+ LongIndex inc_advance_; -+ -+ public: -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() -+ : stride_(0), -+ inc_contiguous_(0), -+ inc_strided_(0), -+ inc_next_(0), -+ inc_advance_(0) {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : stride_({layout.stride(0), layout.stride(1)}) { -+ inc_contiguous_ = -+ (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * -+ sizeof_bits::value / 8; -+ -+ inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * -+ sizeof_bits::value / 8; -+ -+ inc_next_strided_ = inc_strided_ - -+ LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; -+ -+ if (kAdvanceRank) { -+ // advance along strided dimension -+ inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) * -+ sizeof_bits::value / 8; -+ } else { -+ // advance along contiguous dimension -+ inc_advance_ = -+ Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; -+ } -+ -+ inc_next_ = inc_advance_ - -+ LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - -+ LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; -+ }; -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char*; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params const& params_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ UnderlyingPredicates the_predicates; -+ Mask residual_tile_mask; -+ -+ private: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent, -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ the_predicates.compute_predicates_(extent, is_steady_state); -+ } -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ ///< Precomputed parameters object -+ Params const& params, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : params_(params), -+ pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ the_predicates(extent) { -+ the_predicates.set_predicates(thread_id, threadblock_offset); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(the_predicates.thread_offset_)); -+ } -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ the_predicates.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool is_residual_tile) { -+ if (is_residual_tile) { -+ the_predicates.set_mask(residual_tile_mask); -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += sizeof_bits::value * pointer_offset / 8; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); -+ pointer_ += Shape::kContiguous * tile_offset[0]; -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); -+ pointer_ += Shape::kStrided * tile_offset[1]; -+ } -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(pointer_) + -+ the_predicates.iteration_vector_; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ the_predicates.operator++(); -+ ++the_predicates.iteration_vector_; -+ if (the_predicates.iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ -+ the_predicates.iteration_vector_ = 0; -+ ++the_predicates.iteration_contiguous_; -+ -+ if (the_predicates.iteration_contiguous_ < -+ ThreadMap::Iterations::kContiguous) { -+ pointer_ += params_.inc_contiguous_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ the_predicates.iteration_contiguous_ = 0; -+ ++the_predicates.iteration_strided_; -+ -+ if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ pointer_ += params_.inc_next_strided_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ the_predicates.iteration_strided_ = 0; -+ -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, this -+ // subtraction as well as the subsequent integer addition are both elided by -+ // the compiler. -+ pointer_ -= params_.inc_advance_; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ the_predicates.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ the_predicates.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ the_predicates.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ the_predicates.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return the_predicates.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 -+/// column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::AffineRank2ColumnMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::AffineRankN<2>, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ ///< Precomputed parameters object -+ Params const& params, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row(), -+ threadblock_offset.column())) {} -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iterator_.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ iterator_.add_tile_offset( -+ make_Coord(tile_offset.row(), tile_offset.column())); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2 -+/// row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::AffineRank2RowMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::AffineRankN<2>, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ ///< Precomputed parameters object -+ Params const& params, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iterator_.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ iterator_.add_tile_offset( -+ make_Coord(tile_offset.column(), tile_offset.row())); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major -+/// interleaved data. It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_, -+ int InterleavedK> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::ColumnMajorInterleaved, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< -+ layout::PitchLinearShape< -+ Shape::kRow * kInterleavedK, -+ Shape::kColumn / kInterleavedK>, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord( -+ extent.row() * kInterleavedK, -+ extent.column() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row() * kInterleavedK, -+ threadblock_offset.column() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iterator_.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major -+/// interleaved data. -+// It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_, -+ int InterleavedK> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::RowMajorInterleaved, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::RowMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< -+ layout::PitchLinearShape< -+ Shape::kColumn * kInterleavedK, -+ Shape::kRow / kInterleavedK>, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord( -+ extent.column() * kInterleavedK, -+ extent.row() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column() * kInterleavedK, -+ threadblock_offset.row() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iterator_.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h -new file mode 100644 -index 0000000..4bb96a1 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h -@@ -0,0 +1,2120 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 -+ tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. The first tile -+ this iterator visits maybe partial, then the remaining tiles are complete. -+ So, we only need to compute the predicates twice, once before the first tile -+ and once for the remaining full tiles which can share the same predicates. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/memory.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileIteratorResidualLast -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+/// Regular tile iterator using a precomputed control structure to minimize -+/// register liveness and integer arithmetic. -+/// -+/// Layout is assumed to be invariant at the time the precomputed "Params" -+/// object is constructed. -+/// -+/// Base pointer and tensor extents may be specified at the time the iterator is -+/// constructed. Subsequently, they are assumed to be immutable. -+/// -+/// Adding a logical coordinate offset may be performed at the time the iterator -+/// is constructed. Subsequent additions to logical coordinate offset may be -+/// performed but are relatively expensive. -+/// -+/// Visitation order is intended to first visit a "residual" tile that may be -+/// partially full in both the advance dimension and the steady-state dimension. -+/// This is assumed to be the last tile in the iteration sequence. Advancing an -+/// iterator that has just been constructed moves to the first tile that is full -+/// in the advance dimension and recomputes predicates. Subsequent accesses may -+/// be performed without updating internal predicates and are efficient in terms -+/// of live register state and pointer arithmetic instructions. -+/// -+/// To be efficient, this assumes the iterator will be dereferenced and advanced -+/// at least once outside any looping structure to minimize integer arithmetic. -+/// -+/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to -+/// dereferencing the iterator. -+/// -+/// -+/// Example: -+/// -+/// An efficient pipeline structure may be constructed as follows: -+/// -+// template -+// __global__ void kernel( -+// typename Iterator::Params params, -+// typename Iterator::Element *ptr, -+// TensorCoord extent) { -+// -+// typename Iterator::Fragment fragment; -+// -+// TensorCoord threadblock_offset(0, 0); -+// -+// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -+// -+// -+// fragment = *iter; // load "residue" tile first -+// ++iter; // advance to first "steady state" tile and update -+// internal masks -+// -+// -+// #pragma unroll -+// for (int i = Remaining - 1; i >= 0; --i) { -+// -+// f(fragment); -+// -+// if (!i) { -+// iter.clear_mask(); // light-weight operation to clear masks - -+// subsequent loads become NO-OPs. -+// } -+// -+// fragment = *iter; // load tile during "steady state" phase -+// ++iter; // advance to next tile - lightweight due to -+// steady-state masks -+// } -+// } -+// -+// void host(TensorView view) { -+// -+// using Iterator = -+// transform::threadblock::PredicatedTileIteratorResidualLast; -+// -+// typename Iterator::Params params(view.layout()); -+// -+// kernel(params, view.data()); -+// } -+/// -+/// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ int AccessSize = ThreadMap::kElementsPerAccess, -+ bool Gather = false> -+class PredicatedTileIteratorResidualLast; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ bool Gather> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::PitchLinear, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ Gather> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ /// Type used for internal memory accesses -+ using AccessType = AlignedArray< -+ Element, -+ AccessSize, -+ (AccessSize * sizeof_bits::value / 8)>; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< -+ Shape, -+ Element, -+ Layout, -+ kAdvanceRank, -+ ThreadMap, -+ AccessType, -+ Gather>; -+ -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ using Base = typename TileAccessIterator::Params::Base; -+ -+ friend PredicatedTileIteratorResidualLast; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) : params_(layout) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const& base) : params_(base) {} -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char*; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ /// Gather indices -+ int const* indices = nullptr) -+ : address_iterator_( -+ params.params_, -+ pointer, -+ extent, -+ thread_id, -+ threadblock_offset, -+ indices) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset({0, 1}); -+ else -+ address_iterator_.add_tile_offset({1, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ address_iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ address_iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ address_iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ address_iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ address_iterator_.get_mask(mask); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ load_with_byte_offset( -+ frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ int idx = v + -+ kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const* byte_ptr = -+ reinterpret_cast(address_iterator_.get()) + -+ byte_offset; -+ -+ AccessType const* access_ptr = -+ reinterpret_cast(byte_ptr); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, address_iterator_.valid()); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ store_with_byte_offset( -+ frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ int idx = v + -+ kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ char* byte_ptr = -+ reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType* access_ptr = reinterpret_cast(byte_ptr); -+ -+ if (address_iterator_.valid()) { -+ *access_ptr = frag_ptr[idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_byte_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ bool Gather> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::ColumnMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ Gather> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessSize, -+ Gather>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const& threadblock_offset, ///< Initial offset of threadblock -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row(), -+ threadblock_offset.column()), -+ indices) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ bool Gather> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::RowMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ Gather> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessSize, -+ Gather>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const& threadblock_offset, ///< Initial offset of threadblock -+ int const* indices = nullptr ///< Gather indices -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column(), -+ threadblock_offset.row()), -+ indices) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::AffineRankN<2>, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRankN<2>; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ /// Type used for internal memory accesses -+ using AccessType = AlignedArray< -+ Element, -+ AccessSize, -+ (AccessSize * sizeof_bits::value / 8)>; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< -+ Shape, -+ Element, -+ Layout, -+ kAdvanceRank, -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend PredicatedTileIteratorResidualLast; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) : params_(layout) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char*; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : address_iterator_( -+ params.params_, -+ pointer, -+ extent, -+ thread_id, -+ threadblock_offset) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset(make_Coord(0, 1)); -+ else -+ address_iterator_.add_tile_offset(make_Coord(1, 0)); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ address_iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ address_iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ address_iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ address_iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ address_iterator_.get_mask(mask); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ load_with_byte_offset( -+ frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ int idx = v + -+ kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const* byte_ptr = -+ reinterpret_cast(address_iterator_.get()) + -+ byte_offset; -+ -+ AccessType const* access_ptr = -+ reinterpret_cast(byte_ptr); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, address_iterator_.valid()); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ store_with_byte_offset( -+ frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ int idx = v + -+ kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ char* byte_ptr = -+ reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType* access_ptr = reinterpret_cast(byte_ptr); -+ -+ if (address_iterator_.valid()) { -+ *access_ptr = frag_ptr[idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_byte_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 -+/// column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::AffineRank2ColumnMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::AffineRankN<2>, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessSize>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const& threadblock_offset, ///< Initial offset of threadblock -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row(), -+ threadblock_offset.column())) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 -+/// row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::AffineRank2RowMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::AffineRankN<2>, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessSize>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const& threadblock_offset, ///< Initial offset of threadblock -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for interleaved data. -+/// It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ int InterleavedK> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::ColumnMajorInterleaved, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileIteratorResidualLast< -+ layout::PitchLinearShape< -+ Shape::kRow * kInterleavedK, -+ Shape::kColumn / kInterleavedK>, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessSize>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord( -+ extent.row() * kInterleavedK, -+ extent.column() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row() * kInterleavedK, -+ threadblock_offset.column() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 -+/// data. It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ int InterleavedK> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::RowMajorInterleaved, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::RowMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileIteratorResidualLast< -+ layout::PitchLinearShape< -+ Shape::kColumn * kInterleavedK, -+ Shape::kRow / kInterleavedK>, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessSize>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord( -+ extent.column() * kInterleavedK, -+ extent.row() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column() * kInterleavedK, -+ threadblock_offset.row() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h -new file mode 100644 -index 0000000..6f5eb3f ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h -@@ -0,0 +1,1108 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#ifdef HAS_PYTORCH -+#include -+#include -+#include -+#include -+#endif -+ -+#include -+#include -+ -+#include "cutlass/bfloat16.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+ -+#include "attention_scaling_coefs_updater.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/platform/platform.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "debug_utils.h" -+#include "epilogue_pipelined.h" -+#include "epilogue_rescale_output.h" -+#include "find_default_mma.h" -+#include "gemm_kernel_utils.h" -+#include "mma_from_smem.h" -+#include "transform/tile_smem_loader.h" -+ -+#include -+ -+using namespace gemm_kernel_utils; -+ -+namespace { -+template -+constexpr int getWarpsPerSm() { -+ return ( -+ Arch::kMinComputeCapability >= 80 && -+ !cutlass::platform::is_same::value -+ ? 16 -+ : 12); -+} -+} // namespace -+ -+template < -+ // The datatype of Q/K/V -+ typename scalar_t_, -+ // Architecture we are targeting (eg `cutlass::arch::Sm80`) -+ typename ArchTag, -+ // If Q/K/V are correctly aligned in memory and we can run a fast kernel -+ bool isAligned_, -+ int kQueriesPerBlock, -+ int kKeysPerBlock, -+ bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock` -+ > -+struct AttentionKernel { -+ using scalar_t = scalar_t_; -+ using accum_t = float; -+ using lse_scalar_t = float; -+ using output_t = scalar_t; -+ // Accumulator between 2 iterations -+ // Using `accum_t` improves perf on f16 at the cost of -+ // numerical errors -+ using output_accum_t = accum_t; -+ static constexpr bool kIsAligned = isAligned_; -+ static constexpr int32_t kAlignLSE = 32; // block size of backward -+ static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 && -+ cutlass::sizeof_bits::value == 16; -+ static constexpr bool kKeepOutputInRF = kSingleValueIteration; -+ static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && -+ !cutlass::platform::is_same::value; -+ -+ static_assert(kQueriesPerBlock % 32 == 0, ""); -+ static_assert(kKeysPerBlock % 32 == 0, ""); -+ static constexpr int kNumWarpsPerBlock = -+ kQueriesPerBlock * kKeysPerBlock / (32 * 32); -+ static constexpr int kWarpSize = 32; -+ -+ // Launch bounds -+ static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; -+ static constexpr int kMinBlocksPerSm = -+ getWarpsPerSm() / kNumWarpsPerBlock; -+ -+ struct Params { -+ // Input tensors -+ scalar_t* query_ptr; // [num_queries, num_heads, head_dim] -+ scalar_t* key_ptr; // [num_keys, num_heads, head_dim] -+ scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] -+ int32_t* cu_seqlens_q_ptr = nullptr; -+ int32_t* cu_seqlens_k_ptr = nullptr; -+ scalar_t* attn_mask_ptr = nullptr; // [num_queries, num_keys] -+ scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] -+ -+ // Output tensors -+ output_t* output_ptr; // [num_queries, num_heads, head_dim_value] -+ output_accum_t* -+ output_accum_ptr; // [num_queries, num_heads, head_dim_value] -+ lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null -+ float scale; -+ // Dimensions/strides -+ int32_t head_dim; -+ int32_t head_dim_value; -+ int32_t num_queries; -+ int32_t num_keys; -+ -+ bool causal; -+ bool no_bias_head_dim; -+ bool use_past; -+ -+ int32_t q_strideM; -+ int32_t k_strideM; -+ int32_t v_strideM; -+ int32_t attn_mask_strideM; -+ int32_t attn_bias_strideM; -+ -+ // Everything below is only used in `advance_to_block` -+ // and shouldn't use registers -+ int32_t q_strideH; -+ int32_t k_strideH; -+ int32_t v_strideH; -+ int32_t o_strideH; -+ int32_t attn_mask_strideH; -+ int32_t attn_bias_strideH; -+ int64_t q_strideB; -+ int64_t k_strideB; -+ int64_t v_strideB; -+ int64_t o_strideB; -+ int64_t attn_mask_strideB; -+ int64_t attn_bias_strideB; -+ int32_t num_batches; -+ int32_t num_heads; -+ -+ CUTLASS_HOST_DEVICE int32_t o_strideM() const { -+ return head_dim_value * num_heads; -+ } -+ -+ // Moves pointers to what we should process -+ // Returns "false" if there is no work to do -+ CUTLASS_DEVICE bool advance_to_block() { -+ auto batch_id = blockIdx.z; -+ auto head_id = blockIdx.y; -+ auto query_start = blockIdx.x * kQueriesPerBlock; -+ -+ auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; -+ -+ int64_t q_start, k_start; -+ // Advance to current batch - in case of different sequence lengths -+ if (cu_seqlens_q_ptr != nullptr) { -+ assert(cu_seqlens_k_ptr != nullptr); -+ if (cu_seqlens_q_ptr[batch_id] == -1) return false; -+ q_strideH = q_strideM * cu_seqlens_q_ptr[batch_id]; -+ if (!use_past) { -+ k_strideH = k_strideM * cu_seqlens_k_ptr[batch_id]; -+ v_strideH = v_strideM * cu_seqlens_k_ptr[batch_id]; -+ } -+ num_queries = cu_seqlens_q_ptr[batch_id]; -+ num_keys = cu_seqlens_k_ptr[batch_id]; -+ for (int i = 0; i < batch_id; i++) -+ { -+ if (cu_seqlens_q_ptr[i] == -1) continue; -+ query_ptr += cu_seqlens_q_ptr[i] * head_dim * num_heads; -+ output_ptr += cu_seqlens_q_ptr[i] * head_dim * num_heads; -+ if (!use_past) { -+ key_ptr += cu_seqlens_k_ptr[i] * head_dim * num_heads; -+ value_ptr += cu_seqlens_k_ptr[i] * head_dim * num_heads; -+ } -+ } -+ if (use_past) { -+ key_ptr += batch_id * k_strideB; -+ value_ptr += batch_id * v_strideB; -+ } -+ if (query_start >= num_queries) { -+ return false; -+ } -+ q_start = 0; -+ k_start = 0; -+ } else { -+ query_ptr += batch_id * q_strideB; -+ key_ptr += batch_id * k_strideB; -+ value_ptr += batch_id * v_strideB; -+ output_ptr += batch_id * o_strideB; -+ if (output_accum_ptr != nullptr) { -+ output_accum_ptr += batch_id * o_strideB; -+ } -+ q_start = 0; -+ k_start = 0; -+ } -+ if (attn_mask_ptr) { -+ attn_mask_ptr += batch_id * (attn_mask_strideB); -+ } -+ if (attn_bias_ptr) { -+ attn_bias_ptr += head_id * attn_bias_strideH; -+ } -+ // Advance to the current batch / head / query_start -+ query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; -+ key_ptr += k_start * k_strideM + head_id * k_strideH; -+ value_ptr += k_start * v_strideM + head_id * v_strideH; -+ output_ptr += int64_t(q_start + query_start) * o_strideM() + -+ head_id * o_strideH; -+ -+ if (output_accum_ptr != nullptr) { -+ output_accum_ptr += int64_t(q_start + query_start) * o_strideM() + -+ head_id * o_strideH; -+ } else { -+ // Accumulate directly in the destination buffer (eg for f32) -+ output_accum_ptr = (accum_t*)output_ptr; -+ } -+ if (logsumexp_ptr != nullptr) { -+ // lse[batch_id, head_id, query_start] -+ logsumexp_ptr += -+ batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; -+ } -+ -+ num_queries -= query_start; -+ if (causal) { -+ num_keys = cutlass::fast_min( -+ int32_t(query_start + kQueriesPerBlock), num_keys); -+ } -+ num_batches = 0; // no longer used after -+ -+ // Make sure the compiler knows these variables are the same on all -+ // the threads of the warp. -+ if (attn_mask_ptr) { -+ attn_mask_ptr = warp_uniform(attn_mask_ptr); -+ } -+ if (attn_bias_ptr) { -+ attn_bias_ptr = warp_uniform(attn_bias_ptr); -+ } -+ query_ptr = warp_uniform(query_ptr); -+ key_ptr = warp_uniform(key_ptr); -+ value_ptr = warp_uniform(value_ptr); -+ output_ptr = warp_uniform(output_ptr); -+ output_accum_ptr = warp_uniform(output_accum_ptr); -+ logsumexp_ptr = warp_uniform(logsumexp_ptr); -+ num_queries = warp_uniform(num_queries); -+ num_keys = warp_uniform(num_keys); -+ head_dim = warp_uniform(head_dim); -+ head_dim_value = warp_uniform(head_dim_value); -+ return true; -+ } -+ -+ __host__ dim3 getBlocksGrid() const { -+ return dim3( -+ ceil_div(num_queries, (int32_t)kQueriesPerBlock), -+ num_heads, -+ num_batches); -+ } -+ __host__ dim3 getThreadsGrid() const { -+ return dim3(kWarpSize, kNumWarpsPerBlock, 1); -+ } -+ }; -+ -+ struct MM0 { -+ /* -+ In this first matmul, we compute a block of `Q @ K.T`. -+ While the calculation result is still hot in registers, we update -+ `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value -+ into a shared-memory ("AccumulatorSharedStorage") that is used later as -+ operand A for the second matmul (see MM1) -+ */ -+ using GemmType = DefaultGemmType; -+ -+ using OpClass = typename GemmType::OpClass; -+ using DefaultConfig = -+ typename cutlass::gemm::device::DefaultGemmConfiguration< -+ OpClass, -+ ArchTag, -+ scalar_t, -+ scalar_t, -+ scalar_t, // ElementC -+ accum_t // ElementAccumulator -+ >; -+ static constexpr int kAlignmentA = -+ kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; -+ static constexpr int kAlignmentB = -+ kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; -+ using ThreadblockShape = cutlass::gemm:: -+ GemmShape; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; -+ using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< -+ scalar_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ kAlignmentA, -+ scalar_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ kAlignmentB, -+ accum_t, -+ cutlass::layout::RowMajor, // LayoutC, -+ OpClass, -+ ArchTag, // ArchTag -+ ThreadblockShape, // ThreadblockShape -+ WarpShape, // WarpShape -+ typename GemmType::InstructionShape, // InstructionShape -+ DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that -+ // uses too much smem -+ typename GemmType::Operator // Operator -+ >::DefaultMma; -+ using MmaCore = typename DefaultMma::MmaCore; -+ using IteratorA = typename DefaultMma::IteratorA; -+ using IteratorB = typename DefaultMma::IteratorB; -+ using Mma = typename DefaultMma::ThreadblockMma; -+ using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< -+ typename Mma::Operator::IteratorC, -+ accum_t, -+ kWarpSize>::Updater; -+ static_assert( -+ MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * -+ MmaCore::WarpCount::kK == -+ kNumWarpsPerBlock, -+ ""); -+ -+ // used for efficient load of mask tile Mij from global to shared memory -+ using MaskLoader = TileSmemLoader< -+ scalar_t, -+ cutlass::MatrixShape, -+ MmaCore::kThreads, -+ // input restriction: kv_len has to be a multiple of this value -+ 128 / cutlass::sizeof_bits::value>; -+ -+ // used for efficient load of mask tile Mij from global to shared memory -+ using BiasLoader = TileSmemLoader< -+ scalar_t, -+ cutlass::MatrixShape, -+ MmaCore::kThreads, -+ // input restriction: kv_len has to be a multiple of this value -+ 128 / cutlass::sizeof_bits::value>; -+ -+ // Epilogue to store to shared-memory in a format that we can use later for -+ // the second matmul -+ using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< -+ typename Mma::Operator::IteratorC, -+ typename Mma::Operator, -+ scalar_t, -+ WarpShape, -+ ThreadblockShape>; -+ using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; -+ }; -+ -+ struct MM1 { -+ /** -+ Second matmul: perform `attn @ V` where `attn` is the attention (not -+ normalized) and stored in shared memory -+ */ -+ using GemmType = DefaultGemmType; -+ -+ using OpClass = typename GemmType::OpClass; -+ using DefaultConfig = -+ typename cutlass::gemm::device::DefaultGemmConfiguration< -+ OpClass, -+ ArchTag, -+ scalar_t, -+ scalar_t, -+ output_accum_t, // ElementC -+ accum_t // ElementAccumulator -+ >; -+ static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem -+ static constexpr int kAlignmentB = -+ kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; -+ using ThreadblockShape = cutlass::gemm:: -+ GemmShape; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; -+ using InstructionShape = typename GemmType::InstructionShape; -+ -+ using LayoutB = cutlass::layout::RowMajor; -+ using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< -+ scalar_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ kAlignmentA, -+ scalar_t, // ElementB, -+ LayoutB, // LayoutB, -+ kAlignmentB, -+ output_accum_t, -+ cutlass::layout::RowMajor, // LayoutC, -+ accum_t, -+ OpClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ typename GemmType::InstructionShape, -+ typename DefaultConfig::EpilogueOutputOp, -+ void, // ThreadblockSwizzle - not used -+ DefaultConfig::kStages, -+ false, // SplitKSerial -+ typename GemmType::Operator>; -+ -+ using DefaultMmaFromSmem = -+ typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< -+ typename DefaultGemm::Mma, -+ typename MM0::AccumulatorSharedStorage>; -+ using Mma = typename DefaultMmaFromSmem::Mma; -+ using IteratorB = typename Mma::IteratorB; -+ using WarpCount = typename Mma::WarpCount; -+ static_assert( -+ WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, -+ ""); -+ -+ using DefaultEpilogue = typename DefaultGemm::Epilogue; -+ using OutputTileIterator = -+ typename cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename DefaultEpilogue::OutputTileIterator::ThreadMap, -+ output_t>; -+ using OutputTileIteratorAccum = -+ typename cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename DefaultEpilogue::OutputTileIterator::ThreadMap, -+ output_accum_t>; -+ -+ struct SharedStorageMM1 { -+ typename Mma::SharedStorage mm; -+ }; -+ }; -+ -+ static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; -+ static constexpr int64_t kAlignmentK = MM0::kAlignmentB; -+ static constexpr int64_t kAlignmentV = 1; -+ -+ // Shared storage - depends on kernel params -+ struct ScalingCoefs { -+ cutlass::Array m_prime; -+ cutlass::Array s_prime; -+ cutlass::Array mi; -+ }; -+ -+ struct SharedStorageEpilogueAtEnd : ScalingCoefs { -+ struct SharedStorageAfterMM0 { -+ // Everything here might be overwritten during MM0 -+ union { -+ typename MM0::MaskLoader::SmemTile mask; -+ typename MM0::BiasLoader::SmemTile bias; -+ typename MM0::AccumulatorSharedStorage si; -+ }; -+ typename MM1::SharedStorageMM1 mm1; -+ }; -+ -+ union { -+ typename MM0::Mma::SharedStorage mm0; -+ SharedStorageAfterMM0 after_mm0; -+ typename MM1::DefaultEpilogue::SharedStorage epilogue; -+ }; -+ -+ CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& -+ epilogue_shared_storage() { -+ return epilogue; -+ } -+ }; -+ -+ struct SharedStorageEpilogueInLoop : ScalingCoefs { -+ struct SharedStorageAfterMM0 { -+ // Everything here might be overwritten during MM0 -+ union { -+ typename MM0::MaskLoader::SmemTile mask; -+ typename MM0::BiasLoader::SmemTile bias; -+ typename MM0::AccumulatorSharedStorage si; -+ }; -+ typename MM1::SharedStorageMM1 mm1; -+ typename MM1::DefaultEpilogue::SharedStorage epilogue; -+ }; -+ -+ union { -+ typename MM0::Mma::SharedStorage mm0; -+ SharedStorageAfterMM0 after_mm0; -+ }; -+ -+ CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& -+ epilogue_shared_storage() { -+ return after_mm0.epilogue; -+ } -+ }; -+ -+ using SharedStorage = typename cutlass::platform::conditional< -+ kSingleValueIteration || kKeepOutputInRF, -+ SharedStorageEpilogueAtEnd, -+ SharedStorageEpilogueInLoop>::type; -+ -+ static bool __host__ check_supported(Params const& p) { -+ CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); -+ CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); -+ CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); -+ XFORMERS_CHECK( -+ p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.k_strideM % kAlignmentK == 0, "key is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.v_strideM % kAlignmentV == 0, "value is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.k_strideH % kAlignmentK == 0, "key is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.v_strideH % kAlignmentV == 0, "value is not correctly aligned"); -+ -+ if (p.attn_mask_ptr) { -+ CHECK_ALIGNED_PTR(p.attn_mask_ptr, kAlignmentQ); -+ XFORMERS_CHECK( -+ p.attn_mask_strideB % kAlignmentQ == 0, -+ "attn_mask is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.attn_mask_strideH % kAlignmentQ == 0, -+ "attn_mask is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.attn_mask_strideM % kAlignmentQ == 0, -+ "attn_mask is not correctly aligned"); -+ } -+ if (p.attn_bias_ptr) { -+ CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); -+ XFORMERS_CHECK( -+ p.attn_bias_strideB % kAlignmentQ == 0, -+ "attn_bias is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.attn_bias_strideH % kAlignmentQ == 0, -+ "attn_bias is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.attn_bias_strideM % kAlignmentQ == 0, -+ "attn_bias is not correctly aligned"); -+ } -+ return true; -+ } -+ -+ static void CUTLASS_DEVICE attention_kernel(Params& p) { -+ // In this block, we will only ever: -+ // - read query[query_start:query_end, :] -+ // - write to output[query_start:query_end, :] -+ -+ extern __shared__ char smem_buffer[]; -+ SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); -+ auto& m_prime = shared_storage.m_prime; -+ auto& s_prime = shared_storage.s_prime; -+ [[maybe_unused]] auto& si = shared_storage.after_mm0.si; -+ auto& mi = shared_storage.mi; -+ -+ static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); -+ if (thread_id() < kQueriesPerBlock) { -+ s_prime[thread_id()] = accum_t(0); -+ m_prime[thread_id()] = -+ -cutlass::platform::numeric_limits::infinity(); -+ mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); -+ } -+ typename MM1::Mma::FragmentC accum_o; -+ accum_o.clear(); -+ -+ auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { -+ using OutputTileIterator = typename MM1::OutputTileIterator; -+ return OutputTileIterator( -+ typename OutputTileIterator::Params{(int32_t)p.o_strideM()}, -+ p.output_ptr, -+ typename OutputTileIterator::TensorCoord{ -+ p.num_queries, p.head_dim_value}, -+ thread_id(), -+ {0, col}); -+ }; -+ -+ auto createOutputAccumIter = [&](int col) -> -+ typename MM1::OutputTileIteratorAccum { -+ using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; -+ return OutputTileIteratorAccum( -+ typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()}, -+ p.output_accum_ptr, -+ typename OutputTileIteratorAccum::TensorCoord{ -+ p.num_queries, p.head_dim_value}, -+ thread_id(), -+ {0, col}); -+ }; -+ -+ // Iterate through keys -+ for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; -+ iter_key_start += kKeysPerBlock) { -+ int32_t problem_size_0_m = -+ cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); -+ int32_t problem_size_0_n = cutlass::fast_min( -+ int32_t(kKeysPerBlock), p.num_keys - iter_key_start); -+ int32_t const& problem_size_0_k = p.head_dim; -+ int32_t const& problem_size_1_n = p.head_dim_value; -+ int32_t const& problem_size_1_k = problem_size_0_n; -+ -+ auto prologueV = [&](int blockN) { -+ typename MM1::Mma::IteratorB iterator_V( -+ typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, -+ p.value_ptr + iter_key_start * p.v_strideM, -+ {problem_size_1_k, problem_size_1_n}, -+ thread_id(), -+ cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); -+ MM1::Mma::prologue( -+ shared_storage.after_mm0.mm1.mm, -+ iterator_V, -+ thread_id(), -+ problem_size_1_k); -+ }; -+ -+ __syncthreads(); // Need to have shared memory initialized, and `m_prime` -+ // updated from end of prev iter -+ // -+ // MATMUL: Q.K_t -+ // -+ // Computes the block-matrix product of: -+ // (a) query[query_start:query_end, :] -+ // with -+ // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] -+ // and stores that into `shared_storage.si` -+ // -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{ -+ tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()}; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; -+ -+ // Construct iterators to A and B operands -+ typename MM0::IteratorA iterator_A( -+ typename MM0::IteratorA::Params( -+ typename MM0::MmaCore::LayoutA(p.q_strideM)), -+ p.query_ptr, -+ {problem_size_0_m, problem_size_0_k}, -+ thread_id(), -+ tb_offset_A); -+ -+ typename MM0::IteratorB iterator_B( -+ typename MM0::IteratorB::Params( -+ typename MM0::MmaCore::LayoutB(p.k_strideM)), -+ p.key_ptr + iter_key_start * p.k_strideM, -+ {problem_size_0_k, problem_size_0_n}, -+ thread_id(), -+ tb_offset_B); -+ -+ auto my_warp_id = warp_id(); -+ auto my_lane_id = lane_id(); -+ -+ // Construct thread-scoped matrix multiply -+ typename MM0::Mma mma( -+ shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); -+ -+ typename MM0::Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ auto gemm_k_iterations = -+ (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); -+ __syncthreads(); -+ -+ if (kPreloadV) { -+ prologueV(0); -+ } -+ -+ typename MM0::Mma::Operator::IteratorC::TensorCoord -+ iteratorC_tile_offset = { -+ (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + -+ (my_warp_id % MM0::Mma::WarpCount::kM), -+ (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + -+ (my_warp_id / MM0::Mma::WarpCount::kM)}; -+ float scale = p.scale; -+ if (p.attn_bias_ptr) { -+ if (scale != 1.0f) { -+ accum = cutlass::multiplies()(scale, accum); -+ scale = 1.0f; -+ } -+ auto query_start = blockIdx.x * kQueriesPerBlock; -+ // load bias tile Bij into shared memory -+ typename MM0::BiasLoader::GmemTileIterator bias_iter( -+ {cutlass::layout::RowMajor(p.attn_bias_strideM)}, -+ // attn_bias_pointer points to matrix of size (n_queries, n_keys) -+ // for the relevant batch_id -+ p.attn_bias_ptr + query_start * p.attn_bias_strideM + iter_key_start, -+ {problem_size_0_m, problem_size_0_n}, -+ thread_id()); -+ cutlass::TensorRef bias_tensor_ref( -+ shared_storage.after_mm0.bias.data(), -+ cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); -+ typename MM0::BiasLoader::SmemTileIterator smem_tile_iter( -+ bias_tensor_ref, thread_id()); -+ -+ MM0::BiasLoader::load(bias_iter, smem_tile_iter); -+ // Pij += Bij, Pij is in register fragment and Bij is in shared memory -+ auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( -+ lane_id(), warp_id(), iteratorC_tile_offset); -+ MM0::ScalingCoefsUpdater::iterateRows( -+ lane_offset, -+ [&](int accum_m) {}, -+ [&](int accum_m, int accum_n, int idx) { -+ if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { -+ accum[idx] += bias_tensor_ref.at({accum_m, accum_n}); -+ } -+ }, -+ [&](int accum_m) {}); -+ } -+ -+ if (p.attn_mask_ptr) { -+ // scale*Q.K_t prio to mask apply -+ if (scale != 1.0f) { -+ accum = cutlass::multiplies()(scale, accum); -+ scale = 1.0f; -+ } -+ auto query_start = blockIdx.x * kQueriesPerBlock; -+ // load mask tile Mij into shared memory -+ typename MM0::MaskLoader::GmemTileIterator mask_iter( -+ {cutlass::layout::RowMajor(p.attn_mask_strideM)}, -+ // attn_mask_pointer points to matrix of size (n_queries, n_keys) -+ // for the relevant batch_id -+ p.attn_mask_ptr + query_start * p.attn_mask_strideM + iter_key_start, -+ {problem_size_0_m, problem_size_0_n}, -+ thread_id()); -+ cutlass::TensorRef mask_tensor_ref( -+ shared_storage.after_mm0.mask.data(), -+ cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); -+ typename MM0::MaskLoader::SmemTileIterator smem_tile_iter( -+ mask_tensor_ref, thread_id()); -+ -+ MM0::MaskLoader::load(mask_iter, smem_tile_iter); -+ // Pij += (Mij-1)*(-10000), Pij is in register fragment and Mij is in shared memory -+ auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( -+ lane_id(), warp_id(), iteratorC_tile_offset); -+ MM0::ScalingCoefsUpdater::iterateRows( -+ lane_offset, -+ [&](int accum_m) {}, -+ [&](int accum_m, int accum_n, int idx) { -+ if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { -+ scalar_t tmp = scalar_t(1.0) - mask_tensor_ref.at({accum_m, accum_n}); -+ accum[idx] += tmp*scalar_t(-10000.0f); -+ } -+ }, -+ [&](int accum_m) {}); -+ } -+ -+ // Mask out last if causal -+ if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) { -+ auto query_start = blockIdx.x * kQueriesPerBlock; -+ auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( -+ lane_id(), warp_id(), iteratorC_tile_offset); -+ int32_t last_col; -+ MM0::ScalingCoefsUpdater::iterateRows( -+ lane_offset, -+ [&](int accum_m) { -+ last_col = query_start + accum_m - iter_key_start; -+ }, -+ [&](int accum_m, int accum_n, int idx) { -+ if (accum_n > last_col) { -+ accum[idx] = -+ -cutlass::platform::numeric_limits::infinity(); -+ } -+ }, -+ [&](int accum_m) {}); -+ } -+ DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { -+ DISPATCH_BOOL( -+ p.num_keys - iter_key_start >= kKeysPerBlock, -+ kFullColumns, -+ ([&] { -+ // Update `mi` from accum stored in registers -+ // Also updates `accum` with accum[i] <- -+ // exp(accum[i] * scale -+ // - mi) -+ MM0::ScalingCoefsUpdater::update< -+ kQueriesPerBlock, -+ kFullColumns, -+ kIsFirst, -+ kKeepOutputInRF>( -+ accum_o, -+ accum, -+ mi, -+ m_prime, -+ s_prime, -+ lane_id(), -+ thread_id(), -+ warp_id(), -+ p.num_keys - iter_key_start, -+ iteratorC_tile_offset, -+ scale); -+ })); -+ })); -+ -+ // Output results to shared-memory -+ int warp_idx_mn_0 = my_warp_id % -+ (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); -+ auto output_tile_coords = cutlass::MatrixCoord{ -+ warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, -+ warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; -+ -+ MM0::B2bGemm::accumToSmem( -+ shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); -+ -+ __syncthreads(); -+ -+ // -+ // MATMUL: Attn . V -+ // Run the matmul `attn @ V` for a block of attn and V. -+ // `attn` is read from shared memory (in `shared_storage_si`) -+ // `V` is read from global memory (with iterator_B) -+ // -+ -+ const int64_t nBlockN = kSingleValueIteration -+ ? 1 -+ : ceil_div( -+ (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); -+ for (int blockN = 0; blockN < nBlockN; ++blockN) { -+ int gemm_k_iterations = -+ (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add and store it in accum -+ // (in registers) -+ if (!kPreloadV) { -+ __syncthreads(); // we share shmem between mma and epilogue -+ } -+ -+ typename MM1::Mma::IteratorB iterator_V( -+ typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, -+ p.value_ptr + iter_key_start * p.v_strideM, -+ {problem_size_1_k, problem_size_1_n}, -+ thread_id(), -+ cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); -+ typename MM1::Mma mma_pv( -+ shared_storage.after_mm0.mm1.mm, -+ shared_storage.after_mm0.si, -+ (int)thread_id(), -+ (int)warp_id(), -+ (int)lane_id(), -+ (int)problem_size_1_k); -+ mma_pv.set_prologue_done(kPreloadV); -+ if (!kKeepOutputInRF) { -+ accum_o.clear(); -+ } -+ mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); -+ __syncthreads(); -+ -+ if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { -+ prologueV(blockN + 1); -+ } -+ -+ if (!kKeepOutputInRF) { -+ DISPATCH_BOOL( -+ iter_key_start == 0, kIsFirst, ([&] { -+ DISPATCH_BOOL( -+ (iter_key_start + kKeysPerBlock) >= p.num_keys, -+ kIsLast, -+ ([&] { -+ using DefaultEpilogue = typename MM1::DefaultEpilogue; -+ using DefaultOp = -+ typename MM1::DefaultConfig::EpilogueOutputOp; -+ using ElementCompute = typename DefaultOp::ElementCompute; -+ using EpilogueOutputOp = typename cutlass::epilogue:: -+ thread::MemoryEfficientAttentionNormalize< -+ typename cutlass::platform::conditional< -+ kIsLast, -+ output_t, -+ output_accum_t>::type, -+ output_accum_t, -+ DefaultOp::kCount, -+ typename DefaultOp::ElementAccumulator, -+ ElementCompute, -+ kIsFirst, -+ kIsLast, -+ cutlass::Array>; -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ EpiloguePipelined< -+ typename DefaultEpilogue::Shape, -+ typename MM1::Mma::Operator, -+ DefaultEpilogue::kPartitionsK, -+ typename cutlass::platform::conditional< -+ kIsLast, -+ typename MM1::OutputTileIterator, -+ typename MM1::OutputTileIteratorAccum>::type, -+ typename DefaultEpilogue:: -+ AccumulatorFragmentIterator, -+ typename DefaultEpilogue::WarpTileIterator, -+ typename DefaultEpilogue::SharedLoadIterator, -+ EpilogueOutputOp, -+ typename DefaultEpilogue::Padding, -+ DefaultEpilogue::kFragmentsPerIteration, -+ true, // IterationsUnroll -+ typename MM1::OutputTileIteratorAccum // Read -+ // iterator -+ >; -+ -+ int col = blockN * MM1::Mma::Shape::kN; -+ auto source_iter = createOutputAccumIter(col); -+ auto dest_iter = call_conditional< -+ kIsLast, -+ decltype(createOutputIter), -+ decltype(createOutputAccumIter)>:: -+ apply(createOutputIter, createOutputAccumIter, col); -+ EpilogueOutputOp rescale(s_prime, m_prime); -+ Epilogue epilogue( -+ shared_storage.epilogue_shared_storage(), -+ thread_id(), -+ warp_id(), -+ lane_id()); -+ epilogue(rescale, dest_iter, accum_o, source_iter); -+ })); -+ })); -+ if (!kSingleValueIteration) { -+ __syncthreads(); -+ } -+ } -+ } -+ __syncthreads(); // we modify `m_prime` after -+ } -+ -+ if (kKeepOutputInRF) { -+ constexpr bool kIsFirst = true; -+ constexpr bool kIsLast = true; -+ using DefaultEpilogue = typename MM1::DefaultEpilogue; -+ using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; -+ using ElementCompute = typename DefaultOp::ElementCompute; -+ using EpilogueOutputOp = -+ typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< -+ output_t, // output -+ output_accum_t, // source -+ DefaultOp::kCount, -+ typename DefaultOp::ElementAccumulator, // accum -+ output_accum_t, // compute -+ kIsFirst, -+ kIsLast, -+ cutlass::Array>; -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::EpiloguePipelined< -+ typename DefaultEpilogue::Shape, -+ typename MM1::Mma::Operator, -+ DefaultEpilogue::kPartitionsK, -+ typename MM1::OutputTileIterator, // destination -+ typename DefaultEpilogue::AccumulatorFragmentIterator, -+ typename DefaultEpilogue::WarpTileIterator, -+ typename DefaultEpilogue::SharedLoadIterator, -+ EpilogueOutputOp, -+ typename DefaultEpilogue::Padding, -+ DefaultEpilogue::kFragmentsPerIteration, -+ true, // IterationsUnroll -+ typename MM1::OutputTileIteratorAccum // source tile -+ >; -+ auto dest_iter = createOutputIter(0); -+ EpilogueOutputOp rescale(s_prime, m_prime); -+ Epilogue epilogue( -+ shared_storage.epilogue_shared_storage(), -+ thread_id(), -+ warp_id(), -+ lane_id()); -+ epilogue(rescale, dest_iter, accum_o); -+ } -+ -+ // 7. Calculate logsumexp -+ // To make the backward easier, we pad logsumexp with `inf` -+ // this avoids a few bound checks, and is not more expensive during fwd -+ static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); -+ if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { -+ auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; -+ if (thread_id() < p.num_queries) { -+ p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) + -+ cutlass::fast_log(accum_t(s_prime[thread_id()])); -+ } else if (thread_id() < lse_dim) { -+ p.logsumexp_ptr[thread_id()] = -+ cutlass::platform::numeric_limits::infinity(); -+ } -+ } -+ } -+ -+ static CUTLASS_DEVICE int8_t lane_id() { -+ return threadIdx.x; -+ } -+ static CUTLASS_DEVICE int8_t warp_id() { -+ return threadIdx.y; -+ } -+ static CUTLASS_DEVICE int16_t thread_id() { -+ return threadIdx.x + threadIdx.y * blockDim.x; -+ } -+}; -+ -+template -+__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) -+ attention_kernel_batched_impl(typename AK::Params p) { -+ if (!p.advance_to_block()) { -+ return; -+ } -+ AK::attention_kernel(p); -+} -+ -+template -+__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) -+ attention_kernel_batched(typename AK::Params params); -+ -+#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \ -+ template <> \ -+ __global__ void __launch_bounds__( \ -+ __VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \ -+ attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \ -+ using Kernel = __VA_ARGS__; -+#define _ATTENTION_KERNEL_FORWARD_END() } -+ -+#ifdef __CUDA_ARCH__ -+#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__ -+#else -+#define __CUDA_ARCH_OR_ZERO__ 0 -+#endif -+ -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \ -+ ARCH, \ -+ SCALAR_T, \ -+ IS_ALIGNED, \ -+ QUERIES_PER_BLOCK, \ -+ KEYS_PER_BLOCK, \ -+ SINGLE_VALUE_ITER) \ -+ _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ -+ SCALAR_T, \ -+ cutlass::arch::Sm##ARCH, \ -+ IS_ALIGNED, \ -+ QUERIES_PER_BLOCK, \ -+ KEYS_PER_BLOCK, \ -+ SINGLE_VALUE_ITER>) \ -+ if (!p.advance_to_block()) { \ -+ return; \ -+ } \ -+ Kernel::attention_kernel(p); \ -+ _ATTENTION_KERNEL_FORWARD_END(); -+ -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \ -+ ARCH, \ -+ SCALAR_T, \ -+ IS_ALIGNED, \ -+ QUERIES_PER_BLOCK, \ -+ KEYS_PER_BLOCK, \ -+ SINGLE_VALUE_ITER) \ -+ _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ -+ SCALAR_T, \ -+ cutlass::arch::Sm##ARCH, \ -+ IS_ALIGNED, \ -+ QUERIES_PER_BLOCK, \ -+ KEYS_PER_BLOCK, \ -+ SINGLE_VALUE_ITER>) \ -+ printf( \ -+ "FATAL: this function is for sm%d, but was built for sm%d\n", \ -+ int(ARCH), \ -+ int(__CUDA_ARCH_OR_ZERO__)); \ -+ _ATTENTION_KERNEL_FORWARD_END(); -+ -+// All kernels are disabled by default -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__) -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__) -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__) -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__) -+ -+// Enable the right one based on __CUDA_ARCH__ -+#ifndef __CUDA_ARCH__ -+#elif __CUDA_ARCH__ < 500 -+#error "Need cuda arch at least 5.0" -+#elif __CUDA_ARCH__ < 700 -+#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50 -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__) -+#elif __CUDA_ARCH__ < 750 -+#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70 -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__) -+#elif __CUDA_ARCH__ < 800 -+#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75 -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__) -+#elif __CUDA_ARCH__ >= 800 -+#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80 -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__) -+#endif -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/mma_from_smem.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/mma_from_smem.h -new file mode 100644 -index 0000000..21ac4d1 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/mma_from_smem.h -@@ -0,0 +1,1780 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+ -+#include "attention_scaling_coefs_updater.h" -+#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" -+#include "epilogue_thread_apply_logsumexp.h" -+#include "gemm_kernel_utils.h" -+#include "iterators/make_residual_last.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+/// Shared storage object needed by accumulator -+/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename Padding_> -+class AccumulatorSharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using Padding = Padding_; -+ -+ /// Tensor reference to the accumulator -+ using TensorRefAccum = cutlass::TensorRef; -+ -+ /// Shape of the accumulator matrix in shared memory -+ using ShapeAccum = cutlass:: -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for accumulator -+ cutlass::AlignedBuffer accum; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the Accum matrix -+ CUTLASS_DEVICE -+ static Layout LayoutAccum() { -+ return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the Accumulator -+ CUTLASS_HOST_DEVICE -+ TensorRefAccum accum_ref() { -+ return TensorRefAccum{accum.data(), LayoutAccum()}; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Taken from -+// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ // Maximum value for K -+ int kMaxK, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaBaseFromSharedMemory { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = GemmShape< -+ Shape::kM / WarpGemm::kM, -+ Shape::kN / WarpGemm::kN, -+ Shape::kK / WarpGemm::kK>; -+ using WarpCount1 = WarpCount; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ static int const kWarpGemmIterations1 = kWarpGemmIterations; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// If this is true, we fill the entire shmem buffer at start -+ /// and don't need to iterate through it in a circular fashion -+ static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = -+ TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = -+ TensorRef; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = MatrixShape< -+ Shape::kK * kStages + Policy::SmemPaddingB::kRow, -+ Shape::kN + Policy::SmemPaddingB::kColumn>; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ }; -+ -+ protected: -+ // -+ // Data members -+ // -+ -+ // /// Iterator to load a warp-scoped tile of A operand from shared memory -+ // typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaBaseFromSharedMemory( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage& shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Taken from -+// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ // BEGIN smem -+ /// Iterates over the intermediate accumulator tile in shared memory -+ typename WarpIteratorA, -+ // Accumulator type -+ typename AccumulatorSharedStorage, -+ // END smem -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to B operand -+ typename TransformB_ = NumericArrayConverter< -+ typename SmemIteratorB_::Element, -+ typename IteratorB_::Element, -+ IteratorB_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< -+ Shape_, -+ AccumulatorSharedStorage::Shape::kN, -+ Policy_, -+ 2> { -+ public: -+ ///< Base class -+ using Base = MmaBaseFromSharedMemory< -+ Shape_, -+ AccumulatorSharedStorage::Shape::kN, -+ Policy_, -+ 2>; -+ -+ using Shape = -+ Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorB = -+ IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorB = SmemIteratorB_; -+ -+ using TransformB = TransformB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert( -+ (Base::kStages == 2), -+ "MmaPipelined requires kStages set to value 2"); -+ -+ private: -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+ protected: -+ // /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ // SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ /// Iterator to load a warp-scoped tile of A operand from intermediate -+ /// accumulator tile -+ WarpIteratorA warp_tile_iterator_A_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaPipelinedFromSharedMemory( -+ typename Base::SharedStorage& -+ shared_storage, ///< Shared storage needed for internal use by -+ ///< threadblock-scoped GEMM -+ AccumulatorSharedStorage& accumulator_shared_storage, -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx, ///< ID of each thread within a warp -+ int problem_size_0_n) -+ : Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ // For API compatibility with MmaMultistageFromSharedMemory -+ // but not supported as it worsens perf: older gpus < sm80 don't -+ // support async tranfers and have to waste registers -+ CUTLASS_DEVICE -+ void set_prologue_done(bool value) {} -+ CUTLASS_DEVICE -+ static void prologue( -+ typename Base::SharedStorage& shared_storage, -+ IteratorB iterator_B1, -+ int thread_idx, -+ int problem_size_0_n) {} -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC& accum, ///< destination accumulator tile -+ // IteratorA iterator_A, ///< iterator over A -+ // operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const& src_accum, ///< source accumulator tile -+ // TransformA transform_A = TransformA(), ///< transformation -+ // applied to A fragment -+ TransformB transform_B = -+ TransformB()) { ///< transformation applied to B fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentB tb_frag_B; -+ -+ tb_frag_B.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_B.set_residual_tile(gemm_k_iterations == 1); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_B; -+ -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ ++this->smem_iterator_B_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpFragmentA warp_frag_A[2]; -+ WarpFragmentB warp_frag_B[2]; -+ warp_frag_A[0].clear(); -+ warp_frag_B[0].clear(); -+ -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Avoid reading out of bounds -+ iterator_B.set_residual_tile(gemm_k_iterations == 2); -+ iterator_B.clear_mask(gemm_k_iterations <= 1); -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* -+ // issuing shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ bool hasNext = true; -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ // Write fragments to shared memory -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory SMEM: Don't reset iterator A, as -+ // we are continuing our iteration at this point -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } else { -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ hasNext = gemm_k_iterations > 1; -+ } -+ -+ // Only read the next if we need to -+ if (hasNext) { -+ this->warp_tile_iterator_B_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_B; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_B.set_residual_tile(gemm_k_iterations == 3); -+ iterator_B.clear_mask(gemm_k_iterations <= 2); -+ } -+ } -+ -+ warp_mma( -+ accum, -+ warp_frag_A[warp_mma_k % 2], -+ warp_frag_B[warp_mma_k % 2], -+ accum); -+ } -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Taken from -+// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile in shared memory -+ typename WarpIteratorA1_, -+ // Accumulator type -+ typename AccumulatorSharedStorage, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB1, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Number of stages, -+ int Stages_, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< -+ Shape1_, -+ AccumulatorSharedStorage::Shape::kN, -+ Policy1_, -+ Stages_> { -+ public: -+ ///< Base class -+ using Base = MmaBaseFromSharedMemory< -+ Shape1_, -+ AccumulatorSharedStorage::Shape::kN, -+ Policy1_, -+ Stages_>; -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape1 = Shape1_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB1 = IteratorB1_; -+ using IteratorB = IteratorB1; -+ ///< Policy describing tuning details -+ using Policy1 = Policy1_; -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate -+ ///< accumulator tile in shared memory -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; -+ static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ using FragmentC = FragmentC1; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ static_assert( -+ Base::kWarpGemmIterations1 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const TBLoadIterationsB1 = -+ IteratorB1::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB1 = -+ (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / -+ Base::kWarpGemmIterations1; -+ }; -+ -+ static constexpr int kNumStagesConcurrentLoad = -+ kSmemContainsEntireB ? Base::kStages : Base::kStages - 1; -+ -+ private: -+ using WarpLoadedFragmentA1 = typename Operator1::FragmentA; -+ using WarpLoadedFragmentB1 = typename Operator1::FragmentB; -+ using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; -+ using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate -+ /// accumulator tile -+ WarpIteratorA1 warp_tile_iterator_A1_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+ bool prologue_done_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaMultistageFromSharedMemory( -+ typename Base::SharedStorage& -+ shared_storage, ///< Shared storage needed for internal use by -+ ///< threadblock-scoped GEMM -+ AccumulatorSharedStorage& accumulator_shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx, -+ ///< GEMM0 N is used for accumulator extent -+ int problem_size_0_n) -+ : Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ warp_tile_iterator_A1_( -+ accumulator_shared_storage.accum_ref(), -+ lane_idx), -+ smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), -+ prologue_done_(false) { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn_1 = -+ warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ -+ int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; -+ int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ warp_tile_iterator_A1_.add_tile_offset( -+ {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); -+ } -+ -+ CUTLASS_DEVICE -+ void set_prologue_done(bool value) { -+ prologue_done_ = value; -+ } -+ -+ CUTLASS_DEVICE -+ static void prologue( -+ typename Base::SharedStorage& shared_storage, -+ IteratorB iterator_B1, -+ int thread_idx, -+ int problem_size_0_n) { -+ SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx); -+ _prologue( -+ iterator_B1, -+ (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK, -+ smem_iterator_B1); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_1( -+ IteratorB1& iterator_B1, -+ int group_start_B1 = 0) { -+ iterator_B1.set_iteration_index( -+ group_start_B1 * IteratorB1::kAccessesPerVector); -+ this->smem_iterator_B1_.set_iteration_index(group_start_B1); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { -+ if (group_start_B1 + j < Detail::TBLoadIterationsB1) { -+ typename IteratorB1::AccessType* dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / -+ IteratorB1::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B1.get(); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ ++this->smem_iterator_B1_; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ static void _prologue( -+ IteratorB& iterator_B1, -+ int32_t gemm_k_iterations_1, -+ SmemIteratorB1& smem_iterator_B1_) { -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < kNumStagesConcurrentLoad; -+ ++stage, --gemm_k_iterations_1) { -+ iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 0); -+ -+ iterator_B1.set_iteration_index(0); -+ smem_iterator_B1_.set_iteration_index(0); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { -+ typename IteratorB1::AccessType* dst_ptr = -+ reinterpret_cast( -+ smem_iterator_B1_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / -+ IteratorB1::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ -+ ++smem_iterator_B1_; -+ } -+ -+ // Move to the next stage -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 0); -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations_1_, -+ ///< destination accumulator tile -+ FragmentC1& accum, -+ ///< iterator over B1 operand in global memory -+ IteratorB1 iterator_B1, -+ ///< initial value of accumulator -+ FragmentC1 const& src_accum) { -+ // 2nd Gemm -+ -+ // -+ // Prologue -+ // -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ if (!prologue_done_) { -+ _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_); -+ } else if (!kSmemContainsEntireB) { -+ // Restore the iterators increments -+ -+ int gemm_k_iterations_1 = gemm_k_iterations_1_; -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < kNumStagesConcurrentLoad; -+ ++stage, --gemm_k_iterations_1) { -+ iterator_B1.set_iteration_index(0); -+ this->smem_iterator_B1_.set_iteration_index(0); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ ++iterator_B1; -+ } -+ ++this->smem_iterator_B1_; -+ } -+ iterator_B1.add_tile_offset({1, 0}); -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ } -+ iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1); -+ iterator_B1.clear_mask(gemm_k_iterations_1 <= 0); -+ } -+ -+ // DEPBAR+SYNC -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; -+ WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; -+ WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; -+ WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; -+ -+ Operator1 warp_mma1; -+ -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); -+ ++warp_tile_iterator_A1_; -+ -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); -+ ++this->warp_tile_iterator_B_; -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma1.transform( -+ warp_transformed_frag_A1[0], -+ warp_transformed_frag_B1[0], -+ warp_loaded_frag_A1[0], -+ warp_loaded_frag_B1[0]); -+ -+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary -+ // accumulator and this temporary accumulator is added to the final -+ // accumulator once in every mainloop iteration. -+ plus plus_accum; -+ -+ FragmentC1 tmp_accum; -+ -+ if (platform::is_same< -+ typename Operator1::MathOperator, -+ arch::OpMultiplyAddFastF32>::value || -+ platform::is_same< -+ typename Operator1::MathOperator, -+ arch::OpMultiplyAddComplexFastF32>::value) { -+ tmp_accum.clear(); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); -+ gemm_k_iterations_1 > (-Base::kStages + 1); -+ gemm_k_iterations_1--) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; -+ ++warp_mma_k) { -+ // Load warp-level tile from accumulator fragment (A) -+ // or shared memory (operand B) -+ this->warp_tile_iterator_B_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ // skip warp tile loading for the last kgroup (we are out of the buf) -+ if (gemm_k_iterations_1 > (-Base::kStages + 2) || -+ warp_mma_k < Base::kWarpGemmIterations1 - 1) { -+ warp_tile_iterator_A1_.load( -+ warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load( -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ } -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) -+ warp_mma1.transform( -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ warp_loaded_frag_A1[warp_mma_k % 2], -+ warp_loaded_frag_B1[warp_mma_k % 2]); -+ -+ if (platform::is_same< -+ typename Operator1::MathOperator, -+ arch::OpMultiplyAddFastF32>::value || -+ platform::is_same< -+ typename Operator1::MathOperator, -+ arch::OpMultiplyAddComplexFastF32>::value) { -+ warp_mma1( -+ tmp_accum, -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ tmp_accum); -+ -+ if (warp_mma_k == 0) { -+ accum = plus_accum(accum, tmp_accum); -+ tmp_accum.clear(); -+ } -+ } else { -+ warp_mma1( -+ accum, -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ accum); -+ } -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { -+ int group_start_iteration_B1; -+ -+ group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; -+ -+ if (!kSmemContainsEntireB) { -+ copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); -+ } -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { -+ int group_start_iteration_B1; -+ group_start_iteration_B1 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; -+ -+ if (!kSmemContainsEntireB) { -+ copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); -+ } -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (!kSmemContainsEntireB) { -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ } -+ -+ iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2); -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 1); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) -+ warp_mma1.transform( -+ warp_transformed_frag_A1[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ } -+ } -+ -+ if (platform::is_same< -+ typename Operator1::MathOperator, -+ arch::OpMultiplyAddFastF32>::value || -+ platform::is_same< -+ typename Operator1::MathOperator, -+ arch::OpMultiplyAddComplexFastF32>::value) { -+ accum = plus_accum(accum, tmp_accum); -+ } -+ } -+}; -+ -+template < -+ typename WarpShape, -+ typename InstructionShape, -+ typename RegularWarpIterator, -+ typename Policy> -+struct DefaultWarpIteratorAFromSharedMemory {}; -+ -+// TensorOp - Ampere -+template -+struct DefaultWarpIteratorAFromSharedMemory< -+ WarpShape, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ RegularWarpIterator, -+ Policy> { -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ static constexpr auto kWarpSize = 32; -+ using OpDelta = typename Policy::Operator::Policy::OpDelta; -+ -+ using WarpIterator = -+ cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::gemm::Operand::kA, -+ typename RegularWarpIterator::Element, -+ cutlass::layout::RowMajor, -+ cutlass::MatrixShape, -+ OpDelta::kRow, -+ kWarpSize>; -+}; -+ -+// TensorOp - Volta -+template -+struct DefaultWarpIteratorAFromSharedMemory< -+ WarpShape, -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ RegularWarpIterator, -+ Policy> { -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; -+ static constexpr auto kWarpSize = 32; -+ using OpDelta = typename Policy::Operator::Policy::OpDelta; -+ -+ using WarpIterator = -+ cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< -+ cutlass::MatrixShape<32, 32>, // MatrixShape, -+ cutlass::gemm::Operand::kA, -+ typename RegularWarpIterator::Element, -+ cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, -+ cutlass::MatrixShape<16, 4>, -+ OpDelta::kRow, -+ kWarpSize>; -+}; -+ -+// Simt -+template -+struct DefaultWarpIteratorAFromSharedMemory< -+ WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ RegularWarpIterator, -+ Policy> { -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr auto kWarpSize = 32; -+ -+ // We just use the same iterator, as we reproduced the same shared-memory -+ // schema. Just modify it to handle non-complete tiles. -+ using WarpIterator = RegularWarpIterator; -+}; -+ -+// Converts a "regular" Mma into their counterpart from shared memory -+template -+struct DefaultMmaFromSharedMemory; -+ -+// Mma pipelined -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to A operand -+ typename TransformA_, -+ /// Transformation applied to B operand -+ typename TransformB_, -+ typename AccumulatorSharedStorage_> -+struct DefaultMmaFromSharedMemory< -+ MmaPipelined< -+ Shape_, -+ IteratorA_, -+ SmemIteratorA_, -+ IteratorB_, -+ SmemIteratorB_, -+ ElementC_, -+ LayoutC_, -+ Policy_, -+ TransformA_, -+ TransformB_>, -+ AccumulatorSharedStorage_> { -+ static constexpr int kWarpSize = 32; -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ -+ using RegularMma = MmaPipelined< -+ Shape_, -+ IteratorA_, -+ SmemIteratorA_, -+ IteratorB_, -+ SmemIteratorB_, -+ ElementC_, -+ LayoutC_, -+ Policy_, -+ TransformA_, -+ TransformB_>; -+ -+ using WarpShape = typename Policy_::Operator::Shape; -+ using InstructionShape = typename Policy_::Operator::InstructionShape; -+ using ArchMmaOperator = typename Policy_::Operator; -+ -+ using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory< -+ WarpShape, -+ InstructionShape, -+ typename RegularMma::Operator::IteratorA, -+ Policy_>::WarpIterator; -+ using IteratorB = -+ typename cutlass::transform::threadblock::MakeIteratorResidualLast< -+ IteratorB_>::Iterator; -+ -+ using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< -+ Shape_, -+ WarpIteratorA, -+ AccumulatorSharedStorage_, -+ IteratorB, -+ SmemIteratorB_, -+ ElementC_, -+ LayoutC_, -+ Policy_>; -+}; -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ typename AccumulatorSharedStorage_> -+struct DefaultMmaFromSharedMemory< -+ MmaMultistage< -+ Shape_, -+ IteratorA_, -+ SmemIteratorA_, -+ CacheOpA, -+ IteratorB_, -+ SmemIteratorB_, -+ CacheOpB, -+ ElementC_, -+ LayoutC_, -+ Policy_, -+ Stages, -+ SharedMemoryClear>, -+ AccumulatorSharedStorage_> { -+ static constexpr int kWarpSize = 32; -+ -+ using RegularMma = MmaMultistage< -+ Shape_, -+ IteratorA_, -+ SmemIteratorA_, -+ CacheOpA, -+ IteratorB_, -+ SmemIteratorB_, -+ CacheOpB, -+ ElementC_, -+ LayoutC_, -+ Policy_, -+ Stages, -+ SharedMemoryClear>; -+ -+ using WarpShape = typename Policy_::Operator::Shape; -+ using InstructionShape = typename Policy_::Operator::InstructionShape; -+ using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory< -+ WarpShape, -+ InstructionShape, -+ typename RegularMma::Operator::IteratorA, -+ Policy_>::WarpIterator; -+ -+ static int constexpr kMaxK = AccumulatorSharedStorage_::Shape::kN; -+ // Reduce the number of stages if we don't need that many -+ static int constexpr kStagesMax = -+ (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); -+ static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); -+ -+ using IteratorB = -+ typename cutlass::transform::threadblock::MakeIteratorResidualLast< -+ IteratorB_>::Iterator; -+ using Mma = -+ typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< -+ Shape_, -+ WarpIteratorA, -+ AccumulatorSharedStorage_, -+ IteratorB, -+ SmemIteratorB_, -+ RegularMma::kCacheOpB, -+ ElementC_, -+ LayoutC_, -+ Policy_, -+ kStages>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename IteratorC, -+ typename Operator, -+ typename scalar_t, -+ typename WarpShape_, -+ typename ThreadblockShape_> -+struct B2bGemm; -+ -+// Tensor Cores >= Sm75 specialization (Ampere ...) -+template < /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_, -+ typename Operator, -+ typename scalar_t, -+ typename WarpShape_, -+ typename ThreadblockShape_> -+struct B2bGemm< -+ cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< -+ Shape_, -+ Element_, -+ Layout_, -+ InstructionShape_, -+ OpDelta_>, -+ Operator, -+ scalar_t, -+ WarpShape_, -+ ThreadblockShape_> { -+ using IteratorC = -+ typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< -+ Shape_, -+ Element_, -+ Layout_, -+ InstructionShape_, -+ OpDelta_>; -+ using FragmentC = typename IteratorC::Fragment; -+ using InstructionShape = InstructionShape_; -+ using WarpShape = WarpShape_; -+ using ThreadblockShape = ThreadblockShape_; -+ using accum_t = Element_; -+ using lse_scalar_t = float; -+ -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ -+ // Iterator to load accumulators (results of matmul in registers) -+ using FragmentIteratorAccumulator = -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ accum_t, -+ typename Operator::Policy::Operator::FragmentC, -+ cutlass::layout::RowMajor>; -+ -+ // Iterator to store to shared-memory -+ using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ scalar_t, // accum_t, -+ SmemAccumulatorLayout>; -+ using AccumulatorSharedStorage = -+ cutlass::gemm::threadblock::AccumulatorSharedStorage< -+ ThreadblockShape, -+ typename SmemIteratorD0::Element, -+ typename SmemIteratorD0::TensorLayout, -+ typename SmemIteratorD0::Padding>; -+ // We need to provide an operation for the epilogue. Let's create an -+ // operation that does nothing (ScaleType::Nothing), just converts -+ // from accum_t (float) -> scalar_t (can be half) -+ using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< -+ typename SmemIteratorD0::Element, // ElementOutput -+ FragmentIteratorAccumulator::Fragment::kElements, -+ accum_t, // ElementAccumulator -+ typename SmemIteratorD0::Element, // ElementCompute -+ cutlass::epilogue::thread::ScaleType::Nothing>; -+ using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< -+ SmemIteratorD0, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, // ScaleBiasIterator - not used -+ OutputOpNoOp>; -+ -+ // Epilogue 2: with LSE (for backwards pass) -+ static int const kElementsPerAccess = 2; // TODO: Why 2? -+ using IteratorAccumulatorLSE = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ // Shape -+ cutlass::MatrixShape, -+ // WarpShape -+ cutlass::MatrixShape, -+ lse_scalar_t, -+ cutlass::layout::RowMajor, -+ kElementsPerAccess>>; -+ using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp< -+ scalar_t, // ElementOutput_ -+ lse_scalar_t, // ElementLSE_ -+ accum_t, // ElementAccumulator_ -+ accum_t, // ElementCompute_ -+ 128 / cutlass::sizeof_bits::value -+ // FragmentIteratorAccumulator::Fragment::kElements -+ // InstructionShape::kM * InstructionShape::kN / 32 -+ >; -+ using EpilogueWithLSE = -+ cutlass::epilogue::threadblock::EpilogueSmemAccumulator< -+ SmemIteratorD0, -+ FragmentIteratorAccumulator, -+ IteratorAccumulatorLSE, -+ EpilogueOpApplyLSE>; -+ -+ static void CUTLASS_DEVICE accumToSmem( -+ AccumulatorSharedStorage& shared_storage, -+ FragmentC const& accum, -+ int lane_id, -+ cutlass::MatrixCoord const& tile_coords) { -+ SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); -+ smem_iterator_attn.add_tile_offset( -+ tile_coords * -+ cutlass::MatrixCoord{ -+ SmemIteratorD0::TileIterations::kRow, -+ SmemIteratorD0::TileIterations::kColumn}); -+ Epilogue epilogue; -+ epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); -+ } -+ -+ static void CUTLASS_DEVICE accumApplyLSEToSmem( -+ AccumulatorSharedStorage& shared_storage, -+ FragmentC& accum, -+ lse_scalar_t const* lse, -+ int32_t lse_extents, -+ int thread_id, -+ int warp_id, -+ int lane_id, -+ cutlass::MatrixCoord const& tile_coords) { -+ constexpr int32_t kAlignLSE = 32; -+ IteratorAccumulatorLSE iterator_lse( -+ lse, -+ {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE}, -+ thread_id, -+ warp_id, -+ cutlass::MatrixCoord{0, 0} // offset -+ ); -+ -+ SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); -+ smem_iterator_attn.add_tile_offset( -+ tile_coords * -+ cutlass::MatrixCoord{ -+ SmemIteratorD0::TileIterations::kRow, -+ SmemIteratorD0::TileIterations::kColumn}); -+ EpilogueWithLSE epilogue; -+ EpilogueOpApplyLSE minus_lse_exp({}); -+ epilogue( -+ minus_lse_exp, -+ smem_iterator_attn, -+ accum, -+ // scale - unused -+ iterator_lse, -+ // bias -+ iterator_lse); -+ } -+}; -+ -+// Volta Specialization -+// only supported for f16 -+template -+struct B2bGemm< -+ cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< -+ cutlass::MatrixShape<32, 32>, -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ cutlass::MatrixShape<1, 1>>, -+ Operator, -+ cutlass::half_t, -+ WarpShape_, -+ ThreadblockShape_> { -+ using IteratorC = -+ cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< -+ cutlass::MatrixShape<32, 32>, -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ cutlass::MatrixShape<1, 1>>; -+ using scalar_t = cutlass::half_t; -+ using accum_t = IteratorC::Element; -+ using WarpShape = WarpShape_; -+ using ThreadblockShape = ThreadblockShape_; -+ using FragmentC = IteratorC::Fragment; -+ using lse_scalar_t = float; -+ -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< -+ WarpShape, -+ cutlass::gemm::GemmShape<32, 32, 4>, -+ scalar_t, -+ SmemAccumulatorLayout>; -+ -+ // // Storage in shared-memory for Q.Kt -+ using AccumulatorSharedStorage = -+ cutlass::gemm::threadblock::AccumulatorSharedStorage< -+ ThreadblockShape, -+ scalar_t, -+ cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ 16, -+ 32>, // typename SmemIteratorD0::TensorLayout, -+ cutlass::MatrixShape<0, 0> // Padding -+ >; -+ -+ using OutputLayout = -+ cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; -+ using TensorRef = cutlass::TensorRef; -+ using Policy = typename IteratorC::Policy; -+ using Element = accum_t; -+ // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields -+ // Let's copy their values -+ static int const kElementsPerPartial = 4; -+ using EleShapePerPatial = typename cutlass::platform::conditional< -+ cutlass::platform::is_same::value, -+ cutlass::MatrixShape<2, 2>, -+ cutlass::MatrixShape<1, 4>>::type; -+ static int const kElementsPerMma = 8; -+ static int const kAccumulatorPatials = 2; -+ using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; -+ -+ static void CUTLASS_DEVICE accumToSmem( -+ AccumulatorSharedStorage& shared_storage, -+ FragmentC const& accum, -+ int lane_id, -+ cutlass::MatrixCoord const& tile_coords) { -+ // ctor - from MmaVoltaTensorOpAccumulatorTileIterator -+ TensorRef ref_(shared_storage.accum_ref()); -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ int accum_m, accum_n; -+ -+ if (cutlass::platform::is_same::value) { -+ // (quad[2],quad[0])+lane_in_quad[0] -+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); -+ // (quad[1])+lane_in_quad[1] -+ accum_n = -+ ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + -+ (lane_in_quad & 2); -+ } else { -+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + -+ lane_in_quad; // (quad[2],quad[0]) -+ accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; -+ } -+ cutlass::MatrixCoord lane_offset(accum_m, accum_n); -+ -+ // Tile offset -+ ref_.add_coord_offset( -+ tile_coords * -+ cutlass::MatrixCoord( -+ {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); -+ -+ using AccessType = cutlass::Array; -+ -+ // store - from MmaVoltaTensorOpAccumulatorTileIterator -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ int mma_accum_start = -+ (((tile_n * Policy::TileIterations::kRow + tile_m) * -+ Policy::MmaIterations::kColumn + -+ mma_n) * -+ Policy::MmaIterations::kRow + -+ mma_m) * -+ kElementsPerMma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < kAccumulatorPatials; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < EleShapePerPatial::kRow; ++m) { -+ int accum_m = tile_m * Policy::InterleavedTile::kRow + -+ mma_m * QuadShapePerPatialMma::kRow + m * 2; -+ int accum_n = tile_n * Policy::InterleavedTile::kColumn + -+ mma_n * QuadShapePerPatialMma::kColumn + -+ p * Policy::InterleavedTile::kColumn / 2; -+ int r = (accum_m + lane_offset.row()); -+ AccessType to_store; -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { -+ int idx = mma_accum_start + p * kElementsPerPartial + -+ m * EleShapePerPatial::kColumn + n; -+ int c = (accum_n + n + lane_offset.column()); -+ to_store[n] = scalar_t(accum[idx]); -+ } -+ int c = (accum_n + lane_offset.column()); -+ assert(r < 32); -+ assert(c < 32); -+ *reinterpret_cast( -+ ref_.data() + ref_.offset({r, c})) = to_store; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ static void CUTLASS_DEVICE accumApplyLSEToSmem( -+ AccumulatorSharedStorage& shared_storage, -+ typename IteratorC::Fragment& accum, -+ lse_scalar_t const* lse, -+ int lse_extent, -+ int thread_id, -+ int warp_id, -+ int lane_id, -+ cutlass::MatrixCoord const& tile_coords) { -+ // Non-optimized way to apply LSE to registers -+ // NOTE: accum is attn.T -+ // TODO: Optimize for each architecture -+ static constexpr int WarpSize = 32; -+ using RegistersIter = typename DefaultAttentionScalingCoefsUpdater< -+ IteratorC, -+ accum_t, -+ WarpSize>::Updater; -+ auto lane_offset = -+ RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords); -+ -+ cutlass::Array lse_prefetched; -+ lse_prefetched.clear(); -+ int rowIdx = 0; -+ int colIdx = 0; -+ RegistersIter::iterateRows( -+ lane_offset, -+ [&](int accum_m) { -+ ++rowIdx; -+ colIdx = 0; -+ }, -+ [&](int accum_m, int accum_n, int idx) { -+ if (rowIdx == 1) { -+ lse_prefetched[colIdx] = accum_n < lse_extent -+ ? lse[accum_n] -+ : platform::numeric_limits::infinity(); -+ } -+ accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); -+ ++colIdx; -+ }, -+ [&](int accum_m) {}); -+ accumToSmem(shared_storage, accum, lane_id, tile_coords); -+ } -+}; -+ -+// Simt Specialization -+// for f32 on Sm70-Sm75 and f16/f32 below -+ -+template < -+ typename Operator, -+ typename OperatorPolicy, -+ typename scalar_t, -+ typename WarpShape_, -+ typename ThreadblockShape_> -+struct B2bGemm< -+ cutlass::gemm::warp::MmaSimtTileIterator< -+ cutlass::MatrixShape<32, 32>, -+ cutlass::gemm::Operand::kC, -+ float, -+ cutlass::layout::RowMajor, -+ OperatorPolicy, -+ 1, -+ 1>, -+ Operator, -+ scalar_t, -+ WarpShape_, -+ ThreadblockShape_> { -+ using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< -+ cutlass::MatrixShape<32, 32>, -+ cutlass::gemm::Operand::kC, -+ float, -+ cutlass::layout::RowMajor, -+ OperatorPolicy, -+ 1, -+ 1>; -+ using accum_t = typename IteratorC::Element; -+ using WarpShape = WarpShape_; -+ using ThreadblockShape = ThreadblockShape_; -+ using FragmentC = typename IteratorC::Fragment; -+ using lse_scalar_t = float; -+ -+ // Storage in shared-memory for Q.Kt -+ using AccumulatorSharedStorage = -+ cutlass::gemm::threadblock::AccumulatorSharedStorage< -+ ThreadblockShape, -+ scalar_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::MatrixShape<0, 0> // Padding -+ >; -+ -+ static void CUTLASS_DEVICE accumToSmem( -+ AccumulatorSharedStorage& shared_storage, -+ FragmentC const& accum, -+ int lane_id, -+ cutlass::MatrixCoord const& tile_coords) { -+ using Policy = typename IteratorC::Policy; -+ using Element = typename IteratorC::Element; -+ using Iterations = typename IteratorC::Iterations; -+ using Delta = typename IteratorC::Delta; -+ -+ auto ref_ = shared_storage.accum_ref(); -+ // ctor - MmaSimtTileIterator -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); -+ -+ ref_.add_coord_offset(lane_offset); -+ -+ // Tile offset -+ ref_.add_coord_offset( -+ tile_coords * -+ cutlass::MatrixCoord( -+ {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); -+ -+ // store - MmaSimtTileIterator -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { -+ int r = -+ Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + -+ m; -+ int c = mma_n * Delta::kColumn + n; -+ int idx = n + -+ Policy::LaneMmaShape::kN * -+ (mma_n + -+ Iterations::kColumn * -+ (m + mma_m * Policy::LaneMmaShape::kM)); -+ ref_.at({r, c}) = scalar_t(accum[idx]); -+ } -+ } -+ } -+ } -+ } -+ -+ static void CUTLASS_DEVICE accumApplyLSEToSmem( -+ AccumulatorSharedStorage& shared_storage, -+ typename IteratorC::Fragment& accum, -+ lse_scalar_t const* lse, -+ int lse_extent, -+ int thread_id, -+ int warp_id, -+ int lane_id, -+ cutlass::MatrixCoord const& tile_coords) { -+ // Non-optimized way to apply LSE to registers -+ // NOTE: accum is attn.T -+ // TODO: Optimize for each architecture -+ static constexpr int WarpSize = 32; -+ using RegistersIter = typename DefaultAttentionScalingCoefsUpdater< -+ IteratorC, -+ accum_t, -+ WarpSize>::Updater; -+ auto lane_offset = -+ RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords); -+ -+ cutlass::Array lse_prefetched; -+ lse_prefetched.clear(); -+ int rowIdx = 0; -+ int colIdx = 0; -+ RegistersIter::iterateRows( -+ lane_offset, -+ [&](int accum_m) { -+ ++rowIdx; -+ colIdx = 0; -+ }, -+ [&](int accum_m, int accum_n, int idx) { -+ if (rowIdx == 1) { -+ lse_prefetched[colIdx] = accum_n < lse_extent -+ ? lse[accum_n] -+ : platform::numeric_limits::infinity(); -+ } -+ accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); -+ ++colIdx; -+ }, -+ [&](int accum_m) {}); -+ accumToSmem(shared_storage, accum, lane_id, tile_coords); -+ } -+}; -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h -new file mode 100644 -index 0000000..c3a2d9b ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h -@@ -0,0 +1,57 @@ -+#include -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+template < -+ typename scalar_t, // scalar type -+ typename ThreadblockTileShape, // size of tile to load -+ int Threads, // number of participating threads -+ int ElementsPerAccess> // thread access width in elements -+class TileSmemLoader { -+ public: -+ using SmemTile = -+ cutlass::AlignedBuffer; -+ -+ using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape< -+ ThreadblockTileShape::kColumn, // contiguous -+ ThreadblockTileShape::kRow>, // strided -+ Threads, // Threads -+ ElementsPerAccess>; // ElementsPerAccess -+ -+ using GmemTileIterator = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ ThreadblockTileShape, // Shape -+ scalar_t, // Element -+ cutlass::layout::RowMajor, // Layout -+ 0, // AdvanceRank -+ ThreadMap>; // ThreadMap -+ -+ using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< -+ ThreadblockTileShape, // Shape -+ scalar_t, // Element -+ cutlass::layout::RowMajor, // Layout -+ 0, // AdvanceRank -+ ThreadMap>; // ThreadMap -+ -+ using Fragment = typename GmemTileIterator::Fragment; -+ -+ /// load a tile from global memory into shared memory -+ CUTLASS_DEVICE -+ static void load( -+ GmemTileIterator tile_load_iter, -+ SmemTileIterator tile_store_iter) { -+ Fragment tb_frag; -+ tb_frag.clear(); -+ tile_load_iter.load(tb_frag); -+ tile_store_iter.store(tb_frag); -+ -+ __syncthreads(); -+ } -+}; -\ No newline at end of file -diff --git a/3rdparty/cutlass/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu b/3rdparty/cutlass/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu -new file mode 100644 -index 0000000..5f16ff1 ---- /dev/null -+++ b/3rdparty/cutlass/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu -@@ -0,0 +1,706 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+This example shows how to run group convolution kernels using functions and data structures -+provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU. -+ -+There are 2 group conv mode: -+ 1. cutlass::conv::GroupMode::kSingleGroup -+ This mode is for large K problem size: k_per_group (K/groups) equals or larger than -+ threadblock_tile_N. One or multiple threadblocks calculate data of one group. -+ 2. cutlass::conv::GroupMode::kMultipleGroup -+ This mode is for small K problem size: k_per_group (K/groups) is smaller than threadblock_tile_N. -+ One threadblock will calculate data from more than one group. -+ -+Function profile_convolution_selecter() shows how to choose kernel with different group mode according -+to problem size and threadblock_tile size. -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+// Analytic kernel and operation for single group problem size -+using AnalyticSingleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+>::Kernel; -+using AnalyticSingleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution; -+ -+// Analytic kernel and operation for multiple group problem size -+using AnalyticMultipleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kMultipleGroup, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+>::Kernel; -+using AnalyticMultipleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution; -+ -+// Optimized kernel and operation for single group problem size -+using OptimizedSingleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+>::Kernel; -+using OptimizedSingleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ int groups; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool optimized; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ groups(1), -+ reference_check(false), -+ measure_performance(false), -+ iterations(20), -+ alpha(1), -+ beta(0), -+ optimized(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("optimized")) { -+ optimized = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ -+ cmd.get_cmd_line_argument("g", groups); -+ filter_size.c() = input_size.c() / groups; -+ -+ cmd.get_cmd_line_argument("u", conv_stride.row()); -+ cmd.get_cmd_line_argument("v", conv_stride.column()); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "42_ampere_tensorop_group_conv example\n\n" -+ << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" -+ << " forward grouped convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --g= Conv groups G\n\n" -+ << " --u= Conv stride_h\n\n" -+ << " --v= Conv stride_w\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --optimized If set (true), use optimized kernel, otherwise use analytic kernel.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=8 --ref-check\n\n" -+ << "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check\n\n" -+ << "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check --optimized\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,G,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << options.groups << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+template -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_d(options.output_size()); -+ cutlass::HostTensor tensor_ref_d(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(7), -+ ElementOutput(-8), -+ 0); -+ -+ // Fill tensor D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices, -+ options.groups -+ ); -+ -+ // Construct Conv2dOperation::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename Conv2dOperation::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ Conv2dOperation implicit_gemm_op; -+ -+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on device...\n"; -+ -+ // Compute with reference implementation -+ cutlass::reference::device::Conv2dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ tensor_ref_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_d.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Result profile_convolution_selecter(Options const &options) { -+ int k_per_group = options.filter_size.n() / options.groups; -+ -+ // In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups -+ if (k_per_group < ThreadblockShape::kN) { // MultipleGroup mode -+ if (options.optimized) { -+ std::cerr << "Invalid problem: optimized group conv kernel doesn't support MultipleGroup (one CTA calculate multiple groups) mode" << std::endl; -+ exit(-1); -+ } else { -+ std::cout << "Select AnalyticMultipleGroupOperation\n"; -+ return profile_convolution(options); -+ } -+ } else { // SingleGroup mode -+ if (options.optimized) { -+ std::cout << "Select OptimizedSingleGroupOperation\n"; -+ return profile_convolution(options); -+ } else { -+ std::cout << "Select AnalyticSingleGroupOperation\n"; -+ return profile_convolution(options); -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution_selecter(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu b/3rdparty/cutlass/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu -new file mode 100644 -index 0000000..90711df ---- /dev/null -+++ b/3rdparty/cutlass/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu -@@ -0,0 +1,740 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Block-Ell sparse gemm example. -+ -+ This example performs a Sparse-matrix dense-matrix multiplication (SpMM) operation. -+ Matrix A is stored in the Blocked-Ellpack (Blocked-ELL) storage format. -+ Details about the Blocked-Ellpack (Blocked-ELL) storage format can be found here: -+ https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-spmat-create-blockedell -+ Whereas matrix B is a dense matrix. -+ -+ Blocked-Ellpack or Blocked-ELL storage format comprises of two matrices. -+ First is a packed matrix (ellValue matrix) that stores non-zero values in consecutive blocks, -+ represented by tensor_a in this example. Second is a matrix of indices (ellColInd matrix), -+ represented by tensor_ell_idx in this example, that represent the column indices of the -+ corresponding non-zero blocks. All rows in the matrices must have the same number of blocks. -+ ellColInd can contain -1 values for indicating empty blocks. These matrices store elements in -+ row-major order. -+ -+ Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format -+ for this example: -+ a_rows - Rows in the sparse matrix. -+ a_cols - Colums in the sparse matrix. -+ a_ell_blocksize - Size of the ELL-Blocks. -+ a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns) -+ tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns) -+ tensor_ell_idx - Blocked-ELL Column indices (ellColInd), whose size is -+ (a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize) -+ tensor_b - Input dense matrix whose size is (a_cols * n) -+ tensor_c/tensor_d - Output dense matrix whose size is (a_rows * n) -+ {a_rows, n, a_cols} - Problem size -+ -+*/ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/ell_gemm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/host_uncompress.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool reference_check; -+ int iterations; -+ int cuda_streams; -+ int a_rows, n, a_cols; -+ int a_ell_num_columns; -+ int a_ell_blocksize; -+ int a_base; -+ float alpha; -+ float beta; -+ -+ // -+ // Methods -+ // -+ -+ Options(): -+ help(false), -+ reference_check(true), -+ iterations(20), -+ cuda_streams(0), -+ a_rows(1024), -+ n(1024), -+ a_cols(1024), -+ a_ell_num_columns(512), -+ a_ell_blocksize(16), -+ a_base(0), -+ alpha(1), -+ beta() -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("alpha", alpha, 1.0f); -+ cmd.get_cmd_line_argument("beta", beta, 0.0f); -+ cmd.get_cmd_line_argument("iterations", iterations, 20); -+ cmd.get_cmd_line_argument("streams", cuda_streams, 0); -+ cmd.get_cmd_line_argument("reference-check", reference_check, true); -+ -+ cmd.get_cmd_line_argument("a_rows", a_rows, 1024); -+ cmd.get_cmd_line_argument("n", n, 1024); -+ cmd.get_cmd_line_argument("a_cols", a_cols, 1024); -+ -+ cmd.get_cmd_line_argument("a_ell_num_columns", a_ell_num_columns, 512); -+ cmd.get_cmd_line_argument("a_ell_blocksize", a_ell_blocksize, 16); -+ cmd.get_cmd_line_argument("a_base", a_base, 0); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "43_ell_block_sparse_gemm\n\n" -+ << " This example profiles the performance of a ELL block sparse GEMM kernel.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --a_rows= Sets the number of the rows of the sparse matrix.\n" -+ << " --n= Sets the N dimension.\n" -+ << " --a_cols= Sets the number of columns of the sparse matrix.\n" -+ << " --a_ell_num_columns= Sets the actual number of columns of the Blocked-Ellpack format.\n" -+ << " --a_ell_blocksize= Sets the size of the ELL-Block.\n" -+ << " --a_base= Sets the base index.\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --reference-check= If true, performs reference check.\n"; -+ -+ out << "\n\nExamples:\n\n" -+ -+ << "# Runs a 1024x1024x1024 ELL block sparse GEMM with 16x16 block size and actual 512 non-zero columns in A operand\n" -+ << "$ ./examples/43_ell_block_sparse_gemm/43_ell_block_sparse_gemm --a_rows=1024 --n=1024 --a_cols=1024 --a_ell_num_columns=512 --a_ell_blocksize=16\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = (int64_t)a_rows * (int64_t)a_cols * (int64_t)n; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Testbed { -+public: -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using LayoutA = typename Gemm::LayoutA; -+ using LayoutB = typename Gemm::LayoutB; -+ using LayoutC = typename Gemm::LayoutC; -+ -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Options options; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_ELL; -+ uint32_t seed; -+ -+ cutlass::HostTensor tensor_a; -+ cutlass::HostTensor tensor_b; -+ cutlass::HostTensor tensor_c; -+ cutlass::HostTensor tensor_d; -+ -+ cutlass::HostTensor tensor_a_uncompressed; -+ cutlass::HostTensor reference_d; -+ -+ cutlass::HostTensor tensor_ell_idx; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ Testbed( -+ Options const &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_ELL_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), init_ELL(init_ELL_), seed(seed_) { } -+ -+private: -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor_( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian( -+ view, seed, Element(), Element(0.5f)); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ // Fill with increasing elements -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity(), Element(1), Element()); -+ } else { -+ -+ // Fill with all 1s -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity(), Element(), Element(1)); -+ } -+ } -+ -+ /// Initializes data structures -+ void initialize_() { -+ tensor_a.resize(cutlass::make_Coord(options.a_rows, options.a_ell_num_columns)); -+ tensor_b.resize(cutlass::make_Coord(options.a_cols, options.n)); -+ tensor_c.resize(cutlass::make_Coord(options.a_rows, options.n)); -+ tensor_d.resize(cutlass::make_Coord(options.a_rows, options.n)); -+ -+ tensor_a_uncompressed.resize(cutlass::make_Coord(options.a_rows, options.a_cols)); -+ reference_d.resize(cutlass::make_Coord(options.a_rows, options.n)); -+ -+ tensor_ell_idx.resize(cutlass::make_Coord(options.a_rows / options.a_ell_blocksize, -+ options.a_ell_num_columns / options.a_ell_blocksize)); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ initialize_tensor_(tensor_a.host_view(), init_A, seed * 2021); -+ initialize_tensor_(tensor_b.host_view(), init_B, seed * 2022); -+ initialize_tensor_(tensor_c.host_view(), init_C, seed * 2023); -+ -+ if (init_ELL == cutlass::Distribution::Uniform) { -+ cutlass::reference::host::TensorFillRandomEllIdx( -+ tensor_ell_idx.host_view(), seed, -+ options.a_rows / options.a_ell_blocksize, -+ options.a_ell_num_columns / options.a_ell_blocksize, -+ options.a_cols / options.a_ell_blocksize); -+ -+ } else { -+ for(int i = 0; i < options.a_rows / options.a_ell_blocksize; ++i) { -+ for(int j = 0; j < options.a_ell_num_columns / options.a_ell_blocksize; ++j) { -+ tensor_ell_idx.at({i, j}) = j+3; -+ } -+ } -+ } -+ -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ell_idx.sync_device(); -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify_() { -+ -+ bool passed = true; -+ -+ tensor_d.sync_host(); -+ -+ cutlass::uncompress_ell_block_sparse( -+ tensor_a_uncompressed.host_ref(), -+ tensor_a.host_ref(), -+ tensor_ell_idx.host_ref(), -+ options.a_rows, -+ options.a_cols, -+ options.a_ell_num_columns, -+ options.a_ell_blocksize -+ ); -+ -+ cutlass::reference::host::Gemm< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, -+ ElementCompute, -+ ElementAccumulator, typename Gemm::Operator> -+ reference_gemm; -+ -+ reference_gemm( -+ {options.a_rows, options.n, options.a_cols}, -+ options.alpha, -+ tensor_a_uncompressed.host_ref(), -+ tensor_b.host_ref(), -+ options.beta, -+ reference_d.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ // Reference check -+ passed = cutlass::reference::host::TensorEquals(tensor_d.host_view(), reference_d.host_view()); -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; -+ -+ std::stringstream fname; -+ -+ fname << "error_43_ell_block_sparse_gemm" -+ << "mnk_" -+ << options.a_rows << "x" -+ << options.n << "x" -+ << options.a_cols << "_" -+ << options.a_ell_num_columns << "_" -+ << options.a_ell_blocksize << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results -+ << "alpha: " << ElementCompute(options.alpha) << "\n" -+ << "beta: " << ElementCompute(options.beta) << "\n" -+ << "block size: " << options.a_ell_blocksize << "\n" -+ << "\nA:\n" << tensor_a.host_view() << "\n" -+ << "\nA Ell Index:\n" << tensor_ell_idx.host_view() << "\n" -+ << "\nB:\n" << tensor_b.host_view() << "\n" -+ << "\nC:\n" << tensor_c.host_view() << "\n" -+ << "\nD reference:\n" << reference_d.host_view() << "\n" -+ << "\nD computed:\n" << tensor_d.host_view() << "\n"; -+ -+ -+ return passed; -+ } -+ -+ return passed; -+ } -+ -+public: -+ -+ /// Returns the number of threadblocks to launch if the kernel can run on the target -+ /// device. Otherwise, returns zero. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes a BlockedEll SpMM kernel and measures runtime. -+ Result profile() { -+ -+ Result result; -+ -+ // Early exit -+ if (!sufficient()) { -+ std::cout << "Active CUDA device lacks hardware resources to run CUTLASS BlockedEll SpMM kernel." << std::endl; -+ return result; -+ } -+ -+ result.passed = false; -+ -+ // Initialize the problem -+ initialize_(); -+ -+ // Configure the GEMM arguments -+ typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); -+ -+ // Configure GEMM arguments -+ typename Gemm::Arguments args( -+ {options.a_rows, options.n, options.a_cols}, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ tensor_ell_idx.device_data(), -+ options.a_ell_num_columns, -+ options.a_ell_blocksize, -+ options.a_base, -+ epilogue_op -+ ); -+ -+ // Initialize the GEMM object -+ Gemm gemm; -+ -+ result.status = gemm.initialize(args); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize CUTLASS BlockedEll SpMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Run the BlockedEll SpMM object -+ result.status = gemm.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS BlockedEll SpMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ -+ if (options.reference_check) { -+ result.passed = verify_(); -+ } -+ -+ // -+ // Warm-up run -+ // -+ result.status = gemm.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS BlockedEll SpMM kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ gemm(); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ std::cout << std::endl; -+ std::cout << "ELL Block Sparse GEMM (CUTLASS):\n" -+ << "====================================================" << std::endl; -+ -+ std::cout << std::endl; -+ std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << " GFLOPs: " << result.gflops << std::endl; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // -+ // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. -+ // -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { -+ -+ // -+ // This example requires an NVIDIA Ampere-architecture GPU. -+ // -+ -+ std::cout -+ << "CUTLASS's BlockedEll SpMM example requires a GPU of NVIDIA's Ampere Architecture or " -+ << "later (compute capability 80 or greater).\n"; -+ -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // -+ // Define the BlockedEll type -+ // -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ constexpr int32_t kAlignmentA = 128 / cutlass::sizeof_bits::value; -+ constexpr int32_t kAlignmentB = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ constexpr int32_t kStages = 4; -+ using Gemm = typename cutlass::gemm::device::EllGemm< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementOutput, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ kStages, kAlignmentA, kAlignmentB>; -+ -+ // -+ // Profile it -+ // -+ -+ Testbed testbed(options); -+ -+ if (!testbed.sufficient()) { -+ std::cout << "The active CUDA device lacks sufficient hardware resources to execute this kernel.\n"; -+ return 0; -+ } -+ -+ Result result = testbed.profile(); -+ if (!result.passed) { -+ std::cout << "Profiling CUTLASS ELL block sparse GEMM has failed.\n"; -+ std::cout << "\nFailed\n"; -+ return -1; -+ } -+ -+ std::cout << "\nPassed\n"; -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h -new file mode 100644 -index 0000000..fead537 ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h -@@ -0,0 +1,154 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" -+ -+// #include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" -+ -+#include "fused_bias_act_epilogue.h" -+#include "../warp/fused_bias_act_fragment_iterator_tensor_op.h" -+#include "output_tile_thread_map_for_fused_bias.h" -+#include "default_thread_map_tensor_op_for_fused_bias.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultFusedBiasActEpilogueTensorOp { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOpForFusedBias< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ using AccumulatorFragmentIterator = typename std::conditional::value, -+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC>, -+ cutlass::epilogue::warp::FusedBiasActFragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC> >::type; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::FusedBiasActEpilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ OutputOp -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h -new file mode 100644 -index 0000000..d9ce0f8 ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h -@@ -0,0 +1,113 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template < -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ int PartitionsK, -+ typename Element_, -+ int ElementsPerAccess -+> -+struct DefaultThreadMapTensorOpForFusedBias { -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using Element = Element_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ -+ /// Tensor Operations fundamentally perform operations on 8 rows -+ static int const kTensorOpRows = 8; -+ static int const kWarpSize = 32; -+ -+ static_assert( -+ !(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kM % WarpShape::kM), "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ ThreadblockShape::kM / WarpShape::kM, -+ ThreadblockShape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap -+ using Type = OutputTileOptimalThreadMapBiasAct < -+ OutputTileShape, -+ OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>, -+ Detail::kThreads, -+ kElementsPerAccess, -+ sizeof_bits::value -+ >; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h -new file mode 100644 -index 0000000..8b9c24c ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h -@@ -0,0 +1,222 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator without splitk -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename OutputOp_ ///< Output operator -+> -+class FusedBiasActEpilogue { -+ -+public: -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using OutputOp = OutputOp_; -+ -+ /// Output layout is always row-major -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ -+public: -+ -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ FusedBiasActEpilogue( -+ ){ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile -+ AccumulatorTile & fused_bias_act_accumlators, -+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ bool need_bias = output_op.is_source_needed(); -+ -+ if (need_bias) -+ compute_source_needed_(output_op, accumulators, fused_bias_act_accumlators, source_iterator); -+ else -+ compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators); -+ -+ -+ } -+ -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile -+ AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators); -+ } -+ -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const &output_op, ///< Output operator -+ AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile -+ AccumulatorTile & fused_bias_act_accumlators, -+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ typename OutputTileIterator::Fragment source_fragment; -+ -+ -+ source_fragment.clear(); -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ source_iterator.load(source_fragment); -+ ++source_iterator; -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment; -+ fused_bias_act_fragment = output_op(accum_fragment, source_fragment); -+ -+ fused_bias_act_fragment_iterator.store(fused_bias_act_fragment); -+ ++fused_bias_act_fragment_iterator; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void compute_source_no_needed_( -+ OutputOp const &output_op, ///< Output operator -+ AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile -+ AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators); -+ -+ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < AccumulatorFragmentIterator::kIterations; ++iter) { -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment; -+ fused_bias_act_fragment = output_op(accum_fragment); -+ -+ fused_bias_act_fragment_iterator.store(fused_bias_act_fragment); -+ ++fused_bias_act_fragment_iterator; -+ } -+ } -+ -+}; -+ -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h -new file mode 100644 -index 0000000..66a6a34 ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h -@@ -0,0 +1,311 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/fast_math.h" -+ -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// RowArrangement determines how one or more warps cover a region of consecutive rows. -+template < -+ typename Shape, -+ int WarpsRemaining, -+ int ElementsPerAccess, -+ int ElementSize, -+ bool Is2dTile -+> -+struct RowArrangementBiasAct; -+ -+/// RowArrangement in which each warp's access is a 1D tiled arrangement. -+template < -+ typename Shape, -+ int WarpsRemaining, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct RowArrangementBiasAct { -+ static int const kWarpSize = 32; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ static int const kIterationsRow = 1; -+ static int const kDeltaRow = 1; -+ static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize; -+ static int const kDeltaColumn = kWarpSize * kElementsPerAccess; -+ -+ static int const kAccessWidth = kWarpSize; -+ static int const kAccessRows = 1; -+ static int const kWarpPartitionsRow = 1; -+ static int const kWarpPartitionsColumn = WarpsRemaining; -+}; -+ -+/// RowArrangement in which each warp's access is a 2D tiled arrangement. -+template < -+ typename Shape, -+ int WarpsRemaining, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct RowArrangementBiasAct { -+ -+ static int const kMemoryAccessSize = 4;//128; -+ static int const kWarpSize = 32; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ struct Detail { -+ static int const kShapeRow = Shape::kRow / WarpsRemaining; -+ static int const kShapeWidth = Shape::kColumn / kElementsPerAccess; -+ -+ static int const kTargetMemoryAccessWidth = -+ kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8); -+ -+ static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth; -+ }; -+ -+ static int const kAccessWidth = -+ (Detail::kTargetAccessRows > Detail::kShapeRow ? -+ kWarpSize / Detail::kShapeRow -+ : const_min( -+ Detail::kShapeWidth, -+ const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8)) -+ )); -+ -+ static int const kAccessRows = -+ (Detail::kTargetAccessRows > Detail::kShapeRow ? -+ Detail::kShapeRow -+ : const_min(Shape::kRow, kWarpSize / kAccessWidth)); -+ -+ static int const kIterationsRow = Detail::kShapeRow / kAccessRows; -+ static int const kDeltaRow = kAccessRows; -+ -+ static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth; -+ static int const kDeltaColumn = kAccessWidth * kElementsPerAccess; -+ -+ static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access"); -+ static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" ); -+ static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" ); -+ -+ static int const kWarpPartitionsRow = 1; -+ static int const kWarpPartitionsColumn = 1; -+}; -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template metaprogram for partitioning a 4D space across warps to achieve several performance -+/// objectives: -+/// -+/// - coalesced memory accesses in units of 16 Byte lines -+/// - minimal address arithmetic -+/// - minimal predicate calculations -+/// -+template < -+ typename Shape_, -+ typename Count_, -+ int Threads, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct OutputTileOptimalThreadMapBiasAct { -+ -+ using Shape = Shape_; -+ using Count = Count_; -+ -+ static int const kWarpSize = 32; -+ static int const kThreads = Threads; -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ // -+ // Metaprogram computation -+ // -+ -+ struct Detail { -+ -+ // Clusters -+ static int const kIterationsCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ Shape::kCluster / kWarpCount -+ : 1); -+ -+ static int const kDeltaCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster -+ : 1); -+ -+ static int const kCompactedDeltaCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster -+ : 1); -+ -+ static int const kWarpPartitionsCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ kWarpCount -+ : kWarpCount / Shape::kCluster); -+ -+ static int const kWarpsRemainingForGroups = -+ ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster); -+ -+ // Groups -+ static int const kIterationsGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ Shape::kGroup / kWarpsRemainingForGroups -+ : 1); -+ -+ static int const kDeltaGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup -+ : 1); -+ -+ static int const kCompactedDeltaGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ Shape::kRow * Shape::kGroup / kIterationsGroup -+ : 1); -+ -+ static int const kWarpPartitionsGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ 1 -+ : kWarpsRemainingForGroups / Shape::kGroup); -+ -+ static int const kWarpsRemainingForRows = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ 1 -+ : kWarpsRemainingForGroups / Shape::kGroup); -+ -+ // Rows -+ using RowArrangement = detail::RowArrangementBiasAct< -+ Shape, -+ kWarpsRemainingForRows, -+ kElementsPerAccess, -+ kElementSize, -+ (Shape::kRow > kWarpsRemainingForRows) -+ >; -+ -+ // Warp partitions -+ using WarpPartitions = OutputTileShape< -+ RowArrangement::kWarpPartitionsColumn, -+ RowArrangement::kWarpPartitionsRow, -+ kWarpPartitionsGroup, -+ kWarpPartitionsCluster, -+ 1>; -+ -+ static int const kAccessWidth = RowArrangement::kAccessWidth; -+ static int const kAccessRows = RowArrangement::kAccessRows; -+ }; -+ -+ // -+ // Output -+ // -+ -+ using Iterations = OutputTileShape< -+ Detail::RowArrangement::kIterationsColumn, -+ Detail::RowArrangement::kIterationsRow, -+ Detail::kIterationsGroup, -+ Detail::kIterationsCluster, -+ 1>; -+ -+ using Delta = OutputTileShape< -+ Detail::RowArrangement::kDeltaColumn, -+ Detail::RowArrangement::kDeltaRow, -+ Detail::kDeltaGroup, -+ Detail::kDeltaCluster, -+ 1>; -+ -+ /// Initial offset function -+ CUTLASS_HOST_DEVICE -+ static MatrixCoord initial_offset(int thread_idx) { -+ -+ int warp_idx = thread_idx / kWarpSize; -+ int lane_idx = thread_idx % kWarpSize; -+ -+ // Compute warp location -+ int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster; -+ int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster; -+ -+ int group_idx = residual_cluster / Detail::WarpPartitions::kGroup; -+ int residual_group = residual_cluster % Detail::WarpPartitions::kGroup; -+ -+ int row_idx = residual_group / Detail::WarpPartitions::kRow; -+ int col_idx = residual_group % Detail::WarpPartitions::kRow; -+ -+ // Compute per-lane offset -+ int lane_row_offset = lane_idx / Detail::kAccessWidth; -+ int lane_col_offset = lane_idx % Detail::kAccessWidth; -+ -+ // Compute coordinate in output space -+ int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup; -+ int group_offset = group_idx * Shape::kRow * Count::kRow; -+ int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows; -+ int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess; -+ -+ return MatrixCoord( -+ cluster_offset + group_offset + row_offset + lane_row_offset, -+ (column_offset + lane_col_offset) * kElementsPerAccess -+ ); -+ } -+ -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h -new file mode 100644 -index 0000000..9d7a6c7 ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h -@@ -0,0 +1,189 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) -+ typename Layout ///< target shared memory layout -+> -+class FusedBiasActFragmentIteratorTensorOp; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_, ///< shape of the warp-level GEMM tile -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array) -+> -+class FusedBiasActFragmentIteratorTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorElementC = OperatorElementC_; -+ using OperatorFragmentC = OperatorFragmentC_; -+ using Layout = layout::RowMajor; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ OperatorElementC, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ OperatorElementC, -+ OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>; -+ -+ using OutputAccumulatorTile = AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FusedBiasActFragmentIteratorTensorOp(AccumulatorTile &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FusedBiasActFragmentIteratorTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FusedBiasActFragmentIteratorTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ int index = index_ + index_offset; -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int accumulator_access_offset = -+ index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; -+ -+ frag_ptr[n] = accumulators_[accumulator_access_offset]; -+ } -+ } -+ /// Stores a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void store(Fragment &frag, int index_offset = 0) const { -+ -+ int index = index_ + index_offset; -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int accumulator_access_offset = -+ index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; -+ -+ accumulators_[accumulator_access_offset] = frag_ptr[n]; -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h -new file mode 100644 -index 0000000..05a4c90 ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h -@@ -0,0 +1,427 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_conversion.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Size of the accumulation tile shape (concept: MatrixShape) -+ typename AccumulatorShape_, -+ /// KBlocks columns to compute residual -+ int KBlocksColumn_, -+ /// Accumulator Element type -+ typename ElementAccumulator_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Whether beta is zero -+ bool IsBetaZero_ > -+class MmaTensorOpPureFragmentIterator; -+ -+ -+// Partial specialization for col-major accumulator tile -+// And Element type is the same as Accumulator Element type -+ -+template < -+ /// Shape of warp tile to load (concept: MatrixShape) -+ typename Shape_, -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ typename AccumulatorShape_, -+ /// KBlocks columns to compute residual -+ int KBlocksColumn_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_> -+class MmaTensorOpPureFragmentIterator { -+ public: -+ -+ /// Shape of warp tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ using AccumulatorShape = AccumulatorShape_; -+ -+ /// KBlocks columns to compute residual -+ static int const kKBlockColumn = KBlocksColumn_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Whether beta is zero -+ static bool const IsBetaZero = true; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ static_assert( -+ !(AccumulatorShape::kRow % Shape::kRow) && -+ !(AccumulatorShape::kColumn % Shape::kColumn), -+ "Shape of Warp Accumulator must be divisible by warp shape."); -+ static_assert( -+ !(kKBlockColumn % Shape::kColumn), -+ "KBlock size must be divisible by warp shape."); -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = AccumulatorShape::kCount / Shape::kCount; -+ }; -+ -+private: -+ -+ static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; -+ -+ /// Number of mma operations performed by a warp -+ using MmaIterations = MatrixShape; -+ /// Number of mma operations performed by the entire accumulator -+ using AccumulatorIterations = MatrixShape; -+ -+ /// Number of K iterations -+ static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; -+ static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; -+ static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn -+ * (AccumulatorShape::kRow / Shape::kRow); -+ static int const kResidualIndex = kResidualColumn / Shape::kColumn -+ * (AccumulatorShape::kRow / Shape::kRow); -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array; -+ -+ /// Accumulator Fragment object -+ using AccumulatorFragment = Array; -+ -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+ /// Used to access residual tile first -+ bool is_residual_tile_; -+ -+public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum) -+ : accumulators_(reinterpret_cast(&accum)), -+ index_(0), is_residual_tile_(true) {} -+ -+ /// Add offset -+ CUTLASS_HOST_DEVICE -+ void add_offset(int index_offset) { -+ index_ += index_offset; -+ if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { -+ index_ = index_ - kKBlockColumnIterations + kResidualIndex; -+ is_residual_tile_ = false; -+ } -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpPureFragmentIterator &operator++() { -+ add_offset(1); -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpPureFragmentIterator &operator--() { -+ add_offset(-1); -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ AccessType src_fragment; -+ src_fragment.clear(); -+ -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; -+ int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow -+ * MmaIterations::kColumn; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; n++) { -+ for (int m = 0; m < MmaIterations::kRow; m++) { -+ int accumulator_access_offset = -+ (n + index_n) * AccumulatorIterations::kRow + m + index_m; -+ -+ frag_ptr[n * MmaIterations::kRow + m].clear(); -+ if(!(is_residual_tile_ && index_ >= kResidualIndex)) -+ frag_ptr[n * MmaIterations::kRow + m] = accumulators_[accumulator_access_offset]; -+ // frag_ptr[n * MmaIterations::kRow + m] = output_op(accumulators_[accumulator_access_offset], src_fragment); -+ } -+ } -+ } -+ -+}; -+ -+// Partial specialization for row-major accumulator tile -+ -+template < -+ /// Shape of warp tile to load (concept: MatrixShape) -+ typename Shape_, -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ typename AccumulatorShape_, -+ /// KBlocks columns to compute residual -+ int KBlocksColumn_, -+ /// Accumulator Element type -+ typename ElementAccumulator_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_> -+class MmaTensorOpPureFragmentIterator { -+ public: -+ -+ /// Shape of warp tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ using AccumulatorShape = AccumulatorShape_; -+ -+ /// KBlocks columns to compute residual -+ static int const kKBlockColumn = KBlocksColumn_; -+ -+ /// Accumulator Element type -+ using ElementAccumulator = ElementAccumulator_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Whether beta is zero -+ static bool const IsBetaZero = true; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ static_assert( -+ !(AccumulatorShape::kRow % Shape::kRow) && -+ !(AccumulatorShape::kColumn % Shape::kColumn), -+ "Shape of Warp Accumulator must be divisible by warp shape."); -+ static_assert( -+ !(kKBlockColumn % Shape::kColumn), -+ "KBlock size must be divisible by warp shape."); -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = AccumulatorShape::kCount / Shape::kCount; -+ }; -+ -+private: -+ -+ static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; -+ -+ /// Number of mma operations performed by a warp -+ using MmaIterations = MatrixShape; -+ /// Number of mma operations performed by the entire accumulator -+ using AccumulatorIterations = MatrixShape; -+ -+ /// Number of K iterations -+ static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; -+ static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; -+ static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn -+ * (AccumulatorShape::kRow / Shape::kRow); -+ static int const kResidualIndex = kResidualColumn / Shape::kColumn -+ * (AccumulatorShape::kRow / Shape::kRow); -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array; -+ -+ /// Accumulator Fragment object -+ using AccumulatorFragment = Array; -+ -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ using FragmentAccessType = Array; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+ /// Used to access residual tile first -+ bool is_residual_tile_; -+ -+public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum) -+ : accumulators_(reinterpret_cast(&accum)), -+ index_(0), is_residual_tile_(true) {} -+ -+ /// Add offset -+ CUTLASS_HOST_DEVICE -+ void add_offset(int index_offset) { -+ index_ += index_offset; -+ if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { -+ index_ = index_ - kKBlockColumnIterations + kResidualIndex; -+ is_residual_tile_ = false; -+ } -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpPureFragmentIterator &operator++() { -+ add_offset(1); -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpPureFragmentIterator &operator--() { -+ add_offset(-1); -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ -+ FragmentAccessType src_fragment; -+ src_fragment.clear(); -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; -+ int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow -+ * MmaIterations::kColumn; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; m++) { -+ for (int n = 0; n < MmaIterations::kColumn; n++) { -+ int accumulator_access_offset = -+ (m + index_m) * AccumulatorIterations::kColumn + n + index_n; -+ -+ frag_ptr[m * MmaIterations::kColumn + n].clear(); -+ if(!(is_residual_tile_ && index_ >= kResidualIndex)) -+ frag_ptr[m * MmaIterations::kColumn + n] = (accumulators_[accumulator_access_offset]); -+ } -+ } -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h -new file mode 100644 -index 0000000..5b46a5a ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h -@@ -0,0 +1,292 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+#include -+ -+template -+__device__ -+T add(T const & a, T const &b){ -+ return (a + b); -+} -+ -+template <> -+__device__ -+half2 add(half2 const & a, half2 const &b){ -+ return (__hadd2(a,b)); -+} -+ -+template -+struct RELU{ -+ __device__ -+ T operator()(T const & a){ -+ return a > T(0) ? a : T(0); -+ } -+ __device__ -+ half2 operator()(half2 const & a){ -+ float2 a_fp32x2 = __half22float2(a); -+ a_fp32x2.x = a_fp32x2.x > 0.f ? a_fp32x2.x : 0.f; -+ a_fp32x2.y = a_fp32x2.y > 0.f ? a_fp32x2.y : 0.f; -+ if(a_fp32x2.x < 0.f || a_fp32x2.y < 0.f) -+ printf(" %f %f\n", a_fp32x2.x ,a_fp32x2.y); -+ return __float22half2_rn(a_fp32x2); -+ } -+}; -+ -+template -+struct LEAKY_RELU{ -+ __device__ -+ T operator()(T const & a, T const & scale = half(1)){ -+ return a > T(0) ? a : scale * a; -+ } -+ __device__ -+ half2 operator()(half2 const & a, half const & scale = half(1)){ -+ half2 zero = __half2half2(half(0)); -+ half2 gt_zero = __hge2(a, zero); -+ half2 le_zero = __hle2(a, zero); -+ -+ -+ half2 scale_f16x2 = __half2half2(scale); -+ half2 mask_scale_f16x2 = __hfma2(le_zero, scale_f16x2, gt_zero); -+ return __hmul2(a, mask_scale_f16x2); -+ } -+}; -+ -+template -+__global__ void leaky_and_activation(half* inout, half* bias, half scale, bool mat_bias){ -+ -+ constexpr bool N_MOD_2 = N & 1 ? false : true; -+ -+ using Access_tp = typename std::conditional::type; -+ -+ constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); -+ -+ constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); -+ -+ LEAKY_RELU Act; -+ Access_tp src_v[iter]; -+ Access_tp bias_v[iter]; -+ -+ int batch_id = blockIdx.y; -+ int batch_offset = batch_id * gridDim.x * N; -+ -+ for(int i = 0; i < iter; i++){ -+ int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; -+ if (idx < N){ -+ src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); -+ if (mat_bias) -+ bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); -+ else -+ bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); -+ *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]),scale); -+ } -+ -+ } -+} -+ -+ -+ -+template -+__global__ void leaky_and_activation(half* inout, half scale){ -+ -+ constexpr bool N_MOD_2 = N & 1 ? false : true; -+ -+ using Access_tp = typename std::conditional::type; -+ -+ constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); -+ -+ constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); -+ -+ int batch_id = blockIdx.y; -+ int batch_offset = batch_id * gridDim.x * N; -+ -+ LEAKY_RELU Act; -+ Access_tp src_v[iter]; -+ -+ for(int i = 0; i < iter; i++){ -+ int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; -+ if (idx < N){ -+ src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); -+ *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i], scale); -+ } -+ -+ } -+} -+ -+ -+ -+template -+void leaky_and_activation(half* inout, half* bias, int m, int b, half scale, bool mat_bias){ -+ -+ dim3 grid(m, b); -+ if (bias == nullptr) -+ leaky_and_activation<<>>(inout, scale); -+ else -+ leaky_and_activation<<>>(inout, bias, scale, mat_bias); -+} -+ -+template -+__global__ void relu_and_activation(half* inout, half* bias, bool mat_bias){ -+ -+ constexpr bool N_MOD_2 = N & 1 ? false : true; -+ -+ using Access_tp = typename std::conditional::type; -+ -+ constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); -+ -+ constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); -+ -+ RELU Act; -+ Access_tp src_v[iter]; -+ Access_tp bias_v[iter]; -+ -+ int batch_id = blockIdx.y; -+ int batch_offset = batch_id * gridDim.x * N; -+ -+ for(int i = 0; i < iter; i++){ -+ int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; -+ if (idx < N){ -+ src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); -+ if (mat_bias) -+ bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); -+ else -+ bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); -+ *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i])); -+ } -+ -+ } -+} -+ -+ -+ -+template -+__global__ void relu_and_activation(half* inout){ -+ -+ constexpr bool N_MOD_2 = N & 1 ? false : true; -+ -+ using Access_tp = typename std::conditional::type; -+ -+ constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); -+ -+ constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); -+ -+ int batch_id = blockIdx.y; -+ int batch_offset = batch_id * gridDim.x * N; -+ -+ RELU Act; -+ Access_tp src_v[iter]; -+ -+ for(int i = 0; i < iter; i++){ -+ int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; -+ if (idx < N){ -+ src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); -+ *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i]); -+ } -+ -+ } -+} -+ -+ -+ -+template -+void relu_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){ -+ dim3 grid(m, b); -+ if (bias == nullptr) -+ relu_and_activation<<>>(inout); -+ else -+ relu_and_activation<<>>(inout, bias, mat_bias); -+} -+ -+ -+template -+__global__ void identity_and_activation(half* inout, half* bias, bool mat_bias){ -+ -+ constexpr bool N_MOD_2 = N & 1 ? false : true; -+ -+ using Access_tp = typename std::conditional::type; -+ -+ constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); -+ -+ constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); -+ -+ int batch_id = blockIdx.y; -+ int batch_offset = batch_id * gridDim.x * N; -+ -+ Access_tp src_v[iter]; -+ Access_tp bias_v[iter]; -+ -+ for(int i = 0; i < iter; i++){ -+ int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; -+ if (idx < N){ -+ src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); -+ if (mat_bias) -+ bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); -+ else -+ bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); -+ *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = (add(src_v[i],bias_v[i])); -+ } -+ -+ } -+} -+ -+template -+__global__ void identity_and_activation(half* inout){ -+ -+ constexpr bool N_MOD_2 = N & 1 ? false : true; -+ -+ using Access_tp = typename std::conditional::type; -+ -+ constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); -+ -+ constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); -+ -+ int batch_id = blockIdx.y; -+ int batch_offset = batch_id * gridDim.x * N; -+ Access_tp src_v[iter]; -+ -+ for(int i = 0; i < iter; i++){ -+ int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; -+ if (idx < N){ -+ src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); -+ *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = (src_v[i]); -+ } -+ -+ } -+} -+ -+template -+void identity_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){ -+ dim3 grid(m, b); -+ if (bias == nullptr) -+ identity_and_activation<<>>(inout); -+ else -+ identity_and_activation<<>>(inout, bias, mat_bias); -+} -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h -new file mode 100644 -index 0000000..9e1a732 ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h -@@ -0,0 +1,94 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+#define TI(tag) \ -+ cudaEvent_t _event_start_ ##tag; \ -+ cudaEvent_t _event_end_ ##tag; \ -+ float _event_time_ ##tag; \ -+ cudaEventCreate(& _event_start_ ##tag); \ -+ cudaEventCreate(& _event_end_ ##tag); \ -+ cudaEventRecord(_event_start_ ##tag); -+ -+#define TO(tag, str, times) \ -+ cudaEventRecord(_event_end_ ##tag); \ -+ cudaEventSynchronize(_event_end_ ##tag); \ -+ cudaEventElapsedTime(&_event_time_ ##tag, _event_start_ ##tag, _event_end_ ##tag); \ -+ float _event_time_once_ ##tag = _event_time_ ##tag / times; \ -+ printf("%20s:\t %10.3fus\t", str, _event_time_once_ ##tag * 1000); \ -+ cudaDeviceSynchronize(); \ -+ printf("%20s string: %s\n",str, cudaGetErrorString(cudaGetLastError())); -+ -+template -+struct memory_unit{ -+ T* host_ptr; -+ T* device_ptr; -+ int size_bytes; -+ int elements; -+ void h2d(){ -+ cudaMemcpy(device_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice); -+ } -+ void d2h(){ -+ cudaMemcpy(host_ptr, device_ptr, size_bytes, cudaMemcpyDeviceToHost); -+ } -+ void free_all(){ -+ free(host_ptr); -+ cudaFree(device_ptr); -+ } -+ memory_unit(int elements_): size_bytes(elements_ * sizeof(T)), elements(elements_){ -+ host_ptr = (T*) malloc(elements_ * sizeof(T)); -+ cudaMalloc((void**)&device_ptr, elements_ * sizeof(T)); -+ } -+ void init(int abs_range = 1){ -+ for(int i = 0; i < elements; i++){ -+ host_ptr[i] = T(rand() % 100 / float(100) * 2 * abs_range - abs_range); -+ } -+ h2d(); -+ } -+}; -+ -+template -+int check_result(T * a, T * b, int N){ -+ int cnt = 0; -+ for(int i = 0; i < N; i ++){ -+ float std = float(a[i]); -+ float my = float(b[i]); -+ -+ if(abs(std - my) / abs(std) > 1e-2) -+ { -+ // printf("my: %f , std: %f\n", my, std); -+ cnt++; -+ } -+ -+ } -+ printf("total err: %d / %d\n", cnt, N); -+ return cnt; -+} -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/device/dual_gemm.h b/3rdparty/cutlass/examples/45_dual_gemm/device/dual_gemm.h -new file mode 100644 -index 0000000..491888b ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/device/dual_gemm.h -@@ -0,0 +1,457 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Performs a dual gemm in one fused kernel: -+``` -+D0 = epilogue0(X @ B0, C0) -+D1 = epilogue1(X @ B1, C1) -+D2 = element_wise(D0, D1) -+``` -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+ -+#include "../kernel/dual_gemm.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0_, -+ typename EpilogueOutputOp1_, -+ typename EpilogueOutputOp2_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ bool StoreD0 = true, -+ bool StoreD1 = true, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator> -+class DualGemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp0 = EpilogueOutputOp0_; -+ using EpilogueOutputOp1 = EpilogueOutputOp1_; -+ using EpilogueOutputOp2 = EpilogueOutputOp2_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp1::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static bool constexpr kStoreD0 = StoreD0; -+ static bool constexpr kStoreD1 = StoreD1; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ using LayoutScaleBias = layout::RowMajor; -+ /// Define the kernel -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented"); -+ static_assert(kStages >= 3, "Only multistage is implemented"); -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, -+ ThreadblockShape, WarpShape, -+ InstructionShape, Stages, Operator>::ThreadblockMma; -+ using DualMma = threadblock::DualMmaMultistage< -+ typename Mma::Shape, -+ typename Mma::IteratorA, -+ typename Mma::SmemIteratorA, -+ Mma::kCacheOpA, -+ typename Mma::IteratorB, -+ typename Mma::SmemIteratorB, -+ Mma::kCacheOpB, -+ typename Mma::ElementC, -+ typename Mma::LayoutC, -+ typename Mma::Policy, -+ Mma::kStages, -+ SharedMemoryClearOption::kNone -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue0 = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp0, -+ EpilogueOutputOp0::kCount>::Epilogue; -+ using Epilogue1 = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using DualGemmKernel = kernel::DualGemm< -+ DualMma, -+ Epilogue0, Epilogue1, EpilogueOutputOp2, -+ ThreadblockSwizzle, kSplitKSerial, -+ kStoreD0, kStoreD1>; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A0; -+ TensorRef ref_B0; -+ TensorRef ref_C0; -+ TensorRef ref_D0; -+ TensorRef ref_B1; -+ TensorRef ref_C1; -+ TensorRef ref_D1; -+ TensorRef ref_D2; -+ typename EpilogueOutputOp0::Params epilogue0; -+ typename EpilogueOutputOp1::Params epilogue1; -+ typename EpilogueOutputOp2::Params epilogue2; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): problem_size(0, 0, 0), split_k_slices(1) { -+ -+ } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A0_, -+ TensorRef ref_B0_, -+ TensorRef ref_C0_, -+ TensorRef ref_D0_, -+ TensorRef ref_B1_, -+ TensorRef ref_C1_, -+ TensorRef ref_D1_, -+ TensorRef ref_D2_, -+ typename EpilogueOutputOp0::Params epilogue0_ = -+ typename EpilogueOutputOp0::Params(), -+ typename EpilogueOutputOp1::Params epilogue1_ = -+ typename EpilogueOutputOp1::Params(), -+ typename EpilogueOutputOp2::Params epilogue2_ = -+ typename EpilogueOutputOp2::Params(), -+ int split_k_slices_ = 1 -+ ): -+ problem_size(problem_size_), -+ ref_A0(ref_A0_), -+ ref_B0(ref_B0_), -+ ref_C0(ref_C0_), -+ ref_D0(ref_D0_), -+ ref_B1(ref_B1_), -+ ref_C1(ref_C1_), -+ ref_D1(ref_D1_), -+ ref_D2(ref_D2_), -+ epilogue0(epilogue0_), -+ epilogue1(epilogue1_), -+ epilogue2(epilogue2_), -+ split_k_slices(split_k_slices_) { -+ -+ } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename DualGemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ DualGemm() = default; -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ if (kStoreD0 != (args.ref_D0.data() != nullptr)) { -+ return Status::kErrorInternal; -+ } -+ if (kStoreD1 != (args.ref_D1.data() != nullptr)) { -+ return Status::kErrorInternal; -+ } -+ -+ Status status = DualGemmKernel::can_implement( -+ args.problem_size, -+ args.ref_A0.non_const_ref(), -+ args.ref_B0.non_const_ref(), -+ args.ref_C0.non_const_ref(), -+ args.ref_D0, -+ args.ref_B1.non_const_ref(), -+ args.ref_C1.non_const_ref(), -+ args.ref_D1, -+ args.ref_D2 -+ ); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Initialize the Params structure -+ params_ = typename DualGemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A0.non_const_ref(), -+ args.ref_B0.non_const_ref(), -+ args.ref_C0.non_const_ref(), -+ args.ref_D0, -+ args.ref_B1.non_const_ref(), -+ args.ref_C1.non_const_ref(), -+ args.ref_D1, -+ args.ref_D2, -+ args.epilogue0, -+ args.epilogue1, -+ args.epilogue2, -+ reinterpret_cast(workspace), -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); -+ params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); -+ params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); -+ params_.ref_D0.reset(args.ref_D0.data()); -+ params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); -+ params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); -+ params_.ref_D1.reset(args.ref_D1.data()); -+ params_.ref_D2.reset(args.ref_D2.data()); -+ params_.output_op_0 = args.epilogue0; -+ params_.output_op_1 = args.epilogue1; -+ params_.output_op_2 = args.epilogue2; -+ params_.semaphore = reinterpret_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(DualGemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/dual_gemm.cu b/3rdparty/cutlass/examples/45_dual_gemm/dual_gemm.cu -new file mode 100644 -index 0000000..15974e0 ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/dual_gemm.cu -@@ -0,0 +1,262 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS Dual-GEMM Example. -+ -+ Fused kernel that outputs `D0` and `D1`. -+ We assume that B0/B1 have the same shape/layout -+ -+``` -+D0 = epilogue0(X @ B0, C0) -+D1 = epilogue1(X @ B1, C1) -+D2 = element_wise(D0, D1) -+``` -+ D0 and D1 will be optionally stored in gmem (`kStoreD0` / `kStoreD1`) -+*/ -+ -+// #define IS_PROFILING -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/dual_gemm.h" -+#include "thread/left_silu_and_mul.h" -+#include "dual_gemm_run.h" -+#include "test_run.h" -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord problem_size(4096, 4096, 8192); -+ -+constexpr int kStages = 3; -+constexpr bool kSplitKSerial = false; -+constexpr bool kUseBias = true; -+ -+ -+#if 0 -+using ElementOperandA = cutlass::bfloat16_t; -+using ElementOperandB = cutlass::bfloat16_t; -+using ElementOutput = cutlass::bfloat16_t; -+using ElementAccumulator = float; -+using ElementCompute = float; -+#else -+using ElementOperandA = cutlass::half_t; -+using ElementOperandB = cutlass::half_t; -+using ElementOutput = cutlass::half_t; -+using ElementAccumulator = cutlass::half_t; -+using ElementCompute = cutlass::half_t; -+#endif -+ -+constexpr auto kScaleType = kUseBias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling : ( -+ // No bias -+ kSplitKSerial ? cutlass::epilogue::thread::ScaleType::Default : cutlass::epilogue::thread::ScaleType::Nothing -+); -+using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ kScaleType -+>; -+using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ kScaleType -+>; -+using EpilogueOutputOp2 = cutlass::epilogue::thread::LeftSiLUAndMul< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementOutput, -+ ElementCompute -+>; -+ -+const ElementCompute alpha0 = ElementCompute(1); -+const ElementCompute beta0 = ElementCompute(kUseBias ? 1 : 0); -+const ElementCompute alpha1 = ElementCompute(1); -+const ElementCompute beta1 = ElementCompute(kUseBias ? 1 : 0); -+ -+bool run_nonfused_gemm_f16_sm80() { -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ ElementOperandA, -+ cutlass::layout::RowMajor, -+ ElementOperandB, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp0, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ kStages, -+ 8, -+ 8, -+ kSplitKSerial -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ ElementOperandA, -+ cutlass::layout::RowMajor, -+ ElementOperandB, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ kStages, -+ 8, -+ 8, -+ kSplitKSerial -+ >; -+ -+ NonFusedDualGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused GEMMs FP16 TN GEMMs...\n"; -+ bool pass = nonFusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+template -+struct LeftSiLUAndMul { -+ struct Params{}; -+ CUTLASS_HOST_DEVICE LeftSiLUAndMul(Params p) {} -+ -+ CUTLASS_HOST_DEVICE void set_k_partition(int, int) {} -+ -+ CUTLASS_HOST_DEVICE T operator() ( -+ T const &lhs, -+ T const &rhs) const { -+ cutlass::epilogue::thread::SiLu silu; -+ cutlass::multiplies mul; -+ auto silu_lhs = silu(lhs); -+ return mul(silu_lhs, rhs); -+ } -+ -+ template -+ CUTLASS_HOST_DEVICE cutlass::Array operator() ( -+ cutlass::Array const &lhs, -+ cutlass::Array const &rhs) const { -+ cutlass::epilogue::thread::SiLu silu; -+ cutlass::multiplies mul; -+ auto silu_lhs = silu(lhs); -+ return mul(silu_lhs, rhs); -+ } -+}; -+ -+bool run_fused_gemm_f16_sm80_shmem() { -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ // Optionally, we might not need intermediate GEMM outputs -+ constexpr bool kStoreD0 = true; -+ constexpr bool kStoreD1 = true; -+ -+ using DualGemm = cutlass::gemm::device::DualGemm< -+ ElementOperandA, -+ cutlass::layout::RowMajor, -+ ElementOperandB, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ EpilogueOutputOp2, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ kStages, -+ kStoreD0, -+ kStoreD1, -+ kSplitKSerial -+ >; -+ -+ DualFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused FP16 TN GEMMs + Epilogue2...\n"; -+ bool passed = fusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+ -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_f16_sm80, -+ &run_fused_gemm_f16_sm80_shmem -+ }; -+ -+ std::string test_name = "dual-gemm f16 bias=" + std::to_string(kUseBias) + " split_k_serial=" + std::to_string(kSplitKSerial); -+ return testRun(80, funcs, test_name); -+} -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/dual_gemm_run.h b/3rdparty/cutlass/examples/45_dual_gemm/dual_gemm_run.h -new file mode 100644 -index 0000000..63ca2ac ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/dual_gemm_run.h -@@ -0,0 +1,829 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/device/tensor_relu.h" -+ -+#include "helper.h" -+ -+#define CHECK_GT(val1, val2) \ -+ if((val1) <= (val2)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; -+#define CHECK_TRUE(val) \ -+ if(!(val)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; -+ -+template < -+ typename OutputOp, -+ typename Element, -+ typename Layout> -+struct TensorEpilogueForEachFunc { -+ /// View type -+ using TensorView = cutlass::TensorView; -+ -+ /// Coordinate in tensor's index space -+ using TensorCoord = typename TensorView::TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view_x0; -+ TensorView view_x1; -+ TensorView view_y; -+ OutputOp output_op; -+ -+ -+ // -+ // Methods -+ // -+ -+ Params( -+ TensorView view_x0_ = TensorView(), -+ TensorView view_x1_ = TensorView(), -+ TensorView view_y_ = TensorView(), -+ OutputOp output_op_ = OutputOp(typename OutputOp::Params{}) -+ ): -+ view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) { -+ } -+ }; -+ -+ Params params; -+ -+ CUTLASS_DEVICE -+ TensorEpilogueForEachFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ Element const & x0 = params.view_x0.at(coord); -+ Element const & x1 = params.view_x1.at(coord); -+ Element& y = params.view_y.at(coord); -+ y = params.output_op(x0, x1); -+ } -+}; -+ -+template < -+ typename OutputOp, -+ typename Element, -+ typename Layout> -+void TensorEpilogueForEach( -+ cutlass::TensorView x0, -+ cutlass::TensorView x1, -+ cutlass::TensorView y) { -+ -+ using Func = TensorEpilogueForEachFunc; -+ using Params = typename Func::Params; -+ -+ cutlass::reference::device::TensorForEach( -+ y.extent(), -+ Params(x0, x1, y) -+ ); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct NonFusedDualGemmRun -+{ -+ -+ using Gemm0 = Gemm0_; -+ using Gemm1 = Gemm1_; -+ using ElementAccumulator = typename Gemm0::ElementAccumulator; -+ using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ NonFusedDualGemmRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = false, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementA, -+ typename Gemm0::LayoutA> tensor_A0(problem_size.mk()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementB, -+ typename Gemm0::LayoutB> tensor_B0(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_C0(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()}); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_D0(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> reference_D0(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementB, -+ typename Gemm1::LayoutB> tensor_B1(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_C1(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()}); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_D1(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> reference_D1(problem_size.mn()); -+ -+ -+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); -+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); -+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); -+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); -+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ tensor_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D1.host_view()); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_C0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0.sync_device(); -+ reference_D0.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1.sync_device(); -+ reference_D1.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1; -+ typename Gemm0::Arguments arguments_0{ -+ problem_size, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, -+ tensor_D0.device_ref(), -+ {alpha0, beta0}, -+ split_k_slices -+ }; -+ -+ split_k_slices = Gemm1::kSplitKSerial ? 2 : 1; -+ typename Gemm1::Arguments arguments_1{ -+ problem_size, -+ tensor_A0.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, -+ tensor_D1.device_ref(), -+ {alpha1, beta1}, -+ split_k_slices -+ }; -+ -+ -+ Gemm0 gemm_op_0; -+ Gemm1 gemm_op_1; -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace0(gemm_op_0.get_workspace_size(arguments_0)); -+ cutlass::device_memory::allocation workspace1(gemm_op_1.get_workspace_size(arguments_1)); -+ -+ cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get()); -+ -+ CUTLASS_CHECK(status); -+ -+ status = gemm_op_1.initialize(arguments_1, workspace1.get()); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = gemm_op_0(); -+ CUTLASS_CHECK(status); -+ status = gemm_op_1(); -+ CUTLASS_CHECK(status); -+ } -+#ifdef IS_PROFILING -+ return true; -+#endif -+ // -+ // Run the GEMM -+ // -+ cudaEvent_t start, stop1, stop2; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop1); -+ cudaEventCreate(&stop2); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ status = gemm_op_0(); -+ -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop1); -+ for(int i = 0; i < runs; i++) { -+ status = gemm_op_1(); -+ -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop2); -+ cudaDeviceSynchronize(); -+ float gemm0Time, gemm1Time, totalTime; -+ cudaEventElapsedTime(&gemm0Time, start, stop1); -+ cudaEventElapsedTime(&gemm1Time, stop1, stop2); -+ cudaEventElapsedTime(&totalTime, start, stop2); -+ std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; -+ std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; -+ std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n"; -+ -+ tensor_D0.sync_host(); -+ tensor_D1.sync_host(); -+ -+ // -+ // Verify -+ // -+ cutlass::reference::device::Gemm< -+ typename Gemm0::ElementA, typename Gemm0::LayoutA, -+ typename Gemm0::ElementB, typename Gemm0::LayoutB, -+ typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm0::Operator> -+ reference_gemm_0; -+ -+ cutlass::reference::device::Gemm< -+ typename Gemm1::ElementA, typename Gemm1::LayoutA, -+ typename Gemm1::ElementB, typename Gemm1::LayoutB, -+ typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm1::Operator> -+ reference_gemm_1; -+ -+ reference_gemm_0( -+ problem_size, -+ alpha0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ beta0, -+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, -+ reference_D0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D0.device_view()); -+ } -+ -+ reference_gemm_1( -+ problem_size, -+ alpha1, -+ tensor_A0.device_ref(), -+ tensor_B1.device_ref(), -+ beta1, -+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, -+ reference_D1.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D1.device_view()); -+ } -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ reference_D0.sync_host(); -+ reference_D1.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); -+ -+ bool passed0 = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ CHECK_TRUE(passed0); -+ -+ bool passed1 = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ CHECK_TRUE(passed1); -+ if (!passed0 || !passed1) { -+ -+ std::stringstream fname; -+ -+ fname << "error_DualGemm_device_nonfused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A0 =\n" << tensor_A0.host_view() -+ << "\nB0 =\n" << tensor_B0.host_view() -+ << "\nC0 =\n" << tensor_C0.host_view() -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nD0 =\n" << tensor_D0.host_view() -+ << "\nB1 =\n" << tensor_B1.host_view() -+ << "\nC1 =\n" << tensor_C1.host_view() -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\n\nReference =\n" << reference_D1.host_view() -+ << "\nComputed =\n" << tensor_D1.host_view(); -+ } -+ return passed0 && passed1; -+ } -+}; -+ -+template -+struct DualFusedGemmRun -+{ -+ -+ using DualGemm = DualGemm_; -+ using ElementAccumulator = typename DualGemm::ElementAccumulator; -+ using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute; -+ using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Scale; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ DualFusedGemmRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), -+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(1), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(1), -+ bool relu = false, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementA, -+ typename DualGemm::LayoutA> tensor_A0(problem_size.mk()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementB, -+ typename DualGemm::LayoutB> tensor_B0(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> tensor_C0(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutScaleBias> tensor_Bias0({1, problem_size.n()}); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> tensor_D0(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> reference_D0(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementB, -+ typename DualGemm::LayoutB> tensor_B1(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> tensor_C1(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutScaleBias> tensor_Bias1({1, problem_size.n()}); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> tensor_D1(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> tensor_D2(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> reference_D1(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> reference_D2(problem_size.mn()); -+ -+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); -+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118)); -+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011)); -+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113)); -+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ tensor_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ tensor_D2.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D2.host_view()); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_C0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D0.sync_device(); -+ tensor_D1.sync_device(); -+ tensor_D2.sync_device(); -+ reference_D0.sync_device(); -+ reference_D1.sync_device(); -+ reference_D2.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1; -+ typename cutlass::TensorRef nullptr_ref{}; -+ decltype(nullptr_ref) ref_B0, ref_B1; -+ if (beta0 != ElementCompute(0)) { -+ ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}; -+ } -+ if (beta1 != ElementCompute(0)) { -+ ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}; -+ } -+ typename DualGemm::Arguments arguments{ -+ problem_size, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ ref_B0, -+ DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref, -+ tensor_B1.device_ref(), -+ ref_B1, -+ DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref, -+ tensor_D2.device_ref(), -+ {alpha0, beta0}, -+ {alpha1, beta1}, -+ {}, -+ split_k_slices -+ }; -+ -+ DualGemm b2b_gemm_op; -+ -+ cutlass::device_memory::allocation workspace(b2b_gemm_op.get_workspace_size(arguments)); -+ -+ cutlass::Status status = b2b_gemm_op.can_implement(arguments); -+ -+ CUTLASS_CHECK(status); -+ -+ status = b2b_gemm_op.initialize(arguments, workspace.get()); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = b2b_gemm_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+#ifdef IS_PROFILING -+ return true; -+#endif -+ // -+ // Run the GEMM -+ // -+ -+ cudaEvent_t start, stop; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ status = b2b_gemm_op(); -+ -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop); -+ cudaDeviceSynchronize(); -+ float gemmTime; -+ cudaEventElapsedTime(&gemmTime, start, stop); -+ std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; -+ -+ tensor_D0.sync_host(); -+ tensor_D1.sync_host(); -+ tensor_D2.sync_host(); -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::device::Gemm< -+ typename DualGemm::ElementA, typename DualGemm::LayoutA, -+ typename DualGemm::ElementB, typename DualGemm::LayoutB, -+ typename DualGemm::ElementC, typename DualGemm::LayoutC, -+ ElementAccumulator, ElementAccumulator> -+ reference_gemm_0; -+ -+ cutlass::reference::device::Gemm< -+ typename DualGemm::ElementA, typename DualGemm::LayoutA, -+ typename DualGemm::ElementB, typename DualGemm::LayoutB, -+ typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename DualGemm::Operator> -+ reference_gemm_1; -+ -+ reference_gemm_0( -+ problem_size, -+ alpha0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ beta0, -+ {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}, -+ reference_D0.device_ref() -+ ); -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D0.device_view()); -+ } -+ -+ reference_gemm_1( -+ problem_size, -+ alpha1, -+ tensor_A0.device_ref(), -+ tensor_B1.device_ref(), -+ beta1, -+ {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}, -+ reference_D1.device_ref() -+ ); -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D1.device_view()); -+ } -+ TensorEpilogueForEach(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view()); -+ cudaDeviceSynchronize(); -+ reference_D0.sync_host(); -+ reference_D1.sync_host(); -+ reference_D2.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0); -+ -+ bool passed_out0 = true; -+ if (DualGemm::kStoreD0) { -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); -+ passed_out0 = cutlass::reference::host::TensorEquals( -+ reference_D0.host_view(), -+ tensor_D0.host_view()); -+ } -+ CHECK_TRUE(passed_out0); -+ -+ bool passed_out1 = true; -+ if (DualGemm::kStoreD1) { -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ passed_out1 = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ } -+ CHECK_TRUE(passed_out1); -+ -+ bool passed_out2 = cutlass::reference::host::TensorEquals( -+ reference_D2.host_view(), -+ tensor_D2.host_view()); -+ CHECK_TRUE(passed_out2); -+ -+ bool passed = passed_out0 && passed_out1 && passed_out2; -+ if (!passed) -+ { -+ -+ std::stringstream fname; -+ -+ fname << "error_DualGemm_device_fused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A0 =\n" << tensor_A0.host_view() -+ << "\nB0 =\n" << tensor_B0.host_view() -+ << "\nC0 =\n" << tensor_C0.host_view() -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nB1 =\n" << tensor_B1.host_view() -+ << "\nC1 =\n" << tensor_C1.host_view() -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\n\nReference0 =\n" << reference_D0.host_view() -+ << "\nComputed0 =\n" << tensor_D0.host_view() -+ << "\n\nReference1 =\n" << reference_D1.host_view() -+ << "\nComputed1 =\n" << tensor_D1.host_view() -+ << "\n\nReference2 =\n" << reference_D2.host_view() -+ << "\nComputed2 =\n" << tensor_D2.host_view(); -+ } -+ //std::cout << "A0 " << tensor_A0.host_view() << std::endl; -+ // std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; -+ // std::cout << "reference_D1 " << reference_D1.host_view() << std::endl; -+ // std::cout << "reference_D2 " << reference_D2.host_view() << std::endl; -+ //std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; -+ return passed; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h b/3rdparty/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h -new file mode 100644 -index 0000000..4cbddaa ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/semaphore.h" -+ -+#include "../threadblock/dual_mma_multistage.h" -+#include "../threadblock/dual_epilogue.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename DualMma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue0_, ///! Epilogue -+ typename Epilogue1_, ///! Epilogue -+ typename OutputOp2_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled. -+ bool StoreD0, -+ bool StoreD1 -+> -+struct DualGemm { -+ -+ using DualMma = DualMma_; -+ -+ using Epilogue0 = Epilogue0_; -+ using Epilogue1 = Epilogue1_; -+ using OutputOp0 = typename Epilogue0::OutputOp; -+ using OutputOp1 = typename Epilogue1::OutputOp; -+ using OutputOp2 = OutputOp2_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static constexpr bool kStoreD0 = StoreD0; -+ static constexpr bool kStoreD1 = StoreD1; -+ -+ using DualEpilogue = cutlass::epilogue::threadblock::DualEpilogue< -+ typename Epilogue0::Shape, -+ typename Epilogue0::WarpMmaOperator, -+ Epilogue0::kPartitionsK, -+ typename Epilogue0::OutputTileIterator, -+ typename Epilogue0::AccumulatorFragmentIterator, -+ typename Epilogue0::WarpTileIterator, -+ typename Epilogue0::SharedLoadIterator, -+ OutputOp0, -+ OutputOp1, -+ OutputOp2, -+ typename Epilogue0::Padding, -+ kStoreD0, -+ kStoreD1, -+ Epilogue0::kFragmentsPerIteration, -+ true // IterationsUnroll -+ >; -+ -+ static bool const kSplitKSerial = SplitKSerial; -+ static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1), -+ "Split-K serial requires buffers for D0/D1 for reduction"); -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount0 = typename DualMma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount0::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ // Mma0 -+ typename DualMma::IteratorA::Params params_A0; -+ typename DualMma::IteratorA::TensorRef ref_A0; -+ typename DualMma::IteratorB::Params params_B0; -+ typename DualMma::IteratorB::TensorRef ref_B0; -+ typename Epilogue0::OutputTileIterator::Params params_C0; -+ typename Epilogue0::OutputTileIterator::TensorRef ref_C0; -+ typename Epilogue0::OutputTileIterator::Params params_D0; -+ typename Epilogue0::OutputTileIterator::TensorRef ref_D0; -+ typename OutputOp0::Params output_op_0; -+ -+ // Mma1 -+ typename DualMma::IteratorB::Params params_B1; -+ typename DualMma::IteratorB::TensorRef ref_B1; -+ typename Epilogue1::OutputTileIterator::Params params_C1; -+ typename Epilogue1::OutputTileIterator::TensorRef ref_C1; -+ typename Epilogue1::OutputTileIterator::Params params_D1; -+ typename Epilogue1::OutputTileIterator::TensorRef ref_D1; -+ typename OutputOp1::Params output_op_1; -+ -+ typename Epilogue1::OutputTileIterator::Params params_D2; -+ typename Epilogue1::OutputTileIterator::TensorRef ref_D2; -+ typename OutputOp2::Params output_op_2; -+ -+ int *semaphore; -+ int gemm_k_size; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ // Mma0: D0 = A @ B0 + C0 -+ typename DualMma::IteratorA::TensorRef ref_A0, -+ typename DualMma::IteratorB::TensorRef ref_B0, -+ typename Epilogue0::OutputTileIterator::TensorRef ref_C0, -+ typename Epilogue0::OutputTileIterator::TensorRef ref_D0, -+ // Mma1: D1 = A @ B1 + C1 -+ typename DualMma::IteratorB::TensorRef ref_B1, -+ typename Epilogue1::OutputTileIterator::TensorRef ref_C1, -+ typename Epilogue1::OutputTileIterator::TensorRef ref_D1, -+ -+ typename Epilogue1::OutputTileIterator::TensorRef ref_D2, -+ typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), -+ typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), -+ typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(), -+ int *workspace = nullptr -+ ): -+ problem_size(problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ // Mma0 -+ params_A0(ref_A0.layout()), -+ ref_A0(ref_A0), -+ params_B0(ref_B0.layout()), -+ ref_B0(ref_B0), -+ params_C0(ref_C0.layout()), -+ ref_C0(ref_C0), -+ params_D0(ref_D0.layout()), -+ ref_D0(ref_D0), -+ // Mma1 -+ params_B1(ref_B1.layout()), -+ ref_B1(ref_B1), -+ params_C1(ref_C1.layout()), -+ ref_C1(ref_C1), -+ params_D1(ref_D1.layout()), -+ ref_D1(ref_D1), -+ params_D2(ref_D2.layout()), -+ ref_D2(ref_D2), -+ output_op_0(output_op_0), -+ output_op_1(output_op_1), -+ output_op_2(output_op_2) { -+ -+ int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; -+ int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ gemm_k_size = gemm_k_iterations * DualMma::Shape::kK; -+ -+ semaphore = workspace; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename DualMma::SharedStorage main_loop; -+ typename DualEpilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ DualGemm() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size, -+ typename DualMma::IteratorA::TensorRef ref_A0, -+ typename DualMma::IteratorB::TensorRef ref_B0, -+ typename Epilogue0::OutputTileIterator::TensorRef ref_C0, -+ typename Epilogue0::OutputTileIterator::TensorRef ref_D0, -+ typename DualMma::IteratorB::TensorRef ref_B1, -+ typename Epilogue1::OutputTileIterator::TensorRef ref_C1, -+ typename Epilogue1::OutputTileIterator::TensorRef ref_D1, -+ typename Epilogue1::OutputTileIterator::TensorRef ref_D2) { -+ -+ static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = DualMma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess; -+ -+ if (!TensorRef_aligned(ref_A0, kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B0, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C0, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D0, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B1, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C1, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D1, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D2, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A0{ -+ threadblock_tile_offset.m() * DualMma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B0{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ threadblock_tile_offset.n() * DualMma::Shape::kN -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B1{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ threadblock_tile_offset.n() * DualMma::Shape::kN -+ }; -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k = -+ (params.problem_size.k() < (threadblock_tile_offset.k() + 1) * params.gemm_k_size) ? -+ params.problem_size.k() : -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - tb_offset_A0.column() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename DualMma::IteratorA iterator_A0( -+ params.params_A0, -+ params.ref_A0.data(), -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A0); -+ -+ typename DualMma::IteratorB iterator_B0( -+ params.params_B0, -+ params.ref_B0.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B0); -+ -+ typename DualMma::IteratorB iterator_B1( -+ params.params_B1, -+ params.ref_B1.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B1); -+ -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ -+ // Construct thread-scoped matrix multiply -+ typename DualMma::FragmentC accum0; -+ typename DualMma::FragmentC accum1; -+ accum0.clear(); -+ accum1.clear(); -+ -+ DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ if (!kSplitKSerial || gemm_k_iterations > 0) { -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, -+ accum0, accum1, -+ iterator_A0, iterator_B0, iterator_B1, -+ accum0, accum1); -+ } -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp0 output_op_0(params.output_op_0); -+ OutputOp1 output_op_1(params.output_op_1); -+ OutputOp2 output_op_2(params.output_op_2); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * DualMma::Shape::kM, -+ threadblock_tile_offset.n() * DualMma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue0::OutputTileIterator iterator_C0( -+ params.params_C0, -+ params.ref_C0.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ typename Epilogue1::OutputTileIterator iterator_C1( -+ params.params_C1, -+ params.ref_C1.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue0::OutputTileIterator iterator_D0( -+ params.params_D0, -+ params.ref_D0.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ typename Epilogue1::OutputTileIterator iterator_D1( -+ params.params_D1, -+ params.ref_D1.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ typename Epilogue1::OutputTileIterator iterator_D2( -+ params.params_D2, -+ params.ref_D2.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ DualEpilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C0 = iterator_D0; -+ iterator_C1 = iterator_D1; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ typename Epilogue0::OutputTileIterator source_iters[] = { -+ iterator_C0, iterator_C1 -+ }; -+ const bool writeToD2 = (!kSplitKSerial || params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1); -+ epilogue( -+ output_op_0, output_op_1, output_op_2, -+ iterator_D0, iterator_D1, iterator_D2, -+ accum0, accum1, -+ source_iters, -+ writeToD2 -+ ); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ __threadfence(); -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/test_run.h b/3rdparty/cutlass/examples/45_dual_gemm/test_run.h -new file mode 100644 -index 0000000..b64f31f ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/test_run.h -@@ -0,0 +1,95 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+ -+#include -+ -+// Run tests on GPUs -+ -+int testRun(int arch, std::vector & test_funcs, const std::string & test_name) { -+ -+ bool supported = false; -+ -+ int arch_major = arch / 10; -+ int arch_minor = arch - arch / 10 * 10; -+ -+ if(arch_major >= 8) { -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) { -+ supported = true; -+ } -+ } -+ else if(arch_major >= 7) { -+ // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. -+ // -+ // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. -+ if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) { -+ supported = true; -+ } -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (!(props.major == arch_major && props.minor == arch_minor)) { -+ supported = false; -+ } -+ -+ if (!supported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ std::cout << "This example isn't supported on current architecture" << std::endl; -+ return 0; -+ } -+ -+ bool pass = true; -+ -+ std::cout << "Device: " << props.name << std::endl; -+ std::cout << "Arch: SM" << arch << std::endl; -+ std::cout << "Test: " << test_name << std::endl; -+ for(auto func : test_funcs) { -+ pass &= func(); -+ } -+ -+ -+ if(pass) -+ return 0; -+ else -+ return -1; -+ -+} -+ -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h b/3rdparty/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h -new file mode 100644 -index 0000000..0ba9bb9 ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+#include "cutlass/epilogue/thread/linear_combination_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation. -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LeftSiLUAndMul { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ struct Params{}; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LeftSiLUAndMul(Params const &/*params*/) {} -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return true; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ assert(false); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &lhs, -+ FragmentAccumulator const &rhs) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_to_compute; -+ -+ // Convert to destination numeric type -+ NumericArrayConverter compute_to_output; -+ -+ ComputeFragment converted_lhs = accumulator_to_compute(lhs); -+ ComputeFragment converted_rhs = accumulator_to_compute(rhs); -+ -+ cutlass::epilogue::thread::SiLu silu; -+ cutlass::multiplies mul; -+ auto silu_lhs = silu(converted_lhs); -+ return compute_to_output(mul(silu_lhs, converted_rhs)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ ElementOutput operator()( -+ ElementAccumulator const& lhs, -+ ElementAccumulator const& rhs -+ ) const { -+ ElementCompute convert_lhs(lhs); -+ ElementCompute convert_rhs(rhs); -+ cutlass::epilogue::thread::SiLu silu; -+ cutlass::multiplies mul; -+ auto silu_lhs = silu(convert_lhs); -+ return ElementOutput(mul(silu_lhs, convert_rhs)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h b/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h -new file mode 100644 -index 0000000..d9492ab ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h -@@ -0,0 +1,430 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ ///< Output operator -+ typename OutputOp0_, -+ typename OutputOp1_, -+ typename OutputOp2_, -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+ bool StoreD0 = true, -+ bool StoreD1 = true, -+ int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity -+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large -+ (!IsEpilogueFunctorHeavy::value) -+> -+class DualEpilogue { -+ -+public: -+ -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>; -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ static bool constexpr kStoreD0 = StoreD0; -+ static bool constexpr kStoreD1 = StoreD1; -+ using OutputTileIterator = OutputTileIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp0 = OutputOp0_; -+ using OutputOp1 = OutputOp1_; -+ using OutputOp2 = OutputOp2_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Number of warps -+ using WarpCount = typename Base::WarpCount; -+ -+ struct SharedStorage { -+ using Element = typename WarpTileIterator::Element; -+ -+ /// Tensor reference to shared memory allocation -+ using TensorRef = typename WarpTileIterator::TensorRef; -+ -+ /// Logical shape of the shared memory tile written to by all warps. -+ using Shape = typename Base::Shape; -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = typename Base::SharedStorage::StorageShape; -+ -+ // -+ // Data members -+ // -+ -+ AlignedBuffer storage[2]; -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a tensor reference to the shared memory buffer -+ CUTLASS_DEVICE -+ TensorRef reference(int i) { -+ return TensorRef( -+ storage[i].data(), -+ Layout::packed({StorageShape::kRow, StorageShape::kColumn})); -+ } -+ }; -+ -+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; -+ static int constexpr kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles; -+ -+public: -+ -+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+private: -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator0_; -+ SharedLoadIterator shared_load_iterator1_; -+ -+ /// Stores a warp's fragment of accumulators to SMEM -+ WarpTileIterator warp_tile_iterator0_; -+ WarpTileIterator warp_tile_iterator1_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ DualEpilogue( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ shared_load_iterator0_(shared_storage.reference(0), thread_idx), -+ shared_load_iterator1_(shared_storage.reference(1), thread_idx), -+ warp_tile_iterator0_(shared_storage.reference(0), lane_idx), -+ warp_tile_iterator1_(shared_storage.reference(1), lane_idx) -+ { -+ int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); -+ int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); -+ int warp_m = warp_mn % WarpCount::kM; -+ int warp_n = warp_mn / WarpCount::kM; -+ -+ MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; -+ -+ warp_tile_iterator0_.add_tile_offset(warp_offset); -+ warp_tile_iterator1_.add_tile_offset(warp_offset); -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp0 const &output_op0, -+ OutputOp1 const &output_op1, -+ OutputOp2 const &output_op2, -+ OutputTileIterator dest0, -+ OutputTileIterator dest1, -+ OutputTileIterator dest2, -+ AccumulatorTile const &accumulator0, -+ AccumulatorTile const &accumulator1, -+ OutputTileIterator source_iterator[2], -+ bool writeToD2 // true if it's the final split-k -+ ) { -+ // TODO: Implement when no source is needed -+ -+ typename OutputTileIterator::Fragment source_fragment[2]; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ source_fragment[i].clear(); -+ } -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, accumulator1}; -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Load the source -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ source_iterator[i].load(source_fragment[i]); -+ ++source_iterator[i]; -+ } -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ acc2smem_source_needed>::push( -+ iter, accum_fragment_iterator[0], this->warp_tile_iterator0_); -+ acc2smem_source_needed>::push( -+ iter, accum_fragment_iterator[1], this->warp_tile_iterator1_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment0[kPartitionsK]; -+ typename SharedLoadIterator::Fragment aligned_accum_fragment1[kPartitionsK]; -+ -+ shared_load_iterator0_.load(aligned_accum_fragment0[0]); -+ shared_load_iterator1_.load(aligned_accum_fragment1[0]); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ if (kPartitionsK > 1) { -+ -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator0_.load(aligned_accum_fragment0[i]); -+ shared_load_iterator1_.load(aligned_accum_fragment1[i]); -+ aligned_accum_fragment0[0] = add_fragments(aligned_accum_fragment0[0], aligned_accum_fragment0[i]); -+ aligned_accum_fragment1[0] = add_fragments(aligned_accum_fragment1[0], aligned_accum_fragment1[i]); -+ } -+ -+ shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); -+ shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Compute the output result -+ // -+ -+ typename OutputTileIterator::Fragment output_fragment[3]; -+ -+ apply_output_operator_(output_fragment, -+ output_op0, output_op1, output_op2, -+ aligned_accum_fragment0[0], aligned_accum_fragment1[0], -+ source_fragment); -+ -+ -+ // -+ // Store the final result -+ // -+ -+ if (kStoreD0) { -+ dest0.store(output_fragment[0]); -+ ++dest0; -+ } -+ if (kStoreD1) { -+ dest1.store(output_fragment[1]); -+ ++dest1; -+ } -+ if (writeToD2) { -+ dest2.store(output_fragment[2]); -+ ++dest2; -+ } -+ } -+ } -+ -+private: -+ -+ static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); -+ -+ template -+ struct acc2smem_source_needed; -+ -+ template -+ struct acc2smem_source_needed> { -+ template -+ CUTLASS_DEVICE -+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ warp_tile_iterator.store(accum_fragment); -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ } -+ }; -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ typename OutputTileIterator::Fragment (&output_fragment)[3], -+ OutputOp0 const &output_op0, -+ OutputOp1 const &output_op1, -+ OutputOp2 const &output_op2, -+ typename SharedLoadIterator::Fragment const& aligned_accum_fragment0, -+ typename SharedLoadIterator::Fragment const& aligned_accum_fragment1, -+ typename OutputTileIterator::Fragment const (&source_fragment)[2]) { -+ -+ OutputAccessType* output_frag_ptr[3] = { -+ reinterpret_cast(&output_fragment[0]), -+ reinterpret_cast(&output_fragment[1]), -+ reinterpret_cast(&output_fragment[2]) -+ }; -+ -+ AccumulatorAccessType const *compute_frag_ptr[2] = { -+ reinterpret_cast(&aligned_accum_fragment0), -+ reinterpret_cast(&aligned_accum_fragment1) -+ }; -+ -+ OutputAccessType const *source_frag_ptr[2] = { -+ reinterpret_cast(&source_fragment[0]), -+ reinterpret_cast(&source_fragment[1]) -+ }; -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ -+ // Call the output operators -+ output_frag_ptr[0][i] = output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]); -+ output_frag_ptr[1][i] = output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]); -+ output_frag_ptr[2][i] = output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]); -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h b/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h -new file mode 100644 -index 0000000..10563e7 ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h -@@ -0,0 +1,218 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class DualMmaBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ static_assert(kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ static_assert((kWarpGemmIterations % 2) == 0, -+ "Inner loop iteration must be an even number."); -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B0; -+ AlignedBuffer operand_B1; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B0_ref() { -+ return TensorRefB{operand_B0.data(), LayoutB()}; -+ } -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B1_ref() { -+ return TensorRefB{operand_B1.data(), LayoutB()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B0_; -+ typename Operator::IteratorB warp_tile_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ DualMmaBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B0_(shared_storage.operand_B0_ref(), lane_idx), -+ warp_tile_iterator_B1_(shared_storage.operand_B1_ref(), lane_idx) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h b/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h -new file mode 100644 -index 0000000..7843f2b ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h -@@ -0,0 +1,760 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "dual_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class DualMmaMultistage : -+ public DualMmaBase { -+public: -+ ///< Base class -+ using Base = DualMmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B0_; -+ SmemIteratorB smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ DualMmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.operand_B0_ref(), thread_idx), -+ smem_iterator_B1_(shared_storage.operand_B1_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B0, IteratorB &iterator_B1, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B0.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ iterator_B1.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B0_.set_iteration_index(group_start_B); -+ this->smem_iterator_B1_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B0 -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B0.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B0.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B0.valid()); -+ } -+ -+ ++iterator_B0; -+ } -+ ++this->smem_iterator_B0_; -+ } -+ } -+ // Async Copy for operand B1 -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B1.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B1.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B1.valid()); -+ } -+ -+ ++iterator_B1; -+ } -+ ++this->smem_iterator_B1_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum0, -+ FragmentC &accum1, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B0, -+ IteratorB iterator_B1, -+ ///< initial value of accumulator -+ FragmentC const &src_accum0, -+ FragmentC const &src_accum1 -+ ) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B0.clear_mask(gemm_k_iterations == 0); -+ iterator_B1.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B0.set_iteration_index(0); -+ iterator_B1.set_iteration_index(0); -+ this->smem_iterator_B0_.set_iteration_index(0); -+ this->smem_iterator_B1_.set_iteration_index(0); -+ -+ // Async Copy for operand B0 -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ } -+ -+ ++this->smem_iterator_B0_; -+ } -+ // Async Copy for operand B1 -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ -+ ++this->smem_iterator_B1_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B0.add_tile_offset({1, 0}); -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum0 = src_accum0; -+ accum1 = src_accum1; -+ -+ // -+ // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels -+ // so that all accumulator elements outside the GEMM footprint are zero. -+ // -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); -+ -+ typename IteratorA::AccessType zero_A; -+ zero_A.clear(); -+ -+ last_smem_iterator_A.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_A.get()); -+ -+ *dst_ptr = zero_A; -+ -+ ++last_smem_iterator_A; -+ } -+ -+ typename IteratorB::AccessType zero_B; -+ zero_B.clear(); -+ -+ /// Iterator to write threadblock-scoped tile of B0 operand to shared memory -+ SmemIteratorB last_smem_iterator_B0(this->smem_iterator_B0_); -+ last_smem_iterator_B0.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_B0.get()); -+ -+ *dst_ptr = zero_B; -+ -+ ++last_smem_iterator_B0; -+ } -+ /// Iterator to write threadblock-scoped tile of B1 operand to shared memory -+ SmemIteratorB last_smem_iterator_B1(this->smem_iterator_B1_); -+ last_smem_iterator_B1.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_B1.get()); -+ -+ *dst_ptr = zero_B; -+ -+ ++last_smem_iterator_B1; -+ } -+ } -+ -+ // Waits until stages up to the previous (kStages-2)th stage have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B0[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B1[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B0[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B1[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B0_; -+ ++this->warp_tile_iterator_B1_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B0.clear_mask(gemm_k_iterations == 0); -+ iterator_B1.clear_mask(gemm_k_iterations == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B0[0]); -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B1[0]); -+ -+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary -+ // accumulator and this temporary accumulator is added to the final -+ // accumulator once in every mainloop iteration. -+ plus plus_accum; -+ -+ FragmentC tmp_accum0, tmp_accum1; -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ -+ tmp_accum0.clear(); -+ tmp_accum1.clear(); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B0_; -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k > 0) { -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B0[warp_mma_k % 2]); -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B1[warp_mma_k % 2]); -+ } -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ -+ warp_mma( -+ tmp_accum0, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ tmp_accum0 -+ ); -+ warp_mma( -+ tmp_accum1, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ tmp_accum1 -+ ); -+ -+ if (warp_mma_k == 0) { -+ accum0 = plus_accum(accum0, tmp_accum0); -+ accum1 = plus_accum(accum1, tmp_accum1); -+ tmp_accum0.clear(); -+ tmp_accum1.clear(); -+ } -+ } else { -+ warp_mma( -+ accum0, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ accum0 -+ ); -+ warp_mma( -+ accum1, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ accum1 -+ ); -+ } -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, -+ group_start_iteration_B); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until stages up to the previous (kStages-2)th stage have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B0.add_tile_offset({1, 0}); -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B0.clear_mask(gemm_k_iterations == 0); -+ iterator_B1.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ } -+ } -+ -+ } -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ accum0 = plus_accum(accum0, tmp_accum0); -+ accum1 = plus_accum(accum1, tmp_accum1); -+ } -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu b/3rdparty/cutlass/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu -new file mode 100644 -index 0000000..9a26e89 ---- /dev/null -+++ b/3rdparty/cutlass/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu -@@ -0,0 +1,672 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+This example shows how to run depthwise 2d convolution kernels using functions and data structures -+provided by CUTLASS using SIMT instruction; -+ -+There are 3 types of implementations of depthwise 2d convoltion -+ 1. kAnalytic -+ Implicit gemm 2d convoltion algorithm. -+ 2. kOptimized -+ An optimized algorithm and supports arbitrary stride and dilation. -+ 3. kFixedStrideDilation -+ An optimized algorithm with fixed stride and dilation to reduce the runtime computation and do -+more optimizations. -+ -+In general, the perf of kFixedStrideDilation would be better than kOptimized. However, if the filter -+size, stride or dilation is large, it would encounter register spilling and may hurt the perf. If -+in this case, please use kOptimized. -+ -+For kOptimized and kFixedStrideDilation, in order to fully utilize GPU hardware resources and achieve -+better perf, when the output tensor size is large, splitk should be enabled to achieve better perf. -+ -+In this example, it demonstrates how to construct and run a FixedStrideDilation depthwise 2d -+convolution kernel. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_depthwise_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/conv/device/direct_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = cutlass::half_t; // Data type of accumulator -+using ElementComputeEpilogue = cutlass::half_t; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementOutput = cutlass::half_t; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassSimt; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm60; -+ -+// This code section describes the groups a thread block will compute -+constexpr int groups_per_cta = 64; -+ -+// This code section describes the output tile a thread block will compute -+using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+// This code section describes the filter shape -+using FilterShape = cutlass::MatrixShape<3, 3>; -+ -+// Threadblock tile shape -+using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+// This code section describes tile size a warp will computes -+// WarpShape::kM = P * Q the warps would process -+// WarpShape::kN = groups_per_cta that the warps would process -+// WarpShape::kK = filter_size that the warps would process -+using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 4; -+ -+// This code section describe iterator algorithm selected is kFixedStrideDilation -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; -+using StrideShape = cutlass::MatrixShape<1, 1>; -+using DilationShape = cutlass::MatrixShape<1, 1>; -+ -+constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; // Epilogue scaling operation. -+ -+using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kFixed, -+ StrideShape, -+ DilationShape>::Kernel; -+ -+using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ int groups; -+ int splitk; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ std::string tag; -+ -+ Options() -+ : help(false), -+ input_size(1, 128, 128, 32), -+ filter_size(32, 3, 3, 1), -+ groups(32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(false), -+ measure_performance(true), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ splitk(1) {} -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || (filter_size.n() % kAlignment)) { -+ // misaligned tensors -+ return false; -+ } -+ -+ // depthwise conv -+ if (groups != input_size.c()) { -+ return false; -+ } -+ -+ if (filter_size.n() != groups) { -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || (padding.w() != filter_size.w() / 2)) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update(cutlass::Tensor4DCoord input_size, cutlass::Tensor4DCoord filter_size) { -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ -+ cmd.get_cmd_line_argument("g", groups); -+ -+ filter_size.c() = 1; -+ filter_size.n() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ cmd.get_cmd_line_argument("splitk", splitk); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ int32_t padding_h = filter_size.h() / 2; -+ int32_t padding_w = filter_size.w() / 2; -+ padding = {padding_h, padding_h, padding_w, padding_w}; -+ } -+ -+ /// Prints the usage statement. -+ std::ostream &print_usage(std::ostream &out) const { -+ out << "41_depthwise_gemm_fprop example\n\n" -+ << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" -+ << " forward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --g= Groups\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --splitk= Enable splitK\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results " -+ "table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=32 " -+ "--h=224 --w=224 --c=128 --k=128 --g=128 --r=3 --s=3\n\n" -+ << "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=1 " -+ "--h=224 --w=224 --c=32 --k=32 --g=32 --r=3 --s=3 --splitk=10 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = -+ output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result() -+ : runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) {} -+ -+ static std::ostream &print_header(std::ostream &out, Options const &options) { -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,G,stride_h,stride_w,dilation_h,dilation_w,splitK,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream &print(std::ostream &out, int idx, Options const &options) { -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ cutlass::Tensor4DCoord output_size = options.output_size(); -+ out << "conv_" << idx << "," << options.input_size.n() << "," << options.input_size.h() << "," -+ << options.input_size.w() << "," << options.input_size.c() << "," -+ -+ << options.filter_size.n() << "," << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ -+ << options.groups << "," << options.conv_stride.row() << "," << options.conv_stride.column() -+ << "," -+ -+ << options.dilation.row() << "," << options.dilation.column() << "," -+ -+ << options.splitk << "," -+ -+ << runtime_ms << "," << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one testcase -+Result profile_convolution(Options const &options) { -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor tensor_b_transpose(options.filter_size); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_d(options.output_size()); -+ cutlass::HostTensor tensor_ref_d(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), 1, ElementInputA(5), ElementInputA(-6), 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), 1, ElementInputB(3), ElementInputB(-6), 0); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), 1, ElementOutput(5), ElementOutput(-6), 0); -+ -+ // Fill tensor D on host with zeros -+ cutlass::reference::host::TensorFill(tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill(tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_b_transpose.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split P*Q into multiple CTA -+ int split_k_slices = options.splitk; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size(options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices, -+ options.groups); -+ -+ // Construct Direc2dConv::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename Direct2dConv::Arguments arguments{problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta}, -+ tensor_b_transpose.device_ref()}; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ Direct2dConv implicit_gemm_op; -+ -+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on host...\n"; -+ -+ // Compute with reference implementation -+ cutlass::reference::host::Conv2dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter >(problem_size, -+ tensor_a.host_ref(), -+ tensor_b.host_ref(), -+ tensor_c.host_ref(), -+ tensor_ref_d.host_ref(), -+ options.alpha, -+ options.beta); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_d.sync_host(); -+ -+ bool passed = -+ cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ std::stringstream ss; -+ -+ ss << "45_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h() -+ << "x" << options.input_size.w() << "x" << options.input_size.c() << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" -+ << options.filter_size.w() << "x" << options.filter_size.c() << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace << "Input = \n" -+ << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" -+ << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ cudaEvent_t events[2]; -+ -+ for (auto &event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) -+ << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ bool notSupported = false; -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 6)) { -+ std::cerr << "Run on a machine with compute capability at least 60." << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu b/3rdparty/cutlass/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu -new file mode 100644 -index 0000000..12739a0 ---- /dev/null -+++ b/3rdparty/cutlass/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu -@@ -0,0 +1,592 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*************************************************************************************************** -+ Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the -+ "classic data-parallel" and "Split-K" decompositions. -+ -+ For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition -+ for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598) -+ -+ Requires NVIDIA Ampere or newer device (SM80+). -+ -+ - To lock persistence mode, power (400W), clocks (1005MHz) for evaluation (assumes device 0 and A100) -+ -+ cutlass$ sudo nvidia-smi -pm 1 -i 0 -+ -+ cutlass$ sudo nvidia-smi -i 0 -pl 400 -+ -+ cutlass$ sudo nvidia-smi -i 0 -lgc 1005 -+ -+ - Build and run: -+ -+ cutlass$ mkdir build -+ -+ cutlass$ cd build -+ -+ cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80 -+ -+ cutlass/build$ make 47_ampere_gemm_universal_streamk -+ -+ cutlass/build$ ./examples/47_ampere_gemm_universal_streamk/47_ampere_gemm_universal_streamk -+ -+ 10000 timing iterations of 2048 x 2048 x 2048 matrix-matrix multiply -+ -+ Basic data-parallel GEMM -+ Disposition: Passed -+ Avg runtime: 0.112633 ms -+ GFLOPs: 152530 -+ -+ StreamK GEMM with default load-balancing -+ Disposition: Passed -+ Avg runtime: 0.0941929 ms -+ GFLOPs: 182390 -+ Speedup vs Basic-DP: 1.196 -+ -+ StreamK emulating basic data-parallel GEMM -+ Disposition: Passed -+ Avg runtime: 0.113119 ms -+ GFLOPs: 151875 -+ Speedup vs Basic-DP: 0.996 -+ -+ Basic split-K GEMM with tile-splitting factor 2 -+ Disposition: Passed -+ Avg runtime: 0.104772 ms -+ GFLOPs: 163973 -+ -+ StreamK emulating Split-K GEMM with tile-splitting factor 2 -+ Disposition: Passed -+ Avg runtime: 0.105379 ms -+ GFLOPs: 163029 -+ Speedup vs Basic-SplitK: 0.994 -+ -+ **************************************************************************************************/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// A matrix configuration -+using ElementA = cutlass::half_t; // Element type for A matrix operand -+using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand -+constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) -+ -+// B matrix configuration -+using ElementB = cutlass::half_t; // Element type for B matrix operand -+using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand -+constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) -+ -+// C/D matrix configuration -+using ElementC = cutlass::half_t; // Element type for C and D matrix operands -+using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands -+constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C/D matrices in units of elements (up to 16 bytes) -+ -+// Multiply-accumulate blocking/pipelining details -+using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation -+using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature -+using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape) -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape) -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape) -+constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop -+ -+// Epilogue output operator -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementC, // Element type for C and D matrix operands -+ AlignmentC, // Memory access granularity of C and D matrix in units of elements -+ ElementAccumulator, // Element type from internal accumaccumulation -+ ElementAccumulator>; // Data type used to compute linear combination -+ -+// Reference device GEMM implementation type -+using DeviceGemmReference = cutlass::reference::device::Gemm< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ ElementAccumulator>; -+ -+// Classic data-parallel device GEMM implementation type -+using DeviceGemmBasic = cutlass::gemm::device::GemmUniversal< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ NumStages, -+ AlignmentA, -+ AlignmentB>; -+ -+// StreamK device GEMM implementation type -+using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversal< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, // <-- Only difference -+ NumStages, -+ AlignmentA, -+ AlignmentB>; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Testbed utility types -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result -+{ -+ double avg_runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ Result( -+ double avg_runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess) -+ : -+ avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true) -+ {} -+ -+}; -+ -+ -+/// Command line options parsing -+struct Options -+{ -+ std::string command_name; -+ bool help; -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha; -+ float beta; -+ int split_k_factor; -+ int avail_sms; -+ bool reference_check; -+ int iterations; -+ -+ cutlass::HostTensor tensor_a; -+ cutlass::HostTensor tensor_b; -+ cutlass::HostTensor tensor_c; -+ cutlass::HostTensor tensor_d; -+ cutlass::HostTensor tensor_ref_d; -+ -+ Options(std::string command_name) : -+ command_name(command_name), -+ help(false), -+ problem_size({2048, 2048, 2048}), -+ alpha(1.0f), -+ beta(0.0f), -+ split_k_factor(1), -+ avail_sms(-1), // Number of device SMs to use is unlimited -+ reference_check(true), -+ iterations(10000) -+ {} -+ -+ bool valid() const -+ { -+ return true; -+ } -+ -+ void parse(int argc, char const **args) -+ { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ cmd.get_cmd_line_argument("split", split_k_factor); -+ cmd.get_cmd_line_argument("iterations", iterations); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const -+ { -+ out -+ << "Performs a GEMM computation.\n" -+ << "\n" -+ << "Options:\n" -+ << "\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --split= Split-K factor to emulate\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n"; -+ -+ out -+ << "\n\nExamples:\n\n" -+ << "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const -+ { -+ // Two flops per multiply-add -+ return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// GEMM evaluation -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options -+typename DeviceGemmBasic::Arguments args_from_options( -+ const DeviceGemmBasic &device_gemm, -+ const Options &options, -+ cutlass::HostTensor &tensor_a, -+ cutlass::HostTensor &tensor_b, -+ cutlass::HostTensor &tensor_c, -+ cutlass::HostTensor &tensor_d) -+{ -+ return typename DeviceGemmBasic::Arguments( -+ cutlass::gemm::GemmUniversalMode::kGemm, // universal mode -+ options.problem_size, // problem_size -+ options.split_k_factor, // batch count / splitk slices -+ { // epilogue parameters -+ ElementAccumulator(options.alpha), -+ ElementAccumulator(options.beta) -+ }, -+ tensor_a.device_data(), // ptr_A -+ tensor_b.device_data(), // ptr_B -+ tensor_c.device_data(), // ptr_C -+ tensor_d.device_data(), // ptr_D -+ options.problem_size.mk().product(), // batch_stride_A -+ options.problem_size.nk().product(), // batch_stride_B -+ options.problem_size.mn().product(), // batch_stride_C -+ options.problem_size.mn().product(), // batch_stride_D -+ tensor_a.layout().stride(0), // stride_a -+ tensor_b.layout().stride(0), // stride_b -+ tensor_c.layout().stride(0), // stride_c -+ tensor_d.layout().stride(0)); // stride_d -+} -+ -+/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options -+typename DeviceGemmStreamK::Arguments args_from_options( -+ const DeviceGemmStreamK &device_gemm, -+ const Options &options, -+ cutlass::HostTensor &tensor_a, -+ cutlass::HostTensor &tensor_b, -+ cutlass::HostTensor &tensor_c, -+ cutlass::HostTensor &tensor_d) -+{ -+ return typename DeviceGemmStreamK::Arguments( -+ cutlass::gemm::GemmUniversalMode::kGemm, // universal mode -+ options.problem_size, // problem_size -+ options.split_k_factor, // batch count / splitk slices -+ { // epilogue parameters -+ ElementAccumulator(options.alpha), -+ ElementAccumulator(options.beta) -+ }, -+ tensor_a.device_data(), // ptr_A -+ tensor_b.device_data(), // ptr_B -+ tensor_c.device_data(), // ptr_C -+ tensor_d.device_data(), // ptr_D -+ options.problem_size.mk().product(), // batch_stride_A -+ options.problem_size.nk().product(), // batch_stride_B -+ options.problem_size.mn().product(), // batch_stride_C -+ options.problem_size.mn().product(), // batch_stride_D -+ tensor_a.layout().stride(0), // stride_a -+ tensor_b.layout().stride(0), // stride_b -+ tensor_c.layout().stride(0), // stride_c -+ tensor_d.layout().stride(0), // stride_d -+ options.avail_sms); // avail_sms -+} -+ -+ -+/// Execute a given example GEMM computation -+template -+Result run(std::string description, Options &options) -+{ -+ // Display test description -+ std::cout << std::endl << description << std::endl; -+ -+ // Zero-initialize test output matrix D -+ cutlass::reference::host::TensorFill(options.tensor_d.host_view()); -+ options.tensor_d.sync_device(); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ DeviceGemmT device_gemm; -+ -+ // Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT -+ auto arguments = args_from_options(device_gemm, options, options.tensor_a, options.tensor_b, options.tensor_c, options.tensor_d); -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = DeviceGemmT::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Check the problem size is supported or not -+ CUTLASS_CHECK(device_gemm.can_implement(arguments)); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get())); -+ -+ // Correctness / Warmup iteration -+ CUTLASS_CHECK(device_gemm()); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ options.tensor_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ Result result; -+ result.passed = cutlass::reference::host::TensorEquals( -+ options.tensor_d.host_view(), -+ options.tensor_ref_d.host_view()); -+ -+ std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; -+ -+ // Run profiling loop -+ if (options.iterations > 0) -+ { -+ GpuTimer timer; -+ timer.start(); -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ CUTLASS_CHECK(device_gemm()); -+ } -+ timer.stop(); -+ -+ // Compute average runtime and GFLOPs. -+ float elapsed_ms = timer.elapsed_millis(); -+ result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); -+ -+ std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; -+ std::cout << " GFLOPs: " << result.gflops << std::endl; -+ } -+ -+ if (!result.passed) { -+ exit(-1); -+ } -+ -+ return result; -+} -+ -+ -+/// Program entrypoint -+int main(int argc, const char **argv) -+{ -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ // Current device must must have compute capability at least 80 -+ cudaDeviceProp props; -+ int current_device_id; -+ CUDA_CHECK(cudaGetDevice(¤t_device_id)); -+ CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); -+ if (!((props.major * 10 + props.minor) >= 80)) -+ { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ // Parse commandline options -+ Options options("ampere_streamk_gemm"); -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ std::cout << -+ options.iterations << " timing iterations of " << -+ options.problem_size.m() << " x " << -+ options.problem_size.n() << " x " << -+ options.problem_size.k() << " matrix-matrix multiply" << std::endl; -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ -+ // -+ // Initialize GEMM datasets -+ // -+ -+ // Initialize tensors using CUTLASS helper functions -+ options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K -+ options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N -+ options.tensor_c.resize(options.problem_size.mn()); // <- Create matrix C with dimensions M x N -+ options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel -+ options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel -+ -+ // Fill matrix A on host with uniform-random data [4, -4] -+ cutlass::reference::host::TensorFillRandomUniform( -+ options.tensor_a.host_view(), -+ 1, -+ ElementA(2), -+ ElementA(-2), -+ 0); -+ -+ // Fill matrix B on host with uniform-random data [4, -4] -+ cutlass::reference::host::TensorFillRandomUniform( -+ options.tensor_b.host_view(), -+ 1, -+ ElementB(2), -+ ElementB(-2), -+ 0); -+ -+ // Fill matrix C on host with uniform-random data [4, -4] -+ cutlass::reference::host::TensorFillRandomUniform( -+ options.tensor_c.host_view(), -+ 1, -+ ElementC(2), -+ ElementC(-2), -+ 0); -+ -+ -+ // -+ // Compute reference output -+ // -+ -+ // Copy data from host to GPU -+ options.tensor_a.sync_device(); -+ options.tensor_b.sync_device(); -+ options.tensor_c.sync_device(); -+ -+ // Zero-initialize reference output matrix D -+ cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view()); -+ options.tensor_ref_d.sync_device(); -+ -+ // Create instantiation for device reference gemm kernel -+ DeviceGemmReference gemm_reference; -+ -+ // Launch device reference gemm kernel -+ gemm_reference( -+ options.problem_size, -+ ElementAccumulator(options.alpha), -+ options.tensor_a.device_ref(), -+ options.tensor_b.device_ref(), -+ ElementAccumulator(options.beta), -+ options.tensor_c.device_ref(), -+ options.tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ CUDA_CHECK(cudaDeviceSynchronize()); -+ -+ // Copy output data from reference kernel to host for comparison -+ options.tensor_ref_d.sync_host(); -+ -+ -+ // -+ // Evaluate CUTLASS kernels -+ // -+ -+ // Test default operation -+ if (options.split_k_factor == 1) -+ { -+ // Compare basic data-parallel version versus StreamK version using default load-balancing heuristics -+ Result basic_dp = run("Basic data-parallel GEMM", options); -+ Result streamk_default = run("StreamK GEMM with default load-balancing", options); -+ -+ printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms)); -+ -+ // Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1 -+ options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing) -+ Result streamk_dp = run("StreamK emulating basic data-parallel GEMM", options); -+ options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs) -+ -+ printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms)); -+ -+ options.split_k_factor++; // Increment splitting factor for next evaluation -+ -+ } -+ -+ // Show that StreamK can emulate "Split-K" with a tile-splitting factor -+ Result basic_splitk = run( -+ std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor), -+ options); -+ -+ Result streamk_splitk = run( -+ std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor), -+ options); -+ -+ printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms)); -+ -+ return 0; -+} -diff --git a/3rdparty/cutlass/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/3rdparty/cutlass/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu -new file mode 100644 -index 0000000..599d1d5 ---- /dev/null -+++ b/3rdparty/cutlass/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu -@@ -0,0 +1,463 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Simple Hopper GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture -+ -+ This example demonstrate a simple way to instantiate and run a TF32 GEMM using the new CUTLASS 3.0 -+ APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows: -+ -+ 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) -+ which are more efficient than the Ampere tensor core instructions. -+ -+ 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large -+ blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous -+ copies between thread blocks in a cluster. Another advantage is that TMA can load in FP32 data and -+ convert them implicitly to TF32. -+ -+ 3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details). -+ -+ Examples: -+ -+ $ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048 -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cute/tensor.hpp" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/packed_stride.hpp" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+ -+#include "helper.h" -+ -+using namespace cute; -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// GEMM kernel configurations -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// A matrix configuration -+using ElementA = float; // Element type for A matrix operand -+using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand -+constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) -+ -+// B matrix configuration -+using ElementB = float; // Element type for B matrix operand -+using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand -+constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) -+ -+// C/D matrix configuration -+using ElementC = float; // Element type for C and D matrix operands -+using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands -+ -+// Core kernel configurations -+using ElementAccumulator = float; // Element type for internal accumulation -+using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature -+using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -+using TilesShape = Shape<_128,_128,_32>; // Threadblock-level tile size -+using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster -+using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size -+using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder -+ -+using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ ArchTag, OperatorClass, -+ ElementA, LayoutA, AlignmentA, -+ ElementB, LayoutB, AlignmentB, -+ ElementAccumulator, -+ TilesShape, ClusterShape, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, // Indicates ProblemShape -+ CollectiveMainloop, -+ CollectiveEpilogue -+>; -+ -+using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+// Reference device GEMM implementation type -+using DeviceGemmReference = cutlass::reference::device::Gemm< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ ElementAccumulator>; -+ -+using StrideA = typename Gemm::GemmKernel::StrideA; -+using StrideB = typename Gemm::GemmKernel::StrideB; -+using StrideC = typename Gemm::GemmKernel::StrideC; -+using StrideD = typename Gemm::GemmKernel::StrideD; -+ -+// -+// Data members -+// -+ -+/// Initialization -+StrideA stride_A; -+StrideB stride_B; -+StrideC stride_C; -+StrideD stride_D; -+uint64_t seed; -+ -+cutlass::DeviceAllocation block_A; -+cutlass::DeviceAllocation block_B; -+cutlass::DeviceAllocation block_C; -+cutlass::DeviceAllocation block_D; -+cutlass::DeviceAllocation block_ref_D; -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Testbed utility types -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ float alpha, beta; -+ int iterations; -+ int m, n, k; -+ -+ Options(): -+ help(false), -+ m(5120), n(4096), k(4096), -+ alpha(1.f), beta(0.f), -+ iterations(1000) -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("m", m); -+ cmd.get_cmd_line_argument("n", n); -+ cmd.get_cmd_line_argument("k", k); -+ cmd.get_cmd_line_argument("alpha", alpha, 1.f); -+ cmd.get_cmd_line_argument("beta", beta, 0.f); -+ cmd.get_cmd_line_argument("iterations", iterations); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "48_hopper_warp_specialized_gemm\n\n" -+ << " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement\n\n" -+ << " --m= Sets the M extent of the GEMM\n" -+ << " --n= Sets the N extent of the GEMM\n" -+ << " --k= Sets the K extent of the GEMM\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n"; -+ -+ out -+ << "\n\nExamples:\n\n" -+ << "$ " << "48_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const -+ { -+ // Two flops per multiply-add -+ uint64_t flop = uint64_t(2) * m * n * k; -+ double gflop = double(flop) / double(1.0e9); -+ return gflop / runtime_s; -+ } -+}; -+ -+/// Result structure -+struct Result -+{ -+ double avg_runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ Result( -+ double avg_runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess) -+ : -+ avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) -+ {} -+ -+}; -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// GEMM setup and evaluation -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to initialize a block of device data -+template -+bool initialize_block( -+ cutlass::DeviceAllocation& block, -+ uint64_t seed=2023) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block.get(), block.size(), seed, scope_max, scope_min, 0); -+ -+ return true; -+} -+ -+/// Initialize operands to be used in the GEMM and reference GEMM -+void initialize(const Options &options) { -+ -+ stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{})); -+ stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{})); -+ stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{})); -+ stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{})); -+ -+ block_A.reset(options.m * options.k); -+ block_B.reset(options.k * options.n); -+ block_C.reset(options.m * options.n); -+ block_D.reset(options.m * options.n); -+ block_ref_D.reset(options.m * options.n); -+ -+ initialize_block(block_A, seed + 2023); -+ initialize_block(block_B, seed + 2022); -+ initialize_block(block_C, seed + 2021); -+} -+ -+/// Populates a Gemm::Arguments structure from the given commandline options -+typename Gemm::Arguments args_from_options(const Options &options) -+{ -+ typename Gemm::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ {options.m, options.n, options.k}, -+ block_A.get(), -+ stride_A, -+ block_B.get(), -+ stride_B, -+ {block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}} -+ }; -+ -+ return arguments; -+} -+ -+bool verify(const Options &options) { -+ cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k})); -+ cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.n, options.k})); -+ cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n})); -+ cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n})); -+ -+ // -+ // Compute reference output -+ // -+ -+ // Create instantiation for device reference gemm kernel -+ DeviceGemmReference gemm_reference; -+ -+ // Launch device reference gemm kernel -+ gemm_reference( -+ {options.m, options.n, options.k}, -+ ElementAccumulator(options.alpha), -+ ref_A, -+ ref_B, -+ ElementAccumulator(options.beta), -+ ref_C, -+ ref_D); -+ -+ // Wait for kernel to finish -+ CUDA_CHECK(cudaDeviceSynchronize()); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); -+ -+ return passed; -+} -+ -+/// Execute a given example GEMM computation -+template -+int run(Options &options) -+{ -+ initialize(options); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm; -+ -+ // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm -+ auto arguments = args_from_options(options); -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Check if the problem size is supported or not -+ CUTLASS_CHECK(gemm.can_implement(arguments)); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); -+ -+ // Correctness / Warmup iteration -+ CUTLASS_CHECK(gemm.run()); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ Result result; -+ result.passed = verify(options); -+ -+ std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; -+ -+ if (!result.passed) { -+ exit(-1); -+ } -+ -+ // Run profiling loop -+ if (options.iterations > 0) -+ { -+ GpuTimer timer; -+ timer.start(); -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ CUTLASS_CHECK(gemm.run()); -+ } -+ timer.stop(); -+ -+ // Compute average runtime and GFLOPs. -+ float elapsed_ms = timer.elapsed_millis(); -+ result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); -+ -+ std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; -+ std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; -+ std::cout << " GFLOPS: " << result.gflops << std::endl; -+ } -+ -+ return 0; -+} -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example -+ // and must have compute capability at least 90. -+ if (__CUDACC_VER_MAJOR__ < 12) { -+ std::cerr << "This example requires CUDA 12 or newer.\n"; -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ cudaDeviceProp props; -+ int current_device_id; -+ CUDA_CHECK(cudaGetDevice(¤t_device_id)); -+ CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (props.major < 9) { -+ std::cerr -+ << "This example requires a GPU of NVIDIA's Hopper Architecture or " -+ << "later (compute capability 90 or greater).\n"; -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // -+ // Evaluate CUTLASS kernels -+ // -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ run(options); -+#endif -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu b/3rdparty/cutlass/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu -new file mode 100644 -index 0000000..7323cc3 ---- /dev/null -+++ b/3rdparty/cutlass/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu -@@ -0,0 +1,529 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Hopper GEMM example to create a GEMM kernel with custom Collectives -+ -+ The following example shows how to assemble a custom GEMM kernel that spells out the Collectives -+ directly instead of using a builder and, in the process, instance a more efficient Epilogue -+ (from `cutlass/epilogue/collective/epilogue.hpp`) instead of using the default epilogue. -+ -+ The GemmUniversal API takes 3 main template arguments: -+ (1) the problem shape / extents -+ (2) the collective mainloop type -+ (3) the collective epilogue type -+ -+ While the collecive mainloop can be stamped out using a CollectiveBuilder interface, it is -+ possible to build a custom collective mainloop directly as well. Furthermore, since epilogues -+ do not yet have a builder interface, this example shows how to instantiate a more-efficient -+ epilogue alongside the collective mainloop. -+ -+ Note: there are several ways to implement the GEMM epilogue in Hopper - each with its own set -+ of trade-offs. So it is recommended that users look at the options available under -+ cutlass/epilogue/collective and evaluate for their particular scenario. -+ -+ Please refer to examples 48, 49 to learn more about kernel schedules and other CuTe examples -+ present in `test/unit/cute` to famialiarize with the basics of CuTe. -+ -+ Examples: -+ -+ $ ./examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cute/tensor.hpp" -+#include "cutlass/util/command_line.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/epilogue/collective/epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/packed_stride.hpp" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool error; -+ -+ int m, n, k, l; -+ int alpha, beta; -+ -+ Options(): -+ help(false), -+ error(false), -+ m(2048), n(2048), k(2048), l(1), -+ alpha(1), beta(0) -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("m", m, 2048); -+ cmd.get_cmd_line_argument("n", n, 2048); -+ cmd.get_cmd_line_argument("k", k, 2048); -+ cmd.get_cmd_line_argument("l", l, 1); -+ cmd.get_cmd_line_argument("alpha", alpha, 1); -+ cmd.get_cmd_line_argument("beta", beta, 0); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "50_hopper_gemm_with_vectorized_epilogue\n\n" -+ << "Hopper GEMM Example with Epilogue Swizzle.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement\n\n" -+ << " --m= Sets the M extent of the GEMM\n" -+ << " --n= Sets the N extent of the GEMM\n" -+ << " --k= Sets the K extent of the GEMM\n" -+ << " --l= Sets the L extent (batch count) of the GEMM\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n"; -+ -+ return out; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to initialize a block of device data -+template -+bool initialize_block( -+ cutlass::DeviceAllocation& block, -+ uint64_t seed=2023) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block.get(), block.size(), seed, scope_max, scope_min, 0); -+ -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+// Wrapper to run and verify a GEMM. -+template < -+ class Gemm -+> -+struct ExampleRunner { -+ -+ using StrideA = typename Gemm::GemmKernel::StrideA; -+ using StrideB = typename Gemm::GemmKernel::StrideB; -+ using StrideC = typename Gemm::GemmKernel::StrideC; -+ using StrideD = typename Gemm::GemmKernel::StrideD; -+ -+ using LayoutA = typename Gemm::LayoutA; -+ using LayoutB = typename Gemm::LayoutB; -+ using LayoutC = typename Gemm::LayoutC; -+ using LayoutD = typename Gemm::LayoutD; -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementAcc = typename Gemm::ElementAccumulator; -+ -+ using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; -+ using ElementC = typename Gemm::ElementC; -+ using ElementOutput = typename CollectiveEpilogue::ElementOutput; -+ using ElementCompute = typename CollectiveEpilogue::ElementCompute; -+ using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; -+ -+ using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; -+ -+ // -+ // Data members -+ // -+ -+ /// Initialization -+ StrideA stride_A; -+ StrideB stride_B; -+ StrideC stride_C; -+ StrideD stride_D; -+ uint64_t seed = 0; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ cutlass::DeviceAllocation block_ref_D; -+ -+ // -+ // Methods -+ // -+ -+ bool verify(const ProblemShapeType& problem_size, int32_t alpha, int32_t beta) { -+ auto [M, N, K, L] = problem_size; -+ -+ cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); -+ cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); -+ cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); -+ cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); -+ -+ cutlass::reference::device::GemmComplex( -+ {M, N, K}, -+ ElementCompute(alpha), -+ ref_A, -+ cutlass::ComplexTransform::kNone, -+ ref_B, -+ cutlass::ComplexTransform::kNone, -+ ElementCompute(beta), -+ ref_C, -+ ref_D, -+ ElementAccumulator(0), -+ L, // batch_count -+ M * K, // batch_stride_A -+ K * N, // batch_stride_B -+ M * N, // batch_stride_C -+ M * N // batch_stride_D -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "Reference kernel failed. Last CUDA error: " -+ << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); -+ -+ return passed; -+ } -+ -+ /// Initialize operands to be used in the GEMM and reference GEMM -+ void initialize(const ProblemShapeType& problem_size) { -+ auto problem_shape_MNKL = cute::append<4>(problem_size, 1); -+ auto [M, N, K, L] = problem_shape_MNKL; -+ -+ stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); -+ stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); -+ stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); -+ stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); -+ -+ block_A.reset(M * K * L); -+ block_B.reset(K * N * L); -+ block_C.reset(M * N * L); -+ block_D.reset(M * N * L); -+ block_ref_D.reset(M * N * L); -+ -+ initialize_block(block_A, seed + 2023); -+ initialize_block(block_B, seed + 2022); -+ initialize_block(block_C, seed + 2021); -+ } -+ -+ bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { -+ ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; -+ -+ initialize(problem_size); -+ -+ typename Gemm::GemmKernel::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, -+ block_A.get(), -+ stride_A, -+ block_B.get(), -+ stride_B, -+ {block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}}, -+ hw_info -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "This kernel is not supported. Last CUDA error is: " -+ << cudaGetErrorString(cudaGetLastError()) << std::endl; -+ return false; -+ } -+ -+ status = gemm_op.initialize(arguments, workspace.get()); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " -+ << cudaGetErrorString(cudaGetLastError()) << std::endl; -+ return false; -+ } -+ -+ // Run the GEMM -+ status = gemm_op.run(); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " -+ << cudaGetErrorString(cudaGetLastError()) << std::endl; -+ return false; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " -+ << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ // Verify that the result is correct -+ bool passed = verify(problem_size, options.alpha, options.beta); -+ if (!passed) { -+ std::cerr << "Reference check failed" << std::endl; -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 12 || props.major < 9) { -+ std::cout -+ << "This example requires a GPU of NVIDIA's Hopper Architecture or " -+ << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.error) { -+ std::cerr << "Aborting execution." << std::endl; -+ return -1; -+ } -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+ // -+ // Run examples -+ // -+ -+ // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This -+ // information is used by the underlying kernel. -+ cutlass::KernelHardwareInfo hw_info; -+ -+ // Change device_id to another value if you are running on a machine with multiple GPUs and wish -+ // to use a GPU other than that with device ID 0. -+ hw_info.device_id = 0; -+ hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); -+ -+ bool passed; -+ -+ // Problem configuration -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementAcc = int32_t; -+ using ElementOutput = int8_t; -+ -+ // Note : Only TN WGMMA Gemm is supported currently in 3.0 -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using LayoutD = cutlass::layout::ColumnMajor; -+ -+ // Tiling configuration selection -+ using TileShape = Shape<_128,_64,_128>; -+ -+ // Choosing a thread block cluster larger than 1 allows us to Multicast data across thread blocks -+ using ClusterShape = Shape<_1,_2,_1>; -+ -+ // -+ // Assembling the CollectiveMainloop type -+ // -+ -+ // Pipeline Depth to be used i.e number of A, B buffers in shared memory -+ constexpr int PipelineStages = 8; -+ -+ // Let's choose a Warp-Specialized Mainloop implemention which uses TMA -+ // Note : This requires / assumes the tensors to be 16B aligned -+ using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized; -+ -+ // TN => K Major for both A & B -+ static constexpr cute::GMMA::Major GmmaMajorA = cute::GMMA::Major::K; -+ static constexpr cute::GMMA::Major GmmaMajorB = cute::GMMA::Major::K; -+ -+ // We use the SS op selector as both A, B operands are read directly from SMEM (for TN WGMMA) -+ using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< -+ ElementA, ElementB, ElementAcc, TileShape, GmmaMajorA, GmmaMajorB>())); -+ -+ // A loads can be optimized with multicast if cluster-n > 1 -+ using GmemTiledCopyA = std::conditional< cute::size(shape<1>(ClusterShape{})) == 1, -+ cute::SM90_TMA_LOAD, -+ cute::SM90_TMA_LOAD_MULTICAST>::type; -+ -+ // B loads can be optimized with multicast if cluster-m > 1 -+ using GmemTiledCopyB = std::conditional< cute::size(shape<0>(ClusterShape{})) == 1, -+ cute::SM90_TMA_LOAD, -+ cute::SM90_TMA_LOAD_MULTICAST>::type; -+ -+ using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector< -+ GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape{})), decltype(cute::get<2>(TileShape{})) -+ >()); -+ -+ using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector< -+ GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape{})), decltype(cute::get<2>(TileShape{})) -+ >()); -+ -+ using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< -+ DispatchPolicy, -+ TileShape, -+ ElementA, -+ cutlass::gemm::TagToStrideA_t, -+ ElementB, -+ cutlass::gemm::TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, -+ SmemLayoutAtomA, -+ void, // Does not need a SmemCopyAtom, since A is read directly from SMEM -+ cute::identity, -+ GmemTiledCopyB, -+ SmemLayoutAtomB, -+ void, // Does not need a SmemCopyAtom, since B is read directly from SMEM -+ cute::identity -+ >; -+ -+ // -+ // Assembling the Collective Epilogue Type -+ // -+ -+ // Break the 128 along TILE_M into chunks of 32, to get a 128B leading dimension -+ using PreSwizzleLayout = Layout< Shape< Shape <_32,_4 >,_64>, -+ Stride,_32>>; -+ -+ // 128 threads loading 16 elements each (to get vectorized global stores) -+ using TileShapeS2R = Shape<_128,_16>; -+ -+ // Layout to ensure bank-conflict free loads & stores -+ using SmemLayout = ComposedLayout< -+ Swizzle<3,4,3>, -+ smem_ptr_flag_bits::value>, -+ PreSwizzleLayout>; -+ -+ // Tiled copy from Smem to Registers -+ // Note : CuTe will vectorize this copy if the tiling + swizzling above were right -+ using TiledCopyS2R = TiledCopy< -+ Copy_Atom, -+ Layout< Shape<_128,_16>, -+ Stride<_16,_1>>, -+ TileShapeS2R>; -+ -+ using Epilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ SmemLayout, -+ Copy_Atom, -+ TiledCopyS2R, -+ Copy_Atom>; -+ -+ // -+ // Assembling the GemmKernel -+ // -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ Epilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ ExampleRunner runner; -+ -+ passed = runner.run(options, hw_info); -+ -+ std::cout << "WGMMA GEMM with Epilogue Swizzle : " << (passed ? "Passed" : "Failed") << std::endl; -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/common/helper.h b/3rdparty/cutlass/examples/common/helper.h -new file mode 100644 -index 0000000..ba04113 ---- /dev/null -+++ b/3rdparty/cutlass/examples/common/helper.h -@@ -0,0 +1,77 @@ -+#pragma once -+ -+#include "cuda_runtime.h" -+ -+/** -+ * Panic wrapper for unwinding CUTLASS errors -+ */ -+#define CUTLASS_CHECK(status) \ -+ { \ -+ cutlass::Status error = status; \ -+ if (error != cutlass::Status::kSuccess) { \ -+ std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ -+ << std::endl; \ -+ exit(EXIT_FAILURE); \ -+ } \ -+ } -+ -+ -+/** -+ * Panic wrapper for unwinding CUDA runtime errors -+ */ -+#define CUDA_CHECK(status) \ -+ { \ -+ cudaError_t error = status; \ -+ if (error != cudaSuccess) { \ -+ std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ -+ << " at line: " << __LINE__ << std::endl; \ -+ exit(EXIT_FAILURE); \ -+ } \ -+ } -+ -+ -+/** -+ * GPU timer for recording the elapsed time across kernel(s) launched in GPU stream -+ */ -+struct GpuTimer -+{ -+ cudaStream_t _stream_id; -+ cudaEvent_t _start; -+ cudaEvent_t _stop; -+ -+ /// Constructor -+ GpuTimer() : _stream_id(0) -+ { -+ CUDA_CHECK(cudaEventCreate(&_start)); -+ CUDA_CHECK(cudaEventCreate(&_stop)); -+ } -+ -+ /// Destructor -+ ~GpuTimer() -+ { -+ CUDA_CHECK(cudaEventDestroy(_start)); -+ CUDA_CHECK(cudaEventDestroy(_stop)); -+ } -+ -+ /// Start the timer for a given stream (defaults to the default stream) -+ void start(cudaStream_t stream_id = 0) -+ { -+ _stream_id = stream_id; -+ CUDA_CHECK(cudaEventRecord(_start, _stream_id)); -+ } -+ -+ /// Stop the timer -+ void stop() -+ { -+ CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); -+ } -+ -+ /// Return the elapsed time (in milliseconds) -+ float elapsed_millis() -+ { -+ float elapsed = 0.0; -+ CUDA_CHECK(cudaEventSynchronize(_stop)); -+ CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); -+ return elapsed; -+ } -+}; -diff --git a/3rdparty/cutlass/examples/cute/tutorial/sgemm_nt_1.cu b/3rdparty/cutlass/examples/cute/tutorial/sgemm_nt_1.cu -new file mode 100644 -index 0000000..fc4839a ---- /dev/null -+++ b/3rdparty/cutlass/examples/cute/tutorial/sgemm_nt_1.cu -@@ -0,0 +1,426 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+#include -+ -+#include -+ -+#include "cutlass/util/print_error.hpp" -+#include "cutlass/util/GPU_Clock.hpp" -+#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 -+# include "cutlass/util/cublas_wrappers.hpp" -+#endif -+#include "cutlass/util/helper_cuda.hpp" -+ -+template -+__global__ static -+__launch_bounds__(decltype(size(CThreadLayout{}))::value) -+void -+gemm_device(MShape M, NShape N, KShape K, -+ TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, -+ TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, -+ TC * C, CStride dC, CBlockLayout , CThreadLayout tC, -+ Alpha alpha, Beta beta) -+{ -+ using namespace cute; -+ using X = Underscore; -+ -+ // Preconditions -+ CUTE_STATIC_ASSERT(is_static::value); -+ CUTE_STATIC_ASSERT(is_static::value); -+ CUTE_STATIC_ASSERT(is_static::value); -+ -+ CUTE_STATIC_ASSERT(is_static::value); -+ CUTE_STATIC_ASSERT(is_static::value); -+ CUTE_STATIC_ASSERT(is_static::value); -+ -+ CUTE_STATIC_ASSERT_V(size(tA) == size(tC)); -+ CUTE_STATIC_ASSERT_V(size(tB) == size(tC)); -+ -+ //CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M -+ //CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N -+ CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K -+ -+ // Shared memory buffers -+ __shared__ TA smemA[cosize_v]; -+ __shared__ TB smemB[cosize_v]; -+ auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K) -+ auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K) -+ -+ // Represent the full tensors -+ auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K) -+ auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K) -+ auto mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N) -+ -+ // Get the appropriate blocks for this thread block -- -+ // potential for thread block locality -+ auto blk_shape = make_shape(size<0>(sA), size<0>(sB), size<1>(sB));// (BLK_M,BLK_N,BLK_K) -+ auto blk_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) -+ -+ auto gA = local_tile(mA, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) -+ auto gB = local_tile(mB, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) -+ auto gC = local_tile(mC, blk_shape, blk_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) -+ -+ // -+ // Partition the copying of A and B tiles across the threads -+ // -+ -+ // TUTORIAL: Example of simple partitioning of A|B tiles over tA|tB -+ // Default is a raked partition, but can be changed with Step parameter -+ -+ auto tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k) -+ auto tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K) -+ -+ auto tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k) -+ auto tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K) -+ -+ // -+ // Define C accumulators and A/B partitioning -+ // -+ -+ // TUTORIAL: Example of partitioning via projections of tC -+ -+ // Partition sA (M,K) by the rows of tC -+ auto tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K) -+ // Partition sB (N,K) by the cols of tC -+ auto tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K) -+ // Partition gC (M,N) by the tile of tC -+ auto tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N) -+ -+ // Allocate the accumulators -- same size as the projected data -+ auto tCrC = make_fragment_like(tCgC); // (THR_M,THR_N) -+ -+ // Clear the accumulators -+ clear(tCrC); -+ -+#if 0 -+ if(thread0()) { -+ print("mA\n"); -+ print(mA.shape()); print("\n"); print(mA.stride()); -+ print("\n\ngA\n"); -+ print(gA.shape()); print("\n"); print(gA.stride()); -+ print("\n\ntAgA\n"); -+ print(tAgA.shape()); print("\n"); print(tAgA.stride()); -+ print("\n\nsA\n"); -+ print(sA.shape()); print("\n"); print(sA.stride()); -+ print("\n\ntAsA\n"); -+ print(tAsA.shape()); print("\n"); print(tAsA.stride()); -+ print("\n\n"); -+ } -+#endif -+ -+#if 0 -+ if(thread0()) { -+ print("mB\n"); -+ print(mB.shape()); print("\n"); print(mB.stride()); -+ print("\n\ngB\n"); -+ print(gB.shape()); print("\n"); print(gB.stride()); -+ print("\n\ntBgB\n"); -+ print(tBgB.shape()); print("\n"); print(tBgB.stride()); -+ print("\n\nsB\n"); -+ print(sB.shape()); print("\n"); print(sB.stride()); -+ print("\n\ntBsB\n"); -+ print(tBsB.shape()); print("\n"); print(tBsB.stride()); -+ print("\n\n"); -+ } -+#endif -+ -+#if 0 -+ if(thread0()) { -+ print("mC\n"); -+ print(mC.shape()); print("\n"); print(mC.stride()); -+ print("\n\ngC\n"); -+ print(gC.shape()); print("\n"); print(gC.stride()); -+ print("\n\ntCsA\n"); -+ print(tCsA.shape()); print("\n"); print(tCsA.stride()); -+ print("\n\ntCsB\n"); -+ print(tCsB.shape()); print("\n"); print(tCsB.stride()); -+ print("\n\ntCgC\n"); -+ print(tCgC.shape()); print("\n"); print(tCgC.stride()); -+ print("\n\ntCrC\n"); -+ print(tCrC.shape()); print("\n"); print(tCrC.stride()); -+ print("\n\n"); -+ } -+#endif -+ -+#if 1 -+ -+ // TUTORIAL: Example of a very simple compute loop -+ // Data is read from global to shared memory via the tA|tB partitioning -+ // gemm(.) operates on the shared memory directly via the tC partitioning -+ -+ auto k_max = size<2>(tAgA); -+ -+ for (int k = 0; k < k_max; ++k) -+ { -+ // Copy gmem to smem -+ copy(tAgA(_,_,k), tAsA); -+ copy(tBgB(_,_,k), tBsB); -+ -+ // In case copy uses cp.async, make sure that the cp.async -+ // instructions are ordered with respect to other cp.async -+ // instructions (fence), then wait on all the outstanding copy -+ // operations (wait<0>()). __syncthreads() alone does not do -+ // this. -+ // -+ // NOTE: cp_async_wait<0>() currently issues cp.async.wait_all. -+ // This is equivalent to cp.async.commit_group followed by -+ // cp.async_wait_group 0. This should make the first -+ // cp_async_fence() (which also issues cp.async.commit_group) -+ // redundant. The tutorial works as-is, so we'll leave the -+ // redundant fence in for now and study its removal later. -+ cp_async_fence(); -+ cp_async_wait<0>(); -+ -+ __syncthreads(); -+ -+ // Compute gemm on smem -+ gemm(tCsA, tCsB, tCrC); -+ -+ __syncthreads(); -+ } -+ -+#endif -+ -+ // -+ // Epilogue -+ // -+ -+ axpby(alpha, tCrC, beta, tCgC); -+} -+ -+ -+template -+void -+gemm(int m, int n, int k, -+ Alpha alpha, -+ TA const* A, int ldA, -+ TB const* B, int ldB, -+ Beta beta, -+ TC * C, int ldC, -+ cudaStream_t stream = 0) -+{ -+ using namespace cute; -+ -+ // Define shapes (dynamic) -+ auto M = int(m); -+ auto N = int(n); -+ auto K = int(k); -+ -+ // Define strides (mixed) -+ auto dA = make_stride(Int<1>{}, ldA); -+ auto dB = make_stride(Int<1>{}, ldB); -+ auto dC = make_stride(Int<1>{}, ldC); -+ -+ // Define block sizes (static) -+ auto bM = Int<128>{}; -+ auto bN = Int<128>{}; -+ auto bK = Int< 8>{}; -+ -+ // Define the block layouts (static) -+ auto sA = make_layout(make_shape(bM,bK)); -+ auto sB = make_layout(make_shape(bN,bK)); -+ auto sC = make_layout(make_shape(bM,bN)); -+ -+ // Define the thread layouts (static) -+ auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); -+ auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); -+ auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); -+ -+ dim3 dimBlock(size(tC)); -+ dim3 dimGrid(ceil_div(size(M), size(bM)), -+ ceil_div(size(N), size(bN))); -+ gemm_device -+ <<< dimGrid, dimBlock, 0, stream >>> -+ (M, N, K, -+ A, dA, sA, tA, -+ B, dB, sB, tB, -+ C, dC, sC, tC, -+ alpha, beta); -+} -+ -+#include -+#include -+#include -+ -+void test_gemm(int m, int n, int k) -+{ -+ cute::device_init(0); -+ -+ std::cout << "M = " << m << std::endl; -+ std::cout << "N = " << n << std::endl; -+ std::cout << "K = " << k << std::endl; -+ -+ using TA = float; -+ using TB = float; -+ using TC = float; -+ using TI = float; -+ -+ thrust::host_vector h_A(m*k); -+ thrust::host_vector h_B(n*k); -+ thrust::host_vector h_C(m*n); -+ -+ for (int j = 0; j < m*k; ++j) h_A[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); -+ for (int j = 0; j < n*k; ++j) h_B[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); -+ for (int j = 0; j < m*n; ++j) h_C[j] = static_cast(-1); -+ -+ thrust::device_vector d_A = h_A; -+ thrust::device_vector d_B = h_B; -+ thrust::device_vector d_C = h_C; -+ -+ TI alpha = 1.0; -+ TI beta = 0.0; -+ -+ double gflops = (2.0*m*n*k) * 1e-9; -+ -+ const int timing_iterations = 100; -+ GPU_Clock timer; -+ -+#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 -+ // -+ // cuBLas -+ // -+ -+ cublasHandle_t handle; -+ cublasCreate(&handle); -+ -+ // Run once -+ d_C = h_C; -+ blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, -+ m, n, k, -+ &alpha, -+ d_A.data().get(), m, -+ d_B.data().get(), n, -+ &beta, -+ d_C.data().get(), m); -+ CUTE_CHECK_LAST(); -+ -+ thrust::host_vector cublas_result = d_C; -+ -+ // Timing iterations -+ timer.start(); -+ for (int i = 0; i < timing_iterations; ++i) { -+ blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, -+ m, n, k, -+ &alpha, -+ d_A.data().get(), m, -+ d_B.data().get(), n, -+ &beta, -+ d_C.data().get(), m); -+ } -+ double cublas_time = timer.seconds() / timing_iterations; -+ CUTE_CHECK_LAST(); -+ printf("CUBLAS_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cublas_time, cublas_time*1000); -+ -+#else -+ -+ std::cout << "Verification by comparison with cuBLAS is disabled, " -+ "either because the CMake option CUTLASS_ENABLE_CUBLAS " -+ "was explicitly set to OFF, or because CMake could not find cuBLAS. " -+ "If you would like to enable verification with cuBLAS, " -+ "please set the CMake option CUTLASS_ENABLE_CUBLAS to ON, " -+ "rerun CMake, and recompile this example.\n"; -+ -+#endif // CUTLASS_ENABLE_CUBLAS -+ -+ // -+ // CuTe -+ // -+ -+ // Run once (and check) -+ d_C = h_C; -+ gemm(m, n, k, -+ alpha, -+ d_A.data().get(), m, -+ d_B.data().get(), n, -+ beta, -+ d_C.data().get(), m); -+ CUTE_CHECK_LAST(); -+ thrust::host_vector cute_result = d_C; -+ -+ // Timing iterations -+ timer.start(); -+ for (int i = 0; i < timing_iterations; ++i) { -+ gemm(m, n, k, -+ alpha, -+ d_A.data().get(), m, -+ d_B.data().get(), n, -+ beta, -+ d_C.data().get(), m); -+ } -+ double cute_time = timer.seconds() / timing_iterations; -+ CUTE_CHECK_LAST(); -+ printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); -+ -+#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 -+ printf("Empirical Perf: %.1f%%\n", (cublas_time / cute_time) * 100); -+ -+ auto host_matrix_to_const_column_major_cute_tensor = -+ [](const auto& X, int num_rows, int num_cols, int LDX) { -+ const auto shape = cute::Shape{num_rows, num_cols}; -+ const auto strides = cute::Stride{1, LDX}; -+ return cute::make_tensor(X.data(), cute::make_layout(shape, strides)); -+ }; -+ -+ const auto A_view = host_matrix_to_const_column_major_cute_tensor(h_A, m, k, m); -+ // B^T is k x n, so B is n x k. -+ const auto B_view = host_matrix_to_const_column_major_cute_tensor(h_B, n, k, n); -+ const auto C_computed_view = host_matrix_to_const_column_major_cute_tensor(cute_result, m, n, m); -+ const auto C_expected_view = host_matrix_to_const_column_major_cute_tensor(cublas_result, m, n, m); -+ print_matrix_multiply_mollified_relative_error("float", A_view, B_view, C_computed_view, C_expected_view); -+ -+#endif // CUTLASS_ENABLE_CUBLAS -+} -+ -+ -+int main(int argc, char** argv) -+{ -+ int m = 5120; -+ if (argc >= 2) -+ sscanf(argv[1], "%d", &m); -+ -+ int n = 5120; -+ if (argc >= 3) -+ sscanf(argv[2], "%d", &n); -+ -+ int k = 4096; -+ if (argc >= 4) -+ sscanf(argv[3], "%d", &k); -+ -+ test_gemm(m, n, k); -+ -+ return 0; -+} -diff --git a/3rdparty/cutlass/include/cute/algorithm/axpby.hpp b/3rdparty/cutlass/include/cute/algorithm/axpby.hpp -new file mode 100644 -index 0000000..a613417 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/axpby.hpp -@@ -0,0 +1,79 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// Accept mutable temporaries -+// -+template -+CUTE_HOST_DEVICE -+void -+axpby(Alpha const& alpha, -+ Tensor const& x, -+ Beta const& beta, -+ Tensor && y) -+{ -+ return axpby(alpha, x, beta, y); -+} -+ -+// -+// AXPBY -+// -+template -+CUTE_HOST_DEVICE -+void -+axpby(Alpha const& alpha, -+ Tensor const& x, -+ Beta const& beta, -+ Tensor & y) -+{ -+ auto isBetaZero = (beta == Int<0>{}); -+ -+ CUTE_UNROLL -+ for (int i = 0; i < size(x); ++i) { -+ y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i)); -+ } -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/clear.hpp b/3rdparty/cutlass/include/cute/algorithm/clear.hpp -new file mode 100644 -index 0000000..ce7b510 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/clear.hpp -@@ -0,0 +1,66 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// Accept mutable temporaries -+// -+template -+CUTE_HOST_DEVICE -+void -+clear(Tensor&& tensor) -+{ -+ return clear(tensor); -+} -+ -+// -+// Set elements to zero -+// -+template -+CUTE_HOST_DEVICE -+void -+clear(Tensor& tensor) -+{ -+ using T = typename Tensor::value_type; -+ -+ fill(tensor, T{}); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/copy.hpp b/3rdparty/cutlass/include/cute/algorithm/copy.hpp -new file mode 100644 -index 0000000..04ceb05 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/copy.hpp -@@ -0,0 +1,262 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// Accept mutable temporaries -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_if(PrdTensor const& pred, -+ Tensor const& src, -+ Tensor && dst) -+{ -+ return copy_if(pred, src, dst); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_if(Copy_Atom const& copy_atom, -+ PrdTensor const& pred, -+ Tensor const& src, -+ Tensor && dst) -+{ -+ return copy_if(copy_atom, pred, src, dst); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_vec(Tensor const& src, -+ Tensor && dst) -+{ -+ return copy_vec(src, dst); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy(Tensor const& src, -+ Tensor && dst) -+{ -+ return copy(src, dst); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy(Copy_Atom const& copy_atom, -+ Tensor const& src, -+ Tensor && dst) -+{ -+ return copy(copy_atom, src, dst); -+} -+ -+// -+// copy_if -- Predicated Copy -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_if(PrdTensor const& pred, -+ Tensor const& src, -+ Tensor & dst) -+{ -+ auto copy_op = select_elementwise_copy(src, dst); -+ -+ CUTE_UNROLL -+ for (int i = 0; i < size(src); ++i) { -+ if (pred(i)) { -+ copy_op.copy(src(i), dst(i)); -+ } -+ } -+} -+ -+// -+// copy_if -- Predicated CopyAtom -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_if(Copy_Atom const& copy_atom, -+ PredTensor const& pred, // (Rest...) -+ Tensor const& src, // (V,Rest...) -+ Tensor & dst) // (V,Rest...) -+{ -+ static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); -+ if constexpr (SrcLayout::rank == 1) { // Dispatch the copy -+ copy_atom.call(src, dst); -+ } else { // Loop over all but the first mode -+ constexpr int R = SrcLayout::rank; -+ auto src_v = group_modes<1,R>(src); -+ auto dst_v = group_modes<1,R>(dst); -+ CUTE_UNROLL -+ for (int i = 0; i < size<1>(src_v); ++i) { -+ if (pred(i)) { -+ copy_atom.call(src_v(_,i), dst_v(_,i)); -+ } -+ } -+ } -+} -+ -+// -+// copy_vec -- attempt vectorized copy with VecType -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_vec(Tensor const& src, -+ Tensor & dst) -+{ -+ using SrcType = typename SrcEngine::value_type; -+ using DstType = typename DstEngine::value_type; -+ if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType)) -+ { -+ /* @pre is_aligned(src.data()) && -+ * is_aligned(dst.data()) -+ */ -+ auto src_v = recast(src); -+ auto dst_v = recast(dst); -+ -+#if 0 -+ if (thread0()) { -+ print("copy_vec -- vectorizing copy from %3db to %3db\n", int(8*sizeof(SrcType)), int(8*sizeof(VecType))); -+ print(" "); print(layout(src)); print(" => "); print(layout(src_v)); print("\n"); -+ print(" "); print(layout(dst)); print(" => "); print(layout(dst_v)); print("\n"); -+ } -+#endif -+ -+ return copy_if(TrivialPredTensor{}, src_v, dst_v); -+ } else { -+#if 0 -+ if (thread0()) { -+ print("copy_vec -- not vectorizing, copy with %3db and %3db\n", int(8*sizeof(SrcType)), int(8*sizeof(DstType))); -+ print(" "); print(layout(src)); print("\n"); -+ print(" "); print(layout(dst)); print("\n"); -+ } -+#endif -+ -+ return copy_if(TrivialPredTensor{}, src, dst); -+ } -+} -+ -+// -+// copy -- auto-vectorizing copy -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy(Tensor const& src, -+ Tensor & dst) -+{ -+ constexpr int N = decltype(max_common_vector(src, dst))::value; -+ -+#if 0 -+ if (thread0()) { -+ print("copy -- found a max_common_vector of %d\n", N); -+ print(" "); print(src.data()); print(" o "); print(layout(src)); print("\n"); -+ print(" "); print(dst.data()); print(" o "); print(layout(dst)); print("\n"); -+ } -+#endif -+ -+ if constexpr (N <= 1) { -+ return copy_if(TrivialPredTensor{}, src, dst); -+ } else { -+ constexpr int vec_bits = N * sizeof_bits::value; -+ using VecType = uint_bit_t; -+ return copy_vec(src, dst); -+ } -+} -+ -+// -+// copy -- CopyAtom -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy(Copy_Atom const& copy_atom, -+ Tensor const& src, -+ Tensor & dst) -+{ -+ return copy_if(copy_atom, TrivialPredTensor{}, src, dst); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy(Copy_Atom const&, -+ Tensor const& src, -+ Tensor & dst) -+{ -+ return copy(src, dst); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/fill.hpp b/3rdparty/cutlass/include/cute/algorithm/fill.hpp -new file mode 100644 -index 0000000..bc0c4ad ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/fill.hpp -@@ -0,0 +1,87 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+// -+// Accept mutable temporaries -+// -+template -+CUTE_HOST_DEVICE -+void -+fill(Tensor&& tensor, T const& value) -+{ -+ return fill(tensor, value); -+} -+ -+namespace detail -+{ -+ -+// Prefer fill(tensor.data(), value), if possible -+template -+CUTE_HOST_DEVICE -+auto -+fill(Tensor& tensor, T const& value, prefer<1>) -+ -> decltype(fill(tensor.data(), value)) -+{ -+ fill(tensor.data(), value); -+} -+ -+// Default implementation -+template -+CUTE_HOST_DEVICE -+void -+fill(Tensor& tensor, T const& value, prefer<0>) -+{ -+ CUTE_UNROLL -+ for (int i = 0; i < size(tensor); ++i) { -+ tensor(i) = value; -+ } -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE -+void -+fill(Tensor& tensor, T const& value) -+{ -+ return detail::fill(tensor, value, prefer<1>{}); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/functional.hpp b/3rdparty/cutlass/include/cute/algorithm/functional.hpp -new file mode 100644 -index 0000000..e66cd97 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/functional.hpp -@@ -0,0 +1,198 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+/** C++14 extensions */ -+ -+namespace cute { -+ -+/**************/ -+/** Identity **/ -+/**************/ -+ -+struct identity { -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) operator()(T&& arg) const { -+ return std::forward(arg); -+ } -+}; -+ -+template -+struct constant_fn { -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) operator()(T&&...) const { -+ return r_; -+ } -+ R r_; -+}; -+ -+/***********/ -+/** Unary **/ -+/***********/ -+ -+#define CUTE_LEFT_UNARY_OP(NAME,OP) \ -+ struct NAME { \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ decltype(auto) operator()(T&& arg) const { \ -+ return OP std::forward(arg); \ -+ } \ -+ } -+#define CUTE_RIGHT_UNARY_OP(NAME,OP) \ -+ struct NAME { \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ decltype(auto) operator()(T&& arg) const { \ -+ return std::forward(arg) OP ; \ -+ } \ -+ } -+#define CUTE_NAMED_UNARY_OP(NAME,OP) \ -+ struct NAME { \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ decltype(auto) operator()(T&& arg) const { \ -+ return OP (std::forward(arg)); \ -+ } \ -+ } -+ -+CUTE_LEFT_UNARY_OP(unary_plus, +); -+CUTE_LEFT_UNARY_OP(negate, -); -+CUTE_LEFT_UNARY_OP(bit_not, ~); -+CUTE_LEFT_UNARY_OP(logical_not, !); -+CUTE_LEFT_UNARY_OP(dereference, *); -+CUTE_LEFT_UNARY_OP(address_of, &); -+CUTE_LEFT_UNARY_OP(pre_increment, ++); -+CUTE_LEFT_UNARY_OP(pre_decrement, --); -+ -+CUTE_RIGHT_UNARY_OP(post_increment, ++); -+CUTE_RIGHT_UNARY_OP(post_decrement, --); -+ -+CUTE_NAMED_UNARY_OP(abs_fn, abs); -+CUTE_NAMED_UNARY_OP(conjugate, cute::conj); -+ -+#undef CUTE_LEFT_UNARY_OP -+#undef CUTE_RIGHT_UNARY_OP -+#undef CUTE_NAMED_UNARY_OP -+ -+/************/ -+/** Binary **/ -+/************/ -+ -+#define CUTE_BINARY_OP(NAME,OP) \ -+ struct NAME { \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ decltype(auto) operator()(T&& lhs, U&& rhs) const { \ -+ return std::forward(lhs) OP std::forward(rhs); \ -+ } \ -+ } -+#define CUTE_NAMED_BINARY_OP(NAME,OP) \ -+ struct NAME { \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ decltype(auto) operator()(T&& lhs, U&& rhs) const { \ -+ return OP (std::forward(lhs), std::forward(rhs)); \ -+ } \ -+ } -+ -+ -+CUTE_BINARY_OP(plus, +); -+CUTE_BINARY_OP(minus, -); -+CUTE_BINARY_OP(multiplies, *); -+CUTE_BINARY_OP(divides, /); -+CUTE_BINARY_OP(modulus, %); -+ -+CUTE_BINARY_OP(plus_assign, +=); -+CUTE_BINARY_OP(minus_assign, -=); -+CUTE_BINARY_OP(multiplies_assign, *=); -+CUTE_BINARY_OP(divides_assign, /=); -+CUTE_BINARY_OP(modulus_assign, %=); -+ -+CUTE_BINARY_OP(bit_and, &); -+CUTE_BINARY_OP(bit_or, |); -+CUTE_BINARY_OP(bit_xor, ^); -+CUTE_BINARY_OP(left_shift, <<); -+CUTE_BINARY_OP(right_shift, >>); -+ -+CUTE_BINARY_OP(bit_and_assign, &=); -+CUTE_BINARY_OP(bit_or_assign, |=); -+CUTE_BINARY_OP(bit_xor_assign, ^=); -+CUTE_BINARY_OP(left_shift_assign, <<=); -+CUTE_BINARY_OP(right_shift_assign, >>=); -+ -+CUTE_BINARY_OP(logical_and, &&); -+CUTE_BINARY_OP(logical_or, ||); -+ -+CUTE_BINARY_OP(equal_to, ==); -+CUTE_BINARY_OP(not_equal_to, !=); -+CUTE_BINARY_OP(greater, >); -+CUTE_BINARY_OP(less, <); -+CUTE_BINARY_OP(greater_equal, >=); -+CUTE_BINARY_OP(less_equal, <=); -+ -+CUTE_NAMED_BINARY_OP(max_fn, cute::max); -+CUTE_NAMED_BINARY_OP(min_fn, cute::min); -+ -+#undef CUTE_BINARY_OP -+#undef CUTE_NAMED_BINARY_OP -+ -+/**********/ -+/** Meta **/ -+/**********/ -+ -+template -+struct bound_fn { -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator()(T&& arg) { -+ return fn_(arg_, std::forward(arg)); -+ } -+ -+ Fn fn_; -+ Arg arg_; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+bind(Fn const& fn, Arg const& arg) { -+ return bound_fn{fn, arg}; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/gemm.hpp b/3rdparty/cutlass/include/cute/algorithm/gemm.hpp -new file mode 100644 -index 0000000..6e2ce61 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/gemm.hpp -@@ -0,0 +1,718 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+ -+/** The gemm algorithm takes four (or three) tensors and computes -+ * D += A * B + C -+ * It dispatches based on the number of modes each tensor has: -+ * -+ * 1. `(V) x (V) => (V)`. -+ * The element-wise product of vectors. Dispatches to FMA or MMA. -+ * 2. `(M) x (N) => (M,N)`. -+ * The outer product of vectors. Dispatches to [3] with new mode K=(1). -+ * 3. `(M,K) x (N,K) => (M,N)`. -+ * The product of matrices. Dispatches to [5] with MMA vector-mode V. -+ * 4. `(V,M) x (V,N) => (V,M,N)`. -+ * The batched outer product of vectors. Accounts for register reuse and dispatches to [1] for each (m,n). -+ * 5. `(V,M,K) x (V,N,K) => (V,M,N)`. -+ * The batched product of matrices. Dispatches to [4] for each (k). -+ */ -+ -+namespace cute -+{ -+ -+// -+// Three arguments to four -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(Tensor const& A, -+ Tensor const& B, -+ Tensor & C) -+{ -+ return gemm(C, A, B, C); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor & C) -+{ -+ return gemm(mma, C, A, B, C); -+} -+ -+// -+// Accept mutable temporaries -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(Tensor const& A, -+ Tensor const& B, -+ Tensor && C) -+{ -+ return gemm(C, A, B, C); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(Tensor && D, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor const& C) -+{ -+ return gemm(D, A, B, C); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor && C) -+{ -+ return gemm(mma, C, A, B, C); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor && D, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor const& C) -+{ -+ return gemm(mma, D, A, B, C); -+} -+ -+// -+// Default MMA is UniversalFMA -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(Tensor & D, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor const& C) -+{ -+ using MMA = MMA_Atom::value_type, -+ typename Tensor::value_type, -+ typename Tensor::value_type, -+ typename Tensor::value_type>>; -+ -+ return gemm(MMA{}, D, A, B, C); -+} -+ -+// -+// Thread-Local Register-Memory GEMMs -+// -+ -+// Dispatch [1]: (V) x (V) => (V) -+template ::value && -+ ALayout::rank == 1 && is_rmem::value && -+ BLayout::rank == 1 && is_rmem::value && -+ CLayout::rank == 1 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (V) Logical data -+ Tensor const& A, // (V) Logical data -+ Tensor const& B, // (V) Logical data -+ Tensor const& C) // (V) Logical data -+{ -+ // No static assertions on (V), MMA checks compatibility -+ mma.call(D, A, B, C); -+} -+ -+// Dispatch [2]: (M) x (N) => (M,N) -+template ::value && -+ ALayout::rank == 1 && is_rmem::value && -+ BLayout::rank == 1 && is_rmem::value && -+ CLayout::rank == 2 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (M,N) Logical data -+ Tensor const& A, // (M) Logical data -+ Tensor const& B, // (N) Logical data -+ Tensor const& C) // (M,N) Logical data -+{ -+ CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); -+ -+ gemm(mma, -+ D, // (M,N) -+ make_tensor(A.data(), append<2>(A.layout())), // (M,1) -+ make_tensor(B.data(), append<2>(B.layout())), // (N,1) -+ C); // (M,N) -+} -+ -+// Dispatch [3]: (M,K) x (N,K) => (M,N) -+template ::value && -+ ALayout::rank == 2 && is_rmem::value && -+ BLayout::rank == 2 && is_rmem::value && -+ CLayout::rank == 2 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (M,N) Logical data -+ Tensor const& A, // (M,K) Logical data -+ Tensor const& B, // (N,K) Logical data -+ Tensor const& C) // (M,N) Logical data -+{ -+ CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK -+ CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); -+ -+ // Assert this is a 1-value MMA -+ CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); -+ -+ gemm(mma, -+ make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) -+ make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) -+ make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) -+ make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) -+} -+ -+// Dispatch [4]: (V,M) x (V,N) => (V,M,N) -+template ::value && -+ ALayout::rank == 2 && is_rmem::value && -+ BLayout::rank == 2 && is_rmem::value && -+ CLayout::rank == 3 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (V,M,N) Logical data -+ Tensor const& A, // (V,M) Logical data -+ Tensor const& B, // (V,N) Logical data -+ Tensor const& C) // (V,M,N) Logical data -+{ -+ CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); -+ -+ // REGISTER .reuse OPTIMIZATIONS -+ -+ auto M = size<1>(A); -+ auto N = size<1>(B); -+ -+ // 64-bit traversal specialization -- serpentine path -+ if (size<0>(A) * sizeof(typename Tensor::value_type) == 8 && -+ size<0>(B) * sizeof(typename Tensor::value_type) == 8) -+ { -+#if 1 // NOTE: Must depend on the C-matrix order... (which we can test) -+ // Row-major iteration -+ CUTE_UNROLL -+ for (int m = 0; m < M; ++m) { -+ CUTE_UNROLL -+ for (int n = 0; n < N; ++n) { -+ int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate -+ gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns)); -+ } -+ } -+#else -+ // Col-major iteration -+ CUTE_UNROLL -+ for (int n = 0; n < N; ++n) { -+ CUTE_UNROLL -+ for (int m = 0; m < M; ++m) { -+ int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate -+ gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); -+ } -+ } -+#endif -+ } else -+ -+ // 32-bit traversal specialization -- kinked serpentine path -+ if (size<0>(A) * sizeof(typename Tensor::value_type) == 4 && -+ size<0>(B) * sizeof(typename Tensor::value_type) == 4) -+ { -+#if 1 // NOTE: Must depend on the C-matrix order... (which we can test) -+ // Row-major iteration -+ CUTE_UNROLL -+ for (int m = 0; m < M; m += 2) { -+ CUTE_UNROLL -+ for (int n = 0; n < N; ++n) { -+ int ns = (m & 2) ? N-1-n : n; -+ gemm(mma, D(_,m+0,ns), A(_,m+0), B(_,ns), C(_,m+0,ns)); -+ -+ if (m+1 < M) { -+ gemm(mma, D(_,m+1,ns), A(_,m+1), B(_,ns), C(_,m+1,ns)); -+ } -+ } -+ } -+#else -+ // Col-major iteration -+ CUTE_UNROLL -+ for (int n = 0; n < N; n += 2) { -+ CUTE_UNROLL -+ for (int m = 0; m < M; ++m) { -+ // Kinked serpentine traversal for maximum register reuse -+ int ms = (n & 2) ? M-1-m : m; -+ gemm(mma, D(_,ms,n+0), A(_,ms), B(_,n+0), C(_,ms,n+0)); -+ -+ if (n+1 < N) { -+ gemm(mma, D(_,ms,n+1), A(_,ms), B(_,n+1), C(_,ms,n+1)); -+ } -+ } -+ } -+#endif -+ } else { -+ // Fallback to serpentine loop -+ // Col-major iteration -+ CUTE_UNROLL -+ for (int n = 0; n < N; ++n) { -+ CUTE_UNROLL -+ for (int m = 0; m < M; ++m) { -+ int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate -+ gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); -+ } -+ } -+ } -+} -+ -+// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) -+template ::value && -+ ALayout::rank == 3 && is_rmem::value && -+ BLayout::rank == 3 && is_rmem::value && -+ CLayout::rank == 3 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (V,M,N) Logical data -+ Tensor const& A, // (V,M,K) Logical data -+ Tensor const& B, // (V,N,K) Logical data -+ Tensor const& C) // (V,M,N) Logical data -+{ -+ CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK -+ CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); -+ -+ auto K = size<2>(A); -+ -+ CUTE_UNROLL -+ for (int k = 0; k < K; ++k) { -+ gemm(mma, D, A(_,_,k), B(_,_,k), C); -+ } -+} -+ -+// -+// Thread-Local Shared-Memory GEMMs -+// -+ -+// Dispatch [1]: (V) x (V) => (V) -+// Dispatch [2]: (M) x (N) => (M,N) -+// Dispatch [3]: (M,K) x (N,K) => (M,N) -+// Dispatch [4]: (V,M) x (V,N) => (V,M,N) -+// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) -+// Dispatch [3]: (M,K) x (N,K) => (M,N) -+template ::value && -+ ALayout::rank == 2 && is_smem::value && -+ BLayout::rank == 2 && is_smem::value && -+ CLayout::rank == 2 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (M,N) Logical data -+ Tensor const& A, // (M,K) Logical data -+ Tensor const& B, // (N,K) Logical data -+ Tensor const& C) // (M,N) Logical data -+{ -+ CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK -+ CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); -+ -+ // Assert this is a 1-value MMA -+ CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); -+ -+ gemm(mma, -+ make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) -+ make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) -+ make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) -+ make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) -+} -+ -+// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) -+template ::value && -+ ALayout::rank == 3 && is_smem::value && -+ BLayout::rank == 3 && is_smem::value && -+ CLayout::rank == 3 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (V,M,N) Logical data -+ Tensor const& A, // (V,M,K) Logical data -+ Tensor const& B, // (V,N,K) Logical data -+ Tensor const& C) // (V,M,N) Logical data -+{ -+ CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK -+ CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); -+ -+ auto rA = MMA_Atom::make_fragment_A(A); -+ auto rB = MMA_Atom::make_fragment_B(B); -+ -+ auto K = size<2>(A); -+ -+ CUTE_UNROLL -+ for (int k = 0; k < K; ++k) -+ { -+ copy(A(_,_,k), rA(_,_,k)); -+ copy(B(_,_,k), rB(_,_,k)); -+ // Thread-level register gemm for k -+ gemm(mma, D, rA(_,_,k), rB(_,_,k), C); -+ } -+} -+ -+// -+// Collective Shared-Memory GEMMs -+// -+ -+template ::value && -+ BLayout::rank == 2 && is_smem::value && -+ CLayout::rank == 2 && is_smem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(ThrMMA const& thr_mma, -+ Alpha const& alpha, -+ Tensor sA, -+ Tensor sB, -+ Beta const& beta, -+ Tensor sC, -+ ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */, -+ BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */) -+{ -+ CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK -+ -+ using TypeA = typename TA::value_type; -+ using TypeB = typename TB::value_type; -+ using TypeC = typename TC::value_type; -+ -+ static_assert(std::is_same_v>, TypeA>, -+ "ALoadTransformOp functor must accept and return value of type TA::value_type"); -+ static_assert(std::is_same_v>, TypeB>, -+ "BLoadTransformOp functor must accept and return value of type TB::value_type"); -+ -+ // Original, static size of the problem -+ auto M = size<0>(sC); -+ auto N = size<1>(sC); -+ auto K = size<1>(sA); -+ -+ // Block size of the compute tile -+ auto BLK_M = tile_size<0>(thr_mma); -+ auto BLK_N = tile_size<1>(thr_mma); -+ auto BLK_K = tile_size<2>(thr_mma); -+ -+ // Compute the "residues" -+ auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M] -+ auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N] -+ auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0] -+ -+ // Shift the origin so k_residue is zeroth tile -+ sA.data() = &sA(0,k_residue); -+ sB.data() = &sB(0,k_residue); -+ -+#if 0 -+ if (thread0()) { -+ printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M)); -+ printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N)); -+ printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K)); -+ } -+#endif -+ -+ // -+ // MMA Partitioning -+ // -+ -+ // Round the layout extents up to BLK_X -+ Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K)); -+ Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K)); -+ Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N)); -+ -+#if 0 -+ if (thread0()) { -+ print(rounded_sA.layout()); print("\n"); -+ print(rounded_sB.layout()); print("\n"); -+ print(rounded_sC.layout()); print("\n"); -+ } -+#endif -+ -+ // Partition the sA and sB tiles across the threads for the MMA -+ Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K) -+ Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K) -+ Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N) -+ // Create register tensors for the MMA to operate on -+ Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) -+ Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) -+ Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) -+ -+#if 0 -+ if (thread0()) { -+ print(tCsA.layout()); print("\n"); -+ print(tCsB.layout()); print("\n"); -+ print(tCsC.layout()); print("\n"); -+ print(tCrA.layout()); print("\n"); -+ print(tCrB.layout()); print("\n"); -+ print(tCrC.layout()); print("\n"); -+ } -+#endif -+ -+ // -+ // PREDICATION -+ // -+ -+ // Allocate the preds for only the MMA-mode of tCsA and tCsB -+ Tensor tCpA = make_tensor(size<0>(tCsA)); -+ Tensor tCpB = make_tensor(size<0>(tCsB)); -+ -+ // Create coordinate tensors on a single compute block for predication -+ Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k) -+ Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k) -+ -+ // Repeat partitioning with thr_mma -+ Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k) -+ Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k) -+ -+ // Populate the m and n predicates -+ CUTE_UNROLL -+ for (int i = 0; i < size(tCpA); ++i) { -+ tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue); -+ } -+ CUTE_UNROLL -+ for (int i = 0; i < size(tCpB); ++i) { -+ tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue); -+ } -+ -+#if 0 -+ printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n", -+ threadIdx.x, -+ int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)), -+ int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0))); -+#endif -+ -+ // -+ // PREFETCH k_block = 0 (with k-predication) -+ // -+ -+ CUTE_UNROLL -+ for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I -+ if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k -+ CUTE_UNROLL -+ for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m -+ tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{}; -+ } -+ } -+ } -+ -+ CUTE_UNROLL -+ for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I -+ if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k -+ CUTE_UNROLL -+ for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n -+ tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{}; -+ } -+ } -+ } -+ // -+ // MAINLOOP -+ // -+ -+ // Clear accumulators -+ clear(tCrC); -+ -+ constexpr int K_BLOCK_MAX = size<2>(tCrA); -+ -+ CUTE_UNROLL -+ for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) -+ { -+ // static-if load the next k_block. No k-predication required on these loads. -+ if (k_block < K_BLOCK_MAX-1) -+ { -+ // Load the next k_block -+ int k_next = k_block + 1; -+ -+ CUTE_UNROLL -+ for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M -+ CUTE_UNROLL -+ for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m -+ tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{}; -+ } -+ } -+ -+ CUTE_UNROLL -+ for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N -+ CUTE_UNROLL -+ for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n -+ tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{}; -+ } -+ } -+ } -+ -+ // GEMM on k_block in registers -+ gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); -+ } -+ -+ // -+ // Epilogue -+ // -+ -+ Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n) -+ Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n) -+ -+ const bool isBetaZero = (beta == Beta{}); -+ -+ // Custom axpby_if for now -+ CUTE_UNROLL -+ for (int m = 0; m < size<1>(tCsC); ++m) -+ { -+ CUTE_UNROLL -+ for (int n = 0; n < size<2>(tCsC); ++n) -+ { -+ CUTE_UNROLL -+ for (int i = 0; i < size<0>(tCsC); ++i) -+ { -+ if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) && -+ (n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue)) -+ { -+ tCsC(i,m,n) = isBetaZero ? alpha * tCrC(i,m,n) : alpha * tCrC(i,m,n) + beta * tCsC(i,m,n); -+ } -+ } -+ } -+ } -+} -+ -+template ::value && -+ BLayout::rank == 2 && is_smem::value && -+ CLayout::rank == 2 && is_smem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(ThrMMA const& thr_mma, -+ Alpha const& alpha, -+ Tensor sA, -+ Tensor sB, -+ Beta const& beta, -+ Tensor sC) -+{ -+ gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/prefer.hpp b/3rdparty/cutlass/include/cute/algorithm/prefer.hpp -new file mode 100644 -index 0000000..700edff ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/prefer.hpp -@@ -0,0 +1,46 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+namespace cute -+{ -+ -+// Infinite types that inherit from each other -+template -+struct prefer : prefer {}; -+ -+template <> -+struct prefer<0> {}; -+ -+// Can be used to preferencially overload implementations -+// Higher N in prefer have higher priority. -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/tensor_algorithms.hpp b/3rdparty/cutlass/include/cute/algorithm/tensor_algorithms.hpp -new file mode 100644 -index 0000000..258ddec ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/tensor_algorithms.hpp -@@ -0,0 +1,102 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/** Common algorithms on (hierarchical) tensors */ -+ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// for_each -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+for_each(Tensor const& tensor, UnaryOp&& op) -+{ -+ CUTE_UNROLL -+ for (int i = 0; i < size(tensor); ++i) { -+ static_cast(op)(tensor(i)); -+ } -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+for_each(Tensor& tensor, UnaryOp&& op) -+{ -+ CUTE_UNROLL -+ for (int i = 0; i < size(tensor); ++i) { -+ static_cast(op)(tensor(i)); -+ } -+} -+ -+// Accept mutable temporaries -+template -+CUTE_HOST_DEVICE constexpr -+void -+for_each(Tensor&& tensor, UnaryOp&& op) -+{ -+ return for_each(tensor, static_cast(op)); -+} -+ -+// -+// transform -+// -+ -+// Similar to std::transform but does not return number of elements affected -+template -+CUTE_HOST_DEVICE constexpr -+void -+transform(Tensor& tensor, UnaryOp&& op) -+{ -+ CUTE_UNROLL -+ for (int i = 0; i < size(tensor); ++i) { -+ tensor(i) = static_cast(op)(tensor(i)); -+ } -+} -+ -+// Accept mutable temporaries -+template -+CUTE_HOST_DEVICE constexpr -+void -+transform(Tensor&& tensor, UnaryOp&& op) -+{ -+ return transform(tensor, std::forward(op)); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/tuple_algorithms.hpp b/3rdparty/cutlass/include/cute/algorithm/tuple_algorithms.hpp -new file mode 100644 -index 0000000..35b19f9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/tuple_algorithms.hpp -@@ -0,0 +1,846 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+#include -+ -+/** Common algorithms on (hierarchical) tuples */ -+/** Style choice: -+ * Forward params [using static_cast(.)] for const/non-const/ref/non-ref args -+ * but don't bother forwarding functions as ref-qualified member fns are extremely rare -+ */ -+ -+namespace cute -+{ -+ -+// -+// Apply (Unpack) -+// (t, f) => f(t_0,t_1,...,t_n) -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+apply(T&& t, F&& f, seq) -+{ -+ return f(get(static_cast(t))...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+apply(T&& t, F&& f) -+{ -+ return detail::apply(static_cast(t), f, tuple_seq{}); -+} -+ -+// -+// Transform Apply -+// (t, f, g) => g(f(t_0),f(t_1),...) -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tapply(T&& t, F&& f, G&& g, seq) -+{ -+ return g(f(get(static_cast(t)))...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tapply(T0&& t0, T1&& t1, F&& f, G&& g, seq) -+{ -+ return g(f(get(static_cast(t0)), -+ get(static_cast(t1)))...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tapply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g, seq) -+{ -+ return g(f(get(static_cast(t0)), -+ get(static_cast(t1)), -+ get(static_cast(t2)))...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_apply(T&& t, F&& f, G&& g) -+{ -+ return detail::tapply(static_cast(t), f, g, tuple_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_apply(T0&& t0, T1&& t1, F&& f, G&& g) -+{ -+ return detail::tapply(static_cast(t0), static_cast(t1), f, g, tuple_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g) -+{ -+ return detail::tapply(static_cast(t0), static_cast(t1), static_cast(t2), f, g, tuple_seq{}); -+} -+ -+// -+// For Each -+// (t, f) => f(t_0),f(t_1),...,f(t_n) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+for_each(T&& t, F&& f) -+{ -+ detail::apply(t, [&](auto&&... a) { (f(static_cast(a)), ...); }, tuple_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+for_each_leaf(T&& t, F&& f) -+{ -+ if constexpr (is_tuple>::value) { -+ return detail::apply(static_cast(t), [&](auto&&... a){ return (for_each_leaf(static_cast(a), f), ...); }, tuple_seq{}); -+ } else { -+ return f(static_cast(t)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Transform -+// (t, f) => (f(t_0),f(t_1),...,f(t_n)) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform(T const& t, F&& f) -+{ -+ return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform(T0 const& t0, T1 const& t1, F&& f) -+{ -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); -+ return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) -+{ -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); -+ return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_leaf(T const& t, F&& f) -+{ -+ if constexpr (is_tuple::value) { -+ return transform(t, [&](auto const& a) { return transform_leaf(a, f); }); -+ } else { -+ return f(t); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// find and find_if -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+find_if(T const& t, F&& f, seq<>) -+{ -+ return cute::integral_constant::value>{}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+find_if(T const& t, F&& f, seq) -+{ -+ if constexpr (decltype(f(get(t)))::value) { -+ return cute::integral_constant{}; -+ } else { -+ return find_if(t, f, seq{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+find_if(T const& t, F&& f) -+{ -+ if constexpr (is_tuple::value) { -+ return detail::find_if(t, f, tuple_seq{}); -+ } else { -+ return cute::integral_constant{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+find(T const& t, X const& x) -+{ -+ return find_if(t, [&](auto const& v) { return v == x; }); // This should always return a static true/false -+} -+ -+template -+auto -+none_of(T const& t, F&& f) -+{ -+ return cute::integral_constant::value>{}; -+} -+ -+template -+auto -+all_of(T const& t, F&& f) -+{ -+ auto not_f = [&](auto const& a) { return !f(a); }; -+ return cute::integral_constant::value>{}; -+} -+ -+template -+auto -+any_of(T const& t, F&& f) -+{ -+ return cute::integral_constant{}; -+} -+ -+// -+// Filter -+// (t, f) => -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter_tuple(T const& t, F&& f) -+{ -+ return transform_apply(t, f, [](auto const&... a) { return cute::tuple_cat(a...); }); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter_tuple(T0 const& t0, T1 const& t1, F&& f) -+{ -+ return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); }); -+} -+ -+// -+// Fold (Reduce, Accumulate) -+// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n) -+// -+ -+namespace detail { -+ -+// This impl compiles much faster than cute::apply and variadic args -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+fold(T&& t, V&& v, F&& f, seq<>) -+{ -+ return static_cast(v); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+fold(T&& t, V&& v, F&& f, seq) -+{ -+ if constexpr (sizeof...(Is) == 0) { -+ return f(static_cast(v), get(static_cast(t))); -+ } else { -+ return fold(static_cast(t), -+ f(static_cast(v), get(static_cast(t))), -+ f, -+ seq{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+fold(T&& t, V&& v, F&& f) -+{ -+ if constexpr (is_tuple>::value) { -+ return detail::fold(static_cast(t), -+ static_cast(v), -+ f, -+ tuple_seq{}); -+ } else { -+ return f(static_cast(v), static_cast(t)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+fold_first(T&& t, F&& f) -+{ -+ if constexpr (is_tuple>::value) { -+ return detail::fold(static_cast(t), -+ get<0>(static_cast(t)), -+ f, -+ make_range<1,std::tuple_size>::value>{}); -+ } else { -+ return static_cast(t); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// front, back, take, unwrap -+// -+ -+// Get the first non-tuple element in a hierarchical tuple -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+front(T&& t) -+{ -+ if constexpr (is_tuple>::value) { -+ return front(get<0>(static_cast(t))); -+ } else { -+ return static_cast(t); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Get the last non-tuple element in a hierarchical tuple -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+back(T&& t) -+{ -+ if constexpr (is_tuple>::value) { -+ constexpr int N = tuple_size>::value; -+ return back(get(static_cast(t))); -+ } else { -+ return static_cast(t); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Takes the elements in the range [B,E) -+template -+CUTE_HOST_DEVICE constexpr -+auto -+take(T const& t) -+{ -+ return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); -+} -+ -+// Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple -+template -+CUTE_HOST_DEVICE constexpr -+auto -+unwrap(T const& t) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (tuple_size::value == 1) { -+ return unwrap(get<0>(t)); -+ } else { -+ return t; -+ } -+ } else { -+ return t; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Flatten a hierarchical tuple to a tuple of depth one. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+flatten_to_tuple(T const& t) -+{ -+ if constexpr (is_tuple::value) { -+ return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); -+ } else { -+ return cute::make_tuple(t); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+flatten(T const& t) -+{ -+ if constexpr (is_tuple::value) { -+ return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); -+ } else { -+ return t; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// insert and remove and replace -+// -+ -+namespace detail { -+ -+// Shortcut around tuple_cat for common insert/remove/repeat cases -+template -+CUTE_HOST_DEVICE constexpr -+auto -+construct(T const& t, X const& x, seq, seq, seq) -+{ -+ return cute::make_tuple(get(t)..., (void(J),x)..., get(t)...); -+} -+ -+} // end namespace detail -+ -+// Insert x into the Nth position of the tuple -+template -+CUTE_HOST_DEVICE constexpr -+auto -+insert(T const& t, X const& x) -+{ -+ return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); -+} -+ -+// Remove the Nth element of the tuple -+template -+CUTE_HOST_DEVICE constexpr -+auto -+remove(T const& t) -+{ -+ return detail::construct(t, 0, make_seq{}, seq<>{}, make_range::value>{}); -+} -+ -+// Replace the Nth element of the tuple with x -+template -+CUTE_HOST_DEVICE constexpr -+auto -+replace(T const& t, X const& x) -+{ -+ return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); -+} -+ -+// Replace the first element of the tuple with x -+template -+CUTE_HOST_DEVICE constexpr -+auto -+replace_front(T const& t, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ return detail::construct(t, x, seq<>{}, seq<0>{}, make_range<1,tuple_size::value>{}); -+ } else { -+ return x; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Replace the last element of the tuple with x -+template -+CUTE_HOST_DEVICE constexpr -+auto -+replace_back(T const& t, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ return detail::construct(t, x, make_seq::value-1>{}, seq<0>{}, seq<>{}); -+ } else { -+ return x; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Make a tuple of Xs of tuple_size N -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+repeat(X const& x) -+{ -+ return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); -+} -+ -+// -+// Make a tuple of Xs the same profile as tuple -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+repeat_like(T const& t, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ return transform(t, [&](auto const& a) { return repeat_like(a,x); }); -+ } else { -+ return x; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Group the elements [B,E) of a T into a single element -+// e.g. group<2,4>(T<_1,_2,_3,_4,_5,_6>{}) -+// => T<_1,_2,T<_3,_4>,_5,_6>{} -+template -+CUTE_HOST_DEVICE constexpr -+auto -+group(T const& t) -+{ -+ return detail::construct(t, take(t), make_seq{}, seq<0>{}, make_range::value>{}); -+} -+ -+// -+// Extend a T to rank N by appending/prepending an element -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+append(T const& a, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (N == tuple_size::value) { -+ return a; -+ } else { -+ static_assert(N > tuple_size::value); -+ return detail::construct(a, x, make_seq::value>{}, make_seq::value>{}, seq<>{}); -+ } -+ } else { -+ if constexpr (N == 1) { -+ return a; -+ } else { -+ return detail::construct(cute::make_tuple(a), x, seq<0>{}, make_seq{}, seq<>{}); -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+template -+CUTE_HOST_DEVICE constexpr -+auto -+append(T const& a, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ return detail::construct(a, x, make_seq::value>{}, seq<0>{}, seq<>{}); -+ } else { -+ return cute::make_tuple(a, x); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+prepend(T const& a, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (N == tuple_size::value) { -+ return a; -+ } else { -+ static_assert(N > tuple_size::value); -+ return detail::construct(a, x, seq<>{}, make_seq::value>{}, make_seq::value>{}); -+ } -+ } else { -+ if constexpr (N == 1) { -+ return a; -+ } else { -+ static_assert(N > 1); -+ return detail::construct(cute::make_tuple(a), x, seq<>{}, make_seq{}, seq<0>{}); -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+template -+CUTE_HOST_DEVICE constexpr -+auto -+prepend(T const& a, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ return detail::construct(a, x, seq<>{}, seq<0>{}, make_seq::value>{}); -+ } else { -+ return cute::make_tuple(x, a); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Inclusive scan (prefix sum) -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+iscan(T const& t, V const& v, F&& f, seq) -+{ -+ // Apply the function to v and the element at I -+ auto v_next = f(v, get(t)); -+ // Replace I with v_next -+ auto t_next = replace(t, v_next); -+ -+#if 0 -+ std::cout << "ISCAN i" << I << std::endl; -+ std::cout << " t " << t << std::endl; -+ std::cout << " i " << v << std::endl; -+ std::cout << " f(i,t) " << v_next << std::endl; -+ std::cout << " t_n " << t_next << std::endl; -+#endif -+ -+ if constexpr (sizeof...(Is) == 0) { -+ return t_next; -+ } else { -+ return iscan(t_next, v_next, f, seq{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+iscan(T const& t, V const& v, F&& f) -+{ -+ return detail::iscan(t, v, f, tuple_seq{}); -+} -+ -+// -+// Exclusive scan (prefix sum) -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+escan(T const& t, V const& v, F&& f, seq) -+{ -+ if constexpr (sizeof...(Is) == 0) { -+ // Replace I with v -+ return replace(t, v); -+ } else { -+ // Apply the function to v and the element at I -+ auto v_next = f(v, get(t)); -+ // Replace I with v -+ auto t_next = replace(t, v); -+ -+#if 0 -+ std::cout << "ESCAN i" << I << std::endl; -+ std::cout << " t " << t << std::endl; -+ std::cout << " i " << v << std::endl; -+ std::cout << " f(i,t) " << v_next << std::endl; -+ std::cout << " t_n " << t_next << std::endl; -+#endif -+ -+ // Recurse -+ return escan(t_next, v_next, f, seq{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+escan(T const& t, V const& v, F&& f) -+{ -+ return detail::escan(t, v, f, tuple_seq{}); -+} -+ -+// -+// Zip (Transpose) -+// -+ -+// Take ((a,b,c,...),(x,y,z,...),...) rank-R0 x rank-R1 input -+// to produce ((a,x,...),(b,y,...),(c,z,...),...) rank-R1 x rank-R0 output -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip_(T const& t, seq) -+{ -+ return cute::make_tuple(get(get(t))...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip(T const& t, seq, seq) -+{ -+ static_assert(conjunction>::value == tuple_size>::value>...>::value, "Mismatched Ranks"); -+ return cute::make_tuple(detail::zip_(t, seq{})...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip(T const& t) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple>::value) { -+ return detail::zip(t, tuple_seq{}, tuple_seq>{}); -+ } else { -+ return cute::make_tuple(t); -+ } -+ } else { -+ return t; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Convenient to pass them in separately -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip(T0 const& t0, T1 const& t1, Ts const&... ts) -+{ -+ return zip(cute::make_tuple(t0, t1, ts...)); -+} -+ -+// -+// zip2_by -- A guided zip for rank-2 tuples -+// Take a tuple like ((A,a),((B,b),(C,c)),d) -+// and produce a tuple ((A,(B,C)),(a,(b,c),d)) -+// where the rank-2 modes are selected by the terminals of the guide (X,(X,X)) -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip2_by(T const& t, TG const& guide, seq, seq) -+{ -+ // zip2_by produces the modes like ((A,a),(B,b),...) -+ auto split = cute::make_tuple(zip2_by(get(t), get(guide))...); -+ -+ // Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y)) -+ return cute::make_tuple(cute::make_tuple(get(split)...), -+ cute::make_tuple(get(split)..., get(t)...)); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip2_by(T const& t, TG const& guide) -+{ -+ if constexpr (is_tuple::value) { -+ constexpr int TR = tuple_size::value; -+ constexpr int GR = tuple_size::value; -+ static_assert(TR >= GR, "Mismatched ranks"); -+ return detail::zip2_by(t, guide, -+ make_range< 0, GR>{}, -+ make_range{}); -+ } else { -+ static_assert(tuple_size::value == 2, "Mismatched ranks"); -+ return t; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/cluster_sm90.hpp b/3rdparty/cutlass/include/cute/arch/cluster_sm90.hpp -new file mode 100644 -index 0000000..6fd9edd ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/cluster_sm90.hpp -@@ -0,0 +1,190 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ -+ ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))) -+# define CUTE_ARCH_CLUSTER_SM90_ENABLED -+#endif -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) -+# define CUTE_ARCH_ELECT_ONE_SM90_ENABLED -+#endif -+ -+namespace cute { -+ -+CUTE_DEVICE void cluster_arrive_relaxed() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : : ); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+} -+ -+CUTE_DEVICE void cluster_arrive() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ asm volatile("barrier.cluster.arrive.aligned;\n" : : ); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+} -+ -+CUTE_DEVICE void cluster_wait() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ asm volatile("barrier.cluster.wait.aligned;\n" : : ); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+} -+ -+CUTE_DEVICE void cluster_sync() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ cluster_arrive(); -+ cluster_wait(); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+} -+ -+// Returns the dim3 grid size in terms of number of clusters. -+CUTE_DEVICE dim3 cluster_grid_dims() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ uint32_t x, y, z; -+ asm volatile("mov.u32 %0, %nclusterid.x;\n" : "=r"(x) : ); -+ asm volatile("mov.u32 %0, %nclusterid.y;\n" : "=r"(y) : ); -+ asm volatile("mov.u32 %0, %nclusterid.z;\n" : "=r"(z) : ); -+ return {x, y, z}; -+#else -+ return gridDim; -+#endif -+} -+ -+// Returns the dim3 cluster rank in the grid. -+CUTE_DEVICE dim3 cluster_id_in_grid() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ uint32_t x, y, z; -+ asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(x) : ); -+ asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(y) : ); -+ asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(z) : ); -+ return {x, y, z}; -+#else -+ return blockIdx; -+#endif -+} -+ -+// Returns the relative dim3 block rank local to the cluster. -+CUTE_DEVICE dim3 block_id_in_cluster() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ uint32_t x, y, z; -+ asm volatile("mov.u32 %0, %cluster_ctaid.x;\n" : "=r"(x) : ); -+ asm volatile("mov.u32 %0, %cluster_ctaid.y;\n" : "=r"(y) : ); -+ asm volatile("mov.u32 %0, %cluster_ctaid.z;\n" : "=r"(z) : ); -+ return {x, y, z}; -+#else -+ return {0,0,0}; -+#endif -+} -+ -+// Returns the dim3 cluster shape. -+CUTE_DEVICE dim3 cluster_shape() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ uint32_t x, y, z; -+ asm volatile("mov.u32 %0, %cluster_nctaid.x;\n" : "=r"(x) : ); -+ asm volatile("mov.u32 %0, %cluster_nctaid.y;\n" : "=r"(y) : ); -+ asm volatile("mov.u32 %0, %cluster_nctaid.z;\n" : "=r"(z) : ); -+ return {x, y, z}; -+#else -+ return {1,1,1}; -+#endif -+} -+ -+// Get 1D ctaid in a cluster. -+CUTLASS_DEVICE uint32_t block_rank_in_cluster() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ uint32_t rank; -+ asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(rank) :); -+ return rank; -+#else -+ return 0; -+#endif -+} -+ -+// Set the destination block-ID in cluster for a given SMEM Address -+CUTLASS_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ uint32_t result; -+ asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" -+ : "=r"(result) -+ : "r"(smemAddr), "r"(rank)); -+ return result; -+#else -+ return smemAddr; -+#endif -+} -+ -+// Elect one thread in the warp. The elected thread gets its predicate set to true, all others obtain false. -+CUTE_HOST_DEVICE uint32_t elect_one_sync() -+{ -+#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) -+ uint32_t pred = 0; -+ uint32_t laneid = 0; -+ asm volatile( -+ "{\n" -+ ".reg .b32 %rx;\n" -+ ".reg .pred %px;\n" -+ " elect.sync %rx|%px, %2;\n" -+ "@%px mov.s32 %1, 1;\n" -+ " mov.s32 %0, %rx;\n" -+ "}\n" -+ : "+r"(laneid), "+r"(pred) -+ : "r"(0xFFFFFFFF)); -+ return pred; -+#elif defined(__CUDA_ARCH__) -+ return (threadIdx.x % 32) == 0; -+#else -+ return true; -+#endif -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/copy.hpp b/3rdparty/cutlass/include/cute/arch/copy.hpp -new file mode 100644 -index 0000000..aa7bb33 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/copy.hpp -@@ -0,0 +1,71 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+// -+// Direct Copy for any type -+// -+ -+template -+struct UniversalCopy -+{ -+ using SRegisters = S[1]; -+ using DRegisters = D[1]; -+ -+ CUTE_HOST_DEVICE static constexpr void -+ copy(S const& src, -+ D & dst) -+ { -+ dst = src; -+ } -+}; -+ -+// -+// Placeholder for the copy algorithm's default, auto-vectorizing behavior -+// -+ -+struct DefaultCopy -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint128_t[1]; -+}; -+ -+using AutoVectorizingCopy = DefaultCopy; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/copy_sm75.hpp b/3rdparty/cutlass/include/cute/arch/copy_sm75.hpp -new file mode 100644 -index 0000000..fda6340 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/copy_sm75.hpp -@@ -0,0 +1,215 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) -+# define CUTE_ARCH_LDSM_SM75_ENABLED -+#endif -+ -+namespace cute -+{ -+ -+struct SM75_U32x1_LDSM_N -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint32_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint128_t const& smem_src, -+ uint32_t& dst) -+ { -+#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); -+ asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" -+ : "=r"(dst) -+ : "r"(smem_int_ptr)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM75_U32x2_LDSM_N -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint128_t const& smem_src, -+ uint32_t& dst0, uint32_t& dst1) -+ { -+#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); -+ asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" -+ : "=r"(dst0), "=r"(dst1) -+ : "r"(smem_int_ptr)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM75_U32x4_LDSM_N -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint128_t const& smem_src, -+ uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) -+ { -+#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); -+ asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" -+ : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) -+ : "r"(smem_int_ptr)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM75_U16x2_LDSM_T -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint32_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint128_t const& smem_src, -+ uint32_t& dst) -+ { -+#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); -+ asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" -+ : "=r"(dst) -+ : "r"(smem_int_ptr)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM75_U16x4_LDSM_T -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint128_t const& smem_src, -+ uint32_t& dst0, uint32_t& dst1) -+ { -+#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); -+ asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" -+ : "=r"(dst0), "=r"(dst1) -+ : "r"(smem_int_ptr)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM75_U16x8_LDSM_T -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint128_t const& smem_src, -+ uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) -+ { -+#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); -+ asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" -+ : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) -+ : "r"(smem_int_ptr)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); -+#endif -+ } -+}; -+ -+// -+// Legacy LDSM interfaces that aren't very useful -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_ldsm(uint128_t const* const smem_ptr, -+ T* rmem_ptr) -+{ -+ uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); -+ -+ // if constexpr -+ if (sizeof(T) == 4) { -+ SM75_U32x1_LDSM_N::copy(smem_ptr[0], reg_ptr[0]); -+ } -+ else if (sizeof(T) == 8) { -+ SM75_U32x2_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); -+ } -+ else if (sizeof(T) == 16) { -+ SM75_U32x4_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); -+ } -+ else { -+ static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); -+ } -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_ldsm_trans(uint128_t const* const smem_ptr, -+ T* rmem_ptr) -+{ -+ uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); -+ -+ // if constexpr -+ if (sizeof(T) == 4) { -+ SM75_U16x2_LDSM_T::copy(smem_ptr[0], reg_ptr[0]); -+ } -+ else if (sizeof(T) == 8) { -+ SM75_U16x4_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); -+ } -+ else if (sizeof(T) == 16) { -+ SM75_U16x8_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); -+ } -+ else { -+ static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); -+ } -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/copy_sm80.hpp b/3rdparty/cutlass/include/cute/arch/copy_sm80.hpp -new file mode 100644 -index 0000000..c6c4412 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/copy_sm80.hpp -@@ -0,0 +1,138 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+# define CUTE_ARCH_CP_ASYNC_SM80_ENABLED -+#endif -+ -+namespace cute -+{ -+ -+/// Copy via cp.async with caching at all levels -+template -+struct SM80_CP_ASYNC_CACHEALWAYS -+{ -+ using SRegisters = TS[1]; -+ using DRegisters = TD[1]; -+ -+ static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); -+ static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); -+ -+ CUTE_HOST_DEVICE static void -+ copy(TS const& gmem_src, -+ TD & smem_dst) -+ { -+#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) -+ TS const* gmem_ptr = &gmem_src; -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" -+ :: "r"(smem_int_ptr), -+ "l"(gmem_ptr), -+ "n"(sizeof(TS))); -+#else -+ CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled"); -+#endif -+ } -+}; -+ -+/// Copy via cp.async with caching at global level -+template -+struct SM80_CP_ASYNC_CACHEGLOBAL -+{ -+ using SRegisters = TS[1]; -+ using DRegisters = TD[1]; -+ -+ static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); -+ static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); -+ -+ CUTE_HOST_DEVICE static void -+ copy(TS const& gmem_src, -+ TD & smem_dst) -+ { -+#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) -+ TS const* gmem_ptr = &gmem_src; -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" -+ :: "r"(smem_int_ptr), -+ "l"(gmem_ptr), -+ "n"(sizeof(TS))); -+#else -+ CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. -+CUTE_HOST_DEVICE -+void -+cp_async_fence() -+{ -+#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) -+ asm volatile("cp.async.commit_group;\n" ::); -+#endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Blocks until all but N previous cp.async.commit_group operations have committed. -+template -+CUTE_HOST_DEVICE -+void -+cp_async_wait() -+{ -+#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) -+ if constexpr (N == 0) { -+ asm volatile("cp.async.wait_all;\n" ::); -+ } else { -+ asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); -+ } -+#endif -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+cp_async_wait(Int) -+{ -+ return cp_async_wait(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/copy_sm90.hpp b/3rdparty/cutlass/include/cute/arch/copy_sm90.hpp -new file mode 100644 -index 0000000..6ac9643 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/copy_sm90.hpp -@@ -0,0 +1,225 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) -+# define CUTE_ARCH_STSM_SM90_ENABLED -+# define CUTE_ARCH_TMA_SM90_ENABLED -+#endif -+ -+namespace cute -+{ -+ -+struct SM90_U32x1_STSM_N -+{ -+ using SRegisters = uint32_t[1]; -+ using DRegisters = uint128_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint32_t const& src, -+ uint128_t & smem_dst) -+ { -+#if defined(CUTE_ARCH_STSM_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile ("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" -+ :: "r"(smem_int_ptr), -+ "r"(src)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_U32x2_STSM_N -+{ -+ using SRegisters = uint32_t[2]; -+ using DRegisters = uint128_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint32_t const& src0, uint32_t const& src1, -+ uint128_t& smem_dst) -+ { -+#if defined(CUTE_ARCH_STSM_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile ("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" -+ :: "r"(smem_int_ptr), -+ "r"(src0), "r"(src1)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_U32x4_STSM_N -+{ -+ using SRegisters = uint32_t[4]; -+ using DRegisters = uint128_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, -+ uint128_t& smem_dst) -+ { -+#if defined(CUTE_ARCH_STSM_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile ("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" -+ :: "r"(smem_int_ptr), -+ "r"(src0), "r"(src1), "r"(src2), "r"(src3)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_U16x2_STSM_T -+{ -+ using SRegisters = uint32_t[1]; -+ using DRegisters = uint128_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint32_t const& src, -+ uint128_t& smem_dst) -+ { -+#if defined(CUTE_ARCH_STSM_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile ("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" -+ :: "r"(smem_int_ptr), -+ "r"(src)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_U16x4_STSM_T -+{ -+ using SRegisters = uint32_t[2]; -+ using DRegisters = uint128_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint32_t const& src0, uint32_t const& src1, -+ uint128_t& smem_dst) -+ { -+#if defined(CUTE_ARCH_STSM_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile ("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" -+ :: "r"(smem_int_ptr), -+ "r"(src0), "r"(src1)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_U16x8_STSM_T -+{ -+ using SRegisters = uint32_t[4]; -+ using DRegisters = uint128_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, -+ uint128_t& smem_dst) -+ { -+#if defined(CUTE_ARCH_STSM_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile ("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" -+ :: "r"(smem_int_ptr), -+ "r"(src0), "r"(src1), "r"(src2), "r"(src3)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+// -+// Legacy STSM interfaces that aren't very useful -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_stsm(T const* const rmem_ptr, -+ uint128_t* const smem_ptr) -+{ -+ uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); -+ -+ // if constexpr -+ if (sizeof(T) == 4) { -+ SM90_U32x1_STSM_N::copy(reg_ptr[0], smem_ptr[0]); -+ } -+ else if (sizeof(T) == 8) { -+ SM90_U32x2_STSM_N::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); -+ } -+ else if (sizeof(T) == 16) { -+ SM90_U32x4_STSM_N::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); -+ } -+ else { -+ static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); -+ } -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_stsm_trans(T const* const rmem_ptr, -+ uint128_t* const smem_ptr) -+{ -+ uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); -+ -+ // if constexpr -+ if (sizeof(T) == 4) { -+ SM90_U16x2_STSM_T::copy(reg_ptr[0], smem_ptr[0]); -+ } -+ else if (sizeof(T) == 8) { -+ SM90_U16x4_STSM_T::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); -+ } -+ else if (sizeof(T) == 16) { -+ SM90_U16x8_STSM_T::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); -+ } -+ else { -+ static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cute/arch/copy_sm90_desc.hpp b/3rdparty/cutlass/include/cute/arch/copy_sm90_desc.hpp -new file mode 100644 -index 0000000..ca8320f ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/copy_sm90_desc.hpp -@@ -0,0 +1,194 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#include -+#include -+ -+#include -+#include -+#include // to_Format<[u]intX> -+#include // to_Format -+ -+namespace cute -+{ -+ -+////////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Barriers are 64-bit of user-managed information used in broadly two types syncronization patterns -+/// 1) arrive/wait on threads (usage: cp.async and warp-specialized kernels) -+/// 2) transaction-based (usage: TMA transaction where a CTA issues one transaction) -+////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Initialize barrier present in shared memory -+CUTE_HOST_DEVICE -+void -+initialize_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem -+ int thread_count = 1) // Thread count expected to arrive/wait on this barrier -+{ -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); -+ asm volatile ("mbarrier.init.shared.b64 [%0], %1;\n" -+ :: "r"(smem_int_ptr), -+ "r"(thread_count)); -+#endif -+} -+ -+// Set the number of bytes transfered per transaction -+CUTE_HOST_DEVICE -+void -+set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem -+ uint32_t bytes) // Number of bytes transfered by per TMA transaction -+{ -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); -+ asm volatile ("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;\n" -+ :: "r"(smem_int_ptr), -+ "r"(bytes)); -+#endif -+} -+ -+// Barrier wait -+CUTE_HOST_DEVICE -+void -+wait_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem -+ int phase_bit) // Current phase bit the barrier waiting to flip -+{ -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); -+ asm volatile( -+ "{\n" -+ ".reg .pred P1;\n" -+ "LAB_WAIT:\n" -+ "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n" -+ "@P1 bra.uni DONE;\n" -+ "bra.uni LAB_WAIT;\n" -+ "DONE:\n" -+ "}\n" -+ :: "r"(smem_int_ptr), -+ "r"(phase_bit)); -+ -+#endif -+} -+ -+// Barrier arrive -+CUTE_HOST_DEVICE -+void -+arrive_barrier(uint64_t& smem_barrier) // 64 bits user-manged barrier in smem -+{ -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); -+ asm volatile( -+ "{\n" -+ ".reg .b64 state; \n" -+ "mbarrier.arrive.shared.b64 state, [%0];\n" -+ "}\n" -+ :: "r"(smem_int_ptr)); -+#endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// TMA Descriptor and utilities -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace TMA { -+ -+enum class SmemSwizzleBits : uint8_t { -+ DISABLE = 0, -+ B32 = 1, -+ B64 = 2, -+ B128 = 3, -+}; -+ -+#if (__CUDACC_VER_MAJOR__ >= 12) -+ -+template -+inline CUtensorMapDataType to_CUtensorMapDataType() { -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else -+ { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); } -+} -+ -+inline CUtensorMapSwizzle to_CUtensorMapSwizzle(SmemSwizzleBits const& t) { -+ switch (t) { -+ default: assert(false && "Unknown SmemSwizzleBits!"); -+ case SmemSwizzleBits::DISABLE: return CU_TENSOR_MAP_SWIZZLE_NONE; -+ case SmemSwizzleBits::B32: return CU_TENSOR_MAP_SWIZZLE_32B; -+ case SmemSwizzleBits::B64: return CU_TENSOR_MAP_SWIZZLE_64B; -+ case SmemSwizzleBits::B128: return CU_TENSOR_MAP_SWIZZLE_128B; -+ } -+} -+ -+#endif // (__CUDACC_VER_MAJOR__ >= 12) -+} // end namespace TMA -+ -+#if (__CUDACC_VER_MAJOR__ >= 12) -+using TmaDescriptor = CUtensorMap; -+#else -+using TmaDescriptor = struct { char bytes[128]; }; -+#endif -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Initiates a TensorMap Prefetch -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTE_HOST_DEVICE -+void -+prefetch_tma_descriptor(TmaDescriptor const* desc_ptr) -+{ -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ // Prefetch TMA Descriptor using generic addressing (i.e. no specific state space: const or param) -+ asm volatile ( -+ "prefetch.tensormap [%0];" -+ : -+ : "l"(gmem_int_desc) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/copy_sm90_tma.hpp b/3rdparty/cutlass/include/cute/arch/copy_sm90_tma.hpp -new file mode 100644 -index 0000000..d6025e4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/copy_sm90_tma.hpp -@@ -0,0 +1,552 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM90_TMA_LOAD_1D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes" -+ " [%0], [%1, {%3}], [%2];" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "r"(crd0) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_2D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes" -+ " [%0], [%1, {%3, %4}], [%2];" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "r"(crd0), "r"(crd1) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_3D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes" -+ " [%0], [%1, {%3, %4, %5}], [%2];" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "r"(crd0), "r"(crd1), "r"(crd2) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_4D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes" -+ " [%0], [%1, {%3, %4, %5, %6}], [%2];" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_5D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes" -+ " [%0], [%1, {%3, %4, %5, %6, %7}], [%2];" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0) -+ { -+ return SM90_TMA_LOAD_1D::copy(desc_ptr, smem_mbar, smem_ptr, crd0); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1) -+ { -+ return SM90_TMA_LOAD_2D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) -+ { -+ return SM90_TMA_LOAD_3D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) -+ { -+ return SM90_TMA_LOAD_4D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2, crd3); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) -+ { -+ return SM90_TMA_LOAD_5D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2, crd3, crd4); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM90_TMA_LOAD_1D_MULTICAST -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" -+ " [%0], [%1, {%4}], [%2], %3;" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "h"(multicast_mask), -+ "r"(crd0) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_2D_MULTICAST -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" -+ " [%0], [%1, {%4, %5}], [%2], %3;" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "h"(multicast_mask), -+ "r"(crd0), "r"(crd1) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_3D_MULTICAST -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" -+ " [%0], [%1, {%4, %5, %6}], [%2], %3;" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "h"(multicast_mask), -+ "r"(crd0), "r"(crd1), "r"(crd2) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_4D_MULTICAST -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" -+ " [%0], [%1, {%4, %5, %6, %7}], [%2], %3;" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "h"(multicast_mask), -+ "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_5D_MULTICAST -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" -+ " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3;" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "h"(multicast_mask), -+ "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_MULTICAST -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0) -+ { -+ return SM90_TMA_LOAD_1D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1) -+ { -+ return SM90_TMA_LOAD_2D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) -+ { -+ return SM90_TMA_LOAD_3D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) -+ { -+ return SM90_TMA_LOAD_4D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) -+ { -+ return SM90_TMA_LOAD_5D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3, crd4); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// TMA_STORE : Initiates a TMA copy from shared memory to global memory -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM90_TMA_STORE_1D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];" -+ : -+ : "l"(gmem_int_desc), "r"(smem_int_ptr), -+ "r"(crd0) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_STORE_2D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];" -+ : -+ : "l"(gmem_int_desc), "r"(smem_int_ptr), -+ "r"(crd0), "r"(crd1) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_STORE_3D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];" -+ : -+ : "l"(gmem_int_desc), "r"(smem_int_ptr), -+ "r"(crd0), "r"(crd1), "r"(crd2) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_STORE_4D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];" -+ : -+ : "l"(gmem_int_desc), "r"(smem_int_ptr), -+ "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_STORE_5D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" -+ : -+ : "l"(gmem_int_desc), "r"(smem_int_ptr), -+ "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_STORE -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0) -+ { -+ return SM90_TMA_STORE_1D::copy(desc_ptr, smem_ptr, crd0); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1) -+ { -+ return SM90_TMA_STORE_2D::copy(desc_ptr, smem_ptr, crd0, crd1); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) -+ { -+ return SM90_TMA_STORE_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) -+ { -+ return SM90_TMA_STORE_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) -+ { -+ return SM90_TMA_STORE_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); -+ } -+}; -+ -+// Indicate arrival of warp issuing TMA_STORE -+CUTE_HOST_DEVICE static void -+tma_store_arrive() { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ asm volatile("cp.async.bulk.commit_group;"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+} -+ -+// Wait on prior N (Count) TMA_STORE instructions to complete -+template -+CUTE_HOST_DEVICE static void -+tma_store_wait() { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ asm volatile( -+ "cp.async.bulk.wait_group.read %0;" -+ : -+ : "n"(Count) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/mma.hpp b/3rdparty/cutlass/include/cute/arch/mma.hpp -new file mode 100644 -index 0000000..1c1058f ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma.hpp -@@ -0,0 +1,64 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// Direct FMA for any type -+// -+ -+template -+struct UniversalFMA -+{ -+ using DRegisters = D[1]; -+ using ARegisters = A[1]; -+ using BRegisters = B[1]; -+ using CRegisters = C[1]; -+ -+ CUTE_HOST_DEVICE static constexpr void -+ fma(D & d, -+ A const& a, -+ B const& b, -+ C const& c) -+ { -+ // Forward to an ADL/cute free function for these types -+ using cute::fma; -+ fma(d, a, b, c); -+ } -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm61.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm61.hpp -new file mode 100644 -index 0000000..32a9fbb ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm61.hpp -@@ -0,0 +1,87 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) -+# define CUTE_ARCH_MMA_SM61_ENABLED -+#endif -+ -+namespace cute -+{ -+ -+struct SM61_DP4A -+{ -+ using DRegisters = int32_t[1]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = int32_t[1]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) -+ { -+#if defined(CUTE_ARCH_MMA_SM61_ENABLED) -+ asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" -+ : "=r"(d) -+ : "r"(a), "r"(b), "r"(c)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP4A without CUTE_ARCH_MMA_SM61_ENABLED"); -+#endif -+ } -+}; -+ -+struct SM61_DP2A -+{ -+ using DRegisters = int32_t[1]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = int32_t[1]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) -+ { -+#if defined(CUTE_ARCH_MMA_SM61_ENABLED) -+ asm volatile("dp2a.s32.s32 %0, %1, %2, %3;" -+ : "=r"(d) -+ : "r"(a), "r"(b), "r"(c)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP2A without CUTE_ARCH_MMA_SM61_ENABLED"); -+#endif -+ } -+}; -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm70.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm70.hpp -new file mode 100644 -index 0000000..139e600 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm70.hpp -@@ -0,0 +1,329 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) -+# define CUTE_ARCH_MMA_SM70_SUPPORTED -+# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) -+# define CUTE_ARCH_MMA_SM70_ENABLED -+# endif -+#endif -+ -+namespace cute -+{ -+ -+// -+// SM70 MMA 884 F16F16F16 -+// -+ -+struct SM70_8x8x4_F16F16F16F16_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16" -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6, %7}," -+ "{%8, %9, %10, %11};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TN without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM70_8x8x4_F16F16F16F16_NT -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16" -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6, %7}," -+ "{%8, %9, %10, %11};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NT without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM70_8x8x4_F16F16F16F16_NN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16" -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6, %7}," -+ "{%8, %9, %10, %11};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NN without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM70_8x8x4_F16F16F16F16_TT -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16" -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6, %7}," -+ "{%8, %9, %10, %11};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TT without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// SM70 MMA 884 F16F16F32 -+// -+ -+struct SM70_8x8x4_F32F16F16F32_TN -+{ -+ using DRegisters = float[8]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[8]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ float const& c0, float const& c1, float const& c2, float const& c3, -+ float const& c4, float const& c5, float const& c6, float const& c7) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32" -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11}," -+ "{%12, %13, %14, %15, %16, %17, %18, %19};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), -+ "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3), -+ "f"(c4), "f"(c5), "f"(c6), "f"(c7)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TN without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM70_8x8x4_F32F16F16F32_NT -+{ -+ using DRegisters = float[8]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[8]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ float const& c0, float const& c1, float const& c2, float const& c3, -+ float const& c4, float const& c5, float const& c6, float const& c7) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32" -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11}," -+ "{%12, %13, %14, %15, %16, %17, %18, %19};" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), -+ "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3), -+ "f"(c4), "f"(c5), "f"(c6), "f"(c7)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM70_8x8x4_F32F16F16F32_NN -+{ -+ using DRegisters = float[8]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[8]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ float const& c0, float const& c1, float const& c2, float const& c3, -+ float const& c4, float const& c5, float const& c6, float const& c7) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32" -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11}," -+ "{%12, %13, %14, %15, %16, %17, %18, %19};" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), -+ "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3), -+ "f"(c4), "f"(c5), "f"(c6), "f"(c7)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NN without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM70_8x8x4_F32F16F16F32_TT -+{ -+ using DRegisters = float[8]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[8]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ float const& c0, float const& c1, float const& c2, float const& c3, -+ float const& c4, float const& c5, float const& c6, float const& c7) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32" -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11}," -+ "{%12, %13, %14, %15, %16, %17, %18, %19};" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), -+ "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3), -+ "f"(c4), "f"(c5), "f"(c6), "f"(c7)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TT without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm75.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm75.hpp -new file mode 100644 -index 0000000..20d2b56 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm75.hpp -@@ -0,0 +1,120 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) -+# define CUTE_ARCH_MMA_SM75_SUPPORTED -+# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) -+# define CUTE_ARCH_MMA_SM75_ENABLED -+# endif -+#endif -+ -+namespace cute -+{ -+ -+// -+// SM75 MMA 1688 F16F16F32 -+// -+ -+struct SM75_16x8x8_F32F16F16F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = float[4]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ float const& c0, float const& c1, float const& c2, float const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM75_ENABLED) -+ asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM75_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// SM75 MMA 8816 S8S8S32 -+// -+ -+struct SM75_8x8x16_S32S8S8S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM75_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32" -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM75_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm80.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm80.hpp -new file mode 100644 -index 0000000..6050500 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm80.hpp -@@ -0,0 +1,2132 @@ -+ /************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+# define CUTE_ARCH_MMA_SM80_ENABLED -+#endif -+ -+namespace cute { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x8 TN -+struct SM80_16x8x8_F16F16F16F16_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " -+ "{%0, %1}," -+ "{%2, %3}," -+ "{%4}," -+ "{%5, %6};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_F16F16F16F16_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " -+ "{%0, %1}," -+ "{%2, %3, %4, %5}," -+ "{%6, %7}," -+ "{%8, %9};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x8 TN -+struct SM80_16x8x8_F32F16F16F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ float const & c0, float const & c1, float const & c2, float const & c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_F32F16F16F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ float const & c0, float const & c1, float const & c2, float const & c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x8 TN -+struct SM80_16x8x8_F32BF16BF16F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ float const & c0, float const & c1, float const & c2, float const & c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_F32BF16BF16F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ float const & c0, float const & c1, float const & c2, float const & c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x4 TN -+struct SM80_16x8x4_F32TF32TF32F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ float const & c0, float const & c1, float const & c2, float const & c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x4_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x8 TN -+struct SM80_16x8x8_F32TF32TF32F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ float const & c0, float const & c1, float const & c2, float const & c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x4 TN -+struct SM80_8x8x4_F64F64F64F64_TN -+{ -+ using DRegisters = double[2]; -+ using ARegisters = double[1]; -+ using BRegisters = double[1]; -+ using CRegisters = double[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(double & d0, double & d1, -+ double const& a0, -+ double const& b0, -+ double const& c0, double const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=d"(d0), "=d"(d1) -+ : "d"(a0), -+ "d"(b0), -+ "d"(c0), "d"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+// MMA 8x8x4 TN with Planar Complex multiplication -+struct SM80_8x8x4_C64C64C64C64_TN -+{ -+ using DRegisters = complex[2]; -+ using ARegisters = complex[1]; -+ using BRegisters = complex[1]; -+ using CRegisters = complex[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(complex & d0, complex & d1, -+ complex const& a0, -+ complex const& b0, -+ complex const& c0, complex const& c1) -+ { -+ // Because thrust::complex does not provide a mutable ref -+ double& rd0 = reinterpret_cast(d0)[0]; -+ double& id0 = reinterpret_cast(d0)[1]; -+ double& rd1 = reinterpret_cast(d1)[0]; -+ double& id1 = reinterpret_cast(d1)[1]; -+ -+ // d.real() = a.real() * b.real() + c.real(); -+ SM80_8x8x4_F64F64F64F64_TN::fma( -+ rd0, rd1, -+ a0.real(), -+ b0.real(), -+ c0.real(), c1.real()); -+ -+ // d.imag() = a.imag() * b.real() + c.imag(); -+ SM80_8x8x4_F64F64F64F64_TN::fma( -+ id0, id1, -+ a0.imag(), -+ b0.real(), -+ c0.imag(), c1.imag()); -+ -+ // d.real() = -a.imag() * b.imag() + d.real(); -+ SM80_8x8x4_F64F64F64F64_TN::fma( -+ rd0, rd1, -+ -a0.imag(), -+ b0.imag(), -+ d0.real(), d1.real()); -+ -+ // d.imag() = a.real() * b.imag() + d.imag(); -+ SM80_8x8x4_F64F64F64F64_TN::fma( -+ id0, id1, -+ a0.real(), -+ b0.imag(), -+ d0.imag(), d1.imag()); -+ } -+}; -+ -+// MMA 8x8x4 TN with Gaussian Complex multiplication: -+// (a + bi)*(c + di) -+// yields -+// t0 += a*c -+// t1 += b*d -+// t2 += (a+b)*(c+d) -+// then -+// re = t0 - t1 -+// im = t2 - t0 - t1 -+struct SM80_8x8x4_GC64C64C64GC64_TN -+{ -+ struct GaussComplex { -+ double t0, t1, t2; -+ -+ CUTE_HOST_DEVICE //constexpr -+ operator complex() const { return complex(t0 - t1, t2 - t0 - t1); } -+ -+ CUTE_HOST_DEVICE friend //constexpr -+ complex operator*(GaussComplex const& a, complex const& b) { return static_cast>(a) * b; } -+ CUTE_HOST_DEVICE friend //constexpr -+ complex operator*(complex const& a, GaussComplex const& b) { return b * a; } -+ -+ CUTE_HOST_DEVICE friend //constexpr -+ complex operator+(GaussComplex const& a, complex const& b) { return static_cast>(a) + b; } -+ CUTE_HOST_DEVICE friend //constexpr -+ complex operator+(complex const& a, GaussComplex const& b) { return b + a; } -+ }; -+ -+ using DRegisters = GaussComplex[2]; -+ using ARegisters = complex[1]; -+ using BRegisters = complex[1]; -+ using CRegisters = GaussComplex[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(GaussComplex & d0, GaussComplex & d1, -+ complex const& a0, -+ complex const& b0, -+ GaussComplex const& c0, GaussComplex const& c1) -+ { -+ SM80_8x8x4_F64F64F64F64_TN::fma(d0.t0, d1.t0, -+ a0.real(), -+ b0.real(), -+ c0.t0, c1.t0); -+ SM80_8x8x4_F64F64F64F64_TN::fma(d0.t1, d1.t1, -+ a0.imag(), -+ b0.imag(), -+ c0.t1, c1.t1); -+ SM80_8x8x4_F64F64F64F64_TN::fma(d0.t2, d1.t2, -+ a0.real() + a0.imag(), -+ b0.real() + b0.imag(), -+ c0.t2, c1.t2); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32S8S8S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32S8S8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32S8S8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32S8S8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S8S8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S8S8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32S8U8S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32S8U8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32S8U8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32S8U8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S8U8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S8U8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32U8S8S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32U8S8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32U8S8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32U8S8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U8S8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U8S8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32U8U8S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32U8U8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32U8U8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32U8U8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U8U8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U8U8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32S4S4S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32S4S4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S4S4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S4S4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32S4S4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32S4S4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32S4U4S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32S4U4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S4U4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S4U4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32S4U4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32S4U4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32U4S4S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32U4S4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U4S4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U4S4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32U4S4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32U4S4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32U4U4S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32U4U4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U4U4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U4U4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32U4U4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32U4U4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x128 TN -+struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x128 TN -+struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x256 TN -+struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x256_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm90.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm90.hpp -new file mode 100644 -index 0000000..08fe2b2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm90.hpp -@@ -0,0 +1,961 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) -+# define CUTE_ARCH_MMA_SM90_ENABLED -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cute { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x4 TN -+struct SM90_16x8x4_F64F64F64F64_TN -+{ -+ using DRegisters = double[4]; -+ using ARegisters = double[2]; -+ using BRegisters = double[1]; -+ using CRegisters = double[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(double & d0, double & d1, double & d2, double & d3, -+ double const& a0, double const& a1, -+ double const& b0, -+ double const& c0, double const& c1, double const& c2, double const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64" -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) -+ : "d"(a0), "d"(a1), -+ "d"(b0), -+ "d"(c0), "d"(c1), "d"(c2), "d"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x8 TN -+struct SM90_16x8x8_F64F64F64F64_TN -+{ -+ using DRegisters = double[4]; -+ using ARegisters = double[4]; -+ using BRegisters = double[2]; -+ using CRegisters = double[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(double & d0, double & d1, double & d2, double & d3, -+ double const& a0, double const& a1, double const& a2, double const& a3, -+ double const& b0, double const& b1, -+ double const& c0, double const& c1, double const& c2, double const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64" -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) -+ : "d"(a0), "d"(a1), "d"(a2), "d"(a3), -+ "d"(b0), "d"(b1), -+ "d"(c0), "d"(c1), "d"(c2), "d"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM90_16x8x16_F64F64F64F64_TN -+{ -+ using DRegisters = double[4]; -+ using ARegisters = double[8]; -+ using BRegisters = double[4]; -+ using CRegisters = double[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(double & d0, double & d1, double & d2, double & d3, -+ double const& a0, double const& a1, double const& a2, double const& a3, -+ double const& a4, double const& a5, double const& a6, double const& a7, -+ double const& b0, double const& b1, double const& b2, double const& b3, -+ double const& c0, double const& c1, double const& c2, double const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64" -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7, %8, %9, %10, %11}," -+ "{%12, %13, %14, %15}," -+ "{%16, %17, %18, %19};\n" -+ : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) -+ : "d"(a0), "d"(a1), "d"(a2), "d"(a3), -+ "d"(a4), "d"(a5), "d"(a6), "d"(a7), -+ "d"(b0), "d"(b1), "d"(b2), "d"(b3), -+ "d"(c0), "d"(c1), "d"(c2), "d"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x4 TN -+struct SM90_16x8x4_C64C64C64C64_TN -+{ -+ using DRegisters = complex[4]; -+ using ARegisters = complex[2]; -+ using BRegisters = complex[1]; -+ using CRegisters = complex[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(complex & d0, complex & d1, -+ complex & d2, complex & d3, -+ complex const& a0, complex const& a1, -+ complex const& b0, -+ complex const& c0, complex const& c1, -+ complex const& c2, complex const& c3) -+ { -+ // Because thrust::complex does not provide a mutable ref -+ double& rd0 = reinterpret_cast(d0)[0]; -+ double& id0 = reinterpret_cast(d0)[1]; -+ double& rd1 = reinterpret_cast(d1)[0]; -+ double& id1 = reinterpret_cast(d1)[1]; -+ double& rd2 = reinterpret_cast(d2)[0]; -+ double& id2 = reinterpret_cast(d2)[1]; -+ double& rd3 = reinterpret_cast(d3)[0]; -+ double& id3 = reinterpret_cast(d3)[1]; -+ -+ // d.real() = a.real() * b.real() + c.real(); -+ SM90_16x8x4_F64F64F64F64_TN::fma( -+ rd0, rd1, rd2, rd3, -+ a0.real(), a1.real(), -+ b0.real(), -+ c0.real(), c1.real(), c2.real(), c3.real()); -+ -+ // d.imag() = a.imag() * b.real() + c.imag(); -+ SM90_16x8x4_F64F64F64F64_TN::fma( -+ id0, id1, id2, id3, -+ a0.imag(), a1.imag(), -+ b0.real(), -+ c0.imag(), c1.imag(), c2.imag(), c3.imag()); -+ -+ // d.real() = -a.imag() * b.imag() + d.real(); -+ SM90_16x8x4_F64F64F64F64_TN::fma( -+ rd0, rd1, rd2, rd3, -+ -a0.imag(), -a1.imag(), -+ b0.imag(), -+ d0.real(), d1.real(), d2.real(), d3.real()); -+ -+ // d.imag() = a.real() * b.imag() + d.imag(); -+ SM90_16x8x4_F64F64F64F64_TN::fma( -+ id0, id1, id2, id3, -+ a0.real(), a1.real(), -+ b0.imag(), -+ d0.imag(), d1.imag(), d2.imag(), d3.imag()); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x8 TN -+struct SM90_16x8x8_C64C64C64C64_TN -+{ -+ using DRegisters = complex[4]; -+ using ARegisters = complex[4]; -+ using BRegisters = complex[2]; -+ using CRegisters = complex[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(complex & d0, complex & d1, -+ complex & d2, complex & d3, -+ complex const& a0, complex const& a1, -+ complex const& a2, complex const& a3, -+ complex const& b0, complex const& b1, -+ complex const& c0, complex const& c1, -+ complex const& c2, complex const& c3) -+ { -+ // Because thrust::complex does not provide a mutable ref -+ double& rd0 = reinterpret_cast(d0)[0]; -+ double& id0 = reinterpret_cast(d0)[1]; -+ double& rd1 = reinterpret_cast(d1)[0]; -+ double& id1 = reinterpret_cast(d1)[1]; -+ double& rd2 = reinterpret_cast(d2)[0]; -+ double& id2 = reinterpret_cast(d2)[1]; -+ double& rd3 = reinterpret_cast(d3)[0]; -+ double& id3 = reinterpret_cast(d3)[1]; -+ -+ // d.real() = a.real() * b.real() + c.real(); -+ SM90_16x8x8_F64F64F64F64_TN::fma( -+ rd0, rd1, rd2, rd3, -+ a0.real(), a1.real(), a2.real(), a3.real(), -+ b0.real(), b1.real(), -+ c0.real(), c1.real(), c2.real(), c3.real()); -+ -+ // d.imag() = a.imag() * b.real() + c.imag(); -+ SM90_16x8x8_F64F64F64F64_TN::fma( -+ id0, id1, id2, id3, -+ a0.imag(), a1.imag(), a2.imag(), a3.imag(), -+ b0.real(), b1.real(), -+ c0.imag(), c1.imag(), c2.imag(), c3.imag()); -+ -+ // d.real() = -a.imag() * b.imag() + d.real(); -+ SM90_16x8x8_F64F64F64F64_TN::fma( -+ rd0, rd1, rd2, rd3, -+ -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), -+ b0.imag(), b1.imag(), -+ d0.real(), d1.real(), d2.real(), d3.real()); -+ -+ // d.imag() = a.real() * b.imag() + d.imag(); -+ SM90_16x8x8_F64F64F64F64_TN::fma( -+ id0, id1, id2, id3, -+ a0.real(), a1.real(), a2.real(), a3.real(), -+ b0.imag(), b1.imag(), -+ d0.imag(), d1.imag(), d2.imag(), d3.imag()); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM90_16x8x16_C64C64C64C64_TN -+{ -+ using DRegisters = complex[4]; -+ using ARegisters = complex[8]; -+ using BRegisters = complex[4]; -+ using CRegisters = complex[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(complex & d0, complex & d1, -+ complex & d2, complex & d3, -+ complex const& a0, complex const& a1, -+ complex const& a2, complex const& a3, -+ complex const& a4, complex const& a5, -+ complex const& a6, complex const& a7, -+ complex const& b0, complex const& b1, -+ complex const& b2, complex const& b3, -+ complex const& c0, complex const& c1, -+ complex const& c2, complex const& c3) -+ { -+ // Because thrust::complex does not provide a mutable ref -+ double& rd0 = reinterpret_cast(d0)[0]; -+ double& id0 = reinterpret_cast(d0)[1]; -+ double& rd1 = reinterpret_cast(d1)[0]; -+ double& id1 = reinterpret_cast(d1)[1]; -+ double& rd2 = reinterpret_cast(d2)[0]; -+ double& id2 = reinterpret_cast(d2)[1]; -+ double& rd3 = reinterpret_cast(d3)[0]; -+ double& id3 = reinterpret_cast(d3)[1]; -+ -+ // d.real() = a.real() * b.real() + c.real(); -+ SM90_16x8x16_F64F64F64F64_TN::fma( -+ rd0, rd1, rd2, rd3, -+ a0.real(), a1.real(), a2.real(), a3.real(), -+ a4.real(), a5.real(), a6.real(), a7.real(), -+ b0.real(), b1.real(), b2.real(), b3.real(), -+ c0.real(), c1.real(), c2.real(), c3.real()); -+ -+ // d.imag() = a.imag() * b.real() + c.imag(); -+ SM90_16x8x16_F64F64F64F64_TN::fma( -+ id0, id1, id2, id3, -+ a0.imag(), a1.imag(), a2.imag(), a3.imag(), -+ a4.imag(), a5.imag(), a6.imag(), a7.imag(), -+ b0.real(), b1.real(), b2.real(), b3.real(), -+ c0.imag(), c1.imag(), c2.imag(), c3.imag()); -+ -+ // d.real() = -a.imag() * b.imag() + d.real(); -+ SM90_16x8x16_F64F64F64F64_TN::fma( -+ rd0, rd1, rd2, rd3, -+ -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), -+ -a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(), -+ b0.imag(), b1.imag(), b2.imag(), b3.imag(), -+ d0.real(), d1.real(), d2.real(), d3.real()); -+ -+ // d.imag() = a.real() * b.imag() + d.imag(); -+ SM90_16x8x16_F64F64F64F64_TN::fma( -+ id0, id1, id2, id3, -+ a0.real(), a1.real(), a2.real(), a3.real(), -+ a4.real(), a5.real(), a6.real(), a7.real(), -+ b0.imag(), b1.imag(), b2.imag(), b3.imag(), -+ d0.imag(), d1.imag(), d2.imag(), d3.imag()); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cute { -+namespace GMMA { -+ -+template< -+ class ElementA, -+ class ElementB, -+ class ElementC, -+ class TileShape_MNK, -+ GMMA::Major MajorA = GMMA::Major::K, -+ GMMA::Major MajorB = GMMA::Major::K, -+ auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] -+ // But most commonly leave empty for defaults -+> -+CUTE_HOST_DEVICE constexpr -+auto -+ss_op_selector() -+{ -+ static_assert(is_static::value, "TileShape_MNK must be static."); -+ static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); -+ static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); -+ auto Tile_N = size<1>(TileShape_MNK{}); -+ -+ // FP16 accumulator -+ if constexpr (std::is_same_v) { -+ static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); -+ static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); -+ static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); -+ -+ // Dispatch against the Tile N mode size -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x16_F16F16F16_SS{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // FP32 accumulator -+ else if constexpr (std::is_same_v) { -+ -+ // FP16 inputs -+ if constexpr (std::is_same_v) { -+ static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x16_F32F16F16_SS{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // BF16 inputs -+ else if constexpr (std::is_same_v) { -+ static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x16_F32BF16BF16_SS{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // TF32 inputs -+ else if constexpr (std::is_same_v) { -+ static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); -+ static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); -+ static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x8_F32TF32TF32_SS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ else { -+ static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); -+ } -+ } -+ -+ // S32 accumulator -+ else if constexpr (std::is_same_v) { -+ static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); -+ static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ // ElementA == int8_t && ElementB == int8_t -+ if constexpr (std::is_same_v && std::is_same_v) { -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32S8S8_SS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // ElementA == int8_t && ElementB == uint8_t -+ else if constexpr (std::is_same_v && std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32S8U8_SS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // ElementA == uint8_t && ElementB == int8_t -+ else if constexpr (std::is_same_v && std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32U8S8_SS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // ElementA == uint8_t && ElementB == uint8_t -+ else if constexpr (std::is_same_v && std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32U8U8_SS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ } -+ -+ // Unknown accumulator type -+ else { -+ static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); -+ } -+} -+ -+template< -+ class ElementA, -+ class ElementB, -+ class ElementC, -+ class TileShape_MNK, -+ GMMA::Major MajorA = GMMA::Major::K, -+ GMMA::Major MajorB = GMMA::Major::K, -+ auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] -+ // But most commonly leave empty for defaults -+> -+CUTE_HOST_DEVICE constexpr -+auto -+rs_op_selector() -+{ -+ static_assert(is_static::value, "TileShape_MNK must be static."); -+ static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); -+ static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); -+ static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); -+ auto Tile_N = size<1>(TileShape_MNK{}); -+ -+ // FP16 accumulator -+ if constexpr (std::is_same_v) { -+ static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); -+ static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); -+ static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); -+ -+ // Dispatch against the Tile N mode size -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x16_F16F16F16_RS{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // FP32 accumulator -+ else if constexpr (std::is_same_v) { -+ static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); -+ -+ // FP16 inputs -+ if constexpr (std::is_same_v) { -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x16_F32F16F16_RS{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // BF16 inputs -+ else if constexpr (std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x16_F32BF16BF16_RS{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // TF32 inputs -+ else if constexpr (std::is_same_v) { -+ static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x8_F32TF32TF32_RS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ else { -+ static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); -+ } -+ } -+ -+ // S32 accumulator -+ else if constexpr (std::is_same_v) { -+ static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ // ElementA == int8_t && ElementB == int8_t -+ if constexpr (std::is_same_v && std::is_same_v) { -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32S8S8_RS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // ElementA == int8_t && ElementB == uint8_t -+ else if constexpr (std::is_same_v && std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32S8U8_RS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // ElementA == uint8_t && ElementB == int8_t -+ else if constexpr (std::is_same_v && std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32U8S8_RS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // ElementA == uint8_t && ElementB == uint8_t -+ else if constexpr (std::is_same_v && std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32U8U8_RS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ } -+ -+ // Unknown accumulator type -+ else { -+ static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); -+ } -+} -+} // end namespace GMMA -+} // end namespace cute -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm90_desc.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm90_desc.hpp -new file mode 100644 -index 0000000..abac517 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm90_desc.hpp -@@ -0,0 +1,131 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) -+# define CUTE_ARCH_MMA_SM90_ENABLED -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cute { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// GMMA Descriptor and utilities -+ -+// GMMA enums and utilities -+namespace GMMA -+{ -+ -+enum class LayoutType : uint8_t { -+ INTERLEAVE = 0, -+ B128 = 1, -+ B64 = 2, -+ B32 = 3, -+}; -+ -+CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) { -+ switch (t) { -+ case LayoutType::INTERLEAVE: return "INTERLEAVE"; -+ case LayoutType::B128: return "B128"; -+ case LayoutType::B64: return "B64"; -+ case LayoutType::B32: return "B32"; -+ } -+ return nullptr; -+} -+ -+// Output operator for all enums in this namespace -+CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { -+ char const* s = to_string(t); -+ if (s) { -+ std::operator<<(os, s); // Explicit call to avoid ambiguity -+ } else { -+ os.setstate(std::ios_base::failbit); -+ } -+ return os; -+} -+ -+} // end namespace GMMA -+ -+union GmmaDescriptor -+{ -+ uint64_t desc_; -+ uint32_t reg32_[2]; -+ uint16_t reg16_[4]; -+ -+ // Bitfield implementation avoids the need for shifts in assignment -+ struct { -+ // start_address, bit [0,14), 4LSB not included -+ uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused -+ // leading dimension byte offset, bit [16,30), 4LSB not included -+ // For N: This is the stride from the first col to the second col of the 8x2 brick in INTERLEAVED -+ // Unused for all SWIZZLE_* layouts (and assumed to be 1) -+ // For T: This is the stride from the first 8 rows to the next 8 rows. -+ uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused -+ // stride dimension byte offset, bit [32,46), 4LSB not included -+ // For N: This is the stride from the first 8 rows to the next 8 rows. -+ // For T: This is the stride fro mthe first 8 cols to the next 8 cols. -+ uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused -+ // base_offset, bit [49,52) -+ // Valid only for SWIZZLE_128B and SWIZZLE_64B -+ uint8_t : 1, base_offset_ : 3, : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused -+ // layout type, bit [62,64) -+ // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 -+ uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8) -+ }; -+ -+ // Decay to a uint64_t -+ CUTE_HOST_DEVICE constexpr -+ operator uint64_t() const noexcept { return desc_; } -+ -+ // Printer -+ CUTE_HOST_DEVICE friend void print(GmmaDescriptor const& t) -+ { -+ printf("GmmaDescriptor: 0x%016lx\n", t.desc_); -+ printf(" start_addr : 0x%04x\n", t.start_address_); -+ printf(" leading_off: 0x%04x (%d)\n", t.leading_byte_offset_, t.leading_byte_offset_); -+ printf(" stride_off : 0x%04x (%d)\n", t.stride_byte_offset_, t.stride_byte_offset_); -+ printf(" base_offset: 0x%01x\n", t.base_offset_); -+ printf(" layout_type: 0x%01x (%s)\n", t.layout_type_, to_string(static_cast(t.layout_type_))); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm90_gmma.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm90_gmma.hpp -new file mode 100644 -index 0000000..25a1d17 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm90_gmma.hpp -@@ -0,0 +1,12265 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) -+# define CUTE_ARCH_MMA_SM90_ENABLED -+#endif -+ -+namespace cute { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// Warpgroup sync primitives -+ -+CUTE_HOST_DEVICE -+void -+warpgroup_arrive() -+{ -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use wgmma.fence without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+warpgroup_wait() -+{ -+ static_assert(N >= 0 && N <= 7, "_warpgroup.wait {N}; must be in range [0, 7]"); -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+} -+ -+// Marks the commit point for one or more sized batch of warpgroup MMAs. -+CUTE_HOST_DEVICE -+void -+warpgroup_commit_batch() -+{ -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use wgmma.commit_group without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+} -+ -+CUTE_HOST_DEVICE -+void -+warpgroup_fence_operand(uint32_t& reg) { -+ asm volatile("" : "+r"(reg) :: "memory"); -+} -+ -+CUTE_HOST_DEVICE -+void -+warpgroup_fence_operand(float& reg) { -+ asm volatile("" : "+f"(reg) :: "memory"); -+} -+ -+namespace GMMA { -+ -+enum class Major { -+ K = 0, -+ MN = 1 -+}; -+ -+enum class ScaleOut { -+ Zero = 0, -+ One = 1 -+}; -+ -+enum class ScaleIn { -+ Neg = -1, -+ One = 1 -+}; -+ -+} // namespace GMMA -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " -+ "{%0, %1}," -+ " %2," -+ " %3," -+ " %4, %5, %6, %7, %8;\n" -+ : "+r"(d0), "+r"(d1) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " -+ "{%0, %1}," -+ "{%2, %3, %4, %5}," -+ " %6," -+ " %7, %8, %9, %10;\n" -+ : "+r"(d0), "+r"(d1) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6, %7, %8, %9, %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9, %10, %11, %12;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10, %11, %12, %13, %14;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13, %14, %15, %16;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18, %19, %20, %21, %22;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21, %22, %23, %24;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[24]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23}," -+ " %24," -+ " %25," -+ " %26, %27, %28, %29, %30;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[24]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23}," -+ "{%24, %25, %26, %27}," -+ " %28," -+ " %29, %30, %31, %32;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34, %35, %36, %37, %38;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37, %38, %39, %40;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50, %51, %52, %53, %54;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53, %54, %55, %56;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66, %67, %68, %69, %70;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69, %70, %71, %72;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6, %7, %8, %9, %10;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[4]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9, %10, %11, %12;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10, %11, %12, %13, %14;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), -+ "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[8]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13, %14, %15, %16;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), -+ "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18, %19, %20, %21, %22;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[16]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21, %22, %23, %24;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34, %35, %36, %37, %38;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[32]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37, %38, %39, %40;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50, %51, %52, %53, %54;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[48]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53, %54, %55, %56;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66, %67, %68, %69, %70;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[64]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69, %70, %71, %72;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63, -+ float & d64, float & d65, float & d66, float & d67, -+ float & d68, float & d69, float & d70, float & d71, -+ float & d72, float & d73, float & d74, float & d75, -+ float & d76, float & d77, float & d78, float & d79, -+ float & d80, float & d81, float & d82, float & d83, -+ float & d84, float & d85, float & d86, float & d87, -+ float & d88, float & d89, float & d90, float & d91, -+ float & d92, float & d93, float & d94, float & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98, %99, %100, %101, %102;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), -+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), -+ "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), -+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), -+ "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), -+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), -+ "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), -+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), -+ "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[96]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63, -+ float & d64, float & d65, float & d66, float & d67, -+ float & d68, float & d69, float & d70, float & d71, -+ float & d72, float & d73, float & d74, float & d75, -+ float & d76, float & d77, float & d78, float & d79, -+ float & d80, float & d81, float & d82, float & d83, -+ float & d84, float & d85, float & d86, float & d87, -+ float & d88, float & d89, float & d90, float & d91, -+ float & d92, float & d93, float & d94, float & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101, %102, %103, %104;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), -+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), -+ "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), -+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), -+ "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), -+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), -+ "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), -+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), -+ "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d000, float & d001, float & d002, float & d003, -+ float & d004, float & d005, float & d006, float & d007, -+ float & d008, float & d009, float & d010, float & d011, -+ float & d012, float & d013, float & d014, float & d015, -+ float & d016, float & d017, float & d018, float & d019, -+ float & d020, float & d021, float & d022, float & d023, -+ float & d024, float & d025, float & d026, float & d027, -+ float & d028, float & d029, float & d030, float & d031, -+ float & d032, float & d033, float & d034, float & d035, -+ float & d036, float & d037, float & d038, float & d039, -+ float & d040, float & d041, float & d042, float & d043, -+ float & d044, float & d045, float & d046, float & d047, -+ float & d048, float & d049, float & d050, float & d051, -+ float & d052, float & d053, float & d054, float & d055, -+ float & d056, float & d057, float & d058, float & d059, -+ float & d060, float & d061, float & d062, float & d063, -+ float & d064, float & d065, float & d066, float & d067, -+ float & d068, float & d069, float & d070, float & d071, -+ float & d072, float & d073, float & d074, float & d075, -+ float & d076, float & d077, float & d078, float & d079, -+ float & d080, float & d081, float & d082, float & d083, -+ float & d084, float & d085, float & d086, float & d087, -+ float & d088, float & d089, float & d090, float & d091, -+ float & d092, float & d093, float & d094, float & d095, -+ float & d096, float & d097, float & d098, float & d099, -+ float & d100, float & d101, float & d102, float & d103, -+ float & d104, float & d105, float & d106, float & d107, -+ float & d108, float & d109, float & d110, float & d111, -+ float & d112, float & d113, float & d114, float & d115, -+ float & d116, float & d117, float & d118, float & d119, -+ float & d120, float & d121, float & d122, float & d123, -+ float & d124, float & d125, float & d126, float & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130, %131, %132, %133, %134;\n" -+ : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), -+ "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), -+ "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), -+ "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), -+ "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), -+ "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), -+ "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), -+ "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), -+ "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), -+ "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), -+ "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), -+ "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), -+ "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), -+ "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), -+ "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), -+ "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), -+ "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), -+ "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), -+ "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), -+ "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), -+ "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), -+ "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), -+ "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), -+ "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), -+ "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), -+ "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), -+ "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), -+ "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), -+ "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), -+ "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), -+ "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), -+ "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[128]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ float & d000, float & d001, float & d002, float & d003, -+ float & d004, float & d005, float & d006, float & d007, -+ float & d008, float & d009, float & d010, float & d011, -+ float & d012, float & d013, float & d014, float & d015, -+ float & d016, float & d017, float & d018, float & d019, -+ float & d020, float & d021, float & d022, float & d023, -+ float & d024, float & d025, float & d026, float & d027, -+ float & d028, float & d029, float & d030, float & d031, -+ float & d032, float & d033, float & d034, float & d035, -+ float & d036, float & d037, float & d038, float & d039, -+ float & d040, float & d041, float & d042, float & d043, -+ float & d044, float & d045, float & d046, float & d047, -+ float & d048, float & d049, float & d050, float & d051, -+ float & d052, float & d053, float & d054, float & d055, -+ float & d056, float & d057, float & d058, float & d059, -+ float & d060, float & d061, float & d062, float & d063, -+ float & d064, float & d065, float & d066, float & d067, -+ float & d068, float & d069, float & d070, float & d071, -+ float & d072, float & d073, float & d074, float & d075, -+ float & d076, float & d077, float & d078, float & d079, -+ float & d080, float & d081, float & d082, float & d083, -+ float & d084, float & d085, float & d086, float & d087, -+ float & d088, float & d089, float & d090, float & d091, -+ float & d092, float & d093, float & d094, float & d095, -+ float & d096, float & d097, float & d098, float & d099, -+ float & d100, float & d101, float & d102, float & d103, -+ float & d104, float & d105, float & d106, float & d107, -+ float & d108, float & d109, float & d110, float & d111, -+ float & d112, float & d113, float & d114, float & d115, -+ float & d116, float & d117, float & d118, float & d119, -+ float & d120, float & d121, float & d122, float & d123, -+ float & d124, float & d125, float & d126, float & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133, %134, %135, %136;\n" -+ : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), -+ "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), -+ "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), -+ "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), -+ "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), -+ "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), -+ "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), -+ "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), -+ "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), -+ "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), -+ "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), -+ "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), -+ "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), -+ "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), -+ "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), -+ "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), -+ "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), -+ "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), -+ "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), -+ "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), -+ "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), -+ "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), -+ "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), -+ "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), -+ "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), -+ "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), -+ "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), -+ "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), -+ "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), -+ "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), -+ "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), -+ "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6, %7, %8, %9, %10;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[4]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9, %10, %11, %12;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10, %11, %12, %13, %14;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), -+ "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[8]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13, %14, %15, %16;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), -+ "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18, %19, %20, %21, %22;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[16]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21, %22, %23, %24;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34, %35, %36, %37, %38;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[32]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37, %38, %39, %40;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50, %51, %52, %53, %54;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[48]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53, %54, %55, %56;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66, %67, %68, %69, %70;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[64]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69, %70, %71, %72;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63, -+ float & d64, float & d65, float & d66, float & d67, -+ float & d68, float & d69, float & d70, float & d71, -+ float & d72, float & d73, float & d74, float & d75, -+ float & d76, float & d77, float & d78, float & d79, -+ float & d80, float & d81, float & d82, float & d83, -+ float & d84, float & d85, float & d86, float & d87, -+ float & d88, float & d89, float & d90, float & d91, -+ float & d92, float & d93, float & d94, float & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98, %99, %100, %101, %102;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), -+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), -+ "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), -+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), -+ "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), -+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), -+ "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), -+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), -+ "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[96]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63, -+ float & d64, float & d65, float & d66, float & d67, -+ float & d68, float & d69, float & d70, float & d71, -+ float & d72, float & d73, float & d74, float & d75, -+ float & d76, float & d77, float & d78, float & d79, -+ float & d80, float & d81, float & d82, float & d83, -+ float & d84, float & d85, float & d86, float & d87, -+ float & d88, float & d89, float & d90, float & d91, -+ float & d92, float & d93, float & d94, float & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101, %102, %103, %104;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), -+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), -+ "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), -+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), -+ "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), -+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), -+ "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), -+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), -+ "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d000, float & d001, float & d002, float & d003, -+ float & d004, float & d005, float & d006, float & d007, -+ float & d008, float & d009, float & d010, float & d011, -+ float & d012, float & d013, float & d014, float & d015, -+ float & d016, float & d017, float & d018, float & d019, -+ float & d020, float & d021, float & d022, float & d023, -+ float & d024, float & d025, float & d026, float & d027, -+ float & d028, float & d029, float & d030, float & d031, -+ float & d032, float & d033, float & d034, float & d035, -+ float & d036, float & d037, float & d038, float & d039, -+ float & d040, float & d041, float & d042, float & d043, -+ float & d044, float & d045, float & d046, float & d047, -+ float & d048, float & d049, float & d050, float & d051, -+ float & d052, float & d053, float & d054, float & d055, -+ float & d056, float & d057, float & d058, float & d059, -+ float & d060, float & d061, float & d062, float & d063, -+ float & d064, float & d065, float & d066, float & d067, -+ float & d068, float & d069, float & d070, float & d071, -+ float & d072, float & d073, float & d074, float & d075, -+ float & d076, float & d077, float & d078, float & d079, -+ float & d080, float & d081, float & d082, float & d083, -+ float & d084, float & d085, float & d086, float & d087, -+ float & d088, float & d089, float & d090, float & d091, -+ float & d092, float & d093, float & d094, float & d095, -+ float & d096, float & d097, float & d098, float & d099, -+ float & d100, float & d101, float & d102, float & d103, -+ float & d104, float & d105, float & d106, float & d107, -+ float & d108, float & d109, float & d110, float & d111, -+ float & d112, float & d113, float & d114, float & d115, -+ float & d116, float & d117, float & d118, float & d119, -+ float & d120, float & d121, float & d122, float & d123, -+ float & d124, float & d125, float & d126, float & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130, %131, %132, %133, %134;\n" -+ : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), -+ "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), -+ "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), -+ "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), -+ "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), -+ "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), -+ "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), -+ "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), -+ "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), -+ "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), -+ "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), -+ "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), -+ "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), -+ "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), -+ "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), -+ "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), -+ "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), -+ "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), -+ "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), -+ "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), -+ "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), -+ "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), -+ "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), -+ "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), -+ "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), -+ "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), -+ "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), -+ "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), -+ "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), -+ "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), -+ "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), -+ "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[128]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ float & d000, float & d001, float & d002, float & d003, -+ float & d004, float & d005, float & d006, float & d007, -+ float & d008, float & d009, float & d010, float & d011, -+ float & d012, float & d013, float & d014, float & d015, -+ float & d016, float & d017, float & d018, float & d019, -+ float & d020, float & d021, float & d022, float & d023, -+ float & d024, float & d025, float & d026, float & d027, -+ float & d028, float & d029, float & d030, float & d031, -+ float & d032, float & d033, float & d034, float & d035, -+ float & d036, float & d037, float & d038, float & d039, -+ float & d040, float & d041, float & d042, float & d043, -+ float & d044, float & d045, float & d046, float & d047, -+ float & d048, float & d049, float & d050, float & d051, -+ float & d052, float & d053, float & d054, float & d055, -+ float & d056, float & d057, float & d058, float & d059, -+ float & d060, float & d061, float & d062, float & d063, -+ float & d064, float & d065, float & d066, float & d067, -+ float & d068, float & d069, float & d070, float & d071, -+ float & d072, float & d073, float & d074, float & d075, -+ float & d076, float & d077, float & d078, float & d079, -+ float & d080, float & d081, float & d082, float & d083, -+ float & d084, float & d085, float & d086, float & d087, -+ float & d088, float & d089, float & d090, float & d091, -+ float & d092, float & d093, float & d094, float & d095, -+ float & d096, float & d097, float & d098, float & d099, -+ float & d100, float & d101, float & d102, float & d103, -+ float & d104, float & d105, float & d106, float & d107, -+ float & d108, float & d109, float & d110, float & d111, -+ float & d112, float & d113, float & d114, float & d115, -+ float & d116, float & d117, float & d118, float & d119, -+ float & d120, float & d121, float & d122, float & d123, -+ float & d124, float & d125, float & d126, float & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133, %134, %135, %136;\n" -+ : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), -+ "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), -+ "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), -+ "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), -+ "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), -+ "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), -+ "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), -+ "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), -+ "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), -+ "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), -+ "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), -+ "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), -+ "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), -+ "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), -+ "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), -+ "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), -+ "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), -+ "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), -+ "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), -+ "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), -+ "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), -+ "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), -+ "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), -+ "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), -+ "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), -+ "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), -+ "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), -+ "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), -+ "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), -+ "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), -+ "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), -+ "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6, %7, %8;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9, %10, %11;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10, %11, %12;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), -+ "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13, %14, %15;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), -+ "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18, %19, %20;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21, %22, %23;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34, %35, %36;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37, %38, %39;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50, %51, %52;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53, %54, %55;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66, %67, %68;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69, %70, %71;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63, -+ float & d64, float & d65, float & d66, float & d67, -+ float & d68, float & d69, float & d70, float & d71, -+ float & d72, float & d73, float & d74, float & d75, -+ float & d76, float & d77, float & d78, float & d79, -+ float & d80, float & d81, float & d82, float & d83, -+ float & d84, float & d85, float & d86, float & d87, -+ float & d88, float & d89, float & d90, float & d91, -+ float & d92, float & d93, float & d94, float & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98, %99, %100;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), -+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), -+ "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), -+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), -+ "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), -+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), -+ "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), -+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), -+ "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63, -+ float & d64, float & d65, float & d66, float & d67, -+ float & d68, float & d69, float & d70, float & d71, -+ float & d72, float & d73, float & d74, float & d75, -+ float & d76, float & d77, float & d78, float & d79, -+ float & d80, float & d81, float & d82, float & d83, -+ float & d84, float & d85, float & d86, float & d87, -+ float & d88, float & d89, float & d90, float & d91, -+ float & d92, float & d93, float & d94, float & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101, %102, %103;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), -+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), -+ "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), -+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), -+ "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), -+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), -+ "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), -+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), -+ "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d000, float & d001, float & d002, float & d003, -+ float & d004, float & d005, float & d006, float & d007, -+ float & d008, float & d009, float & d010, float & d011, -+ float & d012, float & d013, float & d014, float & d015, -+ float & d016, float & d017, float & d018, float & d019, -+ float & d020, float & d021, float & d022, float & d023, -+ float & d024, float & d025, float & d026, float & d027, -+ float & d028, float & d029, float & d030, float & d031, -+ float & d032, float & d033, float & d034, float & d035, -+ float & d036, float & d037, float & d038, float & d039, -+ float & d040, float & d041, float & d042, float & d043, -+ float & d044, float & d045, float & d046, float & d047, -+ float & d048, float & d049, float & d050, float & d051, -+ float & d052, float & d053, float & d054, float & d055, -+ float & d056, float & d057, float & d058, float & d059, -+ float & d060, float & d061, float & d062, float & d063, -+ float & d064, float & d065, float & d066, float & d067, -+ float & d068, float & d069, float & d070, float & d071, -+ float & d072, float & d073, float & d074, float & d075, -+ float & d076, float & d077, float & d078, float & d079, -+ float & d080, float & d081, float & d082, float & d083, -+ float & d084, float & d085, float & d086, float & d087, -+ float & d088, float & d089, float & d090, float & d091, -+ float & d092, float & d093, float & d094, float & d095, -+ float & d096, float & d097, float & d098, float & d099, -+ float & d100, float & d101, float & d102, float & d103, -+ float & d104, float & d105, float & d106, float & d107, -+ float & d108, float & d109, float & d110, float & d111, -+ float & d112, float & d113, float & d114, float & d115, -+ float & d116, float & d117, float & d118, float & d119, -+ float & d120, float & d121, float & d122, float & d123, -+ float & d124, float & d125, float & d126, float & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130, %131, %132;\n" -+ : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), -+ "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), -+ "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), -+ "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), -+ "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), -+ "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), -+ "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), -+ "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), -+ "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), -+ "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), -+ "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), -+ "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), -+ "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), -+ "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), -+ "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), -+ "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), -+ "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), -+ "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), -+ "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), -+ "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), -+ "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), -+ "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), -+ "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), -+ "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), -+ "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), -+ "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), -+ "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), -+ "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), -+ "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), -+ "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), -+ "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), -+ "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ float & d000, float & d001, float & d002, float & d003, -+ float & d004, float & d005, float & d006, float & d007, -+ float & d008, float & d009, float & d010, float & d011, -+ float & d012, float & d013, float & d014, float & d015, -+ float & d016, float & d017, float & d018, float & d019, -+ float & d020, float & d021, float & d022, float & d023, -+ float & d024, float & d025, float & d026, float & d027, -+ float & d028, float & d029, float & d030, float & d031, -+ float & d032, float & d033, float & d034, float & d035, -+ float & d036, float & d037, float & d038, float & d039, -+ float & d040, float & d041, float & d042, float & d043, -+ float & d044, float & d045, float & d046, float & d047, -+ float & d048, float & d049, float & d050, float & d051, -+ float & d052, float & d053, float & d054, float & d055, -+ float & d056, float & d057, float & d058, float & d059, -+ float & d060, float & d061, float & d062, float & d063, -+ float & d064, float & d065, float & d066, float & d067, -+ float & d068, float & d069, float & d070, float & d071, -+ float & d072, float & d073, float & d074, float & d075, -+ float & d076, float & d077, float & d078, float & d079, -+ float & d080, float & d081, float & d082, float & d083, -+ float & d084, float & d085, float & d086, float & d087, -+ float & d088, float & d089, float & d090, float & d091, -+ float & d092, float & d093, float & d094, float & d095, -+ float & d096, float & d097, float & d098, float & d099, -+ float & d100, float & d101, float & d102, float & d103, -+ float & d104, float & d105, float & d106, float & d107, -+ float & d108, float & d109, float & d110, float & d111, -+ float & d112, float & d113, float & d114, float & d115, -+ float & d116, float & d117, float & d118, float & d119, -+ float & d120, float & d121, float & d122, float & d123, -+ float & d124, float & d125, float & d126, float & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133, %134, %135;\n" -+ : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), -+ "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), -+ "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), -+ "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), -+ "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), -+ "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), -+ "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), -+ "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), -+ "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), -+ "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), -+ "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), -+ "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), -+ "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), -+ "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), -+ "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), -+ "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), -+ "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), -+ "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), -+ "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), -+ "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), -+ "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), -+ "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), -+ "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), -+ "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), -+ "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), -+ "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), -+ "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), -+ "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), -+ "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), -+ "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), -+ "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), -+ "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/util.hpp b/3rdparty/cutlass/include/cute/arch/util.hpp -new file mode 100644 -index 0000000..007781f ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/util.hpp -@@ -0,0 +1,178 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#if (! defined (__clang__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) -+ extern "C" { -+ // This NVVM intrinsic is subject to change in future versions of CUDA. -+ // Clients should not call it directly. -+ CUTE_DEVICE uint32_t __nvvm_get_smem_pointer(void*); -+ } -+#endif -+ -+namespace cute -+{ -+ -+/// CUTE helper to cast SMEM pointer to unsigned -+CUTE_HOST_DEVICE -+uint32_t -+cast_smem_ptr_to_uint(void const* const ptr) -+{ -+// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to -+// the previous internal intrinsics if they are available. -+#if (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11) -+ // -+ // This NVVM intrinsic converts an address in shared memory to a plain -+ // unsigned integer. This is necessary to pass to shared memory instructions -+ // in inline PTX. -+ // -+ // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2]. -+ // -+ //__device__ size_t __cvta_generic_to_shared(void* ptr); -+ -+ /// CUTE helper to get SMEM pointer -+ return static_cast(__cvta_generic_to_shared(ptr)); -+ -+#elif (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) -+ -+ return __nvvm_get_smem_pointer(ptr); -+ -+#elif defined(__CUDA_ARCH__) -+ -+ uint32_t smem_ptr; -+ -+ asm( -+ "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" -+ : "=r"(smem_ptr) : "l"(ptr)); -+ -+ return smem_ptr; -+ -+#else -+ -+ -+ (void) ptr; -+ printf("ERROR: cast_smem_ptr_to_uint not supported but used.\n"); -+ return 0; -+ -+#endif -+} -+ -+// -+// Utility for pointer interfaces -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+explode(Fn fn, -+ PtrS&& s, int_sequence, -+ PtrD&& d, int_sequence) -+{ -+ return fn(s[Is]..., d[Id]...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+explode(Fn fn, -+ PtrA&& a, int_sequence, -+ PtrB&& b, int_sequence, -+ PtrC&& c, int_sequence) -+{ -+ return fn(a[Ia]..., b[Ib]..., c[Ic]...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+explode(Fn fn, -+ PtrD&& d, int_sequence, -+ PtrA&& a, int_sequence, -+ PtrB&& b, int_sequence, -+ PtrC&& c, int_sequence) -+{ -+ return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+explode(Fn fn, PtrS&& s, PtrD&& d) -+{ -+ return detail::explode(fn, -+ s, make_int_sequence{}, -+ d, make_int_sequence{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+explode(Fn fn, PtrA&& a, PtrB&& b, PtrC&& c) -+{ -+ return detail::explode(fn, -+ a, make_int_sequence{}, -+ b, make_int_sequence{}, -+ c, make_int_sequence{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+explode(Fn fn, PtrD&& d, PtrA&& a, PtrB&& b, PtrC&& c) -+{ -+ return detail::explode(fn, -+ d, make_int_sequence{}, -+ a, make_int_sequence{}, -+ b, make_int_sequence{}, -+ c, make_int_sequence{}); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/copy_atom.hpp b/3rdparty/cutlass/include/cute/atom/copy_atom.hpp -new file mode 100644 -index 0000000..2c5d9c5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/copy_atom.hpp -@@ -0,0 +1,671 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#include -+#include -+ -+#include -+ -+namespace cute { -+ -+// Generic copy_unpack for any Copy_Traits -+template -+CUTE_HOST_DEVICE constexpr -+void -+copy_unpack(Copy_Traits const&, -+ Tensor const& src, -+ Tensor & dst) -+{ -+ // Specializations can generalize on these checks -+ //static_assert(is_smem::value, "Expected smem for this Copy_Traits"); -+ //static_assert(is_rmem::value, "Expected rmem for this Copy_Traits"); -+ -+ using RegistersSrc = typename Operation::SRegisters; -+ using RegistersDst = typename Operation::DRegisters; -+ using RegTypeSrc = typename std::remove_extent::type; -+ using RegTypeDst = typename std::remove_extent::type; -+ constexpr int RegNumSrc = std::extent::value; -+ constexpr int RegNumDst = std::extent::value; -+ -+ Tensor rS = recast(src); -+ Tensor rD = recast(dst); -+ -+ CUTE_STATIC_ASSERT_V(size(rS) == Int{}, -+ "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); -+ CUTE_STATIC_ASSERT_V(size(rD) == Int{}, -+ "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); -+ -+ detail::explode(Operation::copy, -+ rS, make_int_sequence{}, -+ rD, make_int_sequence{}); -+} -+ -+ -+template -+struct Copy_Atom; -+ -+template -+struct Copy_Atom : Copy_Atom, T> -+{}; -+ -+template -+struct Copy_Atom, T> -+ : Copy_Traits -+{ -+ using Traits = Copy_Traits; -+ -+ // Bit and Thr layouts from the Copy_Traits -+ using ThrID = typename Traits::ThrID; -+ using BitLayoutSrc = typename Traits::SrcLayout; -+ using BitLayoutDst = typename Traits::DstLayout; -+ using BitLayoutRef = typename Traits::RefLayout; -+ -+ using ValType = T; -+ -+ using ValLayoutSrc = decltype(upcast::value>(BitLayoutSrc{})); -+ using ValLayoutDst = decltype(upcast::value>(BitLayoutDst{})); -+ using ValLayoutRef = decltype(upcast::value>(BitLayoutRef{})); -+ -+ CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType."); -+ CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType."); -+ CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), "CopyOperation is not valid for Ref of ValType."); -+ -+ static constexpr int NumValSrc = size<1>(ValLayoutSrc{}); -+ static constexpr int NumValDst = size<1>(ValLayoutDst{}); -+ -+ // Additional Trait parameters/transformations -+ template -+ CUTE_HOST_DEVICE -+ auto -+ with(TraitsArgs&&... args) const { -+ auto traits = Traits::with(std::forward(args)...); -+ return Copy_Atom{traits}; -+ } -+ -+ // Print thread and data layouts for debugging -+ CUTE_HOST_DEVICE static -+ void -+ print_all() -+ { -+ print("ThrID: "); print(ThrID{}); print("\n"); -+ print("BitLayoutSrc: "); print(BitLayoutSrc{}); print("\n"); -+ print("BitLayoutDst: "); print(BitLayoutDst{}); print("\n"); -+ print("BitLayoutRef: "); print(BitLayoutRef{}); print("\n"); -+ print("ValLayoutSrc: "); print(ValLayoutSrc{}); print("\n"); -+ print("ValLayoutDst: "); print(ValLayoutDst{}); print("\n"); -+ print("ValLayoutRef: "); print(ValLayoutRef{}); print("\n"); -+ print("ValueType: %db", sizeof_bits::value); print("\n"); -+ } -+ -+ // -+ // Tensor call interfaces -+ // -+ -+ // Cast, check, and call -+ template -+ CUTE_HOST_DEVICE -+ void -+ call(Tensor const& src, -+ Tensor & dst) const -+ { -+ static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); -+ static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); -+ -+ if constexpr (is_constant::value || is_constant::value) { -+ // Dispatch to unpack for instruction -+ return copy_unpack(*this, src, dst); -+ } else { -+ // Recurse if needed by peeling the tensor mode -+ return copy(*this, tensor<0>(src), tensor<0>(dst)); -+ } -+ } -+ -+ // Accept mutable temporaries -+ template -+ CUTE_HOST_DEVICE -+ void -+ call(Tensor const& src, -+ Tensor && dst) const -+ { -+ return call(src, dst); -+ } -+}; -+ -+// -+// A tiling of copy atoms -+// -+ -+template coord [Need not be 2D...] -+ class ShapeTile_MN> // coord space -+struct TiledCopy : Copy_Atom -+{ -+ // Layout information from the CopyAtom -+ using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx -+ using AtomLayoutSrc = typename Copy_Atom::ValLayoutSrc; // (thr,val) -> offset -+ using AtomLayoutDst = typename Copy_Atom::ValLayoutDst; // (thr,val) -> offset -+ using AtomLayoutRef = typename Copy_Atom::ValLayoutRef; // (thr,val) -> offset -+ -+ using AtomNumThr = decltype(size<0>(AtomLayoutRef{})); -+ using AtomNumVal = decltype(size<1>(AtomLayoutRef{})); -+ -+ // Layout information for the TiledCopy -+ using Tiler_MN = ShapeTile_MN; -+ using TiledShape_MN = decltype(shape(ShapeTile_MN{})); -+ using TiledLayout_TV = LayoutCopy_TV; -+ using TiledNumThr = decltype(size<0>(TiledLayout_TV{})); -+ using TiledNumVal = decltype(size<1>(TiledLayout_TV{})); -+ -+ CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom"); -+ CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom"); -+ -+ // Tile a tensor or a layout from shape -+ // (M,N,...) -+ // to shape -+ // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) -+ // where -+ // ThrV: The threads local to a COPY_ATOM Src. -+ // ThrX: The threads tiled across COPY_ATOMs Src. -+ // FrgV: The values local to a COPY_ATOM Src. -+ // RestM: The values tiled in M. -+ // RestN: The values tiled in N. -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ tidfrg_S(STensor&& stensor) -+ { -+ return thrfrg(stensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); -+ } -+ -+ // Tile a tensor or a layout from shape -+ // (M,N,...) -+ // to shape -+ // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) -+ // where -+ // ThrV: The threads local to a COPY_ATOM Dst. -+ // ThrX: The threads tiled across COPY_ATOMs Dst. -+ // FrgV: The values local to a COPY_ATOM Dst. -+ // RestM: The values tiled in M. -+ // RestN: The values tiled in N. -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ tidfrg_D(DTensor&& dtensor) -+ { -+ return thrfrg(dtensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg) -+ { -+ constexpr int R = remove_cvref_t::rank; -+ static_assert(R >= rank_v, "Rank of tensor to be partitioned too small."); -+ // Generalize the dimension checks for arbitrary rank -+ //CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); -+ //CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); -+ -+ // Take the thrs/vals that the atom is interested in -+ // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID -+ auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{})); -+ // ((atom_tid,atom_val),(rest_tid,rest_val)) -> (m,n) -+ -+ // Transform to the trg layout -+ auto trg_layout_TV = atom_layout_TV.compose(ref2trg, _); -+ // ((trg_tid,trg_val),(rest_tid,rest_val)) -> (m,n) -+ -+ // Transform the thrs mode from thrid to thr_idx -+ // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID -+ auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1,Shape<_1,_1>>{}); -+ // ((trg_tid,rest_tid),(trg_val,rest_val)) -> (m,n) -+ -+ /// ================== -+ -+ // Tile the tensor for TiledLayout -+ auto t_tensor = zipped_divide(tensor, Tiler_MN{}); -+ // ((TileM,TileN,...),(RestM,RestN,...)) -+ -+ // Transform the tile mode -+ auto tv_tensor = t_tensor.compose(thrval2mn, _); -+ // ((thrid,val),(RM,RN,...)) -+ -+ // Unfold and return -+ return tv_tensor(make_coord(_,_), _); -+ } -+ -+ // retile_S and retile_D assume they are working with the reference layout -- they are the same -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ retile(Tensor&& tensor) -+ { -+ constexpr int R = remove_cvref_t::rank; -+ // Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref transformation -+ -+ // Assume the first size<0>(tensor) elements are the first val_ids in TiledLayout_TV. -+ // Then, we only need the shape+layout of those size<0>(tensor) elements in TiledLayout_TV -+ // and that shape is what we gather from the other modes of tensor -+ -+ auto V = size<0>(tensor); -+ -+ auto frg_layout_mn = upcast(right_inverse(TiledLayout_TV{}).with_shape(TiledShape_MN{})); -+ // (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV -+ -+ auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{})); -+ // (atom_vals,rest_vals) -> (v,m,n) -+ -+ /// ======= -+ -+ // Tile the tensor for TileFrg -+ auto t_tensor = zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V)); -+ // ((TileV,TileM,TileN,...),(1,RestM,RestN,...)) -+ -+ // Transform the tile mode -+ auto v_tensor = t_tensor.compose(frg_layout_v, _); -+ // ((atom_vals,rest_vals),(1,RM,RN,...)) -+ -+ // Unfold and return -+ return v_tensor(_, append(Int<0>{},_)); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutS_MN() -+ { -+ // (M,N) -> (M,N) -+ auto ref_S = make_layout(TiledShape_MN{}); -+ // (thr_idx,val_idx) -> (M,N) -+ auto layoutS_TV = tidfrg_S(ref_S); -+ // (M,K) -> (thr_idx,val_idx) -+ auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(ref_S)); -+ -+ // athrid = (v,m,k) -> thr_idx -+ auto thrID_S = make_layout(size<0>(TiledLayout_TV{})); -+ -+ return cute::make_tuple(layoutS_MK, thrID_S); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutS_TV() -+ { -+ // (M,N) -> (M,N) -+ auto ref_S = make_layout(TiledShape_MN{}); -+ // (thr_idx,val_idx) -> (M,N) -+ return tidfrg_S(ref_S)(_,_,Int<0>{}); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutD_MN() -+ { -+ // (M,N) -> (M,N) -+ auto ref_D = make_layout(TiledShape_MN{}); -+ // (thr_idx,val_idx) -> (M,N) -+ auto layoutD_TV = tidfrg_D(ref_D); -+ // (M,K) -> (thr_idx,val_idx) -+ auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(ref_D)); -+ -+ // athrid = (v,m,k) -> thr_idx -+ auto thrID_D = make_layout(size<0>(TiledLayout_TV{})); -+ -+ return cute::make_tuple(layoutD_MK, thrID_D); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutD_TV() -+ { -+ // (M,N) -> (M,N) -+ auto ref_D = make_layout(TiledShape_MN{}); -+ // (thr_idx,val_idx) -> (M,N) -+ return tidfrg_D(ref_D)(_,_,Int<0>{}); -+ } -+ -+ template -+ struct ThrCopy : Copy_Atom -+ { -+ ThrIdx thr_idx_; -+ -+ CUTE_HOST_DEVICE -+ ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {} -+ -+ template -+ CUTE_HOST_DEVICE -+ auto -+ partition_S(STensor&& stensor) { -+ //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), -+ // "Expected ValType for tiling SrcTensor."); -+ auto thr_tensor = make_tensor(std::forward(stensor).data(), tidfrg_S(stensor.layout())); -+ return thr_tensor(thr_idx_, _, repeat>(_)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE -+ auto -+ partition_D(DTensor&& dtensor) { -+ //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), -+ // "Expected ValType for tiling DstTensor."); -+ auto thr_tensor = make_tensor(std::forward(dtensor).data(), tidfrg_D(dtensor.layout())); -+ return thr_tensor(thr_idx_, _, repeat>(_)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE static -+ auto -+ retile_S(STensor&& stensor) { -+ static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), -+ "Expected ValType for tiling SrcTensor."); -+ return make_tensor(std::forward(stensor).data(), TiledCopy::retile(stensor.layout())); -+ } -+ -+ template -+ CUTE_HOST_DEVICE static -+ auto -+ retile_D(DTensor&& dtensor) { -+ static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), -+ "Expected ValType for tiling DstTensor."); -+ return make_tensor(std::forward(dtensor).data(), TiledCopy::retile(dtensor.layout())); -+ } -+ }; -+ -+ template ::value)> -+ CUTE_HOST_DEVICE static -+ auto -+ get_slice(ThrIdx const& thr_idx) -+ { -+ return ThrCopy(thr_idx); -+ } -+ -+ template ::value)> -+ CUTE_HOST_DEVICE static -+ auto -+ get_thread_slice(ThrIdx const& thr_idx) -+ { -+ return get_slice(thr_idx); -+ } -+}; -+ -+ -+template -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy_impl(Copy_Atom const& atom, -+ LayoutCopy_TV const&, -+ Tile const&) -+{ -+ return TiledCopy, LayoutCopy_TV, Tile>{atom}; -+} -+ -+// -+// These tile the Copy_Atom as a whole -+// -+ -+template -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy_A(Copy_Atom const& copy_atom, -+ TiledMMA const& tiled_mma) -+{ -+ using MNK = typename TiledMMA::TiledShape_MNK; -+ return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), make_shape(size<0>(MNK{}),size<2>(MNK{}))); -+} -+ -+template -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy_B(Copy_Atom const& copy_atom, -+ TiledMMA const& tiled_mma) -+{ -+ using MNK = typename TiledMMA::TiledShape_MNK; -+ return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), make_shape(size<1>(MNK{}),size<2>(MNK{}))); -+} -+ -+template -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy_C(Copy_Atom const& copy_atom, -+ TiledMMA const& tiled_mma) -+{ -+ using MNK = typename TiledMMA::TiledShape_MNK; -+ return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), make_shape(size<0>(MNK{}),size<1>(MNK{}))); -+} -+ -+template > -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy(Copy_Atom const& copy_atom, -+ ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx -+ ValLayout const& val_layout = {}) -+{ -+ constexpr int R = cute::max(rank_v, rank_v); -+ -+ auto thr_layout_mn = append(thr_layout, Layout<_1>{}); -+ auto val_layout_mn = append(val_layout, Layout<_1>{}); -+ -+ // Take the raked_products to compute the Layout_MN -+ auto layout_mn = raked_product(thr_layout_mn, val_layout_mn); -+ auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); -+ -+ //print("thr_layout: "); print(thr_layout_mn); print("\n"); -+ //print("val_layout: "); print(val_layout_mn); print("\n"); -+ //print("layout_mn : "); print(layout_mn); print("\n"); -+ //print("layout_tv : "); print(layout_tv); print("\n"); -+ -+ return make_tiled_copy_impl(copy_atom, layout_tv, product_each(shape(layout_mn))); -+} -+ -+// Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy -+template -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy_S(Copy_Atom const& copy_atom, -+ TiledCopy const& tiled_copy) -+{ -+ return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), typename TiledCopy::Tiler_MN{}); -+} -+ -+// Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy -+template -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy_D(Copy_Atom const& copy_atom, -+ TiledCopy const& tiled_copy) -+{ -+ return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), typename TiledCopy::Tiler_MN{}); -+} -+ -+// -+// Size -+// -+ -+// The logical size of a TileCopy -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tile_size(TiledCopy const&) -+{ -+ return size(typename TiledCopy::TiledShape_MN{}); -+} -+ -+// The number of threads involved in a TiledCopy -+template -+CUTE_HOST_DEVICE constexpr -+auto -+size(TiledCopy const&) -+{ -+ return typename TiledCopy::TiledNumThr{}; -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE -+auto -+print_latex(TiledCopy const& copy) -+{ -+ auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN(); -+ auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN(); -+ -+ print_latex_copy(layoutS_MN, thrID_S, -+ layoutD_MN, thrID_D); -+} -+ -+// MNK Copy Layout to Latex TIKZ -- 8-value color coded by thread -+template -+CUTE_HOST_DEVICE -+void -+print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx -+ LayoutD const& D, ThrIDD const& TD) // (m,n) -> (tid,vid) and tid -> thr_idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); -+ -+ assert(size<0>(S) == size<0>(D)); -+ assert(size<1>(S) == size<1>(D)); -+ -+ char const* latex_header = -+ "\\documentclass{standalone}\n" -+ "\\usepackage{tikz}\n" -+ "\\usetikzlibrary{external}\n" -+ "\\tikzexternalize\n" -+ "\\begin{document}\n" -+ "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; -+ char const* latex_footer = -+ "\\end{tikzpicture}\n" -+ "\\end{document}\n"; -+ -+ char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", -+ "{rgb,255:red,175;green,255;blue,175}", -+ "{rgb,255:red,255;green,255;blue,175}", -+ "{rgb,255:red,255;green,175;blue,175}", -+ "{rgb,255:red,210;green,210;blue,255}", -+ "{rgb,255:red,210;green,255;blue,210}", -+ "{rgb,255:red,255;green,255;blue,210}", -+ "{rgb,255:red,255;green,210;blue,210}",}; -+ -+ // Header -+ printf("%% LayoutS: "); print(S); printf("\n"); -+ printf("%% ThrIDS : "); print(TS); printf("\n"); -+ printf("%% LayoutD: "); print(D); printf("\n"); -+ printf("%% ThrIDD : "); print(TD); printf("\n\n"); -+ -+ printf(latex_header); -+ -+ // S starting at 0,0 -+ for (int i = 0; i < size<0>(S); ++i) { -+ for (int j = 0; j < size<1>(S); ++j) { -+ int thrid = S(i,j) % size(TS); -+ int val_idx = S(i,j) / size(TS); -+ int thr_idx = TS(thrid); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[thr_idx % 8], -+ i, j, -+ thr_idx, val_idx); -+ } -+ } -+ -+ // D starting at 0,size<1>(S)+3 -+ for (int i = 0; i < size<0>(D); ++i) { -+ for (int j = 0; j < size<1>(D); ++j) { -+ int thrid = D(i,j) % size(TD); -+ int val_idx = D(i,j) / size(TD); -+ int thr_idx = TD(thrid); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[thr_idx % 8], -+ i, j + size<1>(S) + 3, -+ thr_idx, val_idx); -+ } -+ } -+ -+ // S Labels -+ for (int i = 0, j = -1; i < size<0>(S); ++i) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); -+ } -+ for (int j = 0, i = -1; j < size<1>(S); ++j) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); -+ } -+ // D Labels -+ for (int i = 0, j = size<1>(D); i < size<0>(S); ++i) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i); -+ } -+ for (int j = 0, i = -1; j < size<1>(D); ++j) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j); -+ } -+ -+ // Footer -+ printf(latex_footer); -+} -+ -+} // end namespace cute -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+// Config -+#if (__CUDACC_VER_MAJOR__ >= 12) -+# define CUTE_COPY_ATOM_TMA_SM90_ENABLED -+#endif -+ -+#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) -+#include -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cute/atom/copy_traits.hpp b/3rdparty/cutlass/include/cute/atom/copy_traits.hpp -new file mode 100644 -index 0000000..83cb056 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/copy_traits.hpp -@@ -0,0 +1,76 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template -+struct Copy_Traits -+{ -+ static_assert(sizeof(CopyOperation) == 0, "Copy_Traits not implemented for this Copy_Operation."); -+}; -+ -+template -+struct Copy_Traits> -+{ -+ // Logical thread id to thread idx (one-thread) -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout::value>>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout::value>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (one-thread) -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout, Stride<_0,_0>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout, Stride<_0,_0>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/copy_traits_sm75.hpp b/3rdparty/cutlass/include/cute/atom/copy_traits_sm75.hpp -new file mode 100644 -index 0000000..13eb166 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/copy_traits_sm75.hpp -@@ -0,0 +1,143 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout,_128>, -+ Stride, _1>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout, -+ Stride<_32, _1>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = DstLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout,_128>, -+ Stride, _1>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>, -+ Stride<_32,Stride< _1,_1024>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = DstLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout, -+ Stride<_128, _1>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>, -+ Stride<_32,Stride< _1,_1024>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = DstLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout,_128>, -+ Stride, _1>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout,Shape <_16, _2>>, -+ Stride,Stride< _1,_128>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = DstLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout,_128>, -+ Stride, _1>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout,Shape <_16, _2, _2>>, -+ Stride,Stride< _1,_128,_1024>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = DstLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout, -+ Stride<_128, _1>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout,Shape <_16, _2, _4>>, -+ Stride,Stride< _1,_128,_1024>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = DstLayout; -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/copy_traits_sm80.hpp b/3rdparty/cutlass/include/cute/atom/copy_traits_sm80.hpp -new file mode 100644 -index 0000000..089d193 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/copy_traits_sm80.hpp -@@ -0,0 +1,98 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template -+struct Copy_Traits> -+{ -+ // Logical thread id to thread idx (one-thread) -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout::value>>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout::value>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template -+struct Copy_Traits> -+{ -+ // Logical thread id to thread idx (one-thread) -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout::value>>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout::value>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Element copy selector -+template -+CUTE_HOST_DEVICE constexpr -+auto -+select_elementwise_copy(SrcTensor const&, DstTensor const&) -+{ -+ using SrcType = typename SrcTensor::value_type; -+ using DstType = typename DstTensor::value_type; -+ -+#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) -+ if constexpr (is_gmem::value && is_smem::value && -+ sizeof(SrcType) == sizeof(DstType) && -+ (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16)) -+ { -+ return SM80_CP_ASYNC_CACHEALWAYS{}; -+ } else { -+ return UniversalCopy{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+#else -+ return UniversalCopy{}; -+#endif -+} -+ -+} -diff --git a/3rdparty/cutlass/include/cute/atom/copy_traits_sm90.hpp b/3rdparty/cutlass/include/cute/atom/copy_traits_sm90.hpp -new file mode 100644 -index 0000000..8c5e843 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/copy_traits_sm90.hpp -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = typename Copy_Traits::DstLayout; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = typename Copy_Traits::SrcLayout; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = typename Copy_Traits::DstLayout; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = typename Copy_Traits::SrcLayout; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = typename Copy_Traits::DstLayout; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = typename Copy_Traits::SrcLayout; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = typename Copy_Traits::DstLayout; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = typename Copy_Traits::SrcLayout; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = typename Copy_Traits::DstLayout; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = typename Copy_Traits::SrcLayout; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = typename Copy_Traits::DstLayout; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = typename Copy_Traits::SrcLayout; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp b/3rdparty/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp -new file mode 100644 -index 0000000..18e22bf ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp -@@ -0,0 +1,795 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+////////////////////////////////////////////////////////////////////////////// -+///////////////////////////// TMA_LOAD /////////////////////////////////////// -+////////////////////////////////////////////////////////////////////////////// -+ -+struct SM90_TMA_LOAD_OP : SM90_TMA_LOAD {}; -+ -+// The executable SM90_TMA_LOAD with tma_desc and tma_mbar -+template -+struct Copy_Traits -+{ -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+ -+ // SM90_TMA_LOAD arguments -+ TmaDescriptor const& tma_desc_; -+ uint64_t& tma_load_mbar_; -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ void -+ copy_unpack_(void const* const dst_ptr, -+ Coord const& src_coord, seq) const -+ { -+#if 0 -+ print("THR (%d,%d,%d) BLK (%d,%d,%d)\n", -+ threadIdx.x, threadIdx.y, threadIdx.z, -+ blockIdx.x, blockIdx.y, blockIdx.z); -+ print(" TMA Coord "); print(src_coord); print("\n"); -+ print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_), -+ uint64_t(tma_desc_.size1_), -+ uint64_t(tma_desc_.size2_), -+ uint64_t(tma_desc_.size3_))); print("\n"); -+#endif -+ -+ SM90_TMA_LOAD::copy(&tma_desc_, -+ tma_load_mbar_, -+ dst_ptr, -+ get(src_coord)...); -+ } -+ -+ // This is the copy_unpack dispatch for this Copy_Traits -+ // Src needs to be a gmem tensor with TmaCoordIterator .data() -+ // Dst needs to be a smem tensor -+ template -+ CUTE_HOST_DEVICE friend constexpr -+ void -+ copy_unpack(Copy_Traits const& traits, -+ Tensor const& src, -+ Tensor & dst) -+ { -+ //static_assert(is_gmem::value, "Expected gmem src for SM90_TMA_LOAD"); // TMA spoofed src tensor -+ static_assert(is_smem::value, "Expected smem dst for SM90_TMA_LOAD"); -+ -+ traits.copy_unpack_(dst.data().get(), src.data().coord_, tuple_seq{}); -+ } -+}; -+ -+// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar -+// Use .with(tma_mbar) to construct an executable version -+template -+struct Copy_Traits -+{ -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+ -+ // SM90_TMA_LOAD arguments -+ TmaDescriptor tma_desc_; -+ GmemStrides g_stride_; -+ -+ // Return TmaDescriptor/TensorMap -+ CUTE_HOST_DEVICE constexpr -+ TmaDescriptor const* -+ get_tma_descriptor() const { -+ return &tma_desc_; -+ } -+ -+ // Construct an executable SM90_TMA_LOAD with tma_mbar -+ CUTE_HOST_DEVICE constexpr -+ Copy_Traits -+ with(uint64_t& tma_mbar, uint16_t const& multicast_mask = 0) const { -+ // We accept multicast_mask here to keep the API for both atoms consistent -+ // assert(multicast_mask == 0); -+ (void) multicast_mask; -+ return {tma_desc_, tma_mbar}; -+ } -+ -+ // Generate the TMA coord tensor -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_tma_tensor(GShape const& g_shape) const { -+ static_assert(is_congruent::value); -+ constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value; -+ return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat(Int<0>{}))), -+ g_shape, -+ g_stride_); -+ } -+ -+ // Don't try to execute a copy with SM90_TMA_LOAD before calling .with() -+ template -+ CUTE_HOST_DEVICE friend constexpr void -+ copy_unpack(Copy_Traits const& traits, -+ Tensor const& src, -+ Tensor & dst) = delete; -+}; -+ -+////////////////////////////////////////////////////////////////////////////// -+///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// -+////////////////////////////////////////////////////////////////////////////// -+ -+struct SM90_TMA_LOAD_MULTICAST_OP : SM90_TMA_LOAD_MULTICAST {}; -+ -+template -+struct Copy_Traits -+{ -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+ -+ // SM90_TMA_LOAD_MULTICAST arguments -+ TmaDescriptor const& tma_desc_; -+ uint64_t& tma_load_mbar_; -+ uint16_t const& multicast_mask_; -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ void -+ copy_unpack_(void const* const dst_ptr, -+ Coord const& src_coord, seq) const -+ { -+#if 0 -+ print("THR (%d,%d,%d) BLK (%d,%d,%d)\n", -+ threadIdx.x, threadIdx.y, threadIdx.z, -+ blockIdx.x, blockIdx.y, blockIdx.z); -+ print(" TMA Coord "); print(src_coord); print("\n"); -+ print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_), -+ uint64_t(tma_desc_.size1_), -+ uint64_t(tma_desc_.size2_), -+ uint64_t(tma_desc_.size3_))); print("\n"); -+#endif -+ -+ SM90_TMA_LOAD_MULTICAST::copy(&tma_desc_, -+ tma_load_mbar_, -+ multicast_mask_, -+ dst_ptr, -+ get(src_coord)...); -+ } -+ -+ template -+ CUTE_HOST_DEVICE friend constexpr -+ void -+ copy_unpack(Copy_Traits const& traits, -+ Tensor const& src, -+ Tensor & dst) -+ { -+ //static_assert(is_gmem::value, "Expected gmem src for SM90_TMA_LOAD"); // TMA spoofed src tensor -+ static_assert(is_smem::value, "Expected smem dst for SM90_TMA_LOAD_MULTICAST"); -+ -+ traits.copy_unpack_(dst.data().get(), src.data().coord_, tuple_seq{}); -+ } -+}; -+ -+template -+struct Copy_Traits -+{ -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+ -+ // SM90_TMA_LOAD_MULTICAST arguments -+ TmaDescriptor tma_desc_; -+ GmemStrides g_stride_; -+ -+ // Return TmaDescriptor/TensorMap -+ CUTE_HOST_DEVICE constexpr -+ TmaDescriptor const* -+ get_tma_descriptor() const { -+ return &tma_desc_; -+ } -+ -+ // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar -+ CUTE_HOST_DEVICE constexpr -+ Copy_Traits -+ with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { -+ return {tma_desc_, tma_load_mbar, multicast_mask}; -+ } -+ -+ // Generate the TMA coord tensor -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_tma_tensor(GShape const& g_shape) const { -+ static_assert(is_congruent::value); -+ constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value; -+ return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat(Int<0>{}))), -+ g_shape, -+ g_stride_); -+ } -+ -+ // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with() -+ template -+ CUTE_HOST_DEVICE friend constexpr void -+ copy_unpack(Copy_Traits const& traits, -+ Tensor const& src, -+ Tensor & dst) = delete; -+}; -+ -+////////////////////////////////////////////////////////////////////////////// -+///////////////////////////// TMA_STORE ////////////////////////////////////// -+////////////////////////////////////////////////////////////////////////////// -+ -+// The executable SM90_TMA_STORE with tma_desc -+template -+struct Copy_Traits -+{ -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+ -+ // SM90_TMA_STORE arguments -+ TmaDescriptor tma_desc_; -+ GmemStrides g_stride_; -+ -+ // Generate the TMA coord tensor -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_tma_tensor(GShape const& g_shape) const { -+ static_assert(is_congruent::value); -+ constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value; -+ return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat(Int<0>{}))), -+ g_shape, -+ g_stride_); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ void -+ copy_unpack_(void const* const src_ptr, -+ Coord const& dst_coord, seq) const -+ { -+#if 0 -+ print("THR (%d,%d,%d) BLK (%d,%d,%d)\n", -+ threadIdx.x, threadIdx.y, threadIdx.z, -+ blockIdx.x, blockIdx.y, blockIdx.z); -+ print(" TMA Coord "); print(dst_coord); print("\n"); -+ print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_), -+ uint64_t(tma_desc_.size1_), -+ uint64_t(tma_desc_.size2_), -+ uint64_t(tma_desc_.size3_))); print("\n"); -+#endif -+ -+ SM90_TMA_STORE::copy(&tma_desc_, -+ src_ptr, -+ get(dst_coord)...); -+ } -+ -+ // This is the copy_unpack dispatch for this Copy_Traits -+ // Src needs to be a smem tensor -+ // Dst needs to be a gmem tensor with TmaCoordIterator .data() -+ template -+ CUTE_HOST_DEVICE friend constexpr -+ void -+ copy_unpack(Copy_Traits const& traits, -+ Tensor const& src, -+ Tensor & dst) -+ { -+ static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); -+ //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor -+ -+ traits.copy_unpack_(src.data().get(), dst.data().coord_, tuple_seq{}); -+ } -+}; -+ -+// -+// MAKE_TMA_COPY and related -+// -+ -+template -+TMA::SmemSwizzleBits -+get_tma_swizzle_bits(ComposedLayout,Offset,SLayout>) -+{ -+ static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle."); -+ static_assert(S == 3, "Unsupported layout swizzle"); -+ -+ switch (B) { -+ default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3. Unsupported layout swizzle."); -+ case 3: return TMA::SmemSwizzleBits::B128; -+ case 2: return TMA::SmemSwizzleBits::B64; -+ case 1: return TMA::SmemSwizzleBits::B32; -+ case 0: return TMA::SmemSwizzleBits::DISABLE; -+ } -+} -+ -+template -+TMA::SmemSwizzleBits -+get_tma_swizzle_bits(Layout) -+{ -+ return TMA::SmemSwizzleBits::DISABLE; -+} -+ -+template -+auto -+get_nonswizzle_layout(ComposedLayout,Offset,SLayout> const& slayout) -+{ -+ return slayout.layout_fn(); -+} -+ -+template -+auto -+get_nonswizzle_layout(Layout const& slayout) -+{ -+ return slayout; -+} -+ -+/** Make a CuTe CTA-collective TiledCopy for a TMA operation. -+ * -+ * @param CopyOp The target copy operation: SM90_TMA_LOAD, SM90_TMA_LOAD_MULTICAST, SM90_TMA_STORE -+ * @param gtensor The GMEM Tensor to be involved in the TMA. -+ * @param slayout The SMEM Layout to be involved in the TMA. -+ * @param cta_tile The CTA-local tile that each CTA will be tiling GMEM with. -+ * This is often the blk_shape that is used to tile the GMEM for CTAs: -+ * local_tile(gtensor, blk_shape, blk_coord) -> CTA-local tile of gtensor -+ * @param cluster_size When using SM90_TMA_LOAD_MULTICAST, this can be a (static) power-of-2 <= 16 -+ * defining the multicast size (used to further partition the SMEM) -+ * Else, static-1 -+ * -+ * This code attempts to maximize the TMA box size. It does this by tracing -+ * the SMEM "vector" -- the inverse of the smem layout -- to find the largest -+ * contiguous array of smem that can be written to/from global memory given -+ * the constraints that the TMA instruction imposes. -+ * -+ * This is accomplished by assigning "basis" strides to the GMEM to track which -+ * modes of SMEM map to which modes of GMEM, then reorder the modes of GMEM according -+ * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. -+ * -+ * Examples: -+ using T = float; -+ T* gptr = nullptr; -+ -+ { -+ // Simple 2D -+ Tensor gtensor = make_tensor(gptr, make_shape(1024, 256), GenRowMajor{}); // K-Major GMEM -+ auto slayout = make_layout(make_shape(_64{}, _32{}), GenRowMajor{}); // K-Major SMEM -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); -+ } -+ -+ { -+ // GMMA 2D -+ Tensor gtensor = make_tensor(gptr, make_shape(1024, 256)); // MN-Major GMEM -+ auto slayout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(_128{},_64{})); // MN-Major Swizzled+Tiled 128x64 SMEM -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); -+ } -+ -+ { -+ // 3D -+ Tensor gtensor = make_tensor(gptr, make_shape(1024, 32, 512), make_stride(64, Int<1>{}, 65536)); // GMEM -+ auto slayout = make_layout(make_shape(_16{}, _8{}, _2{}), make_stride(_16{}, _1{}, _8{})); // SMEM w/ same major-mode -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); -+ } -+ -+ { -+ // cuTENSOR 4D -+ auto layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM -+ auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: -+ // Take 128-elem from m: m0 must divide 128, -+ // m-last may be predicated -+ // Take 32-elem from k0, 2-elem from k1 -+ auto slayout = make_layout(cta_tile); // Col-Major SMEM -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout, cta_tile, Int<1>{}); -+ } -+ * -+ * Check the TMA box size and desc: -+ print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ print("TMA desc : "); print(tma.tma_desc_); print("\n"); -+ * -+ * Usage: -+ Tensor mA = tma_a.get_tma_tensor(make_shape(M,N)); // (M,N) TMA coord tensor -+ Tensor gA = local_tile(mA, cta_tile, cta_coord); // (BLK_M,BLK_N) TMA coord tensor for this CTA -+ Tensor sA = make_tensor(make_smem_ptr(sptr), slayout); // (BLK_M,BLK_N) SMEM tensor -+ -+ auto cta_tma = tma.get_slice(cta_idx_in_cluster); // Slice for multicast partitioning -+ Tensor tAgA = cta_tma.partition_S(gA); // Partition for src -+ Tensor tAsA = cta_tma.partition_D(sA); // Partition for dst -+ -+ copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params -+ */ -+template -+CUTE_HOST -+auto -+make_tma_copy(CopyOp, -+ Tensor const& gtensor, -+ SLayout const& slayout, -+ CTA_Tile const& cta_tile, -+ Cluster_Size const& cluster_size) -+{ -+ static_assert((std::is_same::value && is_constant<1, Cluster_Size>::value) || -+ (std::is_same::value) || -+ (std::is_same::value && is_constant<1, Cluster_Size>::value)); -+ -+ using T = typename Tensor::value_type; -+ -+ // -+ // TMA parameter checking -+ // -+ -+ auto flat_glayout = flatten(gtensor.layout()); -+ -+ CUTE_STATIC_ASSERT_V(rank(flatten(cta_tile)) <= Int<5>{}, -+ "CTA_Tile cannot have more than five modes, TMA arch restriction."); -+ CUTE_STATIC_ASSERT_V(rank(flat_glayout) <= Int<5>{} || rank(flatten(cta_tile)) <= Int<4>{}, -+ "If GTensor has more than five modes, then CTA_Tile cannot have more than four modes. TMA multimode."); -+ CUTE_STATIC_ASSERT_V(compatible(product_each(shape(slayout)), shape(cta_tile)), -+ "CTA_Tile must be compatible with SLayout."); -+ CUTE_STATIC_ASSERT_V(is_integral{} && has_single_bit(cluster_size) && cluster_size <= Int<16>{}, -+ "Expecting a pow2 integral Cluster_Size leq 16."); -+ CUTE_STATIC_ASSERT_V(size(slayout) % cluster_size == Int<0>{}, -+ "ClusterShape must divide domain size of slayout."); -+ -+ // -+ // TMA slayout manipulation -+ // -+ -+ auto tma_multimode = rank(flat_glayout) > Int<5>{}; -+ -+ // Invert the smem to get the largest contiguous vector in the smem layout -+ auto inv_smem_layout = right_inverse(get_nonswizzle_layout(slayout)); -+ // trunc_smem_idx -> trunc_smem_coord -+ -+ // Map from smem idx to a gmem mode -+ auto sidx_to_gmode = flatten(composition(make_identity_layout(cta_tile), inv_smem_layout)); -+ -+ // Truncate any incompatibilities -+ auto smem_rank = find_if(stride(sidx_to_gmode), [](auto e){ -+ [[maybe_unused]] auto v = basis_value(e); -+ return not is_constant<1,decltype(v)>{}; -+ }); -+ static_assert(smem_rank > 0, "Could not find a common smem-gmem vectorization for TMA."); -+ constexpr int smem_tma_rank = cute::min(int(smem_rank), (tma_multimode ? 4 : 5)); -+ -+ // Keep only the static-1 basis modes into gmem -+ auto sidx_to_gmode_cluster_trunc = take<0,smem_tma_rank>(sidx_to_gmode); -+ // Keep only the portion each multicast CTA will be responsible for -+ auto sidx_to_gmode_cta_trunc = composition(sidx_to_gmode_cluster_trunc, shape_div(size(sidx_to_gmode_cluster_trunc), cluster_size)); -+ -+ // -+ // TMA gtensor manipulation -+ // -+ -+ // Generate a TupleBasis for the gtensor -+ auto flat_gbasis = make_basis_like(shape(flat_glayout)); -+ -+ // Fold the flat_gbasis into the glayout -+ auto glayout_basis = make_layout(shape(gtensor), -+ stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), flat_gbasis), -+ make_layout(repeat_like(shape(gtensor), Int<2>{}))))); -+ -+ // Tile the modes of gtensor with cta_tile -+ auto cta_glayout_basis = composition(glayout_basis, cta_tile); -+ -+ // Check that the cta_tile selects modes from gtensor properly -+ for_each(flatten(stride(cta_glayout_basis)), [](auto d) { -+ static_assert(is_constant<1, decltype(d.value())>::value, -+ "CTA_Tile does not faithfully partition the GMEM, it should select the number of elements from each mode of glayout."); -+ }); -+ -+ // Tile the modes of gtensor again with the truncated cta_tile o inv_smem_layout -+ auto tma_layout_cta_trunc = flatten(composition(glayout_basis, sidx_to_gmode_cta_trunc)); -+ -+ // Append any missing basis on the end as size-1 modes b/c they got truncated -+ auto missing_basis = fold(stride(tma_layout_cta_trunc), flat_gbasis, [](auto init, auto e){ -+ auto k = find(init, e); -+ return remove(init); -+ }); -+ -+ // The appended map from truncated smem codomain to gmem mode: trunc_smem_idx -> gmem_mode -+ auto tma_layout_cta = flatten(make_layout(tma_layout_cta_trunc, -+ make_layout(repeat(Int<1>{}), missing_basis))); -+ -+#if 0 -+ print("g_layout : "); print(gtensor.layout()); print("\n"); -+ print("s_layout : "); print(slayout); print("\n"); -+ print("cta_tile : "); print(cta_tile); print("\n"); -+ print("cluster_size : "); print(cluster_size); print("\n"); -+ print("flat_gbasis : "); print(flat_gbasis); print("\n"); -+ print("cta_glayout : "); print(cta_glayout_basis); print("\n"); -+ print("inv_smem : "); print(inv_smem_layout); print("\n"); -+ print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); -+ print("missing_b : "); print(missing_basis); print("\n"); -+ print("tma_layout_cta: "); print(tma_layout_cta); print("\n"); -+#endif -+ -+ // -+ // TMA gmem desc info -+ // -+ -+ constexpr int TmaRANK = cute::min(rank(flat_glayout), 5); -+ void* gmem_address = (void*) gtensor.data(); -+ -+ cute::array gmem_prob_shape = {1,1,1,1,1}; -+ cute::array gmem_prob_stride = {0,0,0,0,0}; -+ for_each(make_seq{}, [&](auto i) { -+ // NOTE : WAR g++-7.3.5, let it deduce e rather than fuse with below -+ auto e = stride(tma_layout_cta); -+ constexpr int j = decltype(e.mode())::value; -+ constexpr int tma_i = i < 5 ? i : 4; -+ -+ // Problem stride -+ uint64_t stride_j = stride(flat_glayout) * sizeof(T); -+ uint64_t old_stride = gmem_prob_stride[tma_i]; -+ gmem_prob_stride[tma_i] = gcd(gmem_prob_stride[tma_i], stride_j); -+ -+ // Problem shape -+ uint64_t shape_j = shape(flat_glayout); -+ if (gmem_prob_stride[tma_i] != 0) { -+ // We're "resetting" this TMA mode and using it as a "multimode" -+ // Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1 -+ gmem_prob_shape[tma_i] = (gmem_prob_shape[tma_i]-1) * (old_stride / gmem_prob_stride[tma_i]) -+ + (shape_j-1) * (stride_j / gmem_prob_stride[tma_i]) -+ + 1; -+ } else { -+ gmem_prob_shape[tma_i] = shape_j; -+ } -+ }); -+ -+ assert((reinterpret_cast(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned -+ -+ assert(gmem_prob_shape[0] >= (uint64_t(1))); // Size must be min 1 -+ assert(gmem_prob_shape[0] <= (uint64_t(1) << 32)); // Size must be max 2^32 -+ assert(gmem_prob_shape[1] >= (uint64_t(1))); // Size must be min 1 -+ assert(gmem_prob_shape[1] <= (uint64_t(1) << 32)); // Size must be max 2^32 -+ assert(gmem_prob_shape[2] >= (uint64_t(1))); // Size must be min 1 -+ assert(gmem_prob_shape[2] <= (uint64_t(1) << 32)); // Size must be max 2^32 -+ assert(gmem_prob_shape[3] >= (uint64_t(1))); // Size must be min 1 -+ assert(gmem_prob_shape[3] <= (uint64_t(1) << 32)); // Size must be max 2^32 -+ assert(gmem_prob_shape[4] >= (uint64_t(1))); // Size must be min 1 -+ assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32 -+ -+ assert((gmem_prob_stride[0]) == sizeof(T)); // First stride is implicitly 1 -+ assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40 -+ assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b) -+ assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40 -+ assert((gmem_prob_stride[2] & 0b1111) == 0); // Stride must be multiple of 16B (128b) -+ assert((gmem_prob_stride[3]) < (uint64_t(1) << 40)); // Stride must be max 2^40 -+ assert((gmem_prob_stride[3] & 0b1111) == 0); // Stride must be multiple of 16B (128b) -+ assert((gmem_prob_stride[4]) < (uint64_t(1) << 40)); // Stride must be max 2^40 -+ assert((gmem_prob_stride[4] & 0b1111) == 0); // Stride must be multiple of 16B (128b) -+ -+ // -+ // TMA smem desc info -+ // -+ -+ // TMA smem box size -+ cute::array smem_box_shape = {1,1,1,1,1}; -+ for_each(make_seq{}, [&](auto i) { -+ uint32_t shape_i = shape(tma_layout_cta); -+ constexpr int tma_i = i < 5 ? i : 4; -+ if (tma_multimode && tma_i == 4) { -+ // We're "reusing" this TMA mode and using it as a "multimode" -+ smem_box_shape[tma_i] = 1; -+ } else { -+ smem_box_shape[tma_i] = shape_i; -+ } -+ }); -+ -+ // TMA smem mode strides -+ [[maybe_unused]] cute::array smem_box_stride = {1,1,1,1,1}; -+ -+ assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 -+ assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 -+ assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 -+ assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 -+ assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 -+ assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 -+ assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 -+ assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 -+ -+ assert(smem_box_stride[0] >= (uint32_t(1))); // Stride must be min 1 -+ assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3 -+ assert(smem_box_stride[1] >= (uint32_t(1))); // Stride must be min 1 -+ assert(smem_box_stride[1] <= (uint32_t(8))); // Stride must be max 2^3 -+ assert(smem_box_stride[2] >= (uint32_t(1))); // Stride must be min 1 -+ assert(smem_box_stride[2] <= (uint32_t(8))); // Stride must be max 2^3 -+ assert(smem_box_stride[3] >= (uint32_t(1))); // Stride must be min 1 -+ assert(smem_box_stride[3] <= (uint32_t(8))); // Stride must be max 2^3 -+ assert(smem_box_stride[4] >= (uint32_t(1))); // Stride must be min 1 -+ assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3 -+ -+ // -+ // Construct the descriptor -+ // -+ -+ TmaDescriptor tma_desc = {0}; -+ -+#if (__CUDACC_VER_MAJOR__ >= 12) -+ -+ // -+ // TMA general info -+ // -+ -+ cuuint32_t tma_dim = TmaRANK; -+ CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); -+ CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; -+ CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; -+ CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; -+ -+ // TMA smem swizzle type -+ CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(slayout)); -+ -+ CUresult result = cuTensorMapEncodeTiled( -+ &tma_desc, -+ tma_format, -+ tma_dim, -+ gmem_address, -+ gmem_prob_shape.data(), -+ gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly 1 -+ smem_box_shape.data(), -+ smem_box_stride.data(), -+ tma_interleave, -+ smem_swizzle, -+ tma_l2Promotion, -+ tma_oobFill); -+ -+ if (result != CUDA_SUCCESS) { -+ std::cerr << "TMA Desc Addr: " << &tma_desc -+ << "\nformat " << tma_format -+ << "\ndim " << tma_dim -+ << "\ngmem_address " << gmem_address -+ << "\nglobalDim " << gmem_prob_shape -+ << "\nglobalStrides " << gmem_prob_stride -+ << "\nboxDim " << smem_box_shape -+ << "\nelementStrides " << smem_box_stride -+ << "\ninterleave " << tma_interleave -+ << "\nswizzle " << smem_swizzle -+ << "\nl2Promotion " << tma_l2Promotion -+ << "\noobFill " << tma_oobFill << std::endl; -+ std::cerr << "Error: Failed to intialize the TMA descriptor " << result << std::endl; -+ assert(false); -+ } -+#endif // (__CUDACC_VER_MAJOR__ >= 12) -+ -+ // -+ // Construct the Copy_Traits -+ // -+ -+ // Finally, get the inverse permutation of the E bases for the mocked gmem stride -+ auto gmem_stride_bases_flat = transform(make_seq{}, [&](auto i) { -+ auto k = find(stride(tma_layout_cta), E{}); -+ // NOTE: gcc 7.3.5 WAR -- avoid if constexpr -+ int32_t tma_coord_stride = int32_t(stride(flat_glayout) * sizeof(T) / (gmem_prob_stride[4] != 0 ? gmem_prob_stride[4] : 16)); -+ return conditional_return(tma_multimode && (k >= Int<4>{}), -+ E<4>{} * tma_coord_stride, // The 4th TMA mode is the multimode, use int32_t coord stride -+ E{}); -+ }); -+ -+ // Give that the profile of gtensor and fold it -+ auto gmem_stride_bases = stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), gmem_stride_bases_flat), -+ make_layout(repeat_like(shape(gtensor), Int<2>{})))); -+ -+ constexpr int num_bits = size(sidx_to_gmode_cta_trunc) * sizeof(T) * 8; -+ using Traits = Copy_Traits, decltype(gmem_stride_bases)>; -+ -+#if 0 -+ print("num_bits : "); print(num_bits); print("\n"); -+ print("g_stride_bases: "); print(gmem_stride_bases); print("\n"); -+#endif -+ -+ // -+ // Construct the TiledCopy -+ // -+ -+ // The ThrVal layout for 1 TMA instruction within cta_tile -+ auto layout_tv_1 = composition(inv_smem_layout, make_layout(make_shape(cluster_size, size(sidx_to_gmode_cta_trunc)), GenRowMajor{})); -+ // The ThrVal layout for N TMA instructions within cta_tile -+ auto layout_tv = tile_to_shape(layout_tv_1, make_shape(cluster_size, size(cta_tile)/cluster_size)); -+ -+#if 0 -+ print("layout_tv : "); print(layout_tv); print("\n"); -+#endif -+ -+ return TiledCopy, decltype(layout_tv), decltype(cta_tile)>{tma_desc, gmem_stride_bases}; -+} -+ -+// Explicit defaulting -+template -+CUTE_HOST -+auto -+make_tma_copy(CopyOp const& copy_op, -+ Tensor const& gtensor, -+ SLayout const& slayout) -+{ -+ return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), Int<1>{}); -+} -+ -+template -+CUTE_HOST -+auto -+make_tma_copy(CopyOp const& copy_op, -+ Tensor const& gtensor, -+ SLayout const& slayout, -+ Cluster_Size const& cluster_size) -+{ -+ return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_atom.hpp b/3rdparty/cutlass/include/cute/atom/mma_atom.hpp -new file mode 100644 -index 0000000..c3025f5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_atom.hpp -@@ -0,0 +1,1081 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#include -+#include -+#include -+#include -+ -+namespace cute { -+ -+// Generic mma_unpack for any MMA_Traits -+template -+CUTE_HOST_DEVICE constexpr -+void -+mma_unpack(MMA_Traits const&, -+ Tensor & D, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor const& C) -+{ -+ static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); -+ static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); -+ static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); -+ static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); -+ -+ // Register value types from the MMA_Operation register arrays -+ using RegTypeD = typename std::remove_extent::type; -+ using RegTypeA = typename std::remove_extent::type; -+ using RegTypeB = typename std::remove_extent::type; -+ using RegTypeC = typename std::remove_extent::type; -+ constexpr int RegNumD = std::extent::value; -+ constexpr int RegNumA = std::extent::value; -+ constexpr int RegNumB = std::extent::value; -+ constexpr int RegNumC = std::extent::value; -+ -+ Tensor rA = recast(A); -+ Tensor rB = recast(B); -+ -+ CUTE_STATIC_ASSERT_V(size(rA) == Int{}); -+ CUTE_STATIC_ASSERT_V(size(rB) == Int{}); -+ -+ if constexpr (std::is_same::value) -+ { -+ static_assert(std::is_same::value, "GMMA C and D value_type must match."); -+ static_assert(std::is_same::value, "GMMA C and D layouts must match."); -+ // assert((void*)&C == (void*)&D); -+ -+ Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D -+ -+ //CUTE_STATIC_ASSERT_V(size(rC) == Int{}); -+ -+ detail::explode(Operation::fma, -+ rA, make_int_sequence{}, -+ rB, make_int_sequence{}, -+ rC, make_int_sequence{}); -+ } else -+ { -+ Tensor rD = recast(D); -+ Tensor rC = recast(C); -+ -+ CUTE_STATIC_ASSERT_V(size(rD) == Int{}); -+ CUTE_STATIC_ASSERT_V(size(rC) == Int{}); -+ -+ detail::explode(Operation::fma, -+ rD, make_int_sequence{}, -+ rA, make_int_sequence{}, -+ rB, make_int_sequence{}, -+ rC, make_int_sequence{}); -+ } -+} -+ -+ -+namespace detail { -+ -+template -+struct FrgTypeA_or_Default { using type = typename X::ElementAVal; }; -+template -+struct FrgTypeA_or_Default> { using type = typename X::ElementAFrg; }; -+ -+template -+struct FrgTypeB_or_Default { using type = typename X::ElementBVal; }; -+template -+struct FrgTypeB_or_Default> { using type = typename X::ElementBFrg; }; -+ -+template -+struct FrgTypeC_or_Default { using type = typename X::ElementCVal; }; -+template -+struct FrgTypeC_or_Default> { using type = typename X::ElementCFrg; }; -+ -+} // end namespace detail -+ -+template -+struct MMA_Atom; -+ -+template -+struct MMA_Atom : MMA_Atom> -+{}; -+ -+template -+struct MMA_Atom> -+ : MMA_Traits -+{ -+ using Traits = MMA_Traits; -+ -+ // Element value types from the MMA_Traits -+ using ValTypeD = typename Traits::ElementDVal; -+ using ValTypeA = typename Traits::ElementAVal; -+ using ValTypeB = typename Traits::ElementBVal; -+ using ValTypeC = typename Traits::ElementCVal; -+ -+ // Thr-Val layouts from the MMA_Traits -+ using Shape_MNK = typename Traits::Shape_MNK; -+ using ThrID = typename Traits::ThrID; -+ using LayoutC_TV = typename Traits::CLayout; -+ using LayoutA_TV = typename Traits::ALayout; -+ using LayoutB_TV = typename Traits::BLayout; -+ -+ // Fragment value types from the MMA_Traits (optional, defaults to Val type) -+ using FrgTypeD = typename detail::FrgTypeC_or_Default::type; -+ using FrgTypeA = typename detail::FrgTypeA_or_Default::type; -+ using FrgTypeB = typename detail::FrgTypeB_or_Default::type; -+ using FrgTypeC = typename detail::FrgTypeC_or_Default::type; -+ -+ // Additional Trait parameters/transformations -+ template -+ CUTE_HOST_DEVICE -+ auto -+ with(TraitsArgs&&... args) const { -+ auto traits = Traits::with(std::forward(args)...); -+ return MMA_Atom{traits}; -+ } -+ -+ // Print thread and data layouts for debugging -+ CUTE_HOST_DEVICE static -+ void -+ print_all() -+ { -+ print("ThrID: "); print(ThrID{}); print("\n"); -+ print("LayoutA_TV: "); print(LayoutA_TV{}); print("\n"); -+ print("LayoutB_TV: "); print(LayoutB_TV{}); print("\n"); -+ print("LayoutC_TV: "); print(LayoutC_TV{}); print("\n"); -+ } -+ -+ // -+ // Tensor call interfaces -+ // -+ -+ // Cast, check, and call fma -+ template -+ CUTE_HOST_DEVICE constexpr -+ void -+ call(Tensor & D, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor const& C) const -+ { -+ static_assert(DLayout::rank == 1, "Expected rank-1 D tensor"); -+ static_assert(ALayout::rank == 1, "Expected rank-1 A tensor"); -+ static_assert(BLayout::rank == 1, "Expected rank-1 B tensor"); -+ static_assert(CLayout::rank == 1, "Expected rank-1 C tensor"); -+ -+ return mma_unpack(*this, D, A, B, C); -+ } -+ -+ // Three arguments reproduces C -+ template -+ CUTE_HOST_DEVICE constexpr -+ void -+ call(Tensor const& A, -+ Tensor const& B, -+ Tensor & C) const -+ { -+ return call(C, A, B, C); -+ } -+ -+ // -+ // make_fragment_A|B|C -+ // These functions are awkward as they expect already-partitioned tensors -+ // resulting from a previous call to partition_A|B|C -+ // The reasoning is that we can inspect the layout of the partitioned data -+ // and attempt to match it in generated fragment to promote vectorization -+ // when copying from partition to fragment. -+ // -+ -+ template -+ CUTE_HOST_DEVICE static constexpr -+ auto -+ make_fragment_C(CTensor&& ctensor) -+ { -+ // Check that this tensor is likely already partitioned -+ CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<3>{}); // VMN -+ CUTE_STATIC_ASSERT_V(size<0>(ctensor) == size<1>(LayoutC_TV{})); -+ -+ // C is a bit special because we are after accumulators here -+ // The input/output type doesn't have to match the accumulator type -+ //static_assert(std::is_same::value_type>::value, "Expecting ValTypeC type"); -+ -+ // We'll never base the accumulator layout on the input tensor layout, so just return a FrgTypeC tensor -+ return make_tensor(shape(ctensor)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE static constexpr -+ auto -+ make_fragment_A(ATensor&& atensor) -+ { -+ // Check that this tensor is likely already partitioned -+ CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<3>{}); // VMK -+ CUTE_STATIC_ASSERT_V(size<0>(atensor) == size<1>(LayoutA_TV{})); -+ static_assert(std::is_same::value_type>::value, "Expecting ValTypeA type"); -+ -+ if constexpr (has_dereference::value) { -+ return recast(std::forward(atensor)); -+ } else { -+ return make_tensor(make_fragment_like(atensor.layout())); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ } -+ -+ template -+ CUTE_HOST_DEVICE static constexpr -+ auto -+ make_fragment_B(BTensor&& btensor) -+ { -+ // Check that this tensor is likely already partitioned -+ CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<3>{}); // VNK -+ CUTE_STATIC_ASSERT_V(size<0>(btensor) == size<1>(LayoutB_TV{})); -+ static_assert(std::is_same::value_type>::value, "Expecting ValTypeB type"); -+ -+ if constexpr (has_dereference::value) { -+ return recast(std::forward(btensor)); -+ } else { -+ return make_tensor(make_fragment_like(btensor.layout())); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ } -+}; -+ -+// -+// A tiling of mma atoms -+// -+ -+template -+struct ThrMMA; -+ -+template >, -+ class ValLayoutMNK = Layout>, -+ class PermutationsMNK = Tile> -+struct TiledMMA : MMA_Atom -+{ -+ static_assert(rank_v == 3, "TiledMMA requires rank-3 AtomLayoutMNK"); -+ static_assert(rank_v == 3, "TiledMMA requires rank-3 ValLayoutMNK"); -+ static_assert(rank_v == 3, "TiledMMA requires rank-3 PermutationsMNK"); -+ -+ using AtomShape_MNK = typename MMA_Atom::Shape_MNK; -+ -+ using AtomLayoutC_TV = typename MMA_Atom::LayoutC_TV; -+ using AtomLayoutA_TV = typename MMA_Atom::LayoutA_TV; -+ using AtomLayoutB_TV = typename MMA_Atom::LayoutB_TV; -+ -+ // ThrV -> thread_idx -+ using AtomThrID = typename MMA_Atom::ThrID; -+ -+ // (M,N,K) -+ using TiledShape_MNK = decltype(make_shape(size<0>(AtomShape_MNK{})*size<0>(AtomLayoutMNK{})*size<0>(ValLayoutMNK{}), -+ size<1>(AtomShape_MNK{})*size<1>(AtomLayoutMNK{})*size<1>(ValLayoutMNK{}), -+ size<2>(AtomShape_MNK{})*size<2>(AtomLayoutMNK{})*size<2>(ValLayoutMNK{}))); -+ -+ // thrid = (ThrV,ThrM,ThrN,ThrK) -> thr_idx -+ using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{})); -+ -+ // thr_idx -> (ThrV,ThrM,ThrN,ThrK) -+ using TidLayout = decltype(right_inverse(ThrLayoutVMNK{})); -+ -+ // Tile a tensor or a layout from shape -+ // (M,N,...) -+ // to shape -+ // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN,...))) -+ // where -+ // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx -+ // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx -+ // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx -+ // FrgV: The values local to an MMA. -+ // RestM: The values tiled in M. -+ // RestN: The values tiled in N. -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ thrfrg_C(CTensor&& ctensor) -+ { -+ CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<2>{}); -+ CUTE_STATIC_ASSERT_V(size<0>(ctensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(ctensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); -+ -+ // Reorder the tensor for the TiledAtom -+ auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})), -+ left_inverse(get<1>(PermutationsMNK{}))); -+ auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN) -+ -+ // Tile the tensor for the Atom -+ auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), -+ make_layout(size<1>(AtomShape_MNK{}))); -+ auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomN),(RestM,RestN)) -+ -+ // Transform the Atom mode from (M,K) to (Thr,Val) -+ auto tv_tensor = a_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) -+ -+ // Tile the tensor for the C-threads -+ auto thr_tile = make_tile(_, -+ make_tile(make_layout(size<1>(ThrLayoutVMNK{})), -+ make_layout(size<2>(ThrLayoutVMNK{})))); -+ auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN))) -+ -+ return thr_tensor; -+ } -+ -+ // Tile from (M,N,...) -+ // to (thr_idx,(FrgV,(RestM,RestN,...))) -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ tidfrg_C(CTensor&& ctensor) -+ { -+ // Don't need a ctile composition because ThrK is last mode in TidLayout -+ -+ return thrfrg_C(ctensor).compose(TidLayout{}, _); -+ } -+ -+ // Tile a tensor or a layout from shape -+ // (M,K,...) -+ // to shape -+ // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK,...))) -+ // where -+ // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx -+ // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx -+ // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx -+ // FrgV: The values local to an MMA. -+ // RestM: The values tiled in M. -+ // RestK: The values tiled in K. -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ thrfrg_A(ATensor&& atensor) -+ { -+ CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<2>{}); -+ CUTE_STATIC_ASSERT_V(size<0>(atensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); -+ -+ // Reorder the tensor for the TiledAtom -+ auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})), -+ left_inverse(get<2>(PermutationsMNK{}))); -+ auto t_tensor = logical_divide(atensor, t_tile); // (PermM,PermK) -+ -+ // Tile the tensor for the Atom -+ auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), -+ make_layout(size<2>(AtomShape_MNK{}))); -+ auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) -+ -+ // Transform the Atom mode from (M,K) to (Thr,Val) -+ auto tv_tensor = a_tensor.compose(AtomLayoutA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) -+ -+ // Tile the tensor for the Thread -+ auto thr_tile = make_tile(_, -+ make_tile(make_layout(size<1>(ThrLayoutVMNK{})), -+ make_layout(size<3>(ThrLayoutVMNK{})))); -+ auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) -+ -+ return thr_tensor; -+ } -+ -+ // Tile from (M,K,...) -+ // to (thr_idx,(FrgV,(RestM,RestK,...))) -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ tidfrg_A(ATensor&& atensor) -+ { -+ auto atile = make_tile(_, -+ make_tile(make_layout(make_shape (size<1>(ThrLayoutVMNK{}), size<2>(ThrLayoutVMNK{})), -+ make_stride( Int<1>{} , Int<0>{} )), -+ _)); -+ // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) -+ -+ return thrfrg_A(atensor).compose(atile, _).compose(TidLayout{}, _); -+ } -+ -+ // Tile a tensor or a layout from shape -+ // (N,K,...) -+ // to shape -+ // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -+ // where -+ // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx -+ // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx -+ // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx -+ // FrgV: The values local to an MMA. -+ // RestN: The values tiled in N. -+ // RestK: The values tiled in K. -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ thrfrg_B(BTensor&& btensor) -+ { -+ CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<2>{}); -+ CUTE_STATIC_ASSERT_V(size<0>(btensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(btensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); -+ -+ // Reorder the tensor for the TiledAtom -+ auto t_tile = make_tile(left_inverse(get<1>(PermutationsMNK{})), -+ left_inverse(get<2>(PermutationsMNK{}))); -+ auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK) -+ -+ // Tile the tensor for the Atom -+ auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), -+ make_layout(size<2>(AtomShape_MNK{}))); -+ auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) -+ -+ // Transform the Atom mode from (M,K) to (Thr,Val) -+ auto tv_tensor = a_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) -+ -+ // Tile the tensor for the Thread -+ auto thr_tile = make_tile(_, -+ make_tile(make_layout(size<2>(ThrLayoutVMNK{})), -+ make_layout(size<3>(ThrLayoutVMNK{})))); -+ auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) -+ -+ return thr_tensor; -+ } -+ -+ // Tile from (N,K,...) -+ // to (thr_idx,(FrgV,(RestN,RestK,...))) -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ tidfrg_B(BTensor&& btensor) -+ { -+ auto btile = make_tile(_, -+ make_tile(make_layout(make_shape (size<1>(ThrLayoutVMNK{}), size<2>(ThrLayoutVMNK{})), -+ make_stride( Int<0>{} , Int<1>{} )), -+ _)); -+ // (ThrV,(ThrN,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) -+ -+ return thrfrg_B(btensor).compose(btile, _).compose(TidLayout{}, _); -+ } -+ -+ template ::value)> -+ CUTE_HOST_DEVICE static constexpr -+ auto -+ get_slice(ThrIdx const& thr_idx) -+ { -+ auto thr_vmnk = ThrLayoutVMNK{}.get_flat_coord(thr_idx); -+ return ThrMMA(thr_vmnk); -+ } -+ -+ template ::value)> -+ CUTE_HOST_DEVICE static constexpr -+ auto -+ get_thread_slice(ThrIdx const& thr_idx) -+ { -+ return get_slice(thr_idx); -+ } -+ -+ // -+ // Utility for printing and visualization -+ // -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutC_MN() -+ { -+ // (M,N) -> (M,N) -+ auto ref_C = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<1>(TiledShape_MNK{}))); -+ // (cthrid,val) -> (M,N) -+ auto layoutC_TV = thrfrg_C(ref_C); -+ // (M,N) -> (cthrid,frg) -+ auto layoutC_MN = right_inverse(layoutC_TV).with_shape(shape(ref_C)); -+ -+ // cthrid = (v,m,n) -> thr_idx -+ auto thrID_C = ThrLayoutVMNK{}(_,_,_,Int<0>{}); -+ -+ return cute::make_tuple(layoutC_MN, thrID_C); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutC_TV() -+ { -+ // (M,N) -> (M,N) -+ auto ref_C = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<1>(TiledShape_MNK{}))); -+ -+ return tidfrg_C(ref_C); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutA_MK() -+ { -+ // (M,K) -> (M,K) -+ auto ref_A = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); -+ // (athrid,val) -> (M,K) -+ auto layoutA_TV = thrfrg_A(ref_A); -+ // (M,K) -> (athrid,frg) -+ auto layoutA_MK = right_inverse(layoutA_TV).with_shape(shape(ref_A)); -+ -+ // athrid = (v,m,k) -> thr_idx -+ auto thrID_A = ThrLayoutVMNK{}(_,_,Int<0>{},_); -+ -+ return cute::make_tuple(layoutA_MK, thrID_A); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutA_TV() -+ { -+ // (M,K) -> (M,K) -+ auto ref_A = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); -+ -+ return tidfrg_A(ref_A); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutB_NK() -+ { -+ // (N,K) -> (N,K) -+ auto ref_B = make_layout(make_shape(size<1>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); -+ // (bthrid,val) -> (N,K) -+ auto layoutB_TV = thrfrg_B(ref_B); -+ // (N,K) -> (bthrid,frg) -+ auto layoutB_NK = right_inverse(layoutB_TV).with_shape(shape(ref_B)); -+ -+ // bthrid = (v,n,k) -> thr_idx -+ auto thrID_B = ThrLayoutVMNK{}(_,Int<0>{},_,_); -+ -+ return cute::make_tuple(layoutB_NK, thrID_B); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutB_TV() -+ { -+ // (N,K) -> (N,K) -+ auto ref_B = make_layout(make_shape(size<1>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); -+ -+ return tidfrg_B(ref_B); -+ } -+}; -+ -+template -+struct ThrMMA : TiledMMA -+{ -+ // Use ThrVMNK and thrfrg rather than thr_idx and tidfrg -+ // to support swizzled threads partitioning dynamic layouts -+ ThrVMNK thr_vmnk_; -+ -+ CUTE_HOST_DEVICE constexpr -+ ThrMMA(ThrVMNK const& thr_vmnk) : thr_vmnk_(thr_vmnk) {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ partition_C(CTensor&& ctensor) const -+ { -+ auto thr_tensor = make_tensor(std::forward(ctensor).data(), thrfrg_C(ctensor.layout())); -+ -+ auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_))); -+ return thr_tensor(thr_vmn, make_coord(_, repeat(thr_tensor)>(_))); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ partition_A(ATensor&& atensor) const -+ { -+ auto thr_tensor = make_tensor(std::forward(atensor).data(), thrfrg_A(atensor.layout())); -+ -+ auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_))); -+ return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ partition_B(BTensor&& btensor) const -+ { -+ auto thr_tensor = make_tensor(std::forward(btensor).data(), thrfrg_B(btensor.layout())); -+ -+ auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_))); -+ return thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ partition_fragment_C(CTensor&& ctensor) const -+ { -+ return make_fragment_C(partition_C(ctensor)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ partition_fragment_A(ATensor&& atensor) const -+ { -+ return make_fragment_A(partition_A(atensor)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ partition_fragment_B(BTensor&& btensor) const -+ { -+ return make_fragment_B(partition_B(btensor)); -+ } -+}; -+ -+// -+// These tile the MMA_Atom as a whole -+// -+ -+template >, -+ class MMAValLayout = Layout>, -+ class Permutations = Tile> -+CUTE_HOST_DEVICE constexpr -+auto -+make_tiled_mma(MMA_Atom const&, -+ MMAThrLayout const& thr_layout = {}, -+ MMAValLayout const& val_layout = {}, -+ Permutations const& permutations = {}) -+{ -+ auto thr_layout_mnk = append<3>(thr_layout, Layout<_1>{}); -+ auto val_layout_mnk = append<3>(val_layout, Layout<_1>{}); -+ auto permutation_mnk = append<3>(permutations, _); -+ -+ return TiledMMA, -+ decltype(thr_layout_mnk), -+ decltype(val_layout_mnk), -+ decltype(permutation_mnk)>{}; -+} -+ -+template >, -+ class MMAValLayout = Layout>, -+ class Permutations = Tile> -+CUTE_HOST_DEVICE constexpr -+auto -+make_tiled_mma(MMA_Op const&, -+ MMAThrLayout const& thr_layout = {}, -+ MMAValLayout const& val_layout = {}, -+ Permutations const& permutations = {}) -+{ -+ // Attempt to wrap in an MMA_Atom<> and forward -+ return make_tiled_mma(MMA_Atom{}, thr_layout, val_layout, permutations); -+} -+ -+// -+// partition_fragment_C -- static context -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+partition_fragment_C(TiledMMA, Shape_MN shapeMN) -+{ -+ constexpr int R = rank_v; -+ static_assert(R >= 2, "Must have at least rank-2"); -+ auto atomMNK = typename TiledMMA::AtomShape_MNK{}; -+ auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; -+ -+ auto V = size<1>(typename TiledMMA::AtomLayoutC_TV{}); -+ auto M = shape_div(size<0>(shapeMN), size<0>(atomMNK) * size<1>(thrVMNK)); -+ auto N = shape_div(size<1>(shapeMN), size<1>(atomMNK) * size<2>(thrVMNK)); -+ auto frg_shape = tuple_cat(make_shape(V,M,N), take<2,R>(shapeMN)); -+ -+ return make_tensor::FrgTypeC>(frg_shape); -+} -+ -+// partition_fragment_A and partition_fragment_B often depend on the -+// layout of A and B and/or the thread_idx that is requesting the partition. -+// For these reasons, they should not be used in a static context. -+// See TiledMMA::get_slice(thr_idx).partition_fragment_A(tensorA) instead. -+ -+// -+// Size -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tile_size(TiledMMA const& mma) -+{ -+ return size(typename TiledMMA::TiledShape_MNK{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+size(TiledMMA const& mma) -+{ -+ return size(typename TiledMMA::ThrLayoutVMNK{}); -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE -+auto -+print_latex(TiledMMA const& mma) -+{ -+ auto layout_and_thrid_C = mma.get_layoutC_MN(); -+ auto layoutC_MN = get<0>(layout_and_thrid_C); -+ auto thrID_C = get<1>(layout_and_thrid_C); -+ -+ auto layout_and_thrid_A = mma.get_layoutA_MK(); -+ auto layoutA_MK = get<0>(layout_and_thrid_A); -+ auto thrID_A = get<1>(layout_and_thrid_A); -+ -+ auto layout_and_thrid_B = mma.get_layoutB_NK(); -+ auto layoutB_NK = get<0>(layout_and_thrid_B); -+ auto thrID_B = get<1>(layout_and_thrid_B); -+ -+ print_latex_mma(layoutC_MN, thrID_C, -+ layoutA_MK, thrID_A, -+ layoutB_NK, thrID_B); -+} -+ -+// EXPERIMENTAL -- Doesn't work with Swizzled Thr TileMMAs... -+template -+CUTE_HOST_DEVICE -+auto -+print_latex_2(TiledMMA const& mma) -+{ -+ print_latex_mma(typename TiledMMA::TiledShape_MNK{}, -+ mma.get_layoutC_TV(), -+ mma.get_layoutA_TV(), -+ mma.get_layoutB_TV()); -+} -+ -+// MNK MMA Layout to console printer -- 8-value color coded by thread -+template -+CUTE_HOST_DEVICE -+void -+print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx -+ LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx -+ LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); -+ -+ assert(size<0>(A) == size<0>(C)); -+ assert(size<0>(B) == size<1>(C)); -+ assert(size<1>(A) == size<1>(B)); -+ -+ int a_width = size<1>(A) * 6 + 4; -+ -+ // Print out B (white-shifted) k-by-n -+ for (int k = 0; k < size<1>(B); ++k) { -+ // Header -+ printf("%*s", a_width, ""); -+ for (int n = 0; n < size<0>(B); ++n) printf("+-----"); -+ printf("+\n"); -+ // Values -+ printf("%*s", a_width, ""); -+ for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB))); -+ printf("|\n"); -+ } -+ // Footer -+ printf("%*s", a_width, ""); -+ for (int n = 0; n < size<0>(B); ++n) printf("+-----"); -+ printf("+\n\n"); -+ -+ // Print out A m-by-k and C m-by-n -+ for (int m = 0; m < size<0>(A); ++m) { -+ // Header -+ for (int k = 0; k < size<1>(A); ++k) printf("+-----"); -+ printf("+ "); -+ for (int n = 0; n < size<1>(C); ++n) printf("+-----"); -+ printf("+\n"); -+ // Values -+ for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA))); -+ printf("| "); -+ for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC))); -+ printf("|\n"); -+ } -+ // Footer -+ for (int k = 0; k < size<1>(A); ++k) printf("+-----"); -+ printf("+ "); -+ for (int n = 0; n < size<1>(C); ++n) printf("+-----"); -+ printf("+\n"); -+} -+ -+// MNK MMA Layout to Latex TIKZ -- 8-value color coded by thread -+template -+CUTE_HOST_DEVICE -+void -+print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx -+ LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx -+ LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); -+ -+ assert(size<0>(A) == size<0>(C)); -+ assert(size<0>(B) == size<1>(C)); -+ assert(size<1>(A) == size<1>(B)); -+ -+ char const* latex_header = -+ "\\documentclass{standalone}\n" -+ "\\usepackage{tikz}\n" -+ "\\usetikzlibrary{external}\n" -+ "\\tikzexternalize\n" -+ "\\begin{document}\n" -+ "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; -+ char const* latex_footer = -+ "\\end{tikzpicture}\n" -+ "\\end{document}\n"; -+ -+ char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", -+ "{rgb,255:red,175;green,255;blue,175}", -+ "{rgb,255:red,255;green,255;blue,175}", -+ "{rgb,255:red,255;green,175;blue,175}", -+ "{rgb,255:red,210;green,210;blue,255}", -+ "{rgb,255:red,210;green,255;blue,210}", -+ "{rgb,255:red,255;green,255;blue,210}", -+ "{rgb,255:red,255;green,210;blue,210}"}; -+ -+ // Header -+ printf("%% LayoutC: "); print(C); printf("\n"); -+ printf("%% ThrIDC : "); print(TC); printf("\n"); -+ printf("%% LayoutA: "); print(A); printf("\n"); -+ printf("%% ThrIDA : "); print(TA); printf("\n"); -+ printf("%% LayoutB: "); print(B); printf("\n"); -+ printf("%% ThrIDB : "); print(TB); printf("\n\n"); -+ -+ printf(latex_header); -+ -+ // C starting at 0,0 -+ for (int m = 0; m < size<0>(C); ++m) { -+ for (int n = 0; n < size<1>(C); ++n) { -+ int thrid = C(m,n) % size(TC); -+ int val_idx = C(m,n) / size(TC); -+ int thr_idx = TC(thrid); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[thr_idx % 8], -+ m, n, -+ thr_idx, val_idx); -+ } -+ } -+ -+ // A starting at 0,-size<1>(A)-1 -+ for (int m = 0; m < size<0>(A); ++m) { -+ for (int k = 0; k < size<1>(A); ++k) { -+ int thrid = A(m,k) % size(TA); -+ int val_idx = A(m,k) / size(TA); -+ int thr_idx = TA(thrid); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[thr_idx % 8], -+ m, k-1-size<1>(A), -+ thr_idx, val_idx); -+ } -+ } -+ -+ // B starting at -size<1>(B)-1,0 -+ for (int n = 0; n < size<0>(B); ++n) { -+ for (int k = 0; k < size<1>(B); ++k) { -+ int thrid = B(n,k) % size(TB); -+ int val_idx = B(n,k) / size(TB); -+ int thr_idx = TB(thrid); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[thr_idx % 8], -+ k-1-size<1>(B), n, -+ thr_idx, val_idx); -+ } -+ } -+ -+ // A labels -+ for (int m = 0, k = -1; m < size<0>(A); ++m) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m); -+ } -+ for (int k = 0, m = -1; k < size<1>(A); ++k) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k); -+ } -+ // B labels -+ for (int n = 0, k = -1; n < size<0>(B); ++n) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, n); -+ } -+ for (int k = 0, n = -1; k < size<1>(B); ++k) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, k); -+ } -+ -+ // Footer -+ printf(latex_footer); -+} -+ -+// ThrVal MMA Layout to Latex TIKZ -- 8-value color coded by thread -+template -+CUTE_HOST_DEVICE -+void -+print_latex_mma(Shape_MNK const& shape_mnk, -+ LayoutC const& C, // (thr_idx,vid) -> (m,n) -+ LayoutA const& A, // (thr_idx,vid) -> (m,k) -+ LayoutB const& B) // (thr_idx,vid) -> (n,k) -+{ -+ CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); -+ -+ char const* latex_header = -+ "\\documentclass{standalone}\n" -+ "\\usepackage{tikz}\n" -+ "\\usetikzlibrary{external}\n" -+ "\\tikzexternalize\n" -+ "\\begin{document}\n" -+ "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; -+ char const* latex_footer = -+ "\\end{tikzpicture}\n" -+ "\\end{document}\n"; -+ -+ char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", -+ "{rgb,255:red,175;green,255;blue,175}", -+ "{rgb,255:red,255;green,255;blue,175}", -+ "{rgb,255:red,255;green,175;blue,175}", -+ "{rgb,255:red,210;green,210;blue,255}", -+ "{rgb,255:red,210;green,255;blue,210}", -+ "{rgb,255:red,255;green,255;blue,210}", -+ "{rgb,255:red,255;green,210;blue,210}"}; -+ -+ // Header -+ printf("%% Shape_MNK: "); print(shape_mnk); printf("\n"); -+ printf("%% LayoutC : "); print(C); printf("\n"); -+ printf("%% LayoutA : "); print(A); printf("\n"); -+ printf("%% LayoutB : "); print(B); printf("\n\n"); -+ -+ printf(latex_header); -+ -+ int M = size<0>(shape_mnk); -+ int N = size<1>(shape_mnk); -+ int K = size<2>(shape_mnk); -+ -+ // C starting at 0,0 -+ bool c_filled[M][N] = {}; -+ for (int t = 0; t < size<0>(C); ++t) { -+ for (int v = 0; v < size<1>(C); ++v) { -+ int m = C(t,v) % M; -+ int n = C(t,v) / M; -+ -+ if (not c_filled[m][n]) { -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[t % 8], -+ m, n, -+ t, v); -+ c_filled[m][n] = true; -+ } -+ } -+ } -+ -+ // A starting at 0,-size<1>(A)-1 -+ bool a_filled[M][K] = {}; -+ for (int t = 0; t < size<0>(A); ++t) { -+ for (int v = 0; v < size<1>(A); ++v) { -+ int m = A(t,v) % M; -+ int k = A(t,v) / M; -+ -+ if (not a_filled[m][k]) { -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[t % 8], -+ m, k - 1 - K, -+ t, v); -+ a_filled[m][k] = true; -+ } -+ } -+ } -+ -+ // B starting at -size<1>(B)-1,0 -+ bool b_filled[N][K] = {}; -+ for (int t = 0; t < size<0>(B); ++t) { -+ for (int v = 0; v < size<1>(B); ++v) { -+ int n = B(t,v) % N; -+ int k = B(t,v) / N; -+ -+ if (not b_filled[n][k]) { -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[t % 8], -+ k - 1 - K, n, -+ t, v); -+ b_filled[n][k] = true; -+ } -+ } -+ } -+ -+ // A labels -+ for (int m = 0, k = -1; m < M; ++m) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k - 1 - K, m); -+ } -+ for (int k = 0, m = -1; k < K; ++k) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k - 1 - K, k); -+ } -+ // B labels -+ for (int n = 0, k = -1; n < N; ++n) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k - 1 - K, n, n); -+ } -+ for (int k = 0, n = -1; k < K; ++k) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k - 1 - K, n, k); -+ } -+ -+ // Footer -+ printf(latex_footer); -+} -+ -+} // namespace cute -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits.hpp -new file mode 100644 -index 0000000..a8c3323 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits.hpp -@@ -0,0 +1,70 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template -+struct MMA_Traits -+{ -+ static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation."); -+}; -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = D; -+ using ElementAVal = A; -+ using ElementBVal = B; -+ using ElementCVal = C; -+ -+ // Logical shape of the MMA -+ using Shape_MNK = Shape<_1,_1,_1>; -+ -+ // Logical thread id (tid) -> tidx -+ using ThrID = Layout<_1>; -+ -+ // (Logical thread id (tid), Logical value id (vid)) -> coord -+ -+ // (tid,vid) -> (m,k) -+ using ALayout = Layout>; -+ // (tid,vid) -> (n,k) -+ using BLayout = Layout>; -+ // (tid,vid) -> (m,n) -+ using CLayout = Layout>; -+}; -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits_sm61.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits_sm61.hpp -new file mode 100644 -index 0000000..85d4e98 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits_sm61.hpp -@@ -0,0 +1,73 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_1,_1,_4>; -+ using ThrID = Layout<_1>; -+ using ALayout = Layout>; -+ using BLayout = Layout>; -+ using CLayout = Layout>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int16_t; -+ using ElementBVal = int16_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_1,_1,_2>; -+ using ThrID = Layout<_1>; -+ using ALayout = Layout>; -+ using BLayout = Layout>; -+ using CLayout = Layout>; -+}; -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits_sm70.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits_sm70.hpp -new file mode 100644 -index 0000000..7943035 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits_sm70.hpp -@@ -0,0 +1,198 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+namespace { -+ -+// Logical thread id to thread idx (quadpair) -+using SM70_QuadPair = Layout, -+ Stride<_1,_16>>; -+// (T8,V4) -> (M8,K4) -+using SM70_8x4_Row = Layout, -+ Stride<_1,_8>>; -+// (T8,V4) -> (M8,K4) -+using SM70_8x4_Col = Layout,_4>, -+ Stride,_1>>; -+// (T8,V8) -> (M8,N8) -+using SM70_8x8_16b = Layout, -+ Stride<_1,_8>>; -+// (T8,V8) -> (M8,N8) -+using SM70_8x8_32b = Layout,Shape <_2,_2, _2>>, -+ Stride,Stride<_8,_2,_32>>>; -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Row; -+ using BLayout = SM70_8x4_Row; -+ using CLayout = SM70_8x8_16b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Col; -+ using BLayout = SM70_8x4_Col; -+ using CLayout = SM70_8x8_16b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Col; -+ using BLayout = SM70_8x4_Row; -+ using CLayout = SM70_8x8_16b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Row; -+ using BLayout = SM70_8x4_Col; -+ using CLayout = SM70_8x8_16b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Row; -+ using BLayout = SM70_8x4_Row; -+ using CLayout = SM70_8x8_32b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Col; -+ using BLayout = SM70_8x4_Col; -+ using CLayout = SM70_8x8_32b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Col; -+ using BLayout = SM70_8x4_Row; -+ using CLayout = SM70_8x8_32b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Row; -+ using BLayout = SM70_8x4_Col; -+ using CLayout = SM70_8x8_32b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits_sm75.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits_sm75.hpp -new file mode 100644 -index 0000000..405e871 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits_sm75.hpp -@@ -0,0 +1,81 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_16,_8,_8>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape < _2,_2>>, -+ Stride,Stride<_16,_1>>>; -+ using BLayout = Layout,_2>, -+ Stride,_8>>; -+ using CLayout = Layout,Shape < _2,_2>>, -+ Stride,Stride<_16,_1>>>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_8,_8,_16>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,_4>, -+ Stride,_8>>; -+ using BLayout = Layout,_4>, -+ Stride,_8>>; -+ using CLayout = Layout,_2>, -+ Stride,_8>>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits_sm80.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits_sm80.hpp -new file mode 100644 -index 0000000..6636b7a ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits_sm80.hpp -@@ -0,0 +1,446 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+namespace { -+ -+// (T32,V1) -> (M8,N8) -+using SM80_8x4 = Layout,_1>, -+ Stride,_0>>; -+// (T32,V2) -> (M8,N8) -+using SM80_8x8_Row = Layout,_2>, -+ Stride,_8>>; -+// (T32,V4) -> (M8,N16) -+using SM80_8x16_Row = Layout,_4>, -+ Stride,_8>>; -+// (T32,V4) -> (M16,N8) -+using SM80_16x8_Row = Layout,Shape < _2,_2>>, -+ Stride,Stride<_16,_8>>>; -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////// fp16 = fp16 * fp16 + fp16 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using Shape_MNK = Shape<_16,_8,_8>; -+ using ThrID = Layout<_32>; -+ using ALayout = SM80_16x8_Row; -+ using BLayout = SM80_8x8_Row; -+ using CLayout = SM80_16x8_Row; -+}; -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using Shape_MNK = Shape<_16,_8,_16>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape < _2,_2, _2>>, -+ Stride,Stride<_16,_8,_128>>>; -+ using BLayout = Layout,Shape <_2, _2>>, -+ Stride,Stride<_8,_64>>>; -+ using CLayout = SM80_16x8_Row; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////// fp32 = fp16 * fp16 + fp32 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////// fp32 = bf16 * bf16 + fp32 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////// fp32 = tf32 * tf32 + fp32 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = cutlass::tfloat32_t; -+ using ElementBVal = cutlass::tfloat32_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_16,_8,_4>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,_2>, -+ Stride,_8>>; -+ using BLayout = SM80_8x4; -+ using CLayout = SM80_16x8_Row; -+}; -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = cutlass::tfloat32_t; -+ using ElementBVal = cutlass::tfloat32_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_16,_8,_8>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape <_2, _2>>, -+ Stride,Stride<_8,_64>>>; -+ using BLayout = Layout, _2>, -+ Stride,_32>>; -+ using CLayout = SM80_16x8_Row; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = double; -+ using ElementAVal = double; -+ using ElementBVal = double; -+ using ElementCVal = double; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = Layout<_32>; -+ using ALayout = SM80_8x4; -+ using BLayout = SM80_8x4; -+ using CLayout = SM80_8x8_Row; -+}; -+ -+// Custom complex fp64 MMA composed of 4 fp64 MMAs -- same layouts -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = complex; -+ using ElementAVal = complex; -+ using ElementBVal = complex; -+ using ElementCVal = complex; -+}; -+ -+// Custom complex fp64 MMA composed of 3 fp64 MMAs -- same layouts -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; -+ using ElementAVal = complex; -+ using ElementBVal = complex; -+ using ElementCVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////// s32 = s8 * s8 + s32 /////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_8,_8,_16>; -+ using ThrID = Layout<_32>; -+ using ALayout = SM80_8x16_Row; -+ using BLayout = SM80_8x16_Row; -+ using CLayout = SM80_8x8_Row; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_16,_8,_16>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape < _4,_2>>, -+ Stride,Stride<_16,_8>>>; -+ using BLayout = SM80_8x16_Row; -+ using CLayout = SM80_16x8_Row; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_16,_8,_32>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape < _4,_2, _2>>, -+ Stride,Stride<_16,_8,_256>>>; -+ using BLayout = Layout, Shape <_4, _2>>, -+ Stride, Stride<_8,_128>>>; -+ using CLayout = SM80_16x8_Row; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////// s32 = s8 * u8 + s32 /////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////// s32 = u8 * s8 + s32 /////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////// s32 = u8 * u8 + s32 /////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////// s32 = b1 ^ b1 + s32 /////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = cute::uint1b_t; -+ using ElementBVal = cute::uint1b_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_16,_8,_256>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout>, -+ Stride<_64,Stride<_64,_16,_8,_2048>>>; -+ using BLayout = Layout>, -+ Stride<_32,Stride< _1,_1024>>>; -+ using CLayout = SM80_16x8_Row; -+}; -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits_sm90.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits_sm90.hpp -new file mode 100644 -index 0000000..b7a12b9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits_sm90.hpp -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+namespace cute { -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = double; -+ using ElementAVal = double; -+ using ElementBVal = double; -+ using ElementCVal = double; -+ -+ using Shape_MNK = Shape<_16,_8,_4>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,_2>, -+ Stride,_8>>; -+ using BLayout = Layout,_1>, -+ Stride,_0>>; -+ using CLayout = Layout,Shape < _2,_2>>, -+ Stride,Stride<_16,_8>>>; -+}; -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = double; -+ using ElementAVal = double; -+ using ElementBVal = double; -+ using ElementCVal = double; -+ -+ using Shape_MNK = Shape<_16,_8,_8>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape <_2, _2>>, -+ Stride,Stride<_8,_64>>>; -+ using BLayout = Layout, _2>, -+ Stride,_32>>; -+ using CLayout = Layout,Shape < _2,_2>>, -+ Stride,Stride<_16,_8>>>; -+}; -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = double; -+ using ElementAVal = double; -+ using ElementBVal = double; -+ using ElementCVal = double; -+ -+ using Shape_MNK = Shape<_16,_8,_16>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape <_2, _4>>, -+ Stride,Stride<_8,_64>>>; -+ using BLayout = Layout, _4>, -+ Stride,_32>>; -+ using CLayout = Layout,Shape < _2,_2>>, -+ Stride,Stride<_16,_8>>>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////// -+//////////////////////// cfp64 = cfp64 * cfp64 + cfp64 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = complex; -+ using ElementAVal = complex; -+ using ElementBVal = complex; -+ using ElementCVal = complex; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = complex; -+ using ElementAVal = complex; -+ using ElementBVal = complex; -+ using ElementCVal = complex; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = complex; -+ using ElementAVal = complex; -+ using ElementBVal = complex; -+ using ElementCVal = complex; -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp -new file mode 100644 -index 0000000..d390daf ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp -@@ -0,0 +1,2975 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+namespace cute { -+ -+namespace GMMA { -+ -+/////////////////////////////////////////// -+// Common layouts for GMMA Shared Memory // -+/////////////////////////////////////////// -+ -+// M|N-major GMMA layouts in units of bits -+using Layout_MN_INTER_Atom_Bits = Layout,Stride<_1,_128>>; -+using Layout_MN_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _256>>>; -+using Layout_MN_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _512>>>; -+using Layout_MN_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1,_1024>>>; -+ -+// K-major GMMA layouts in units of bits -+using Layout_K_INTER_Atom_Bits = Layout,Stride<_128,_1>>; -+using Layout_K_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _256,_1>>>; -+using Layout_K_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _512,_1>>>; -+using Layout_K_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1024,_1>>>; -+ -+// M|N-major layouts in units of Type -+template -+using Layout_MN_INTER_Atom = decltype(upcast::value>(Layout_MN_INTER_Atom_Bits{})); -+template -+using Layout_MN_SW32_Atom = decltype(upcast::value>(Layout_MN_SW32_Atom_Bits{})); -+template -+using Layout_MN_SW64_Atom = decltype(upcast::value>(Layout_MN_SW64_Atom_Bits{})); -+template -+using Layout_MN_SW128_Atom = decltype(upcast::value>(Layout_MN_SW128_Atom_Bits{})); -+ -+// K-major layouts in units of Type -+template -+using Layout_K_INTER_Atom = decltype(upcast::value>(Layout_K_INTER_Atom_Bits{})); -+template -+using Layout_K_SW32_Atom = decltype(upcast::value>(Layout_K_SW32_Atom_Bits{})); -+template -+using Layout_K_SW64_Atom = decltype(upcast::value>(Layout_K_SW64_Atom_Bits{})); -+template -+using Layout_K_SW128_Atom = decltype(upcast::value>(Layout_K_SW128_Atom_Bits{})); -+ -+// With GMMA::Major param -+template -+using Layout_INTER_Atom = typename std::conditional, -+ Layout_K_INTER_Atom>::type; -+template -+using Layout_SW32_Atom = typename std::conditional, -+ Layout_K_SW32_Atom>::type; -+template -+using Layout_SW64_Atom = typename std::conditional, -+ Layout_K_SW64_Atom>::type; -+template -+using Layout_SW128_Atom = typename std::conditional, -+ Layout_K_SW128_Atom>::type; -+ -+// Helper for GMMA smem selection that considers a tensor TileShape: -+// (BLK_MN, BLK_K) -+// or hierarchically -+// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) -+// and returns the largest GMMA::Layout that fits BLK_MN0 and BLK_K0 -+template -+CUTE_HOST_DEVICE constexpr -+auto -+smem_selector() -+{ -+ auto BLK_MN0 = size<0>(BLK_MN{}); -+ auto BLK_K0 = size<0>(BLK_K{}); -+ -+ static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); -+ static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); -+ -+ -+ if constexpr (major == GMMA::Major::MN) { -+ if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { -+ return GMMA::Layout_MN_SW128_Atom{}; -+ } else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) { -+ return GMMA::Layout_MN_SW64_Atom{}; -+ } else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { -+ return GMMA::Layout_MN_SW32_Atom{}; -+ } else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { -+ return GMMA::Layout_MN_INTER_Atom{}; -+ } else { -+ static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, -+ "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); -+ } -+ } else if constexpr (major == GMMA::Major::K) { -+ if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) { -+ return GMMA::Layout_K_SW128_Atom{}; -+ } else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) { -+ return GMMA::Layout_K_SW64_Atom{}; -+ } else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) { -+ return GMMA::Layout_K_SW32_Atom{}; -+ } else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) { -+ return GMMA::Layout_K_INTER_Atom{}; -+ } else { -+ static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0, -+ "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); -+ } -+ } -+} -+ -+// -+// Tensor to LayoutType utility -+// -+ -+// smem_ptr_swizzle LayoutType -+template -+CUTE_HOST_DEVICE constexpr -+LayoutType -+layout_type(Tensor>>, -+ Layout> const&) -+{ -+ static_assert(M == 4, "Unsupported layout swizzle"); -+ static_assert(0 <= B && B <= 3, "Unsupported layout swizzle"); -+ static_assert(S == 3, "Unsupported layout swizzle"); -+ -+ switch (B) { -+ case 0: return LayoutType::INTERLEAVE; -+ case 1: return LayoutType::B32; -+ case 2: return LayoutType::B64; -+ case 3: return LayoutType::B128; -+ } -+ return LayoutType::INTERLEAVE; // ERROR -+} -+ -+// smem_ptr non-swizzled LayoutType -+template -+CUTE_HOST_DEVICE constexpr -+LayoutType -+layout_type(Tensor>, -+ Layout> const&) -+{ -+ return LayoutType::INTERLEAVE; -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+// Construction method for GMMA Descriptors -+/////////////////////////////////////////////////////////////////////////////// -+ -+/** -+* /////////////////////////////// -+* // make_gmma_desc // -+* /////////////////////////////// -+* Each GmmaDescriptor Major-MN describes a canonical layout of the form -+* -+* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO)) -+* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO)) -+* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO)) -+* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO)) -+* -+* where -+* T : sizeof(uint128_t) / sizeof(value_type) -+* m : integer in [1,16] corresponding to GMMA shape -+* k : integer in [1,32] corresponding to GMMA shape -+* SBO: stride byte offset -+* LBO: leading byte offset -+* -+* See GMMA::Layout_MN_XXX_Atom for building canonical GmmaDescriptor Major-MN layouts. -+* For example, -+* auto smem_layout = tile_to_shape(Layout_MN_SW128_Atom{}, Shape<_128,_64>{}); -+* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. -+* -+* ////////////////////////////// -+* // make_gmma_desc // -+* ////////////////////////////// -+* Each GmmaDescriptor Major-K describes a canonical layout of the form -+* -+* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO)) -+* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T )) -+* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T )) -+* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T )) -+* -+* See GMMA::Layout_K_XXX_Atom for building canonical GmmaDescriptor Major-K layouts. -+* For example, -+* auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); -+* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. -+*/ -+template -+CUTE_HOST_DEVICE constexpr -+GmmaDescriptor -+make_gmma_desc(Tensor const& tensor) -+{ -+ static_assert(is_smem::value, "GMMA Descriptors can only be constructed on smem."); -+ static_assert(TLayout::rank == 2, "GMMA Descriptors can only be constructed on rank-2 tensors."); -+ using value_type = typename TEngine::value_type; -+ -+ Tensor u128_tensor = recast(tensor); -+ -+ // Result -+ GmmaDescriptor desc; -+ -+ // Layout type -+ constexpr GMMA::LayoutType LAYOUT_TYPE = GMMA::layout_type(u128_tensor); -+ desc.layout_type_ = uint8_t(LAYOUT_TYPE); -+ -+ // Start address (4LSB not included) -+ uint32_t start_address = cast_smem_ptr_to_uint(u128_tensor.data().get()); -+ desc.start_address_ = start_address >> 4; -+ -+ constexpr uint8_t base_offset = 0; -+ desc.base_offset_ = base_offset; -+ -+ // LayoutType meta -+ constexpr int W = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? 1 : -+ LAYOUT_TYPE == GMMA::LayoutType::B32 ? 2 : -+ LAYOUT_TYPE == GMMA::LayoutType::B64 ? 4 : -+ LAYOUT_TYPE == GMMA::LayoutType::B128 ? 8 : -1; -+ -+ if constexpr (MajorMode == GMMA::Major::MN) -+ { -+ /* In units of uint128_t, each GmmaDescriptor Major-MN describes a canonical layout of the form -+ * -+ * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((1,n),(8,k)):((X,SBO),(1,LBO)) -+ * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((2,n),(8,k)):((1,LBO),(2,SBO)) -+ * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((4,n),(8,k)):((1,LBO),(4,SBO)) -+ * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),(8,k)):((1,LBO),(8,SBO)) -+ */ -+ static_assert(size<1>(u128_tensor) == Int<(256 / cute::sizeof_bits::value)>{}, // K size -+ "Not a canonical GMMA_MN Layout: Expected K-size 256/sizeof_bits."); -+ -+ // Construct the canonical GMMA T Layout with shape ((W,n),(8,2)) -+ Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout,_1>{}, Layout,_1>{})); -+ -+ // Check ranks of canonical -+ CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); -+ CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); -+ // Check canonical mode strides -+ constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); -+ constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? stride<0,0>(canonical_layout) : 1; -+ static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_MN Layout: Expected stride failure."); -+ constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); -+ constexpr uint32_t expected_stride_10 = W; -+ static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_MN Layout: Expected stride failure."); -+ -+ // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) -+ constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); -+ constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); -+ -+ desc.stride_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_01 : stride_11; -+ desc.leading_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_11 : stride_01; -+ } -+ else if constexpr (MajorMode == GMMA::Major::K) -+ { -+ /* In units of uint128_t, each GmmaDescriptor Major-K describes a canonical layout of the form -+ * -+ * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,n),2):((1,SBO),LBO) -+ * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,n),2):((2,SBO),1) -+ * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,n),2):((4,SBO),1) -+ * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),2):((8,SBO),1) -+ */ -+ CUTE_STATIC_ASSERT_V(size<0>(u128_tensor) % Int<8>{} == Int<0>{}, // N|M size -+ "Not a canonical GMMA_K Layout: Expected MN-size multiple of 8."); -+ CUTE_STATIC_ASSERT_V(size<1>(u128_tensor) == Int<2>{}, // K size -+ "Not a canonical GMMA_K Layout: Expected K-size 2 (in units of uint128_t)."); -+ -+ // Construct the canonical GMMA N Layout with shape ((8,n),(2,1)) -+ Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout<_8,_1>{}, Layout<_2,_1>{})); -+ -+ // Check ranks of canonical -+ CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_K Layout: No flat offset mode"); -+ CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_K Layout: No flat offset mode"); -+ // Check canonical mode strides -+ constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); -+ constexpr uint32_t expected_stride_00 = W; -+ static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_K Layout: Expected stride failure."); -+ constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); -+ constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride<1,0>(canonical_layout) : 1; -+ static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_K Layout: Expected stride failure."); -+ -+ // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) -+ constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); -+ -+ desc.stride_byte_offset_ = stride_01; -+ desc.leading_byte_offset_ = stride_10; -+ } else { -+ static_assert(MajorMode != GMMA::Major::MN && MajorMode != GMMA::Major::K, "Unrecognized MajorMode!"); -+ } -+ -+#if 0 -+ // DEBUG and SANITY -+ assert((start_address & 0b0000001111) == 0); // Must be 16B aligned (4LSB are 0) no negotiation -+ assert((start_address & 0b1110000000) == 0); // Assert base_offset is 0, generalize later -+ if (thread0()) { -+ print("smem_desc input tensor: "); print(tensor.data()); print(" o "); print(tensor.layout()); print("\n"); -+ print("smem_desc uint128_t tensor: "); print(u128_tensor.data()); print(" o "); print(u128_tensor.layout()); print("\n"); -+ //print(" desc canonical layout: "); print(canonical_layout); print("\n"); -+ print(desc); -+ } -+#endif -+ -+ return desc; -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+// Higher level GMMA Descriptor utilities -+/////////////////////////////////////////////////////////////////////////////// -+ -+struct gmma_descriptor_iterator -+{ -+ GmmaDescriptor desc_; -+ -+ // Dereference returns the GmmaDescriptor -+ CUTE_HOST_DEVICE constexpr -+ GmmaDescriptor const& operator*() const { return desc_; } -+ -+ // Advance and return a new GmmaDescriptor -+ template -+ CUTE_HOST_DEVICE constexpr -+ GmmaDescriptor operator[](Index const& i) const { return *(*this + i); } -+ -+ // Return an advanced iterator -+ template -+ CUTE_HOST_DEVICE constexpr -+ gmma_descriptor_iterator operator+(Index const& offset) const -+ { -+ // offset is in the units of uint128_t (4LSB of start_address not included) -+ -+ //GmmaDescriptor desc = desc_; -+ //desc.start_address_ += uint16_t(offset); -+ //desc.reg32_[0] += uint16_t(offset); // Generates better asm than adding to the bitfield -+ -+ // May need to update base_offset if swizzle alignment isn't guaranteed -+ //desc.base_offset_ = 0; -+ //assert((desc.start_address_ & 0b111000) == 0); // Assert base_offset is 0, generalize later -+ -+ //return {desc}; -+ -+ // The above seems to not work for some reason... -+ return {desc_ + uint64_t(offset)}; -+ } -+}; -+ -+template -+struct smem_desc : gmma_descriptor_iterator {}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_gmma_desc_fragment(Tensor const& t) -+{ -+ // Cast to a uint128_t tensor for GMMA Desc iteration -+ return make_tensor(gmma_descriptor_iterator{make_gmma_desc(tensor<0>(t))}, -+ recast(t).layout()); -+} -+ -+// Recast a tensor to a tensor of gmma_descriptor_iterator -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(Tensor&& tensor, type_list>) -+{ -+ return make_gmma_desc_fragment(tensor); -+} -+ -+// Recast a gmma_descriptor_iterator Tensor to uint64_t, it's RegType -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(Tensor,TLayout> const& tensor, type_list) -+{ -+ static_assert(std::is_same::value, "Can only cast descriptors to uint64_t."); -+ return make_tensor(tensor.data(), Layout<_1,_0>{}); -+} -+ -+} // end namespace GMMA -+ -+// Fence between the async destination accumulators of GMMA & source for their dependent use -+template -+CUTE_HOST_DEVICE -+void -+warpgroup_fence_operand(Tensor& frg) { -+ CUTE_STATIC_ASSERT(is_static::value); -+ if constexpr (std::is_same_v) { -+ auto f32_frg = recast(frg); -+ CUTE_UNROLL -+ for (int i = 0; i < size(f32_frg); ++i) { -+ warpgroup_fence_operand(f32_frg(i)); -+ } -+ } -+ else { -+ CUTE_STATIC_ASSERT(is_rmem::value); -+ auto u32_frg = recast(frg); -+ CUTE_UNROLL -+ for (int i = 0; i < size(u32_frg); ++i) { -+ warpgroup_fence_operand(u32_frg(i)); -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////////// MMA_TRAITS /////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace GMMA { -+ -+// Accumulator layouts -+using CLayout_64x8 = Layout,Shape < _2,_2>>, -+ Stride,Stride<_64,_8>>>; -+ -+using CLayout_64x16 = Layout,Shape < _2,_2, _2>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+using CLayout_64x32 = Layout,Shape < _2,_2, _4>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+using CLayout_64x64 = Layout,Shape < _2,_2, _8>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+using CLayout_64x96 = Layout,Shape < _2,_2, _12>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+using CLayout_64x128 = Layout,Shape < _2,_2, _16>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+using CLayout_64x192 = Layout,Shape < _2,_2, _24>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+using CLayout_64x256 = Layout,Shape < _2,_2, _32>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+// Register source layout for 32-bit value types -+using ALayout_64x8 = Layout,Shape < _2, _2>>, -+ Stride,Stride< _8,_256>>>; -+ -+// Register source layout for 16-bit value types -+using ALayout_64x16 = CLayout_64x16; -+ -+// Register source layout for 8-bit value types -+using ALayout_64x32 = Layout,Shape < _4,_2, _2>>, -+ Stride,Stride<_64,_8,_1024>>>; -+ -+// Shared memory source layouts for any value type -+template -+using ABLayout = Layout,Int>>, -+ Stride< _0,Stride< _1,Int>>>; -+ -+} // namespace GMMA -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 8, 16>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 8, 16>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 16, 16>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 16, 16>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 32, 16>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 32, 16>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 64, 16>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 64, 16>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 96, 16>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 96, 16>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<128, 16>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<128, 16>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<192, 16>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<192, 16>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<256, 16>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<256, 16>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 8, 16>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 8, 16>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 16, 16>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 16, 16>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 32, 16>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 32, 16>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 64, 16>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 64, 16>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 96, 16>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 96, 16>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<128, 16>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<128, 16>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<192, 16>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<192, 16>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<256, 16>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<256, 16>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 8, 16>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 8, 16>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 16, 16>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 16, 16>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 32, 16>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 32, 16>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 64, 16>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 64, 16>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 96, 16>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 96, 16>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<128, 16>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<128, 16>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<192, 16>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<192, 16>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<256, 16>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<256, 16>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout< 8, 8>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout< 8, 8>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout< 16, 8>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout< 16, 8>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout< 32, 8>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout< 32, 8>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout< 64, 8>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout< 64, 8>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout< 96, 8>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout< 96, 8>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout<128, 8>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout<128, 8>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout<192, 8>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout<192, 8>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout<256, 8>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout<256, 8>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/config.hpp b/3rdparty/cutlass/include/cute/config.hpp -new file mode 100644 -index 0000000..b2f4de8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/config.hpp -@@ -0,0 +1,121 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) -+# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__ -+# define CUTE_DEVICE __forceinline__ __device__ -+# define CUTE_HOST __forceinline__ __host__ -+#else -+# define CUTE_HOST_DEVICE inline -+# define CUTE_DEVICE inline -+# define CUTE_HOST inline -+#endif // CUTE_HOST_DEVICE, CUTE_DEVICE -+ -+#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) -+# define CUTE_UNROLL #pragma unroll -+# define CUTE_NO_UNROLL #pragma unroll 1 -+#else -+# define CUTE_UNROLL -+# define CUTE_NO_UNROLL -+#endif // CUTE_UNROLL -+ -+#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) -+# define CUTE_INLINE_CONSTANT static const __device__ -+#else -+# define CUTE_INLINE_CONSTANT static constexpr -+#endif -+ -+// Some versions of GCC < 11 have trouble deducing that a -+// function with "auto" return type and all of its returns in an "if -+// constexpr ... else" statement must actually return. Thus, GCC -+// emits spurious "missing return statement" build warnings. -+// Developers can suppress these warnings by using the -+// CUTE_GCC_UNREACHABLE macro, which must be followed by a semicolon. -+// It's harmless to use the macro for other GCC versions or other -+// compilers, but it has no effect. -+#if ! defined(CUTE_GCC_UNREACHABLE) -+# if defined(__GNUC__) && __GNUC__ < 11 -+ // GCC 10, but not 7.5, 9.4.0, or 11, issues "missing return -+ // statement" warnings without this little bit of help. -+# define CUTE_GCC_UNREACHABLE __builtin_unreachable() -+# else -+# define CUTE_GCC_UNREACHABLE -+# endif -+#endif -+ -+// -+// Assertion helpers -+// -+ -+#include -+ -+#define CUTE_STATIC_ASSERT static_assert -+#define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__) -+ -+#if defined(__CUDA_ARCH__) -+# define CUTE_RUNTIME_ASSERT(x) asm volatile ("brkpt;\n" ::: "memory") -+#else -+# define CUTE_RUNTIME_ASSERT(x) assert(0 && x) -+#endif -+ -+// -+// IO -+// -+ -+#include -+#include -+#include -+ -+// -+// Support -+// -+ -+#include -+ -+// -+// Basic types -+// -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+// -+// Debugging utilities -+// -+ -+#include -+#include -diff --git a/3rdparty/cutlass/include/cute/container/alignment.hpp b/3rdparty/cutlass/include/cute/container/alignment.hpp -new file mode 100644 -index 0000000..49101fa ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/alignment.hpp -@@ -0,0 +1,70 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+// Test if a pointer is aligned to N bytes -+template -+CUTE_HOST_DEVICE constexpr -+bool -+is_byte_aligned(void const* const ptr) -+{ -+ static_assert(N > 0 && (N & (N - 1)) == 0, "N must be a power of 2 in alignment check"); -+ return (reinterpret_cast(ptr) & (N-1)) == 0; -+} -+ -+#if defined(__CUDACC__) -+# define CUTE_ALIGNAS(n) __align__(n) -+#else -+# define CUTE_ALIGNAS(n) alignas(n) -+#endif -+ -+template -+struct aligned_struct {}; -+ -+template <> struct CUTE_ALIGNAS( 1) aligned_struct< 1> {}; -+template <> struct CUTE_ALIGNAS( 2) aligned_struct< 2> {}; -+template <> struct CUTE_ALIGNAS( 4) aligned_struct< 4> {}; -+template <> struct CUTE_ALIGNAS( 8) aligned_struct< 8> {}; -+template <> struct CUTE_ALIGNAS( 16) aligned_struct< 16> {}; -+template <> struct CUTE_ALIGNAS( 32) aligned_struct< 32> {}; -+template <> struct CUTE_ALIGNAS( 64) aligned_struct< 64> {}; -+template <> struct CUTE_ALIGNAS(128) aligned_struct<128> {}; -+template <> struct CUTE_ALIGNAS(256) aligned_struct<256> {}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/container/array.hpp b/3rdparty/cutlass/include/cute/container/array.hpp -new file mode 100644 -index 0000000..571ac08 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/array.hpp -@@ -0,0 +1,282 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template -+struct array -+{ -+ using value_type = T; -+ using size_type = std::size_t; -+ using difference_type = std::ptrdiff_t; -+ using reference = value_type&; -+ using const_reference = const value_type&; -+ using pointer = value_type*; -+ using const_pointer = const value_type*; -+ using iterator = pointer; -+ using const_iterator = const_pointer; -+ -+ CUTE_HOST_DEVICE constexpr -+ reference operator[](size_type pos) -+ { -+ return begin()[pos]; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference operator[](size_type pos) const -+ { -+ return begin()[pos]; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference front() -+ { -+ return *begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference front() const -+ { -+ return *begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference back() -+ { -+ // return *rbegin(); -+ return operator[](N-1); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference back() const -+ { -+ // return *rbegin(); -+ return operator[](N-1); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ T* data() -+ { -+ return __elems_; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ T const* data() const -+ { -+ return __elems_; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator begin() -+ { -+ return data(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator begin() const -+ { -+ return data(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cbegin() -+ { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cbegin() const -+ { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator end() -+ { -+ return data() + size(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator end() const -+ { -+ return data() + size(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cend() -+ { -+ return end(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cend() const -+ { -+ return end(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool empty() const -+ { -+ return size() == 0; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type size() const -+ { -+ return N; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type max_size() const -+ { -+ return size(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ void fill(const T& value) -+ { -+ for (auto& e : *this) { -+ e = value; -+ } -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ void clear() -+ { -+ fill(T(0)); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ void swap(array& other) -+ { -+ using std::swap; -+ for (size_type i = 0; i < size(); ++i) { -+ swap((*this)[i], other[i]); -+ } -+ } -+ -+ value_type __elems_[N > 0 ? N : 1]; -+}; -+ -+ -+template -+CUTE_HOST_DEVICE constexpr -+bool operator==(array const& lhs, array const& rhs) -+{ -+ for (std::size_t i = 0; i < N; ++i) { -+ if (lhs[i] != rhs[i]) { -+ return false; -+ } -+ } -+ return true; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void clear(array& a) -+{ -+ a.fill(T(0)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void fill(array& a, T const& value) -+{ -+ a.fill(value); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void swap(array& a, array& b) -+{ -+ a.swap(b); -+} -+ -+} // end cute -+ -+ -+// -+// Specialize tuple-related functionality for cute::array -+// -+ -+#include -+ -+namespace cute -+{ -+ -+template -+CUTE_HOST_DEVICE constexpr -+T& get(array& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T const& get(array const& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T&& get(array&& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return std::move(a[I]); -+} -+ -+} // end namespace cute -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+{ -+ using type = T; -+}; -+ -+} // end std -diff --git a/3rdparty/cutlass/include/cute/container/array_aligned.hpp b/3rdparty/cutlass/include/cute/container/array_aligned.hpp -new file mode 100644 -index 0000000..b1b3572 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/array_aligned.hpp -@@ -0,0 +1,276 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+template -+struct array_aligned -+ : public aligned_struct -+{ -+ /// Make sure the Alignment makes sense wrt the size of elements. -+ static_assert(Alignment == 16 || Alignment >= sizeof(T), "Alignment is too small"); -+ /// Alignment must be a power of two -+ static_assert(has_single_bit(Alignment), "Alignment must be a power of two"); -+ -+ using value_type = T; -+ using size_type = std::size_t; -+ using difference_type = std::ptrdiff_t; -+ using reference = value_type&; -+ using const_reference = const value_type&; -+ using pointer = value_type*; -+ using const_pointer = const value_type*; -+ using iterator = pointer; -+ using const_iterator = const_pointer; -+ -+ CUTE_HOST_DEVICE constexpr -+ reference operator[](size_type pos) -+ { -+ return begin()[pos]; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference operator[](size_type pos) const -+ { -+ return begin()[pos]; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference front() -+ { -+ return *begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference front() const -+ { -+ return *begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference back() -+ { -+ // return *rbegin(); -+ return operator[](N-1); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference back() const -+ { -+ // return *rbegin(); -+ return operator[](N-1); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ T* data() -+ { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ T const* data() const -+ { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator begin() -+ { -+ return data(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator begin() const -+ { -+ return data(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cbegin() -+ { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cbegin() const -+ { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator end() -+ { -+ return data() + size(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator end() const -+ { -+ return data() + size(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cend() -+ { -+ return end(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cend() const -+ { -+ return end(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool empty() const -+ { -+ return size() == 0; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type size() const -+ { -+ return N; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type max_size() const -+ { -+ return size(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ void fill(T const& value) -+ { -+ for (auto& e : *this) { -+ e = value; -+ } -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ void clear() -+ { -+ fill(T(0)); -+ } -+ -+ // Not private, we want trivial type -+ //private: -+ -+ /// Storage type to use for Elements -+ using StorageType = typename uint_byte(Alignment)>::type; -+ -+ /// Ensure that there's enough storage for all elements -+ static_assert(sizeof(StorageType) <= Alignment, "StorageType is too big for given alignment"); -+ -+ /// Number of elements in the storage -+ static constexpr std::size_t storageN = (sizeof(T)*N + sizeof(StorageType) - 1) / sizeof(StorageType); -+ -+ /// The storage. -+ StorageType storage[storageN > 0 ? storageN : 1]; -+}; -+ -+// -+// Operators -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+void clear(array_aligned& a) -+{ -+ a.clear(); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void fill(array_aligned& a, T const& value) -+{ -+ a.fill(value); -+} -+ -+} // end namespace cute -+ -+// -+// Specialize tuple-related functionality for cute::array -+// -+ -+#include -+ -+namespace cute -+{ -+ -+template -+CUTE_HOST_DEVICE constexpr -+T& get(array_aligned& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T const& get(array_aligned const& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T&& get(array_aligned&& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return std::move(a[I]); -+} -+ -+} // end namespace cute -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+{ -+ using type = T; -+}; -+ -+} // end std -diff --git a/3rdparty/cutlass/include/cute/container/array_subbyte.hpp b/3rdparty/cutlass/include/cute/container/array_subbyte.hpp -new file mode 100644 -index 0000000..a217a67 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/array_subbyte.hpp -@@ -0,0 +1,613 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates subbyte trivial types -+ in a packed storage. -+*/ -+ -+#pragma once -+ -+#include -+ -+#include // sizeof_bits -+ -+namespace cute -+{ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Statically sized array for any data type -+template -+class array_subbyte -+{ -+ public: -+ -+ /// Number of total bits in the array -+ static constexpr int kSizeBits = sizeof_bits::value * N; -+ -+ /// Storage type -+ using Storage = typename std::conditional< -+ (kSizeBits % 32) == 0, -+ uint32_t, -+ typename std::conditional< -+ (kSizeBits % 16) == 0, -+ uint16_t, -+ uint8_t -+ >::type -+ >::type; -+ -+ -+ /// Number of logical elements per stored object -+ static constexpr int kElementsPerStoredItem = sizeof_bits::value / sizeof_bits::value; -+ -+ /// Number of storage elements -+ static constexpr std::size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; -+ -+ /// Bitmask for covering one item -+ static constexpr Storage bit_mask_ = ((Storage(1) << sizeof_bits::value) - 1); -+ -+ // -+ // C++ standard members with reference and iterator types omitted -+ // -+ -+ using value_type = T; -+ using pointer = value_type*; -+ using const_pointer = value_type const*; -+ -+ using size_type = std::size_t; -+ using difference_type = std::ptrdiff_t; -+ -+ // -+ // References -+ // -+ -+ /// Reference object inserts or extracts sub-byte items -+ class reference { -+ /// Pointer to storage element -+ Storage* ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ /// Default ctor -+ CUTE_HOST_DEVICE constexpr -+ reference() : ptr_(nullptr), idx_(0) {} -+ -+ /// Ctor -+ CUTE_HOST_DEVICE constexpr -+ reference(Storage* ptr, int idx = 0) : ptr_(ptr), idx_(idx) {} -+ -+ /// Assignment -+ CUTE_HOST_DEVICE constexpr -+ reference& operator=(T x) { -+ Storage item = (reinterpret_cast(x) & bit_mask_); -+ Storage kUpdateMask = Storage(~(bit_mask_ << (idx_ * sizeof_bits::value))); -+ *ptr_ = Storage((*ptr_ & kUpdateMask) | (item << (idx_ * sizeof_bits::value))); -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ T get() const { -+ Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & bit_mask_); -+ return reinterpret_cast(item); -+ } -+ -+ /// Extract to type T -- disable if T == bool -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ operator T() const { -+ return get(); -+ } -+ -+ // Extract to bool -- potentially faster impl -+ CUTE_HOST_DEVICE constexpr -+ operator bool() const { -+ return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits::value))); -+ } -+ -+ /// Explicit cast to int -+ CUTE_HOST_DEVICE constexpr -+ explicit operator int() const { -+ return int(get()); -+ } -+ -+ /// Explicit cast to float -+ CUTE_HOST_DEVICE constexpr -+ explicit operator float() const { -+ return float(get()); -+ } -+ }; -+ -+ /// Reference object extracts sub-byte items -+ class const_reference { -+ -+ /// Pointer to storage element -+ Storage const* ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ /// Default ctor -+ CUTE_HOST_DEVICE constexpr -+ const_reference(): ptr_(nullptr), idx_(0) { } -+ -+ /// Ctor -+ CUTE_HOST_DEVICE constexpr -+ const_reference(Storage const* ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ CUTE_HOST_DEVICE constexpr -+ const T get() const { -+ Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & bit_mask_); -+ return reinterpret_cast(item); -+ } -+ -+ /// Extract to type T -- disable if T == bool -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ operator T() const { -+ return get(); -+ } -+ -+ // Extract to bool -- potentially faster impl -+ CUTE_HOST_DEVICE constexpr -+ operator bool() const { -+ return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits::value))); -+ } -+ -+ /// Explicit cast to int -+ CUTE_HOST_DEVICE constexpr -+ explicit operator int() const { -+ return int(get()); -+ } -+ -+ /// Explicit cast to float -+ CUTE_HOST_DEVICE constexpr -+ explicit operator float() const { -+ return float(get()); -+ } -+ }; -+ -+ // -+ // Iterators -+ // -+ -+ /// Bidirectional iterator over elements -+ class iterator { -+ -+ /// Pointer to storage element -+ Storage* ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator(): ptr_(nullptr), idx_(0) { } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator(Storage* ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator& operator++() { -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator& operator--() { -+ if (idx_) { -+ --idx_; -+ } else { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator operator++(int) { -+ iterator ret(*this); -+ ++(*this); -+ return ret; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator operator--(int) { -+ iterator ret(*this); -+ --(*this); -+ return ret; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator& operator+=(int k) { -+ idx_ += k; -+ ptr_ += idx_ / kElementsPerStoredItem; -+ idx_ = idx_ % kElementsPerStoredItem; -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator operator+(int k) const { -+ return iterator(ptr_,idx_) += k; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference operator*() const { -+ return reference(ptr_, idx_); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference operator[](int k) const { -+ return *(*this + k); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool operator==(iterator const& other) const { -+ return ptr_ == other.ptr_ && idx_ == other.idx_; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool operator!=(iterator const& other) const { -+ return !(*this == other); -+ } -+ }; -+ -+ /// Bidirectional constant iterator over elements -+ class const_iterator { -+ -+ /// Pointer to storage element -+ Storage const* ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator(): ptr_(nullptr), idx_(0) { } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator(Storage const* ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator& operator++() { -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator& operator--() { -+ if (idx_) { -+ --idx_; -+ } else { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator operator++(int) { -+ iterator ret(*this); -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return ret; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator operator--(int) { -+ iterator ret(*this); -+ if (idx_) { -+ --idx_; -+ } else { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ return ret; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator& operator+=(int k) { -+ idx_ += k; -+ ptr_ += idx_ / kElementsPerStoredItem; -+ idx_ = idx_ % kElementsPerStoredItem; -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator operator+(int k) const { -+ return const_iterator(ptr_,idx_) += k; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference operator*() const { -+ return const_reference(ptr_, idx_); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference operator[](int k) const { -+ return *(*this + k); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool operator==(iterator const& other) const { -+ return ptr_ == other.ptr_ && idx_ == other.idx_; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool operator!=(iterator const& other) const { -+ return !(*this == other); -+ } -+ }; -+ -+private: -+ -+ /// Internal storage -+ Storage storage[kStorageElements]; -+ -+public: -+ -+ CUTE_HOST_DEVICE constexpr -+ array_subbyte() { } -+ -+ CUTE_HOST_DEVICE constexpr -+ array_subbyte(array_subbyte const& x) { -+ CUTE_UNROLL -+ for (unsigned i = 0; i < kStorageElements; ++i) { -+ storage[i] = x.storage[i]; -+ } -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type size() const { -+ return N; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type max_size() const { -+ return N; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool empty() const { -+ return !N; -+ } -+ -+ /// Efficient clear method -+ CUTE_HOST_DEVICE constexpr -+ void clear() { -+ CUTE_UNROLL -+ for (unsigned i = 0; i < kStorageElements; ++i) { -+ storage[i] = Storage(0); -+ } -+ } -+ -+ // Efficient fill method -+ CUTE_HOST_DEVICE constexpr -+ void fill(T const& value) { -+ Storage item = (reinterpret_cast(value) & bit_mask_); -+ -+ // Reproduce the value over the bits of the storage item -+ CUTE_UNROLL -+ for (unsigned s = sizeof_bits::value; s < sizeof_bits::value; s *= 2) { -+ item |= item << s; -+ } -+ -+ CUTE_UNROLL -+ for (unsigned i = 0; i < kStorageElements; ++i) { -+ storage[i] = item; -+ } -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference at(size_type pos) { -+ return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference at(size_type pos) const { -+ return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference operator[](size_type pos) { -+ return at(pos); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference operator[](size_type pos) const { -+ return at(pos); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference front() { -+ return at(0); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference front() const { -+ return at(0); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference back() { -+ return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference back() const { -+ return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ pointer data() { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_pointer data() const { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ Storage* raw_data() { -+ return storage; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ Storage const* raw_data() const { -+ return storage; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator begin() { -+ return iterator(storage); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator begin() const { -+ return const_iterator(storage); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cbegin() const { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator end() { -+ return iterator(storage + N / kElementsPerStoredItem, N % kElementsPerStoredItem); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator end() const { -+ return const_iterator(storage + N / kElementsPerStoredItem, N % kElementsPerStoredItem); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cend() const { -+ return end(); -+ } -+ -+ // -+ // Comparison operators -+ // -+ -+}; -+ -+// -+// Operators -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+void clear(array_subbyte& a) -+{ -+ a.clear(); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void fill(array_subbyte& a, T const& value) -+{ -+ a.fill(value); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -+ -+// -+// Specialize tuple-related functionality for cute::array_subbyte -+// -+ -+#include -+ -+namespace cute -+{ -+ -+template -+CUTE_HOST_DEVICE constexpr -+T& get(array_subbyte& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T const& get(array_subbyte const& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T&& get(array_subbyte&& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return std::move(a[I]); -+} -+ -+} // end namespace cute -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+{ -+ using type = T; -+}; -+ -+} // end namespace std -diff --git a/3rdparty/cutlass/include/cute/container/array_view.hpp b/3rdparty/cutlass/include/cute/container/array_view.hpp -new file mode 100644 -index 0000000..51b3ccc ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/array_view.hpp -@@ -0,0 +1,274 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template -+struct array_view -+{ -+ using value_type = T; -+ using size_type = std::size_t; -+ using difference_type = std::ptrdiff_t; -+ using reference = value_type&; -+ using const_reference = const value_type&; -+ using pointer = value_type*; -+ using const_pointer = const value_type*; -+ using iterator = pointer; -+ using const_iterator = const_pointer; -+ -+ array_view(array& a) -+ : __elems_(a.data()) {} -+ -+ CUTE_HOST_DEVICE -+ reference operator[](size_type pos) -+ { -+ return begin()[pos]; -+ } -+ -+ CUTE_HOST_DEVICE -+ const_reference operator[](size_type pos) const -+ { -+ return begin()[pos]; -+ } -+ -+ CUTE_HOST_DEVICE -+ reference front() -+ { -+ return *begin(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_reference front() const -+ { -+ return *begin(); -+ } -+ -+ CUTE_HOST_DEVICE -+ reference back() -+ { -+ // return *rbegin(); -+ return operator[](N-1); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_reference back() const -+ { -+ // return *rbegin(); -+ return operator[](N-1); -+ } -+ -+ CUTE_HOST_DEVICE -+ T* data() -+ { -+ return __elems_; -+ } -+ -+ CUTE_HOST_DEVICE -+ const T* data() const -+ { -+ return __elems_; -+ } -+ -+ CUTE_HOST_DEVICE -+ iterator begin() -+ { -+ return data(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_iterator begin() const -+ { -+ return data(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_iterator cbegin() -+ { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_iterator cbegin() const -+ { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE -+ iterator end() -+ { -+ return data() + size(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_iterator end() const -+ { -+ return data() + size(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_iterator cend() -+ { -+ return end(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_iterator cend() const -+ { -+ return end(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool empty() const -+ { -+ return size() == 0; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type size() const -+ { -+ return N; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type max_size() const -+ { -+ return size(); -+ } -+ -+ CUTE_HOST_DEVICE -+ void fill(const T& value) -+ { -+ for(auto& e : *this) -+ { -+ e = value; -+ } -+ } -+ -+ CUTE_HOST_DEVICE -+ void swap(array_view& other) -+ { -+ using std::swap; -+ swap(__elems_, other.__elems_); -+ } -+ -+ value_type* __elems_; -+}; -+ -+ -+template -+CUTE_HOST_DEVICE -+bool operator==(const array_view& lhs, const array_view& rhs) -+{ -+ for(std::size_t i = 0; i < N; ++i) -+ { -+ if(lhs[i] != rhs[i]) return false; -+ } -+ -+ return true; -+} -+ -+template -+CUTE_HOST_DEVICE -+void clear(array_view& a) -+{ -+ a.fill(T(0)); -+} -+ -+template -+CUTE_HOST_DEVICE -+void swap(array_view& a, array_view& b) -+{ -+ a.swap(b); -+} -+ -+} // end cute -+ -+ -+// -+// Specialize tuple-related functionality for cute::array_view -+// -+ -+#include -+ -+namespace cute -+{ -+ -+template -+CUTE_HOST_DEVICE constexpr -+T& -+get(array_view& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+const T& -+get(const array_view& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T&& -+get(array_view&& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return std::move(a[I]); -+} -+ -+} // end namespace cute -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+{ -+ using type = T; -+}; -+ -+} // end std -diff --git a/3rdparty/cutlass/include/cute/container/bit_field.hpp b/3rdparty/cutlass/include/cute/container/bit_field.hpp -new file mode 100644 -index 0000000..06b0875 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/bit_field.hpp -@@ -0,0 +1,131 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Portable bit field that supports byte and word straddling that can -+ be used in unions to bit-wise define parameters. -+*/ -+ -+#pragma once -+ -+#include -+ -+#include // uint_bit_t -+ -+namespace cute -+{ -+ -+class dummy_type {}; -+ -+template -+struct bit_field -+{ -+ static_assert(0 < NumBits && NumBits <= 64, "bit_fields with more than 64 bits are not supported."); -+ -+ // value_type: Use the smallest value type that fits NumBits -+ static constexpr uint32_t value_type_bits = (NumBits <= 8) ? 8 : -+ (NumBits <= 16) ? 16 : -+ (NumBits <= 32) ? 32 : 64; -+ using value_type = cute::uint_bit_t; -+ // storage_type: Use the smallest storage_type that avoids boundary crossing -+ static constexpr uint32_t storage_type_bits = (BitStart / 8 == (BitStart + NumBits - 1) / 8) ? 8 : -+ (BitStart / 16 == (BitStart + NumBits - 1) / 16) ? 16 : -+ (BitStart / 32 == (BitStart + NumBits - 1) / 32) ? 32 : 64; -+ using storage_type = cute::uint_bit_t; -+ -+ static_assert(sizeof(OtherValueType) == sizeof(value_type) || std::is_same::value, -+ "sizeof(OtherValueType) must be same as sizeof(value_type)."); -+ -+ // Number of storage values needed: ceil_div(BitStart + NumBits, storage_type_bits) -+ static constexpr uint32_t N = (BitStart + NumBits + storage_type_bits - 1) / storage_type_bits; -+ // Index of storage value for BitStart -+ static constexpr uint32_t idx = BitStart / storage_type_bits; -+ // Bit of data_[idx] for BitStart -+ static constexpr uint32_t bit_lo = BitStart % storage_type_bits; -+ // Number of bits in data_[idx] used for NumBits if straddling, else 0 -+ static constexpr uint32_t bit_hi = (idx + 1 < N) ? (storage_type_bits - bit_lo) : 0; -+ -+ // NumBits mask -+ static constexpr value_type mask = (NumBits < 64) ? ((uint64_t(1) << NumBits) - 1) : uint64_t(-1); -+ // NumBits mask for BitStart -+ static constexpr storage_type mask_lo = storage_type(mask) << bit_lo; -+ // NumBits mask for leftover bits in data_[idx+1] if straddling, else 0 -+ static constexpr storage_type mask_hi = (idx + 1 < N) ? (storage_type(mask) >> bit_hi) : 0; -+ -+ storage_type data_[N]; -+ -+ // Get value -+ CUTE_HOST_DEVICE constexpr -+ value_type get() const { -+ storage_type result = (data_[idx] & mask_lo) >> bit_lo; -+ if constexpr (bit_hi) { -+ result |= (data_[idx+1] & mask_hi) << bit_hi; -+ } -+ return static_cast(result); -+ } -+ -+ // Set value -+ CUTE_HOST_DEVICE constexpr -+ void set(value_type x) { -+ storage_type item = static_cast(x & mask); -+ data_[idx] = static_cast((data_[idx] & ~mask_lo) | (item << bit_lo)); -+ if constexpr (bit_hi) { -+ data_[idx+1] = static_cast((data_[idx+1] & ~mask_hi) | (item >> bit_hi)); -+ } -+ } -+ -+ // Assign value -+ CUTE_HOST_DEVICE constexpr -+ bit_field& operator=(value_type x) { -+ set(x); -+ return *this; -+ } -+ -+ // Cast to value -+ CUTE_HOST_DEVICE constexpr -+ operator value_type () const { -+ return get(); -+ } -+ -+ // Assign OtherValueType -+ CUTE_HOST_DEVICE constexpr -+ bit_field& operator=(OtherValueType x) { -+ return *this = *reinterpret_cast(&x); -+ } -+ -+ // Cast to OtherValueType -+ CUTE_HOST_DEVICE constexpr -+ operator OtherValueType () const { -+ value_type x = get(); -+ return *reinterpret_cast(&x); -+ } -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/container/tuple.hpp b/3rdparty/cutlass/include/cute/container/tuple.hpp -new file mode 100644 -index 0000000..1b3ffa4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/tuple.hpp -@@ -0,0 +1,671 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+#include -+ -+#include // cute::true_type, cute::false_type -+//#include // Advanced optimizations -+ -+#if 0 -+// -+// Use of agency::tuple is functional, but is over-engineered for our purposes... -+// This tends to result in slow compilation times and unintentionally propagated cvref types -+// -+ -+#include -+ -+namespace cute -+{ -+ -+using agency::tuple; -+ -+using agency::make_tuple; -+using agency::tuple_cat; -+ -+} // end namespace cute -+#endif -+ -+// cute::tuple is like std::tuple, with two differences. -+// -+// 1. It works on both host and device. -+// 2. Its template arguments must be semiregular types. -+// -+// Semiregular types are default constructible and copyable. -+// They include "value types" like int or float, -+// but do _not_ include references like int& or float&. -+// (See std::tie for an example of a tuple of references.) -+// -+// This is simplified over the implementation in std:: and agency:: by ignoring much of -+// the conversion SFINAE, special overloading, and avoiding cvref template types. -+// Furthermore, the empty base optimization (EBO) is MORE aggressive by avoiding -+// construction calls, and ignoring any need for unique element addresses. -+// -+// Over the agency::tuple implementation, this appears to accelerate compilation times by over 3x. -+ -+namespace cute -+{ -+ -+namespace detail -+{ -+ -+// EBO stands for "empty base optimization." -+// We use this technique to ensure that cute::tuple -+// doesn't need to waste space storing any template arguments -+// of cute::tuple that have no data (like integral_constant). -+// Otherwise, cute::tuple would need to spend at least 1 byte -+// for each of its template arguments. -+// -+// EBO always "holds" a single value of type T. -+// N is like an array index that TupleBase uses -+// to access the desired tuple element. -+template ::value> -+struct EBO; -+ -+// Specialization for types T that have no data; -+// the "static tuple leaf." Valid T here include -+// integral_constant, Int, -+// and any other semiregular type -+// for which std::is_empty_v is true. -+template -+struct EBO -+{ -+ CUTE_HOST_DEVICE constexpr -+ EBO() {} -+ -+ CUTE_HOST_DEVICE constexpr -+ EBO(T const&) {} -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr T getv(EBO const&) -+{ return {}; } -+ -+// Specialization for types T that are not empty; -+// the "dynamic tuple leaf." Valid T here include int, -+// any other integral or floating-point type, -+// or any semiregular type for which std::is_empty_v is false. -+template -+struct EBO -+{ -+ CUTE_HOST_DEVICE constexpr -+ EBO() : t_{} {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ EBO(U const& u) : t_{u} {} -+ -+ T t_; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr T const& getv(EBO const& x) -+{ return x.t_; } -+ -+template -+CUTE_HOST_DEVICE constexpr T& getv(EBO& x) -+{ return x.t_; } -+ -+template -+CUTE_HOST_DEVICE constexpr T&& getv(EBO&& x) -+{ return static_cast(x.t_); } -+ -+template -+struct TupleBase; -+ -+// Base class of cute::tuple. -+// It inherits from EBO for each (i, t) in (I..., T...). -+// The actual storage (for nonempty t) lives in the base classes. -+// index_sequence is a way to wrap up a sequence of zero or more -+// compile-time integer values in a single type. -+// We only ever use index_sequence<0, 1, ..., sizeof...(T)> in practice, -+// as the type alias TupleBase below indicates. -+template -+struct TupleBase, T...> -+ : EBO... -+{ -+ CUTE_HOST_DEVICE constexpr -+ TupleBase() {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr explicit -+ TupleBase(U const&... u) -+ : EBO(u)... {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ TupleBase(TupleBase, U...> const& u) -+ : EBO(getv(static_cast const&>(u)))... {} -+}; -+ -+} // end namespace detail -+ -+// make_index_sequence returns index_sequence<0, 1, ..., K-1>. -+template -+using TupleBase = detail::TupleBase, T...>; -+ -+// This is the actual cute::tuple class. -+// The storage (if any) lives in TupleBase's EBO base classes. -+template -+struct tuple : TupleBase -+{ -+ CUTE_HOST_DEVICE constexpr -+ tuple() {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ tuple(U const&... u) : TupleBase(u...) {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ tuple(tuple const& u) -+ : TupleBase(static_cast const&>(u)) {} -+}; -+ -+// -+// get for cute::tuple (just like std::get for std::tuple) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+get(tuple const& t) noexcept -+{ -+ static_assert(I < sizeof...(T), "Index out of range"); -+ return detail::getv(t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+get(tuple& t) noexcept -+{ -+ static_assert(I < sizeof...(T), "Index out of range"); -+ return detail::getv(t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+get(tuple&& t) noexcept -+{ -+ static_assert(I < sizeof...(T), "Index out of range"); -+ return detail::getv(static_cast&&>(t)); -+} -+ -+// -+// Custom is_tuple trait simply checks the existence of std::tuple_size -+// and assumes std::get(.), std::tuple_element -+// -+namespace detail { -+ -+template -+std::integral_constant::value >= 0> has_tuple_size(int); -+ -+template -+std::false_type has_tuple_size(...); -+ -+} // end namespace detail -+ -+template -+struct is_tuple : decltype(detail::has_tuple_size(0)) {}; -+ -+// -+// make_tuple (value-based implementation) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+tuple -+make_tuple(T const&... t) -+{ -+ return {t...}; -+} -+ -+// -+// tuple_cat concatenates multiple cute::tuple into a single cute::tuple, -+// just like std::tuple_cat for std::tuple. -+// -+ -+#if 0 -+// Original implementation -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, -+ std::index_sequence, std::index_sequence) -+{ -+ return cute::make_tuple(get(t0)..., get(t1)...); -+} -+ -+} // end namespace detail -+ -+CUTE_HOST_DEVICE constexpr -+tuple<> -+tuple_cat() -+{ -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+Tuple const& -+tuple_cat(Tuple const& t) -+{ -+ return t; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1) -+{ -+ return detail::tuple_cat(t0, t1, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, Ts const&... ts) -+{ -+ return cute::tuple_cat(cute::tuple_cat(t0,t1),t2,ts...); -+} -+#endif -+ -+#if 1 -+// Extended implementation -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, -+ std::index_sequence, std::index_sequence) -+{ -+ return cute::make_tuple(get(t0)..., get(t1)...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, -+ std::index_sequence, std::index_sequence, std::index_sequence) -+{ -+ return cute::make_tuple(get(t0)..., get(t1)..., get(t2)...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, -+ std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence) -+{ -+ return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, -+ std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence) -+{ -+ return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)..., get(t4)...); -+} -+ -+} // end namespace detail -+ -+CUTE_HOST_DEVICE constexpr -+tuple<> -+tuple_cat() -+{ -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+Tuple const& -+tuple_cat(Tuple const& t) -+{ -+ return t; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1) -+{ -+ return detail::tuple_cat(t0, t1, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2) -+{ -+ return detail::tuple_cat(t0, t1, t2, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3) -+{ -+ return detail::tuple_cat(t0, t1, t2, t3, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4) -+{ -+ return detail::tuple_cat(t0, t1, t2, t3, t4, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts) -+{ -+ return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), t5, ts...); -+} -+#endif -+ -+#if 0 -+// Outer-Inner indexing trick to concat all tuples at once -+ -+namespace detail { -+ -+template -+struct tuple_cat_helper -+{ -+ static constexpr cute::array ns = {Ns...}; -+ -+ static constexpr std::size_t total_size() { -+ std::size_t sum = 0; -+ for (std::size_t n : ns) sum += n; -+ return sum; -+ } -+ static constexpr std::size_t total_size_ = total_size(); -+ -+ static constexpr auto values() { -+ cute::array outer_inner = {}; -+ -+ std::size_t idx = 0; -+ for (std::size_t i = 0; i < ns.size(); ++i) { -+ for (std::size_t j = 0; j < ns[i]; ++j, ++idx) { -+ outer_inner[idx][0] = i; -+ outer_inner[idx][1] = j; -+ } -+ } -+ return outer_inner; -+ } -+ static constexpr auto outer_inner_ = values(); -+ -+ using total_sequence = std::make_index_sequence; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(Tuple const& t, std::index_sequence) -+{ -+ return cute::make_tuple(get(get(t))...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, -+ std::index_sequence, std::index_sequence) -+{ -+ return cute::make_tuple(get(t0)..., get(t1)...); -+} -+ -+} // end namespace detail -+ -+CUTE_HOST_DEVICE constexpr -+tuple<> -+tuple_cat() -+{ -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+Tuple const& -+tuple_cat(Tuple const& t) -+{ -+ return t; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1) -+{ -+ return detail::tuple_cat(t0, t1, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(Tuples const&... ts) -+{ -+ using Helper = detail::tuple_cat_helper::value...>; -+ return detail::tuple_cat(make_tuple(ts...), typename Helper::total_sequence{}); -+} -+#endif -+ -+// -+// Equality operators -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+equal_impl(TupleA const& a, TupleB const& b) -+{ -+ if constexpr (I == std::tuple_size::value) { -+ return cute::true_type{}; // Terminal: TupleA is exhausted -+ } else if constexpr (I == std::tuple_size::value) { -+ return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted -+ } else { -+ return (get(a) == get(b)) && equal_impl(a,b); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template ::value && is_tuple::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(TupleT const& t, TupleU const& u) -+{ -+ return detail::equal_impl<0>(t, u); -+} -+ -+template ::value ^ is_tuple::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(TupleT const& t, TupleU const& u) -+{ -+ return cute::false_type{}; -+} -+ -+template ::value && is_tuple::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+operator!=(TupleT const& t, TupleU const& u) -+{ -+ return !(t == u); -+} -+ -+template ::value ^ is_tuple::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+operator!=(TupleT const& t, TupleU const& u) -+{ -+ return cute::true_type{}; -+} -+ -+// -+// Comparison operators -+// -+ -+// -+// There are many ways to compare tuple of elements and because CuTe is built -+// on parameterizing layouts of coordinates, some comparisons are appropriate -+// only in certain cases. -+// -- lexicographical comparison [reverse, reflected, revref] -+// -- colexicographical comparison [reverse, reflected, revref] -+// -- element-wise comparison [any,all] -+// This can be very confusing. To avoid errors in selecting the appropriate -+// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. -+// -+// That said, see int_tuple for more explicitly named common comparison ops. -+// -+ -+// -+// Shortcuts -+// -+ -+//using std::get; -+using std::tuple_size; -+using std::tuple_element; -+using std::tuple_element_t; -+ -+// -+// Display utilities -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE void print_tuple(Tuple const& t, -+ std::index_sequence, char s = '(', char e = ')') -+{ -+ using eat = int[]; -+ using cute::print; -+ (void) eat {(print(s), 0), -+ (print(Is == 0 ? "" : ","), print(get(t)), 0)..., -+ (print(e), 0)}; -+} -+ -+template -+CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, -+ std::index_sequence, char s = '(', char e = ')') -+{ -+ using eat = int[]; -+ (void) eat {(void(os << s), 0), -+ (void(os << (Is == 0 ? "" : ",") << get(t)), 0)..., -+ (void(os << e), 0)}; -+ return os; -+} -+ -+} // end namespace detail -+ -+template ::value)> -+CUTE_HOST_DEVICE void print(Tuple const& t) -+{ -+ return detail::print_tuple(t, std::make_index_sequence::value>{}); -+} -+ -+template ::value)> -+CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t) -+{ -+ return detail::print_tuple_os(os, t, std::make_index_sequence::value>{}); -+} -+ -+} // end namespace cute -+ -+// -+// std:: compatability -+// -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+ : std::tuple_element> -+{}; -+ -+} // end std -diff --git a/3rdparty/cutlass/include/cute/container/type_list.hpp b/3rdparty/cutlass/include/cute/container/type_list.hpp -new file mode 100644 -index 0000000..c082a6d ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/type_list.hpp -@@ -0,0 +1,84 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+namespace cute -+{ -+ -+template -+struct type_c { -+ using type = T; -+}; -+ -+template -+struct type_list {}; -+ -+} // end namespace cute -+ -+// -+// Specialize tuple-related functionality for cute::type_list -+// -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+template -+CUTE_HOST_DEVICE constexpr -+std::tuple_element_t> -+get(type_list&) noexcept { -+ return {}; -+} -+template -+CUTE_HOST_DEVICE constexpr -+std::tuple_element_t> -+get(type_list const& t) noexcept { -+ return {}; -+} -+ -+} // end namespace cute -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+ : cute::type_c>::type> -+{}; -+ -+} // end namespace std -diff --git a/3rdparty/cutlass/include/cute/int_tuple.hpp b/3rdparty/cutlass/include/cute/int_tuple.hpp -new file mode 100644 -index 0000000..045e721 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/int_tuple.hpp -@@ -0,0 +1,827 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+template -+using IntTuple = cute::tuple; -+ -+// Construct an IntTuple with all value-elements -+template -+CUTE_HOST_DEVICE constexpr -+IntTuple -+make_int_tuple(Ts const&... t) -+{ -+ return {t...}; -+} -+ -+/** if rank(int) == 1, then get<0>(int) should work too -+ */ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+get(T&& t) noexcept -+{ -+ static_assert(I == 0, "Index out of range"); -+ return static_cast(t); -+} -+ -+/** Custom recursive get for anything that implements get(.) -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+get(Tuple&& t) noexcept -+{ -+ return get(get(static_cast(t))); -+} -+ -+// -+// rank -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+rank(IntTuple const& t) -+{ -+ if constexpr (sizeof...(Is) == 0) { -+ if constexpr (is_tuple::value) { -+ return Int::value>{}; -+ } else { -+ return Int<1>{}; -+ } -+ } else { -+ return rank(get(t)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+using rank_t = decltype(rank(std::declval())); -+ -+template -+static constexpr int rank_v = rank_t::value; -+ -+// -+// shape -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+shape(IntTuple const& s) -+{ -+ if constexpr (is_tuple::value) { -+ return transform(s, [](auto const& a) { return shape(a); }); -+ } else { -+ return s; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+shape(IntTuple const& s) -+{ -+ if constexpr (is_tuple::value) { -+ return shape(get(s)); -+ } else { -+ return get(shape(s)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// max -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+max(T0 const& t0, Ts const&... ts) -+{ -+ if constexpr (is_tuple::value) { -+ return cute::max(cute::apply(t0, [](auto const&... a){ return cute::max(a...); }), ts...); -+ } else if constexpr (sizeof...(Ts) == 0) { -+ return t0; -+ } else { -+ return cute::max(t0, cute::max(ts...)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// min -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+min(T0 const& t0, Ts const&... ts) -+{ -+ if constexpr (is_tuple::value) { -+ return cute::min(cute::apply(t0, [](auto const&... a){ return cute::min(a...); }), ts...); -+ } else if constexpr (sizeof...(Ts) == 0) { -+ return t0; -+ } else { -+ return cute::min(t0, cute::min(ts...)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// depth -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+depth(IntTuple const& t) -+{ -+ if constexpr (sizeof...(Is) == 0) { -+ if constexpr (is_tuple::value) { -+ return Int<1>{} + cute::apply(t, [](auto const&... v){ return cute::max(depth(v)...); }); -+ } else { -+ return Int<0>{}; -+ } -+ } else { -+ return depth(get(t)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+using depth_t = decltype(depth(std::declval())); -+ -+template -+static constexpr int depth_v = depth_t::value; -+ -+// -+// product -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+product(IntTuple const& a) -+{ -+ if constexpr (is_tuple::value) { -+ return cute::apply(a, [](auto const&... v){ return (Int<1>{} * ... * product(v)); }); -+ } else { -+ return a; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Product of a subrange -+template -+CUTE_HOST_DEVICE constexpr -+auto -+product(Tuple const& a) -+{ -+ return detail::apply(a, [](auto const&... v){ return (Int<1>{} * ... * product(v)); }, make_range{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+product_each(Tuple const& t) -+{ -+ return transform(t, [](auto const& x) { return product(x); }); -+} -+ -+// Return the product of elements in a mode -+template -+CUTE_HOST_DEVICE constexpr -+auto -+size(IntTuple const& a) -+{ -+ if constexpr (sizeof...(Is) == 0) { -+ return product(a); -+ } else { -+ return product(get(a)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+static constexpr int size_v = decltype(size(std::declval()))::value; -+ -+// -+// sum -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+sum(IntTuple const& a) -+{ -+ if constexpr (is_tuple::value) { -+ return cute::apply(a, [](auto const&... v){ return (Int<0>{} + ... + sum(v)); }); -+ } else { -+ return a; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// inner_product -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+inner_product(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value && is_tuple::value) { -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); -+ return transform_apply(a, b, [](auto const& x, auto const& y) { return inner_product(x,y); }, -+ [](auto const&... v) { return (Int<0>{} + ... + v); }); -+ } else { -+ return a * b; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// ceil_div -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+ceil_div(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value && is_tuple::value) { -+ static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); -+ constexpr int R = tuple_size::value; // Missing ranks in TupleB are implictly 1 -+ return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); -+ } else { -+ return (a + b - Int<1>{}) / b; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+/** Division for Shapes -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+shape_div(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple::value) { // tuple tuple -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); -+ return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); }); -+ } else { // tuple int -+ auto const [result, rest] = fold(a, make_tuple(make_tuple(), b), -+ [] (auto const& init, auto const& ai) { -+ return make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai)); -+ }); -+ return result; -+ } -+ } else { -+ if constexpr (is_tuple::value) { // int tuple -+ return shape_div(a, product(b)); -+ } else { // int int -+ //assert(a % b == 0 || b % a == 0); -+ return a / b != 0 ? a / b : signum(a) * signum(b); // divide with rounding away from zero -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+/** Division for Shapes that are static constants -+ * @pre t % u == 0 || u % t == 0 -+ * @result if t % u == 0, then t / u -+ * if u % t == 0, then signum(t) * signum(u) -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+constant -+shape_div(constant const&, constant const&) -+{ -+ static_assert(t % u == 0 || u % t == 0, "Static shape_div failure"); -+ return {}; -+} -+ -+/** Return a tuple the same profile as A scaled by corresponding elements in B -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_scale(A const& a, B const& b) -+{ -+ if constexpr (is_tuple::value) { -+ return transform(a, b, [](auto const& x, auto const& y) { return elem_scale(x,y); }); -+ } else { -+ return a * product(b); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+/** Test if two IntTuple have the same profile (hierarchical rank division) -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+congruent(IntTupleA const& a, IntTupleB const& b) -+{ -+ return bool_constant::value>{}; -+} -+ -+template -+using is_congruent = decltype(congruent(std::declval(), std::declval())); -+ -+/** Test if Shape B is compatible with Shape A: -+ * Any coordinate into A can also be used as a coordinate into B -+ * A <= B is a partially ordered set of factored shapes -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compatible(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value && is_tuple::value) { -+ if constexpr (tuple_size::value != tuple_size::value) { -+ return false_type{}; -+ } else { -+ return transform_apply(a, b, [](auto const& x, auto const& y) { return compatible(x,y); }, -+ [](auto const&... z) { return (true_type{} && ... && z); }); -+ } -+ } else if constexpr (is_integral::value) { -+ return a == size(b); -+ } else if constexpr (is_integral::value) { -+ return false_type{}; -+ } else { -+ return compatible(shape(a), shape(b)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+using is_compatible = decltype(compatible(std::declval(), std::declval())); -+ -+/** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1> -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter_zeros(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value) { -+ return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); }); -+ } else if constexpr (is_constant<0, IntTupleA>::value) { -+ return Int<1>{}; -+ } else { -+ return b; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter_zeros(Tuple const& t) -+{ -+ return filter_zeros(t, t); -+} -+ -+// -+// Converters and constructors with arrays and params -+// -+ -+/** Make an IntTuple of rank N from an Indexable array. -+ * Access elements up to a dynamic index n, then use init (requires compatible types) -+ * Consider cute::take if all indexing is known to be valid -+ * \code -+ * std::vector a = {6,3,4}; -+ * auto tup = make_int_tuple<5>(a, a.size(), 0) // (6,3,4,0,0) -+ * \endcode -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_int_tuple(Indexable const& t, int n, T const& init) -+{ -+ static_assert(N > 0); -+ if constexpr (N == 1) { -+ return 0 < n ? t[0] : init; -+ } else { -+ return transform(make_seq{}, [&](auto i) { return i < n ? t[i] : init; }); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+/** Fill the dynamic values of a Tuple with values from another Tuple -+ * \code -+ * auto params = make_int_tuple(6,3,4); -+ * cute::tuple, cute::tuple>, int, Int<2>> result; -+ * fill_int_tuple_from(result, params); // (_1,(6,3,_3),4,_2) -+ * \endcode -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+fill_int_tuple_from(Tuple& result, TupleV const& vals) -+{ -+ return fold(result, vals, [](auto const& init, auto&& r) { -+ if constexpr (is_static>::value) { // Skip static elements of result -+ return init; -+ } else if constexpr (is_tuple>::value) { // Recurse into tuples -+ return fill_int_tuple_from(r, init); -+ } else { // Assign and consume arg -+ static_assert(tuple_size>::value > 0, "Not enough values to fill with!"); -+ r = get<0>(init); -+ return remove<0>(init); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ }); -+} -+ -+/** Make a "Tuple" by filling in the dynamic values in order from the arguments -+ * \code -+ * using result_t = cute::tuple, cute::tuple>, int, Int<2>>; -+ * auto result = make_int_tuple_from(6,3,4); // (_1,(6,3,_3),4,_2) -+ * \endcode -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+Tuple -+make_int_tuple_from(Ts const&... ts) -+{ -+ Tuple result = Tuple{}; -+ fill_int_tuple_from(result, make_tuple(ts...)); -+ return result; -+} -+ -+/** Convert a tuple to a flat homogeneous array of type T -+ * \code -+ * auto tup = make_tuple(Int<1>{}, make_tuple(6,3,Int<3>{}),4,Int<2>{}); -+ * cute::array result = to_array(tup); // [1,6,3,3,4,2] -+ * \endcode -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+to_array(IntTuple const& t) -+{ -+ auto flat_t = flatten_to_tuple(t); -+ constexpr int N = tuple_size::value; -+ cute::array result; -+ for_each(make_seq{}, [&] (auto i) { result[i] = get(flat_t); }); -+ return result; -+} -+ -+// -+// Comparison operators -+// -+ -+// -+// There are many ways to compare tuple of elements and because CuTe is built -+// on parameterizing layouts of coordinates, some comparisons are appropriate -+// only in certain cases. -+// -- lexicographical comparison [reverse, reflected, revref] : Correct for coords in RowMajor Layout -+// -- colexicographical comparison [reverse, reflected, revref] : Correct for coords in ColMajor Layout -+// -- element-wise comparison [any,all] : -+// This can be very confusing. To avoid errors in selecting the appropriate -+// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. -+// -+// When actually desiring to order coordinates, the user should map them to -+// their indices within the Layout they came from: -+// e.g. layoutX(coordA) < layoutX(coordB) -+// That said, we implement the three most common ways to compare tuples below. -+// These are implemented with slighly more explicit names than op<. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+lex_less(IntTupleA const& a, IntTupleB const& b); -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+colex_less(IntTupleA const& a, IntTupleB const& b); -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_less(IntTupleA const& a, IntTupleB const& b); -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+lex_less_impl(TupleA const& a, TupleB const& b) -+{ -+ if constexpr (I == tuple_size::value) { -+ return cute::false_type{}; // Terminal: TupleB is exhausted -+ } else if constexpr (I == tuple_size::value) { -+ return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted -+ } else { -+ return lex_less(get(a), get(b)) || (get(a) == get(b) && lex_less_impl(a,b)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+colex_less_impl(TupleA const& a, TupleB const& b) -+{ -+ if constexpr (I == tuple_size::value) { -+ return cute::false_type{}; // Terminal: TupleB is exhausted -+ } else if constexpr (I == tuple_size::value) { -+ return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted -+ } else { -+ constexpr std::size_t A = tuple_size::value - 1 - I; -+ constexpr std::size_t B = tuple_size::value - 1 - I; -+ return colex_less(get(a), get(b)) || (get(a) == get(b) && colex_less_impl(a,b)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_less_impl(TupleA const& a, TupleB const& b) -+{ -+ if constexpr (I == tuple_size::value) { -+ return cute::true_type{}; // Terminal: TupleA is exhausted -+ } else if constexpr (I == tuple_size::value) { -+ return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted -+ } else { -+ return elem_less(get(a), get(b)) && elem_less_impl(a,b); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+// Lexicographical comparison -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+lex_less(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value && is_tuple::value) { -+ return detail::lex_less_impl<0>(a, b); -+ } else { -+ return a < b; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+lex_leq(T const& t, U const& u) { -+ return !lex_less(u, t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+lex_gtr(T const& t, U const& u) { -+ return lex_less(u, t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+lex_geq(T const& t, U const& u) { -+ return !lex_less(t, u); -+} -+ -+// Colexicographical comparison -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+colex_less(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value && is_tuple::value) { -+ return detail::colex_less_impl<0>(a, b); -+ } else { -+ return a < b; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+colex_leq(T const& t, U const& u) { -+ return !colex_less(u, t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+colex_gtr(T const& t, U const& u) { -+ return colex_less(u, t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+colex_geq(T const& t, U const& u) { -+ return !colex_less(t, u); -+} -+ -+// Elementwise [all] comparison -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_less(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value && is_tuple::value) { -+ return detail::elem_less_impl<0>(a, b); -+ } else { -+ return a < b; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_leq(T const& t, U const& u) { -+ return !elem_less(u, t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_gtr(T const& t, U const& u) { -+ return elem_less(u, t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_geq(T const& t, U const& u) { -+ return !elem_less(t, u); -+} -+ -+/** Increment a (dynamic) coord lexicographically within a shape -+ * \code -+ * auto shape = make_shape(1,2,make_shape(2,3),3); -+ * -+ * int i = 0; -+ * for (auto coord = repeat_like(shape, 0); back(coord) != back(shape); increment(coord, shape)) { -+ * std::cout << i++ << ": " << coord << std::endl; -+ * } -+ * assert(i == size(shape)); -+ * \endcode -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+void -+increment(Coord& coord, Shape const& shape); -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+increment(Coord& coord, Shape const& shape, seq) -+{ -+ cute::increment(get(coord), get(shape)); -+ if constexpr (sizeof...(Is) != 0) { -+ if (back(get(coord)) == back(get(shape))) { -+ back(get(coord)) = 0; -+ increment(coord, shape, seq{}); -+ } -+ } -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+increment(Coord& coord, Shape const& shape) -+{ -+ if constexpr (is_integral::value && is_integral::value) { -+ ++coord; -+ } else if constexpr (is_tuple::value && is_tuple::value) { -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); -+ detail::increment(coord, shape, tuple_seq{}); -+ } else { -+ static_assert(sizeof(Coord) == 0, "Invalid parameters"); -+ } -+} -+ -+struct ForwardCoordIteratorSentinal -+{}; -+ -+// A forward iterator for a coordinate that starts from zero and goes to shape -+template -+struct ForwardCoordIterator -+{ -+ static_assert(is_congruent::value); -+ -+ CUTE_HOST_DEVICE constexpr -+ Coord const& operator*() const { return coord; } -+ -+ CUTE_HOST_DEVICE constexpr -+ ForwardCoordIterator& operator++() { increment(coord, shape); return *this; } -+ -+ // Sentinal for the end of the implied range -+ CUTE_HOST_DEVICE constexpr -+ bool operator< (ForwardCoordIteratorSentinal const&) const { return back(coord) < back(shape); } -+ CUTE_HOST_DEVICE constexpr -+ bool operator==(ForwardCoordIteratorSentinal const&) const { return back(coord) == back(shape); } -+ CUTE_HOST_DEVICE constexpr -+ bool operator!=(ForwardCoordIteratorSentinal const&) const { return back(coord) != back(shape); } -+ // NOTE: These are expensive, avoid use -+ CUTE_HOST_DEVICE constexpr -+ bool operator< (ForwardCoordIterator const& other) const { return colex_less(coord, other.coord); } -+ CUTE_HOST_DEVICE constexpr -+ bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; } -+ CUTE_HOST_DEVICE constexpr -+ bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; } -+ -+ Coord coord; -+ Shape const& shape; -+}; -+ -+// A forward iterator for a coordinate that starts from zero -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_coord_iterator(Shape const& shape) -+{ -+ auto coord = repeat_like(shape, int(0)); -+ return ForwardCoordIterator{coord,shape}; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/layout.hpp b/3rdparty/cutlass/include/cute/layout.hpp -new file mode 100644 -index 0000000..fe937ee ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/layout.hpp -@@ -0,0 +1,1638 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+// Aliases -+ -+template -+using Shape = IntTuple; -+ -+template -+using Stride = IntTuple; -+ -+template -+using Step = IntTuple; -+ -+template -+using Coord = IntTuple; -+ -+template -+CUTE_HOST_DEVICE constexpr -+Shape -+make_shape(Ts const&... t) { -+ return {t...}; -+} -+template -+CUTE_HOST_DEVICE constexpr -+Stride -+make_stride(Ts const&... t) { -+ return {t...}; -+} -+template -+CUTE_HOST_DEVICE constexpr -+Step -+make_step(Ts const&... t) { -+ return {t...}; -+} -+template -+CUTE_HOST_DEVICE constexpr -+Coord -+make_coord(Ts const&... t) { -+ return {t...}; -+} -+ -+ -+template > -+struct Layout -+ : private cute::tuple // EBO for static layouts -+{ -+ // Avoid bad CTAD: -+ // Layout smem = GMMA::Layout_MN_SW128_Atom; -+ // Should fail because smem is a ComposedLayout (SwizzleLayout) and not a Layout -+ static_assert(is_integral::value || is_tuple::value); -+ -+ // Expensive in compilation time... -+ //static_assert(is_congruent::value, -+ // "Shape and Stride must have the same hierarchical structure"); -+ //static_assert(is_integral::value || is_tuple::value); -+ -+ // NOTE: This defaults static Shapes/Strides correctly, but not dynamic -+ CUTE_HOST_DEVICE constexpr -+ Layout(LogicalShape const& logical_shape = {}, -+ LogicalStride const& logical_stride = {}) -+ : cute::tuple(logical_shape, logical_stride) -+ {} -+ -+ // -+ // Accessors -+ // -+ -+ static constexpr int rank = rank_v ; -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ layout() { -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ layout() const { -+ return *this; -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ shape() { -+ return get<0,I...>(static_cast&>(*this)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ shape() const { -+ return get<0,I...>(static_cast const&>(*this)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ stride() { -+ return get<1,I...>(static_cast&>(*this)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ stride() const { -+ return get<1,I...>(static_cast const&>(*this)); -+ } -+ -+ // -+ // Mappings -+ // -+ -+ // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) -+ // OR -+ // Slice the layout and return the sublayout (Coord has an Underscore slice op) -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator()(Coord const& coord) const { -+ if constexpr (has_underscore::value) { -+ return slice(coord, *this); -+ } else { -+ return crd2idx(coord, shape(), stride()); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ } -+ -+ // Convenience function for multi-dimensional coordinates -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { -+ return operator()(make_coord(c0,c1,cs...)); -+ } -+ -+ // Map a linear index to a hier ND logical coordinate -+ // NOTE: Dangerous and error-prone -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator[](Int const& linear_idx) const { -+ static_assert(is_integral::value); -+ return get_hier_coord(linear_idx); -+ } -+ -+ // -+ // Compose -+ // -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ compose(OtherLayout const& other) const { -+ return composition(*this, other); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ compose(Layouts const&... layouts) const { -+ return composition(*this, make_tile(layouts...)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ with_shape(OtherShape const& shape) const { -+ return composition(*this, make_layout(shape)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ with_shape(Shapes const&... shapes) const { -+ return composition(*this, make_layout(make_shape(shapes...))); -+ } -+ -+ // -+ // Tile -+ // -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ tile(OtherLayout const& other) const { -+ return tiled_divide(*this, other); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ tile(Layouts const&... layouts) const { -+ return tiled_divide(*this, make_tile(layouts...)); -+ } -+ -+ // -+ // Utility -+ // -+ -+ // -+ // Index to Coordinate -+ // -+ -+ // NOTE: Only valid for compact layouts -+ -+ // Return the (hierarchical) ND logical coordinate corresponding to the linear index -+ // @post crd2idx(@a result, shape(), stride()) == idx -+ // @post congruent(@a result, shape()) -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_hier_coord(IInt const& idx) const { -+ return cute::idx2crd(idx, shape(), stride()); -+ } -+ -+ // Return the (flat) ND logical coordinate corresponding to the linear index -+ // @post crd2idx(@a result, shape(), stride()) == idx -+ // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_flat_coord(IInt const& idx) const { -+ return cute::crd2crd(this->get_hier_coord(idx), shape(), repeat(Int<1>{})); -+ } -+ -+ // Return the generalized column-major 1D logical coordinate corresponding to the linear index -+ // @post crd2idx(@a result, shape(), stride()) == idx -+ // @post is_integral::value -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_1d_coord(IInt const& idx) const { -+ return cute::crd2idx(this->get_hier_coord(idx), shape()); -+ } -+ -+ // -+ // Coordinate to Coordinate -+ // -+ -+#if 0 -+ // Return the (hierarchical) ND logical coordinate corresponding to the linear index -+ // @post congruent(@a result, shape()) -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ crd_2_hier_coord(Coord const& crd) const { -+ return cute::crd2crd(crd, shape(), shape()); -+ } -+ -+ // Return the (flat) ND logical coordinate corresponding to the linear index -+ // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ crd_2_flat_coord(Coord const& crd) const { -+ return cute::crd2crd(crd, shape(), product_each(shape())); -+ } -+ -+ // Return the generalized column-major 1D logical coordinate corresponding to the linear index -+ // @post is_integral::value -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ crd_2_1d_coord(Coord const& crd) const { -+ //return cute::crd2crd(crd, shape(), product(shape())); -+ return cute::crd2idx(crd, shape()); -+ } -+#endif -+}; -+ -+ -+template -+struct is_layout : false_type {}; -+template -+struct is_layout> : true_type {}; -+ -+ -+template ::value || is_integral::value) && -+ (is_tuple::value || is_integral::value))> -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Shape const& shape, Stride const& stride) -+{ -+ return Layout(shape, stride); -+} -+ -+template ::value || is_integral::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Shape const& shape) -+{ -+ return make_layout(shape, compact_col_major(shape)); -+} -+ -+// Construct a layout from multiple layouts by -+// concatenating each layout as an independent mode -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Layout const&... layouts) -+{ -+ return make_layout(make_shape (layouts.shape()...), -+ make_stride(layouts.stride()...)); -+} -+ -+// -+// Convenience tags for common layouts -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Shape const& shape, GenColMajor) -+{ -+ return make_layout(shape, compact_col_major(shape)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Shape const& shape, GenRowMajor) -+{ -+ return make_layout(shape, compact_row_major(shape)); -+} -+ -+// Follow the same ordering induced by the strides, but make the layout compact -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_ordered_layout(Shape const& shape, Order const& order) -+{ -+ static_assert(is_static::value && is_static::value); -+ return make_layout(shape, compact_order(shape, order)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_ordered_layout(Layout const& layout) -+{ -+ return make_ordered_layout(layout.shape(), layout.stride()); -+} -+ -+// Make a layout of the same shape that is either ordered or colmajor depending on staticness -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout_like(Layout const& layout) -+{ -+ if constexpr (is_static::value && is_static::value) { -+ return make_ordered_layout(layout.shape(), layout.stride()); -+ } else { -+ return make_layout(layout.shape()); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Make a layout of the same shape, -+// with mode-0 being colmajor then following the the mode order in layout -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_fragment_like(Layout const& layout) -+{ -+ auto shape = replace<0>(layout.shape(), size<0>(layout)); -+ auto order = replace<0>(layout.stride(), Int<0>{}); -+ if constexpr (is_static::value && is_static::value) { -+ return make_ordered_layout(shape, order); -+ } else { -+ return make_layout(layout.shape()); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_identity_layout(Shape const& shape) -+{ -+ return make_layout(shape, make_basis_like(shape)); -+} -+ -+// -+// Operations to manipulate Layouts like a tuple of pairs -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+get(Layout const& layout) -+{ -+ // Let the static_asserts in get(shape|stride) catch problems -+ return make_layout(get(layout.shape()), get(layout.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+take(Layout const& layout) -+{ -+ // Let the static_asserts in take(shape|stride) catch problems -+ return make_layout(take(layout.shape()), take(layout.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+flatten(Layout const& layout) -+{ -+ return make_layout(flatten(layout.shape()), flatten(layout.stride())); -+} -+ -+// -+// Utilities -+// -+ -+// Return the layout of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+layout(Layout const& layout) -+{ -+ if constexpr (sizeof...(Is) == 0) { -+ return layout; -+ } else { -+ return get(layout); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Return the shape of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+shape(Layout& layout) -+{ -+ return layout.template shape(); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+shape(Layout const& layout) -+{ -+ return layout.template shape(); -+} -+ -+// Return the stride of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+stride(Layout& layout) -+{ -+ return layout.template stride(); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+stride(Layout const& layout) -+{ -+ return layout.template stride(); -+} -+ -+// Return the number of elements in a mode -+template -+CUTE_HOST_DEVICE constexpr -+auto -+size(Layout const& layout) -+{ -+ return size(shape(layout)); -+} -+ -+// Return the number of modes -+template -+CUTE_HOST_DEVICE constexpr -+auto -+rank(Layout const& layout) -+{ -+ return rank(shape(layout)); -+} -+ -+// Return the depth of the layout -+template -+CUTE_HOST_DEVICE constexpr -+auto -+depth(Layout const& layout) -+{ -+ return depth(shape(layout)); -+} -+ -+// Return the codomain size of a mode -+// @return M smallest integer such that @a sub_layout(c) < M for all c < size(@a sub_layout) -+// where sub_layout = get(layout). -+template -+CUTE_HOST_DEVICE constexpr -+auto -+cosize(Layout const& layout) -+{ -+ // Protect against negative strides -+ auto abs_sub_layout = make_layout(shape(layout), -+ transform_leaf(stride(layout), abs_fn{})); -+ return abs_sub_layout(size(abs_sub_layout) - Int<1>{}) + Int<1>{}; -+} -+ -+template -+using cosize_t = decltype(cosize(std::declval())); -+ -+template -+static constexpr int cosize_v = cosize_t::value; -+ -+// Equality -+// Return a static or dynamic boolean -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(Layout const& layoutA, Layout const& layoutB) -+{ -+ return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride(); -+} -+ -+// With crd2idx(coord, shape), makes sense to have crd2idx(coord, Layout) as well -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx(Coord const& c, Layout const& layout) -+{ -+ return crd2idx(c, layout.shape(), layout.stride()); -+} -+ -+// -+// Slice and Dice a layout -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+slice(Coord const& c, Layout const& layout) -+{ -+ return make_layout(slice(c, layout.shape()), -+ slice(c, layout.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+slice_and_offset(Coord const& c, Layout const& layout) -+{ -+ return cute::make_tuple(slice(c, layout), crd2idx(c, layout)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+dice(Coord const& c, Layout const& layout) -+{ -+ return make_layout(dice(c, layout.shape()), -+ dice(c, layout.stride())); -+} -+ -+// -+// Transform the modes of a layout -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_layout(Tuple const& t, F&& f, seq) -+{ -+ return make_layout(f(get(t))...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f, seq, seq, seq) -+{ -+ return make_layout(f(get(t0),get(t1))..., get(t0)..., get(t1)...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_layout(Tuple const& t, F&& f) -+{ -+ return detail::transform_layout(t, f, make_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f) -+{ -+ constexpr int R0 = decltype(rank(t0))::value; -+ constexpr int R1 = decltype(rank(t1))::value; -+ constexpr int R = (R0 < R1) ? R0 : R1; -+ return detail::transform_layout(t0, t1, f, make_seq{}, make_range{}, make_range{}); -+} -+ -+// -+// Coalesce and Filter -+// -+ -+namespace detail { -+ -+// Look at each element and the front of the stack (in order of priority) -+// front(NewLayout) get(Layout) -+// s0:d0 _1:d1 => continue -+// _1:d0 s1:d1 => replace_front s1:d1 -+// s0:s1*d1 s1:d1 => replace_front s0*s1:d1 -+// s0:d0 s1:d1 => prepend s1:d1 -+// -+// @pre OldShape and OldStride are flat -+template -+CUTE_HOST_DEVICE constexpr -+auto -+bw_coalesce(OldShape const& old_shape, OldStride const& old_stride, -+ NewShape const& new_shape, NewStride const& new_stride) -+{ -+ if constexpr (I == -1) { -+ // Base case, we're done -+ if constexpr (is_constant<1, NewShape>::value) { -+ return Layout<_1,_0>{}; -+ } else { -+ return Layout{new_shape,new_stride}; -+ } -+ } else if constexpr (is_constant<1, decltype(get(old_shape))>::value) { -+ // shape(layout) == _1, skip it and continue -+ return bw_coalesce(old_shape, old_stride, new_shape, new_stride); -+ } else if constexpr (is_constant<1, NewShape>::value) { -+ // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) -+ return bw_coalesce(old_shape, old_stride, get(old_shape), get(old_stride)); -+ } else if constexpr (is_constant(old_shape) * get(old_stride) == get<0>(new_stride))>::value) { -+ // Merge modes because the shapes and strides match -+ return bw_coalesce(old_shape, old_stride, -+ replace_front(new_shape, get(old_shape) * get<0>(new_shape)), -+ replace_front(new_stride, get(old_stride))); -+ } else { -+ // Can't replace or merge, so prepend a new mode -+ return bw_coalesce(old_shape, old_stride, -+ prepend(new_shape, get(old_shape)), -+ prepend(new_stride, get(old_stride))); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+// Combine all the modes that are possible to combine -+// Does not respect the profile of the layout, but does preserve total size -+template -+CUTE_HOST_DEVICE constexpr -+auto -+coalesce(Layout const& layout) -+{ -+ auto flat_shape = flatten(layout.shape()); -+ auto flat_stride = flatten(layout.stride()); -+ -+ constexpr int R = decltype(rank(flat_shape))::value; -+ return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); -+} -+ -+// Apply coalesce at the terminals of trg_profile -+template -+CUTE_HOST_DEVICE constexpr -+auto -+coalesce(Layout const& layout, IntTuple const& trg_profile) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value <= Layout::rank); -+ return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce(l,t); }); -+ } else { -+ return coalesce(layout); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Replace the modes in layout that have a 0-stride with a 1-size -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter_zeros(Layout const& layout) -+{ -+ return make_layout(filter_zeros(layout.stride(), layout.shape()), layout.stride()); -+} -+ -+// Remove all of the 0-strides and 1-sizes -+// Return 1-shape if empty -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter(Layout const& layout) -+{ -+ return coalesce(filter_zeros(layout)); -+} -+ -+// Apply filter at the terminals of trg_profile -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter(Layout const& layout, IntTuple const& trg_profile) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value <= Layout::rank); -+ return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return filter(l,t); }); -+ } else { -+ return filter(layout); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Append, Prepend, Replace -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+append(Layout const& layout, -+ Layout const& x = {}) -+{ -+ return make_layout(append(layout.shape(), x.shape()), -+ append(layout.stride(), x.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+prepend(Layout const& layout, -+ Layout const& x = {}) -+{ -+ return make_layout(prepend(layout.shape(), x.shape()), -+ prepend(layout.stride(), x.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+replace(Layout const& layout, -+ Layout const& x) -+{ -+ return make_layout(replace(layout.shape(), x.shape()), -+ replace(layout.stride(), x.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+group(Layout const& layout) -+{ -+ return make_layout(group(layout.shape()), -+ group(layout.stride())); -+} -+ -+// -+// Composition of two layouts: lhs o rhs -+// @post compatible(rhs, result) -+// @post result(c) = lhs(rhs(c)) -+// for all c in the domain of result -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Layout const& lhs, -+ RShape const& rhs_shape, RStride const& rhs_stride) -+{ -+ if constexpr (is_tuple::value) { -+ // Apply the right-distributivity of Layout composition -+ return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { return composition(lhs, s, d); }); -+ } else -+ if constexpr (is_scaled_basis::value) { -+ // Special case for a ScaledBasis stride -+ return composition(get(lhs), rhs_shape, rhs_stride.value()); -+ } else -+ if constexpr (is_integral::value) { -+ // Integral Rstride (and RShape) -+ -+ // NOTE: Should only flatten once for efficiency -+ auto flat_shape = flatten(lhs.shape()); -+ auto flat_stride = flatten(lhs.stride()); -+ [[maybe_unused]] constexpr int R = rank(flat_shape); -+ -+ if constexpr (is_constant<0, RStride>::value) { -+ // Special case shortcut for any static stride-0 -+ return Layout{rhs_shape, rhs_stride}; -+ } else -+ if constexpr (is_integral::value) { -+ // Special case shortcut for any integral LShape -+ auto result_stride = rhs_stride * flat_stride; -+ return Layout{rhs_shape, result_stride}; -+ } else -+ if constexpr (is_constant<1, RStride>::value) { -+ // Special case shortcut for any static stride-1 -+ auto result_shape_0 = take<0,R-1>(flat_shape); -+ -+ // Mod out the rhs_shape from the lhs.shape() -+ auto const [result_shape_1, rest_shape] = fold(result_shape_0, make_tuple(make_tuple(), rhs_shape), -+ [] (auto const& init, auto const& si) { -+ return make_tuple(append(get<0>(init), cute::min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); -+ }); -+ -+ // Jump into coalesce and append (rest_shape, get(lhs.stride()) -+ return detail::bw_coalesce(result_shape_1, flat_stride, rest_shape, get(flat_stride)); -+ } else -+ { -+ // General case -+ auto result_shape_0 = take<0,R-1>(flat_shape); -+ auto result_stride_0 = take<0,R-1>(flat_stride); -+ -+ // Divide out the rhs_stride from the lhs.shape() -+ auto const [result_shape_1, rest_stride] = fold(result_shape_0, make_tuple(make_tuple(), rhs_stride), -+ [] (auto const& init, auto const& di) { -+ return make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); -+ }); -+ -+ // Apply any lhs.shape() changes to the stride -+ auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); -+ -+ // Mod out the rhs_shape from the lhs.shape() -+ auto const [result_shape_2, rest_shape] = fold(result_shape_1, make_tuple(make_tuple(), rhs_shape), -+ [] (auto const& init, auto const& si) { -+ return make_tuple(append(get<0>(init), cute::min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); -+ }); -+ -+ // Jump into coalesce and append (rest_shape, rest_stride * get(lhs.stride()) -+ return detail::bw_coalesce(result_shape_2, result_stride_1, rest_shape, rest_stride * get(flat_stride)); -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Layout const& lhs, -+ Layout const& rhs) -+{ -+ //return detail::composition(flatten(lhs), rhs.shape(), rhs.stride()); -+ return detail::composition(lhs, rhs.shape(), rhs.stride()); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Layout const& lhs, -+ IntTuple const& rhs) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value <= Layout::rank); -+ // Drop any modes of lhs that aren't hit by rhs -+ return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq::value>{}, seq<>{}, seq<>{}); -+ } else if constexpr (is_underscore::value) { -+ return lhs; -+ } else { -+ return composition(lhs, make_layout(rhs)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Complement -+// -+// Build the complement of a layout. -+// @post size(@a result) >= @a cosize_hi / size(filter(@a layout))); -+// @post For all i in [1,size(@a result)), -+// @a result(i) < @a result(i-1) -+// For all j in [0, size(@a layout)), -+// @a result(i) != @a layout(j) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+complement(Layout const& layout, CoSizeHi const& cosize_hi) -+{ -+ // Remove the stride-0 modes, the size-1 modes, and flatten the layout -+ auto flat_layout = filter(layout); -+ -+ if constexpr (is_constant<0, decltype(flat_layout.stride())>::value) { -+ // Special case for stride-0 layout -+ return make_layout(cosize_hi); -+ } else { -+ // General case -+ constexpr int R = decltype(rank(flat_layout))::value; -+ static_assert(R == 1 || is_static::value, -+ "Dynamic-stride complement only for rank-1 layouts"); -+ -+ // Should just be a sort and a fold... -+ // Then we could even handle dynamic strides (but they would destroy all static strides) -+ auto result = fold(make_seq{}, -+ make_tuple(flat_layout.shape(), -+ flat_layout.stride(), -+ make_tuple(), -+ make_tuple(Int<1>{})), -+ [](auto const& init, auto i) -+ { -+ auto curr_stride = cute::min(get<1>(init)); -+ auto curr_idx = find(get<1>(init), curr_stride); -+ auto curr_shape = get(get<0>(init)); -+ -+ return make_tuple(remove(get<0>(init)), // Remove the curr shape -+ remove(get<1>(init)), // Remove the curr stride -+ append(get<2>(init), curr_stride / get<3,i>(init)), // new shape = curr_stride / last_stride -+ append(get<3>(init), curr_shape * curr_stride)); // new stride = curr_shape * curr_stride -+ }); -+ -+ // Append the last shape mode -+ auto result_stride = get<3>(result); -+ auto result_shape = append(get<2>(result), get<1,0>(result) / back(result_stride)); // new shape = curr_stride / last_stride -+ -+ // Compute the rest_stride -+ auto rest_stride = get<0,0>(result) * get<1,0>(result); -+ //return make_layout(append(result_shape, ceil_div(cosize_hi, rest_stride)), append(result_stride, rest_stride)); -+ // Jump into coalesce and append (ceil_div(cosize_hi, rest_stride), rest_stride) -+ return detail::bw_coalesce(result_shape, result_stride, ceil_div(cosize_hi, rest_stride), rest_stride); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+complement(Layout const& layout) -+{ -+ return complement(layout, cosize(layout)); -+} -+ -+// -+// Right-Inverse and Left-Inverse -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+inverse_seq(Shape const& shape, Stride const& stride, seq) -+{ -+ if constexpr (I == decltype(rank(stride))::value) { -+ return seq{}; -+ } else { -+ //auto next_stride = get(shape) * get(stride); -+ using next_stride = decltype(get(shape) * get(stride)); // NOTE: WAR for g++-7 -+ -+ if constexpr (is_static::value) { -+ auto next_idx = find_if(stride, [](auto a) { return is_constant{}; }); -+ return inverse_seq(shape, stride, seq{}); -+ } else { -+ return seq{}; -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+// -+// Build the right-inverse of a layout -+// @pre is_static -+// @result A layout @a result such that -+// @a layout(@a result(i)) == i for all i < size(@a result) -+// @result A layout @a result such that -+// composition(@a layout, @a result) is identical to make_layout(shape(result)) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+right_inverse(Layout const& layout) -+{ -+ auto flat_layout = coalesce(layout); -+ auto astride = transform_leaf(flat_layout.stride(), abs_fn{}); -+ -+ // Find Int<1>{}, the starting idx, and follow the strides to gen inverse_seq -+ auto next_I = find_if(astride, [](auto a) { return is_constant<1, decltype(a)>{}; }); -+ [[maybe_unused]] auto iseq = detail::inverse_seq(flat_layout.shape(), astride, seq<>{}); -+ -+ if constexpr (tuple_size::value == 0) { -+ return Layout<_1,_0>{}; // Empty case, nothing found -+ } else { -+ // Generate the corresponding new strides and construct -+ auto rstride = compact_col_major(flat_layout.shape()); -+ return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), -+ unwrap(transform(iseq, [&](auto i) { return signum(stride(flat_layout)) * get(rstride); }))); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+CUTE_HOST_DEVICE constexpr -+auto -+right_inverse(Underscore const& _) -+{ -+ return _; -+} -+ -+// -+// Build the left-inverse of a layout -+// @pre is_static -+// @pre not has_int0 // @a layout has no 0-strides (is injective) -+// @result A layout @a result such that -+// @a result(@a layout(i)) == i for all i < size(@a layout) -+// @result A layout @a result such that -+// composition(@a result, @a layout) is identical to make_layout(shape(layout)) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+left_inverse(Layout const& layout) -+{ -+ return right_inverse(make_layout(layout, complement(layout))); -+} -+ -+CUTE_HOST_DEVICE constexpr -+auto -+left_inverse(Underscore const& _) -+{ -+ return _; -+} -+ -+// -+// Max Common Vector -+// -+ -+/* Return Int such that N is the maximum number of continguous elements -+ * that logically correspond in the layouts of @a a and @a b. This is, -+ * the number of elements that could reasonably be "vectorized" in the layouts. -+ * -+ * @returns Int with N >= 1 -+ * @post For all 0 <= n < N, a(b[n]) == n (NOTE: Problems with negative strides/coords in this post-condition) -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+max_common_vector(Layout const& a, Layout const& b) -+{ -+ if constexpr (is_static>::value && -+ is_static>::value) -+ { -+ auto result = coalesce(composition(a, right_inverse(b))); -+ -+ if constexpr (is_constant<1, decltype(stride<0>(result))>::value) { -+ return shape<0>(result); -+ } else { -+ return Int<1>{}; -+ } -+ } else { -+ // Dynamic case NOTE: could weaken if we assume dynamic strides are large and multiples of the vector -+ return Int<1>{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Zip -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip(Layout const& layout) -+{ -+ return make_layout(zip(layout.shape()), -+ zip(layout.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip(Layout const& layoutA, -+ Layout const& layoutB) -+{ -+ return make_layout(zip(layoutA.shape(), layoutB.shape()), -+ zip(layoutA.stride(), layoutB.stride())); -+} -+ -+// -+// Tile unzip -+// Logical product and logical divide (on layouts) produce rank-2 results by design. -+// Follow the profile of @a tile and zip the rank-2 modes located at the terminals into -+// their own mode. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tile_unzip(Layout const& layout, -+ IntTuple const& tile) -+{ -+ return make_layout(zip2_by(layout.shape(), tile), -+ zip2_by(layout.stride(), tile)); -+} -+ -+// -+// Logical divide -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_divide(Layout const& layout, -+ Layout const& tile) -+{ -+ //CUTE_STATIC_ASSERT_V(size(layout) % size(tile) == Int<0>{}, -+ // "Tiling does not evenly divide the block"); -+ // NOTE: With tiles that have stride-0, this doesn't have to be true -+ -+ return composition(layout, make_layout(tile, complement(tile, size(layout)))); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_divide(Layout const& layout, -+ IntTuple const& tile) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value <= Layout::rank, "logical_divide: Too many modes in tile."); -+ return transform_layout(layout, tile, [](auto const& l, auto const& t) { return logical_divide(l,t); }); -+ } else if constexpr (is_underscore::value) { -+ return layout; -+ } else if constexpr (is_integral::value) { -+ return logical_divide(layout, make_layout(tile)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Convenience operator -+// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) -+// by gathering the tile modes and residuals into a rank-2 result. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zipped_divide(Layout const& layout, -+ Tile const& tile) -+{ -+ return tile_unzip(logical_divide(layout, tile), tile); -+} -+ -+// Same as zipped_divide, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tiled_divide(Layout const& layout, -+ Tile const& tile) -+{ -+ auto div = zipped_divide(layout, tile); -+ -+ auto R = rank<1>(div); -+ return div(_, repeat(_)); -+} -+ -+// -+// Logical product -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_product(Layout const& layout, -+ Layout const& tile) -+{ -+ return make_layout(layout, composition(complement(layout, size(layout)*cosize(tile)), tile)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_product(Layout const& layout, -+ IntTuple const& tile) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value <= Layout::rank); -+ return transform_layout(layout, tile, [](auto const& l, auto const& t) { return logical_product(l,t); }); -+ } else if constexpr (is_underscore::value) { -+ return layout; -+ } else if constexpr (is_integral::value) { -+ return logical_product(layout, make_layout(tile)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Convenience operator -+// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) -+// by gathering the block modes and products into a rank-2 result. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zipped_product(Layout const& layout, -+ Tile const& tile) -+{ -+ return tile_unzip(logical_product(layout, tile), tile); -+} -+ -+// Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tiled_product(Layout const& layout, -+ Tile const& tile) -+{ -+ auto div = zipped_product(layout, tile); -+ -+ auto R = rank(tile); -+ return div(_, repeat(_)); -+} -+ -+// Attempts to reproduce layout "block" over layout "layout" -+// That is, think of every element of "layout" as a "block" -+// and return the layout of the resulting structure -+template -+CUTE_HOST_DEVICE constexpr -+auto -+blocked_product(Layout const& block, -+ Layout const& layout) -+{ -+ constexpr int R = cute::max(rank_v, rank_v); -+ auto padded_block = append(block); -+ auto padded_layout = append(layout); -+ -+ auto result = logical_product(padded_block, padded_layout); -+ -+ return coalesce(zip(get<0>(result), get<1>(result)), repeat(Int<1>{})); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+raked_product(Layout const& block, -+ Layout const& layout) -+{ -+ constexpr int R = cute::max(rank_v, rank_v); -+ auto padded_block = append(block); -+ auto padded_layout = append(layout); -+ -+ auto result = logical_product(padded_block, padded_layout); -+ -+ return coalesce(zip(get<1>(result), get<0>(result)), repeat(Int<1>{})); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tile_to_shape(Layout const& layout, -+ TrgShape const& trg_shape, -+ ModeOrder const& ord_shape = {}) -+{ -+ CUTE_STATIC_ASSERT_V(rank(layout) <= rank(trg_shape), "Rank of layout must be <= rank of target shape."); -+ constexpr int R = rank_v; -+ -+ auto padded_layout = append(layout); -+ -+ auto layout_shape = product_each(padded_layout.shape()); -+ auto target_shape = product_each(trg_shape); -+ -+ // Assert proper division -+ CUTE_STATIC_ASSERT_V(sum(transform(target_shape, layout_shape, modulus{})) == Int<0>{}, -+ "Layout shape does not divide the target shape."); -+ -+ auto product_shape = shape_div(target_shape, layout_shape); -+ -+ return coalesce(blocked_product(padded_layout, make_ordered_layout(product_shape, ord_shape)), product_shape); -+} -+ -+// -+// Upcast -+// For stride-1 mode, divide size by N. Divide all other strides by N. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(Shape const& shape, Stride const& stride) -+{ -+ if constexpr (is_tuple::value) { // tuple stride -+ return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); -+ } else if constexpr (is_constant<0, Stride>::value) { // static-0 stride -+ return Layout{shape,stride}; -+ } else if constexpr (is_static::value) { // static stride -+ return make_layout(shape_div(shape, shape_div(Int{}, abs(stride))), -+ shape_div(stride, Int{})); -+ } else { // dynamic stride -+ // assume dynamic strides are larger than N and divisible -+ // assert(stride % N == 0); -+ return make_layout(shape, safe_div(stride, Int{})); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(Layout const& layout) -+{ -+ return upcast(layout.shape(), layout.stride()); -+} -+ -+// -+// Downcast -+// For stride-1 mode, multiply size by N. Multiply all other strides by N. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(Shape const& shape, Stride const& stride) -+{ -+ if constexpr (is_tuple::value) { -+ return transform_layout(shape, stride, [](auto const& s, auto const& d) { return downcast(s,d); }); -+ } else if constexpr (is_constant<1, Stride>::value || is_constant<-1, Stride>::value) { -+ return make_layout(shape * Int{}, stride); -+ } else { -+ return make_layout(shape, stride * Int{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(Layout const& layout) -+{ -+ CUTE_STATIC_ASSERT(has_int1::value, "Downcast requires adjacent elements"); -+ return downcast(layout.shape(), layout.stride()); -+} -+ -+// -+// Recast -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(Layout const& layout) -+{ -+ if constexpr (sizeof(NewType) == sizeof(OldType)) { -+ return layout; -+ } else if constexpr (sizeof(NewType) > sizeof(OldType)) { -+ static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType"); -+ return upcast(layout); -+ } else if constexpr (sizeof(NewType) < sizeof(OldType)) { -+ static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType"); -+ return downcast(layout); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print(Layout const& layout) -+{ -+ print(layout.shape()); print(":"); print(layout.stride()); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout const& layout) -+{ -+ return os << shape(layout) << ":" << stride(layout); -+} -+ -+// Generic 2D Layout to console table -+template -+CUTE_HOST_DEVICE -+void -+print_layout(Layout const& layout) // (m,n) -> idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); -+ -+ int idx_width = num_digits(cosize(layout)) + 2; -+ const char* delim = "+-----------------------"; -+ -+ print(layout); print("\n"); -+ -+ // Column indices -+ print(" "); -+ for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); } -+ printf("\n"); -+ -+ // Print out A m-by-n -+ for (int m = 0; m < size<0>(layout); ++m) { -+ // Header -+ print(" "); -+ for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } -+ printf("+\n"); -+ // Values -+ printf("%2d ", m); // Row indices -+ for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); } -+ printf("|\n"); -+ } -+ // Footer -+ print(" "); -+ for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } -+ printf("+\n"); -+} -+ -+// Generic ThrVal 2D Layout to console table -+template -+CUTE_HOST_DEVICE -+void -+print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) and tid -> thr_idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); -+ -+ print(layout); print("\n"); -+ print(thrid); print("\n"); -+ -+ // Print out m-by-n -+ for (int m = 0; m < size<0>(layout); ++m) { -+ // Header -+ for (int n = 0; n < size<1>(layout); ++n) printf("+------"); -+ printf("+\n"); -+ // Values -+ for (int n = 0; n < size<1>(layout); ++n) printf("|%03d-%02d", int(thrid(layout(m,n) % size(thrid))), int(layout(m,n) / size(thrid))); -+ printf("|\n"); -+ } -+ // Footer -+ for (int n = 0; n < size<1>(layout); ++n) printf("+------"); -+ printf("+\n"); -+} -+ -+// Generic 2D Layout to Latex printer -- B&W 8-value color coding -+template -+CUTE_HOST_DEVICE -+void -+print_latex(Layout const& layout) // (m,n) -> idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); -+ -+ char const* latex_header = -+ "\\documentclass[convert]{standalone}\n" -+ "\\usepackage{tikz}\n\n" -+ "\\begin{document}\n" -+ "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center,font=\\Large}]\n\n"; -+ char const* latex_footer = -+ "\\end{tikzpicture}\n" -+ "\\end{document}\n"; -+ -+ char const* color_map[8] = {"black!00", -+ "black!40", -+ "black!20", -+ "black!60", -+ "black!10", -+ "black!50", -+ "black!30", -+ "black!70"}; -+ -+ // Header -+ printf("%% Layout: "); print(layout); printf("\n"); -+ -+ printf(latex_header); -+ -+ // Layout -+ for (int i = 0; i < size<0>(layout); ++i) { -+ for (int j = 0; j < size<1>(layout); ++j) { -+ int idx = layout(i,j); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {%d};\n", -+ color_map[idx % 8], -+ i, j, -+ idx); -+ } -+ } -+ -+ // Labels -+ for (int i = 0, j = -1; i < size<0>(layout); ++i) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); -+ } -+ for (int j = 0, i = -1; j < size<1>(layout); ++j) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); -+ } -+ -+ // Footer -+ printf(latex_footer); -+} -+ -+// Generic ThrVal 2D Layout to Latex TIKZ -- 8-value color coded by thread -+template -+CUTE_HOST_DEVICE -+void -+print_latex(Layout const& layout, ThrID const& thr) // (m,n) -> (tid,vid) and tid -> thr_idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); -+ -+ char const* latex_header = -+ "\\documentclass[convert]{standalone}\n" -+ "\\usepackage{tikz}\n\n" -+ "\\begin{document}\n" -+ "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; -+ char const* latex_footer = -+ "\\end{tikzpicture}\n" -+ "\\end{document}\n"; -+ -+ char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", -+ "{rgb,255:red,175;green,255;blue,175}", -+ "{rgb,255:red,255;green,255;blue,175}", -+ "{rgb,255:red,255;green,175;blue,175}", -+ "{rgb,255:red,210;green,210;blue,255}", -+ "{rgb,255:red,210;green,255;blue,210}", -+ "{rgb,255:red,255;green,255;blue,210}", -+ "{rgb,255:red,255;green,210;blue,210}"}; -+ -+ // Header -+ printf("%% layout: "); print(layout); printf("\n"); -+ printf("%% thrid: "); print(thr); printf("\n\n"); -+ -+ printf(latex_header); -+ -+ // Layout -+ for (int i = 0; i < size<0>(layout); ++i) { -+ for (int j = 0; j < size<1>(layout); ++j) { -+ int thrid = layout(i,j) % size(thr); -+ int val_idx = layout(i,j) / size(thr); -+ int thr_idx = thr(thrid); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[thr_idx % 8], -+ i, j, -+ thr_idx, val_idx); -+ } -+ } -+ -+ // Labels -+ for (int i = 0, j = -1; i < size<0>(layout); ++i) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); -+ } -+ for (int j = 0, i = -1; j < size<1>(layout); ++j) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); -+ } -+ -+ // Footer -+ printf(latex_footer); -+} -+ -+} // end namespace cute -+ -+// -+// Extended Layouts -+// -+ -+#include -diff --git a/3rdparty/cutlass/include/cute/numeric/arithmetic_tuple.hpp b/3rdparty/cutlass/include/cute/numeric/arithmetic_tuple.hpp -new file mode 100644 -index 0000000..33471e4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/arithmetic_tuple.hpp -@@ -0,0 +1,388 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+template -+struct ArithmeticTuple : tuple -+{ -+ template -+ CUTE_HOST_DEVICE constexpr -+ ArithmeticTuple(ArithmeticTuple const& u) -+ : tuple(static_cast const&>(u)) {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ ArithmeticTuple(tuple const& u) -+ : tuple(u) {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ ArithmeticTuple(U const&... u) -+ : tuple(u...) {} -+}; -+ -+template -+struct is_tuple> : true_type {}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_arithmetic_tuple(T const&... t) { -+ return ArithmeticTuple(t...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+as_arithmetic_tuple(tuple const& t) { -+ return ArithmeticTuple(t); -+} -+ -+// -+// Numeric operators -+// -+ -+// Addition -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ArithmeticTuple const& t, ArithmeticTuple const& u) { -+ constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); -+ return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ArithmeticTuple const& t, tuple const& u) { -+ constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); -+ return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(tuple const& t, ArithmeticTuple const& u) { -+ constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); -+ return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); -+} -+ -+// -+// Special cases -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(constant, ArithmeticTuple const& u) { -+ return u; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ArithmeticTuple const& t, constant) { -+ return t; -+} -+ -+// -+// ArithmeticTupleIterator -+// -+ -+template -+struct ArithmeticTupleIterator -+{ -+ ArithTuple coord_; -+ -+ CUTE_HOST_DEVICE constexpr -+ ArithmeticTupleIterator() : coord_() {} -+ CUTE_HOST_DEVICE constexpr -+ ArithmeticTupleIterator(ArithTuple const& coord) : coord_(coord) {} -+ -+ CUTE_HOST_DEVICE constexpr -+ ArithTuple const& operator*() const { return coord_; } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto operator+(Coord const& c) const { -+ return ArithmeticTupleIterator(coord_ + c); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto operator[](Coord const& c) const { return *(*this + c); } -+}; -+ -+template -+CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) { -+ printf("ArithTuple"); print(iter.coord_); -+} -+ -+// -+// ArithmeticTuple "basis" elements -+// -+ -+// Abstract value: -+// A ScaledBasis is a (at least) rank-N0 ArithmeticTuple: -+// (_0,_0,...,T,_0,...) -+ -+template -+struct ScaledBasis : private tuple -+{ -+ CUTE_HOST_DEVICE constexpr -+ ScaledBasis(T const& t = {}) : tuple(t) {} -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) value() { return get<0>(static_cast &>(*this)); } -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) value() const { return get<0>(static_cast const&>(*this)); } -+ -+ CUTE_HOST_DEVICE static constexpr -+ auto mode() { return Int{}; } -+}; -+ -+template -+struct is_scaled_basis : false_type {}; -+template -+struct is_scaled_basis> : true_type {}; -+ -+template -+struct is_integral> : true_type {}; -+ -+template -+CUTE_HOST_DEVICE constexpr auto -+basis_value(T const& e) { -+ return e; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr auto -+basis_value(ScaledBasis const& e) { -+ return basis_value(e.value()); -+} -+ -+namespace detail { -+ -+template -+struct Basis; -+ -+template <> -+struct Basis<> { -+ using type = Int<1>; -+}; -+ -+template -+struct Basis { -+ using type = ScaledBasis::type, N>; -+}; -+ -+} // end namespace detail -+ -+template -+using E = typename detail::Basis::type; -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+as_arithmetic_tuple(T const& t, seq, seq) { -+ return make_arithmetic_tuple((void(I),Int<0>{})..., t, (void(J),Int<0>{})...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+as_arithmetic_tuple(ArithmeticTuple const& t, seq, seq) { -+ return make_arithmetic_tuple(get(t)..., (void(J),Int<0>{})...); -+} -+ -+} // end namespace detail -+ -+// Turn a ScaledBases into a rank-M ArithmeticTuple -+// with N prefix 0s: (_0,_0,...N...,_0,T,_0,...,_0,_0) -+template -+CUTE_HOST_DEVICE constexpr -+auto -+as_arithmetic_tuple(ScaledBasis const& t) { -+ static_assert(M > N, "Mismatched ranks"); -+ return detail::as_arithmetic_tuple(t.value(), make_seq{}, make_seq{}); -+} -+ -+// Turn an ArithmeticTuple into a rank-M ArithmeticTuple -+// with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0) -+template -+CUTE_HOST_DEVICE constexpr -+auto -+as_arithmetic_tuple(ArithmeticTuple const& t) { -+ static_assert(M >= sizeof...(T), "Mismatched ranks"); -+ return detail::as_arithmetic_tuple(t, make_seq{}, make_seq{}); -+} -+ -+// Return... -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_basis_like(Shape const& shape) -+{ -+ if constexpr (is_integral::value) { -+ return Int<1>{}; -+ } else { -+ // Generate bases for each rank of shape -+ return transform(tuple_seq{}, [&](auto I) { -+ // Generate bases for each rank of shape_i and add an i on front -+ constexpr int i = decltype(I)::value; // NOTE: nvcc workaround -+ return transform_leaf(make_basis_like(get(shape)), [&](auto e) { return ScaledBasis{}; }); -+ }); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Equality -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(ScaledBasis, Int) { -+ return false_type{}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(Int, ScaledBasis) { -+ return false_type{}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(ScaledBasis const& t, ScaledBasis const& u) { -+ return bool_constant{} && t.value() == u.value(); -+} -+ -+// Multiplication -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+operator*(A const& a, ScaledBasis const& e) { -+ return ScaledBasis{a*e.value()}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+operator*(ScaledBasis const& e, B const& b) { -+ return ScaledBasis{e.value()*b}; -+} -+ -+// Addition -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ScaledBasis const& t, ArithmeticTuple const& u) { -+ constexpr int R = cute::max(N+1, int(sizeof...(U))); -+ return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ArithmeticTuple const& t, ScaledBasis const& u) { -+ constexpr int R = cute::max(int(sizeof...(T)), M+1); -+ return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ScaledBasis const& t, ScaledBasis const& u) { -+ constexpr int R = cute::max(N+1,M+1); -+ return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(constant, ScaledBasis const& u) { -+ return u; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ScaledBasis const& t, constant) { -+ return t; -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print(ScaledBasis const& e) { -+ printf("%d:", N); print(e.value()); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) { -+ return os << N << ":" << e.value(); -+} -+ -+} // end namespace cute -+ -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+ : std::tuple_element> -+{}; -+ -+} // end namespace std -diff --git a/3rdparty/cutlass/include/cute/numeric/bfloat.hpp b/3rdparty/cutlass/include/cute/numeric/bfloat.hpp -new file mode 100644 -index 0000000..94f64ab ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/bfloat.hpp -@@ -0,0 +1,51 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute { -+ -+using cutlass::bfloat16_t; -+ -+// -+// Display utilities -+// -+ -+CUTE_HOST std::ostream& operator<<(std::ostream& os, bfloat16_t const& v) -+{ -+ return os << float(v); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/complex.hpp b/3rdparty/cutlass/include/cute/numeric/complex.hpp -new file mode 100644 -index 0000000..3790ebd ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/complex.hpp -@@ -0,0 +1,163 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+//#if defined(__CUDA_ARCH__) -+//# include -+//#else -+//# include -+//#endif -+ -+// With CUDA 11.4, builds show spurious "-Wconversion" warnings -+// on line 656 of thrust/detail/type_traits.h. -+// These pragmas suppress the warnings. -+#pragma GCC diagnostic push -+#pragma GCC diagnostic ignored "-Wconversion" -+#include -+#pragma GCC diagnostic pop -+ -+#include -+ -+namespace cute -+{ -+ -+//#if defined(__CUDA_ARCH__) -+//template -+//using complex = cuda::std::complex; -+//#else -+//template -+//using complex = std::complex; -+//#endif -+ -+//template -+//using complex = thrust::complex; -+ -+using thrust::complex; -+ -+template -+CUTE_HOST_DEVICE -+T real(complex const& z) { -+ return z.real(); -+} -+ -+template -+CUTE_HOST_DEVICE -+T imag(complex const& z) { -+ return z.imag(); -+} -+ -+template -+CUTE_HOST_DEVICE -+complex conj(complex const& z) { -+ return complex(real(z), -imag(z)); -+} -+ -+// cute::conj forwards scalars -+template -+CUTE_HOST_DEVICE -+T conj(T z) { -+ return z; -+} -+ -+//CUTE_HOST_DEVICE constexpr -+//float conj(float z) { return z; } -+//CUTE_HOST_DEVICE constexpr -+//double conj(double z) { return z; } -+ -+/// Fused multiply-add for complex numbers -+template -+CUTE_HOST_DEVICE constexpr -+void -+fma(complex & d, -+ complex const& a, -+ complex const& b, -+ complex const& c) -+{ -+ d.real(c.real() + a.real() * b.real()); -+ d.imag(c.imag() + a.real() * b.imag()); -+ d.real(d.real() - a.imag() * b.imag()); -+ d.imag(d.imag() + a.imag() * b.real()); -+} -+ -+/// Fused multiply-add for triplets -+template -+CUTE_HOST_DEVICE constexpr -+void -+fma(complex const& a, -+ complex const& b, -+ complex & c) -+{ -+ return fma(c, a, b, c); -+} -+ -+/// Used to determine the real-valued underlying type of a numeric type T -+template -+struct RealType { -+ using Type = T; -+}; -+ -+/// Partial specialization for complex-valued type -+template -+struct RealType> { -+ using Type = T; -+}; -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct is_complex { -+ static bool const value = false; -+}; -+ -+template -+struct is_complex> { -+ static bool const value = true; -+}; -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -+// Display utilities -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, complex const& z) -+{ -+ T _r = z.real(); -+ T _i = z.imag(); -+ -+ if (bool(_i)) { -+ return os << _r << "+i" << _i; -+ } else { -+ return os << _r; -+ } -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/float8.hpp b/3rdparty/cutlass/include/cute/numeric/float8.hpp -new file mode 100644 -index 0000000..3fa471d ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/float8.hpp -@@ -0,0 +1,43 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute { -+ -+using cutlass::float_e4m3_t; -+using cutlass::float_e5m2_t; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/half.hpp b/3rdparty/cutlass/include/cute/numeric/half.hpp -new file mode 100644 -index 0000000..704ba28 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/half.hpp -@@ -0,0 +1,41 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+#include -+ -+namespace cute { -+ -+using cutlass::half_t; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/int.hpp b/3rdparty/cutlass/include/cute/numeric/int.hpp -new file mode 100644 -index 0000000..a08297f ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/int.hpp -@@ -0,0 +1,129 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+// -+// Signed integers -+// -+ -+using int8_t = std::int8_t; -+using int16_t = std::int16_t; -+using int32_t = std::int32_t; -+using int64_t = std::int64_t; -+ -+template struct int_bit; -+template <> struct int_bit< 2> { using type = cute::int2b_t; }; -+template <> struct int_bit< 4> { using type = cute::int4b_t; }; -+template <> struct int_bit< 8> { using type = int8_t; }; -+template <> struct int_bit< 16> { using type = int16_t; }; -+template <> struct int_bit< 32> { using type = int32_t; }; -+template <> struct int_bit< 64> { using type = int64_t; }; -+ -+template -+using int_bit_t = typename int_bit::type; -+ -+template -+using int_byte = int_bit<8*N>; -+ -+template -+using int_byte_t = typename int_byte::type; -+ -+// -+// Unsigned integers -+// -+ -+using uint8_t = std::uint8_t; -+using uint16_t = std::uint16_t; -+using uint32_t = std::uint32_t; -+using uint64_t = std::uint64_t; -+ -+template struct uint_bit; -+template <> struct uint_bit< 1> { using type = cute::uint1b_t; }; -+template <> struct uint_bit< 2> { using type = cute::uint2b_t; }; -+template <> struct uint_bit< 4> { using type = cute::uint4b_t; }; -+template <> struct uint_bit< 8> { using type = uint8_t; }; -+template <> struct uint_bit< 16> { using type = uint16_t; }; -+template <> struct uint_bit< 32> { using type = uint32_t; }; -+template <> struct uint_bit< 64> { using type = uint64_t; }; -+template <> struct uint_bit<128> { using type = cute::uint128_t; }; -+ -+template -+using uint_bit_t = typename uint_bit::type; -+ -+template -+using uint_byte = uint_bit<8*N>; -+ -+template -+using uint_byte_t = typename uint_byte::type; -+ -+// -+// sizeof_bytes -+// -+ -+template -+struct sizeof_bytes { -+ static constexpr std::size_t value = sizeof(T); -+}; -+template -+static constexpr int sizeof_bytes_v = sizeof_bytes::value; -+ -+// -+// sizeof_bits -+// -+ -+template -+struct sizeof_bits { -+ static constexpr std::size_t value = sizeof(T) * 8; -+}; -+template <> -+struct sizeof_bits { -+ static constexpr std::size_t value = 1; -+}; -+template -+struct sizeof_bits> { -+ static constexpr std::size_t value = Bits; -+}; -+template -+static constexpr int sizeof_bits_v = sizeof_bits::value; -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/integer_sequence.hpp b/3rdparty/cutlass/include/cute/numeric/integer_sequence.hpp -new file mode 100644 -index 0000000..73a83f7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/integer_sequence.hpp -@@ -0,0 +1,139 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include // std::integer_sequence -+ -+#include -+ -+namespace cute -+{ -+ -+using std::integer_sequence; -+using std::make_integer_sequence; -+ -+namespace detail { -+ -+template -+struct make_integer_range_impl; -+ -+template -+struct make_integer_range_impl, Begin> { -+ using type = integer_sequence; -+}; -+ -+} // end namespace detail -+ -+template -+using make_integer_range = typename detail::make_integer_range_impl< -+ T, -+ make_integer_sequence 0) ? (End-Begin) : 0>, -+ Begin>::type; -+ -+// -+// Common aliases -+// -+ -+// int_sequence -+ -+template -+using int_sequence = integer_sequence; -+ -+template -+using make_int_sequence = make_integer_sequence; -+ -+template -+using make_int_range = make_integer_range; -+ -+// index_sequence -+ -+template -+using index_sequence = integer_sequence; -+ -+template -+using make_index_sequence = make_integer_sequence; -+ -+template -+using make_index_range = make_integer_range; -+ -+// -+// Shortcuts -+// -+ -+template -+using seq = int_sequence; -+ -+template -+using make_seq = make_int_sequence; -+ -+template -+using make_range = make_int_range; -+ -+template -+using tuple_seq = make_seq>::value>; -+ -+} // end namespace cute -+ -+ -+// -+// Specialize tuple-related functionality for cute::integer_sequence -+// -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+template -+CUTE_HOST_DEVICE constexpr -+std::tuple_element_t> -+get(integer_sequence) { -+ static_assert(I < sizeof...(Ints), "Index out of range"); -+ return {}; -+} -+ -+} // end namespace cute -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+ : std::tuple_element...>> -+{}; -+ -+} // end namespace std -diff --git a/3rdparty/cutlass/include/cute/numeric/integer_subbyte.hpp b/3rdparty/cutlass/include/cute/numeric/integer_subbyte.hpp -new file mode 100644 -index 0000000..3d24a95 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/integer_subbyte.hpp -@@ -0,0 +1,233 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include -+#include -+ -+namespace cute { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct integer_subbyte -+{ -+ /// Storage type -+ using Storage = uint8_t; -+ -+ /// Number of bits -+ static_assert(Bits <= 8*sizeof(Storage), "Require a subbyte of bits in integer_subbyte"); -+ -+ /// External type -+ using xint_t = typename std::conditional::type; -+ -+ /// Bitmask for truncation from larger integers -+ static constexpr Storage bits_mask_ = Storage((1 << Bits) - 1); -+ /// Bitmask for the sign bit -+ static constexpr Storage sign_mask_ = Storage((Signed ? 1 : 0) << (Bits - 1)); -+ -+ // -+ // Data members -+ // -+ -+ Storage storage; -+ -+ // -+ // Methods -+ // -+ -+ /// No operation -+ CUTE_HOST_DEVICE constexpr -+ integer_subbyte() {} -+ -+ /// Conversion from integer type -+ CUTE_HOST_DEVICE constexpr -+ integer_subbyte(int value) // NOTE: Sign extension? -+ : storage(reinterpret_cast(value) & bits_mask_) {} -+ -+ CUTE_HOST_DEVICE constexpr -+ integer_subbyte(unsigned value) -+ : storage(reinterpret_cast(value) & bits_mask_) {} -+ -+ /// Convert to int or unsigned -+ CUTE_HOST_DEVICE constexpr -+ operator xint_t() const { -+ if (sign_mask_ & storage) { // Sign extend -+ return xint_t(storage) | ~xint_t(bits_mask_); -+ } else { -+ return xint_t(storage); -+ } -+ } -+ -+ /// Equality -+ CUTE_HOST_DEVICE constexpr -+ bool operator==(integer_subbyte const& rhs) const { -+ return storage == rhs.storage; -+ } -+ -+ /// Inequality -+ CUTE_HOST_DEVICE constexpr -+ bool operator!=(integer_subbyte const& rhs) const { -+ return storage != rhs.storage; -+ } -+ -+ /// Less than or equal -+ CUTE_HOST_DEVICE constexpr -+ bool operator<=(integer_subbyte const& rhs) const { -+ if (sign_mask_ & storage) { -+ return !(rhs.storage < storage); -+ } else { -+ return storage < rhs.storage; -+ } -+ } -+ -+ /// Less than -+ CUTE_HOST_DEVICE constexpr -+ bool operator<(integer_subbyte const& rhs) const { -+ if (sign_mask_ & storage) { -+ return !(rhs.storage <= storage); -+ } else { -+ return storage < rhs.storage; -+ } -+ } -+ -+ /// Greater than or equal -+ CUTE_HOST_DEVICE constexpr -+ bool operator>=(integer_subbyte const& rhs) const { -+ return !(*this < rhs); -+ } -+ -+ /// Greater than -+ CUTE_HOST_DEVICE constexpr -+ bool operator>(integer_subbyte const& rhs) const { -+ return !(*this <= rhs); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 1-bit unsigned integer type -+using uint1b_t = integer_subbyte<1, false>; -+ -+/// 2-bit integer type -+using int2b_t = integer_subbyte<2, true>; -+ -+/// 2-bit unsigned integer type -+using uint2b_t = integer_subbyte<2, false>; -+ -+/// 4-bit integer type -+using int4b_t = integer_subbyte<4, true>; -+ -+/// 4-bit unsigned integer type -+using uint4b_t = integer_subbyte<4, false>; -+ -+/// 1-bit binary type -+using bin1_t = bool; -+ -+} // namespace cute -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if !defined(__CUDACC_RTC__) -+ -+#include -+ -+namespace std { -+ -+template <> -+struct numeric_limits { -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint1b_t const lowest() noexcept { return 0; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint1b_t const min() noexcept { return 0; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint1b_t const max() noexcept { return 1; } -+ static constexpr bool is_integer = true; -+ static constexpr bool is_signed = false; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTE_HOST_DEVICE static constexpr -+ cute::int2b_t lowest() noexcept { return -2; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::int2b_t min() noexcept { return -2; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::int2b_t max() noexcept { return 1; } -+ static constexpr bool is_integer = true; -+ static constexpr bool is_signed = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint2b_t const lowest() noexcept { return 0; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint2b_t const min() noexcept { return 0; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint2b_t const max() noexcept { return 3; } -+ static constexpr bool is_integer = true; -+ static constexpr bool is_signed = false; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTE_HOST_DEVICE static constexpr -+ cute::int4b_t lowest() noexcept { return -8; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::int4b_t min() noexcept { return -8; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::int4b_t max() noexcept { return 7; } -+ static constexpr bool is_integer = true; -+ static constexpr bool is_signed = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint4b_t const lowest() noexcept { return 0; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint4b_t const min() noexcept { return 0; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint4b_t const max() noexcept { return 15; } -+ static constexpr bool is_integer = true; -+ static constexpr bool is_signed = false; -+}; -+ -+} // namespace std -+ -+#endif -diff --git a/3rdparty/cutlass/include/cute/numeric/integral_constant.hpp b/3rdparty/cutlass/include/cute/numeric/integral_constant.hpp -new file mode 100644 -index 0000000..106763d ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/integral_constant.hpp -@@ -0,0 +1,414 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+template -+struct constant : std::integral_constant { -+ static constexpr T value = v; -+ using value_type = T; -+ using type = constant; -+ CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } -+ CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } -+}; -+ -+template -+using integral_constant = constant; -+ -+template -+using bool_constant = constant; -+ -+using true_type = bool_constant; -+using false_type = bool_constant; -+ -+// -+// Traits -+// -+ -+// Use std::is_integral to match built-in integral types (int, int64_t, unsigned, etc) -+// Use cute::is_integral to match both built-in integral types AND constant -+ -+template -+struct is_integral : bool_constant::value> {}; -+template -+struct is_integral> : true_type {}; -+ -+// is_static detects if an (abstract) value is defined completely by it's type (no members) -+ -+template -+struct is_static : bool_constant::value> {}; -+ -+// is_constant detects if a type is a constant and if v is equal to a value -+ -+template -+struct is_constant : false_type {}; -+template -+struct is_constant > : bool_constant {}; -+template -+struct is_constant const > : bool_constant {}; -+template -+struct is_constant const&> : bool_constant {}; -+template -+struct is_constant &> : bool_constant {}; -+template -+struct is_constant &&> : bool_constant {}; -+ -+// -+// Specializations -+// -+ -+template -+using Int = constant; -+ -+using _m32 = Int<-32>; -+using _m24 = Int<-24>; -+using _m16 = Int<-16>; -+using _m12 = Int<-12>; -+using _m10 = Int<-10>; -+using _m9 = Int<-9>; -+using _m8 = Int<-8>; -+using _m7 = Int<-7>; -+using _m6 = Int<-6>; -+using _m5 = Int<-5>; -+using _m4 = Int<-4>; -+using _m3 = Int<-3>; -+using _m2 = Int<-2>; -+using _m1 = Int<-1>; -+using _0 = Int<0>; -+using _1 = Int<1>; -+using _2 = Int<2>; -+using _3 = Int<3>; -+using _4 = Int<4>; -+using _5 = Int<5>; -+using _6 = Int<6>; -+using _7 = Int<7>; -+using _8 = Int<8>; -+using _9 = Int<9>; -+using _10 = Int<10>; -+using _12 = Int<12>; -+using _16 = Int<16>; -+using _24 = Int<24>; -+using _32 = Int<32>; -+using _64 = Int<64>; -+using _96 = Int<96>; -+using _128 = Int<128>; -+using _192 = Int<192>; -+using _256 = Int<256>; -+using _512 = Int<512>; -+using _1024 = Int<1024>; -+using _2048 = Int<2048>; -+using _4096 = Int<4096>; -+using _8192 = Int<8192>; -+ -+/***************/ -+/** Operators **/ -+/***************/ -+ -+#define CUTE_LEFT_UNARY_OP(OP) \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ constant \ -+ operator OP (constant) { \ -+ return {}; \ -+ } -+#define CUTE_RIGHT_UNARY_OP(OP) \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ constant \ -+ operator OP (constant) { \ -+ return {}; \ -+ } -+ -+#define CUTE_BINARY_OP(OP) \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ constant \ -+ operator OP (constant, constant) { \ -+ return {}; \ -+ } -+ -+CUTE_LEFT_UNARY_OP(+); -+CUTE_LEFT_UNARY_OP(-); -+CUTE_LEFT_UNARY_OP(~); -+CUTE_LEFT_UNARY_OP(!); -+CUTE_LEFT_UNARY_OP(*); -+ -+CUTE_BINARY_OP( +); -+CUTE_BINARY_OP( -); -+CUTE_BINARY_OP( *); -+CUTE_BINARY_OP( /); -+CUTE_BINARY_OP( %); -+CUTE_BINARY_OP( &); -+CUTE_BINARY_OP( |); -+CUTE_BINARY_OP( ^); -+CUTE_BINARY_OP(<<); -+CUTE_BINARY_OP(>>); -+ -+CUTE_BINARY_OP(&&); -+CUTE_BINARY_OP(||); -+ -+CUTE_BINARY_OP(==); -+CUTE_BINARY_OP(!=); -+CUTE_BINARY_OP( >); -+CUTE_BINARY_OP( <); -+CUTE_BINARY_OP(>=); -+CUTE_BINARY_OP(<=); -+ -+#undef CUTE_BINARY_OP -+#undef CUTE_LEFT_UNARY_OP -+#undef CUTE_RIGHT_UNARY_OP -+ -+// -+// Mixed static-dynamic special cases -+// -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator*(constant, U) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator*(U, constant) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator/(constant, U) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator%(U, constant) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator%(U, constant) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator%(constant, U) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator&(constant, U) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator&(U, constant) { -+ return {}; -+} -+ -+template ::value && !bool(t))> -+CUTE_HOST_DEVICE constexpr -+constant -+operator&&(constant, U) { -+ return {}; -+} -+ -+template ::value && !bool(t))> -+CUTE_HOST_DEVICE constexpr -+constant -+operator&&(U, constant) { -+ return {}; -+} -+ -+template ::value && bool(t))> -+CUTE_HOST_DEVICE constexpr -+constant -+operator||(constant, U) { -+ return {}; -+} -+ -+template ::value && bool(t))> -+CUTE_HOST_DEVICE constexpr -+constant -+operator||(U, constant) { -+ return {}; -+} -+ -+// -+// Named functions from math.hpp -+// -+ -+#define CUTE_NAMED_UNARY_FN(OP) \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ constant \ -+ OP (constant) { \ -+ return {}; \ -+ } -+ -+#define CUTE_NAMED_BINARY_FN(OP) \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ constant \ -+ OP (constant, constant) { \ -+ return {}; \ -+ } \ -+ \ -+ template ::value)> \ -+ CUTE_HOST_DEVICE constexpr \ -+ auto \ -+ OP (constant, U u) { \ -+ return OP(t,u); \ -+ } \ -+ \ -+ template ::value)> \ -+ CUTE_HOST_DEVICE constexpr \ -+ auto \ -+ OP (T t, constant) { \ -+ return OP(t,u); \ -+ } -+ -+CUTE_NAMED_UNARY_FN(abs); -+CUTE_NAMED_UNARY_FN(signum); -+CUTE_NAMED_UNARY_FN(has_single_bit); -+ -+CUTE_NAMED_BINARY_FN(max); -+CUTE_NAMED_BINARY_FN(min); -+CUTE_NAMED_BINARY_FN(shiftl); -+CUTE_NAMED_BINARY_FN(shiftr); -+CUTE_NAMED_BINARY_FN(gcd); -+CUTE_NAMED_BINARY_FN(lcm); -+ -+#undef CUTE_NAMED_UNARY_FN -+#undef CUTE_NAMED_BINARY_FN -+ -+// -+// Other functions -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+constant -+safe_div(constant, constant) { -+ static_assert(t % u == 0, "Static safe_div requires t % u == 0"); -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+safe_div(constant, U u) { -+ return t / u; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+safe_div(T t, constant) { -+ return t / u; -+} -+ -+// cute::true_type prefers standard conversion to std::true_type -+// over user-defined conversion to bool -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+conditional_return(std::true_type, TrueType&& t, FalseType&&) { -+ return static_cast(t); -+} -+ -+// cute::false_type prefers standard conversion to std::false_type -+// over user-defined conversion to bool -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+conditional_return(std::false_type, TrueType&&, FalseType&& f) { -+ return static_cast(f); -+} -+ -+// TrueType and FalseType must have a common type -+template -+CUTE_HOST_DEVICE constexpr -+auto -+conditional_return(bool b, TrueType const& t, FalseType const& f) { -+ return b ? t : f; -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print(integral_constant const&) { -+ printf("_%d", N); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, integral_constant const&) { -+ return os << "_" << N; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/math.hpp b/3rdparty/cutlass/include/cute/numeric/math.hpp -new file mode 100644 -index 0000000..03e8379 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/math.hpp -@@ -0,0 +1,319 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// Common Operations -+// -+ -+template ::value && -+ std::is_arithmetic::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+max(T const& t, U const& u) { -+ return t < u ? u : t; -+} -+ -+template ::value && -+ std::is_arithmetic::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+min(T const& t, U const& u) { -+ return t < u ? t : u; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+abs(T const& t) { -+ if constexpr (std::is_signed::value) { -+ return t < T(0) ? -t : t; -+ } else { -+ return t; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// C++17 operations -+// -+ -+// Greatest common divisor of two integers -+template ::value && -+ std::is_integral::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+gcd(T t, U u) { -+ while (true) { -+ if (t == 0) { return u; } -+ u %= t; -+ if (u == 0) { return t; } -+ t %= u; -+ } -+} -+ -+// Least common multiple of two integers -+template ::value && -+ std::is_integral::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+lcm(T const& t, U const& u) { -+ return (t / gcd(t,u)) * u; -+} -+ -+// -+// C++20 operations -+// -+ -+// Checks if a number is an integral power of two -+template -+CUTE_HOST_DEVICE constexpr -+bool -+has_single_bit(T x) { -+ return x != 0 && (x & (x - 1)) == 0; -+} -+ -+// Smallest number of bits needed to represent the given value -+// bit_width( 0b0000 ) = 0 -+// bit_width( 0b0001 ) = 1 -+// bit_width( 0b0010 ) = 2 -+// bit_width( 0b0011 ) = 2 -+// bit_width( 0b0100 ) = 3 -+// bit_width( 0b0101 ) = 3 -+// bit_width( 0b0110 ) = 3 -+// bit_width( 0b0111 ) = 3 -+template -+CUTE_HOST_DEVICE constexpr -+T -+bit_width(T x) { -+ static_assert(std::is_unsigned::value, "Only to be used for unsigned types."); -+ constexpr int N = (std::numeric_limits::digits == 64 ? 6 : -+ (std::numeric_limits::digits == 32 ? 5 : -+ (std::numeric_limits::digits == 16 ? 4 : -+ (std::numeric_limits::digits == 8 ? 3 : (assert(false),0))))); -+ T r = 0; -+ for (int i = N - 1; i >= 0; --i) { -+ T shift = (x > ((T(1) << (T(1) << i))-1)) << i; -+ x >>= shift; -+ r |= shift; -+ } -+ return r + (x != 0); -+} -+ -+// Smallest integral power of two not less than the given value -+// bit_ceil( 0b00000000 ) = 0b00000001 -+// bit_ceil( 0b00000001 ) = 0b00000001 -+// bit_ceil( 0b00000010 ) = 0b00000010 -+// bit_ceil( 0b00000011 ) = 0b00000100 -+// bit_ceil( 0b00000100 ) = 0b00000100 -+// bit_ceil( 0b00000101 ) = 0b00001000 -+// bit_ceil( 0b00000110 ) = 0b00001000 -+// bit_ceil( 0b00000111 ) = 0b00001000 -+// bit_ceil( 0b00001000 ) = 0b00001000 -+// bit_ceil( 0b00001001 ) = 0b00010000 -+template -+CUTE_HOST_DEVICE constexpr -+T -+bit_ceil(T x) { -+ return x == 0 ? T(1) : (T(1) << bit_width(x - 1)); -+} -+ -+// Largest integral power of two not greater than the given value -+// bit_floor( 0b00000000 ) = 0b00000000 -+// bit_floor( 0b00000001 ) = 0b00000001 -+// bit_floor( 0b00000010 ) = 0b00000010 -+// bit_floor( 0b00000011 ) = 0b00000010 -+// bit_floor( 0b00000100 ) = 0b00000100 -+// bit_floor( 0b00000101 ) = 0b00000100 -+// bit_floor( 0b00000110 ) = 0b00000100 -+// bit_floor( 0b00000111 ) = 0b00000100 -+// bit_floor( 0b00001000 ) = 0b00001000 -+// bit_floor( 0b00001001 ) = 0b00001000 -+template -+CUTE_HOST_DEVICE constexpr -+T -+bit_floor(T x) { -+ return x == 0 ? 0 : (T(1) << (bit_width(x) - 1)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr T rotl(T x, int s); -+template -+CUTE_HOST_DEVICE constexpr T rotr(T x, int s); -+ -+// Computes the result of circular bitwise left-rotation -+template -+CUTE_HOST_DEVICE constexpr -+T -+rotl(T x, int s) { -+ constexpr int N = std::numeric_limits::digits; -+ return s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s); -+} -+ -+// Computes the result of circular bitwise right-rotation -+template -+CUTE_HOST_DEVICE constexpr -+T -+rotr(T x, int s) { -+ constexpr int N = std::numeric_limits::digits; -+ return s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s); -+} -+ -+// Counts the number of consecutive 0 bits, starting from the most significant bit -+// countl_zero( 0b00000000 ) = 8 -+// countl_zero( 0b11111111 ) = 0 -+// countl_zero( 0b00011100 ) = 3 -+template -+CUTE_HOST_DEVICE constexpr -+T -+countl_zero(T x) { -+ return std::numeric_limits::digits - bit_width(x); -+} -+ -+// Counts the number of consecutive 1 bits, starting from the most significant bit -+// countl_one( 0b00000000 ) = 0 -+// countl_one( 0b11111111 ) = 8 -+// countl_one( 0b11100011 ) = 3 -+template -+CUTE_HOST_DEVICE constexpr -+T -+countl_one(T x) { -+ return countl_zero(~x); -+} -+ -+// Counts the number of consecutive 0 bits, starting from the least significant bit -+// countr_zero( 0b00000000 ) = 8 -+// countr_zero( 0b11111111 ) = 0 -+// countr_zero( 0b00011100 ) = 2 -+template -+CUTE_HOST_DEVICE constexpr -+T -+countr_zero(T x) { -+ return x == 0 ? std::numeric_limits::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB -+} -+ -+// Counts the number of consecutive 1 bits, starting from the least significant bit -+// countr_one( 0b00000000 ) = 0 -+// countr_one( 0b11111111 ) = 8 -+// countr_one( 0b11100011 ) = 2 -+template -+CUTE_HOST_DEVICE constexpr -+T -+countr_one(T x) { -+ return countr_zero(~x); -+} -+ -+// Counts the number of 1 bits in an unsigned integer -+// popcount( 0b00000000 ) = 0 -+// popcount( 0b11111111 ) = 8 -+// popcount( 0b00011101 ) = 4 -+template -+CUTE_HOST_DEVICE constexpr -+int -+popcount(T x) { -+ int c = 0; -+ while (x) { -+ ++c; -+ x &= x - 1; // clear the least significant bit set -+ } -+ return c; -+} -+ -+// -+// Custom operations -+// -+ -+// Computes the result of bitwise left-shift -+template -+CUTE_HOST_DEVICE constexpr -+T -+shiftl(T x, int s) { -+ return s >= 0 ? (x << s) : (x >> -s); -+} -+ -+// Computes the result of bitwise right-shift -+template -+CUTE_HOST_DEVICE constexpr -+T -+shiftr(T x, int s) { -+ return s >= 0 ? (x >> s) : (x << -s); -+} -+ -+// Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero. -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+int -+signum(T const& x) { -+ return T(0) < x; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+int -+signum(T const& x) { -+ return (T(0) < x) - (x < T(0)); -+} -+ -+// Safe divide -+// @pre t % u == 0 -+// @result t / u -+template ::value && -+ std::is_integral::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+safe_div(T const& t, U const& u) { -+ //assert(t % u == 0); -+ return t / u; -+} -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/real.hpp b/3rdparty/cutlass/include/cute/numeric/real.hpp -new file mode 100644 -index 0000000..d85e304 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/real.hpp -@@ -0,0 +1,56 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+namespace cute -+{ -+ -+/// Generic fused multiply-add -+template -+CUTE_HOST_DEVICE constexpr -+void -+fma(D& d, A const& a, B const& b, C const& c) -+{ -+ d = a * b + c; -+} -+ -+/// Fused multiply-add for triplets -+template -+CUTE_HOST_DEVICE constexpr -+void -+fma(A const& a, B const& b, C& c) -+{ -+ return fma(c, a, b, c); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/tfloat.hpp b/3rdparty/cutlass/include/cute/numeric/tfloat.hpp -new file mode 100644 -index 0000000..bb68b70 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/tfloat.hpp -@@ -0,0 +1,51 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute { -+ -+using cutlass::tfloat32_t; -+ -+// -+// Display utilities -+// -+ -+CUTE_HOST std::ostream& operator<<(std::ostream& os, tfloat32_t const& v) -+{ -+ return os << float(v); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/uint128.hpp b/3rdparty/cutlass/include/cute/numeric/uint128.hpp -new file mode 100644 -index 0000000..fb02441 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/uint128.hpp -@@ -0,0 +1,259 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#include -+#include -+#include -+#include -+#endif -+ -+#include -+ -+/// Optionally enable GCC's built-in type -+#if defined(__x86_64) && !defined(__CUDA_ARCH__) -+# if defined(__GNUC__) && 0 -+# define CUTE_UINT128_NATIVE -+# elif defined(_MSC_VER) -+# define CUTE_INT128_ARITHMETIC -+# include -+# endif -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cute { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///! Unsigned 128b integer type -+struct alignas(16) uint128_t -+{ -+ /// Size of one part of the uint's storage in bits -+ static constexpr int storage_bits_ = 64; -+ -+ struct hilo -+ { -+ uint64_t lo; -+ uint64_t hi; -+ }; -+ -+ // Use a union to store either low and high parts or, if present, a built-in 128b integer type. -+ union -+ { -+ struct hilo hilo_; -+ -+#if defined(CUTE_UINT128_NATIVE) -+ unsigned __int128 native; -+#endif // defined(CUTE_UINT128_NATIVE) -+ }; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTE_HOST_DEVICE constexpr -+ uint128_t() : hilo_{0, 0} {} -+ -+ /// Constructor from uint64 -+ CUTE_HOST_DEVICE constexpr -+ uint128_t(uint64_t lo_) : hilo_{lo_, 0} {} -+ -+ /// Constructor from two 64b unsigned integers -+ CUTE_HOST_DEVICE constexpr -+ uint128_t(uint64_t lo_, uint64_t hi_) : hilo_{lo_, hi_} {} -+ -+ /// Optional constructor from native value -+#if defined(CUTE_UINT128_NATIVE) -+ uint128_t(unsigned __int128 value) : native(value) { } -+#endif -+ -+ /// Lossily cast to uint64 -+ CUTE_HOST_DEVICE constexpr -+ explicit operator uint64_t() const -+ { -+ return hilo_.lo; -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ static void exception() -+ { -+ //static_assert(sizeof(Dummy) == 0, "Not implemented exception!"); -+ //abort(); -+ //printf("uint128 not implemented!\n"); -+ } -+ -+ /// Add -+ CUTE_HOST_DEVICE constexpr -+ uint128_t operator+(uint128_t const& rhs) const -+ { -+ uint128_t y; -+#if defined(CUTE_UINT128_NATIVE) -+ y.native = native + rhs.native; -+#else -+ y.hilo_.lo = hilo_.lo + rhs.hilo_.lo; -+ y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (!y.hilo_.lo && (rhs.hilo_.lo)); -+#endif -+ return y; -+ } -+ -+ /// Subtract -+ CUTE_HOST_DEVICE constexpr -+ uint128_t operator-(uint128_t const& rhs) const -+ { -+ uint128_t y; -+#if defined(CUTE_UINT128_NATIVE) -+ y.native = native - rhs.native; -+#else -+ y.hilo_.lo = hilo_.lo - rhs.hilo_.lo; -+ y.hilo_.hi = hilo_.hi - rhs.hilo_.hi - (rhs.hilo_.lo && y.hilo_.lo > hilo_.lo); -+#endif -+ return y; -+ } -+ -+ /// Multiply by unsigned 64b integer yielding 128b integer -+ CUTE_HOST_DEVICE constexpr -+ uint128_t operator*(uint64_t const& rhs) const -+ { -+ uint128_t y; -+#if defined(CUTE_UINT128_NATIVE) -+ y.native = native * rhs; -+#elif defined(CUTE_INT128_ARITHMETIC) -+ // Multiply by the low part -+ y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi); -+ -+ // Add the high part and ignore the overflow -+ uint64_t overflow; -+ y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow); -+#else -+ exception(); -+#endif -+ return y; -+ } -+ -+ /// Divide 128b operation by 64b operation yielding a 64b quotient -+ CUTE_HOST_DEVICE constexpr -+ uint64_t operator/(uint64_t const& divisor) const -+ { -+ uint64_t quotient = 0; -+#if defined(CUTE_UINT128_NATIVE) -+ quotient = uint64_t(native / divisor); -+#elif defined(CUTE_INT128_ARITHMETIC) -+ // implemented using MSVC's arithmetic intrinsics -+ uint64_t remainder = 0; -+ quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -+#else -+ exception(); -+#endif -+ return quotient; -+ } -+ -+ /// Divide 128b operation by 64b operation yielding a 64b quotient -+ CUTE_HOST_DEVICE constexpr -+ uint64_t operator%(uint64_t const& divisor) const -+ { -+ uint64_t remainder = 0; -+#if defined(CUTE_UINT128_NATIVE) -+ remainder = uint64_t(native % divisor); -+#elif defined(CUTE_INT128_ARITHMETIC) -+ // implemented using MSVC's arithmetic intrinsics -+ (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -+#else -+ exception(); -+#endif -+ return remainder; -+ } -+ -+ /// Computes the quotient and remainder in a single method. -+ CUTE_HOST_DEVICE constexpr -+ uint64_t divmod(uint64_t &remainder, uint64_t divisor) const -+ { -+ uint64_t quotient = 0; -+#if defined(CUTE_UINT128_NATIVE) -+ quotient = uint64_t(native / divisor); -+ remainder = uint64_t(native % divisor); -+#elif defined(CUTE_INT128_ARITHMETIC) -+ // implemented using MSVC's arithmetic intrinsics -+ quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -+#else -+ exception(); -+#endif -+ return quotient; -+ } -+ -+ /// Left-shifts a 128b unsigned integer -+ CUTE_HOST_DEVICE constexpr -+ uint128_t operator<<(int sh) const -+ { -+ if (sh == 0) { -+ return *this; -+ } -+ else if (sh >= storage_bits_) { -+ return uint128_t(0, hilo_.lo << (sh - storage_bits_)); -+ } -+ else { -+ return uint128_t( -+ (hilo_.lo << sh), -+ (hilo_.hi << sh) | uint64_t(hilo_.lo >> (storage_bits_ - sh)) -+ ); -+ } -+ } -+ -+ /// Right-shifts a 128b unsigned integer -+ CUTE_HOST_DEVICE constexpr -+ uint128_t operator>>(int sh) const -+ { -+ if (sh == 0) { -+ return *this; -+ } -+ else if (sh >= storage_bits_) { -+ return uint128_t((hilo_.hi >> (sh - storage_bits_)), 0); -+ } -+ else { -+ return uint128_t( -+ (hilo_.lo >> sh) | (hilo_.hi << (storage_bits_ - sh)), -+ (hilo_.hi >> sh) -+ ); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cute/pointer.hpp b/3rdparty/cutlass/include/cute/pointer.hpp -new file mode 100644 -index 0000000..40ce5d1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/pointer.hpp -@@ -0,0 +1,322 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+// -+// has_dereference to determine if a type is a pointer concept -+// -+ -+template -+struct has_dereference : std::false_type { -+}; -+ -+template -+struct has_dereference())>> : std::true_type { -+}; -+ -+// -+// Pointer categories -+// -+ -+template -+struct is_gmem : false_type {}; -+ -+template -+struct is_smem : false_type {}; -+ -+// Anything that is not gmem or smem is rmem -+template -+struct is_rmem : bool_constant< not (is_gmem::value || is_smem::value)> {}; -+ -+// -+// A very simplified wrapper for pointers -- use for constructing tagged pointers -+// -+template -+struct device_ptr -+{ -+ using value_type = T; -+ -+ CUTE_HOST_DEVICE constexpr -+ device_ptr(T* ptr) : ptr_(ptr) {} -+ -+ CUTE_HOST_DEVICE constexpr -+ T* get() const { return ptr_; } -+ -+ CUTE_HOST_DEVICE constexpr -+ T& operator*() const { return *ptr_; } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ T& operator[](Index const& i) const { return ptr_[i]; } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ DerivedType operator+(Index const& i) const { return {ptr_ + i}; } -+ -+ CUTE_HOST_DEVICE constexpr friend -+ std::ptrdiff_t operator-(device_ptr const& a, -+ device_ptr const& b) { -+ return a.ptr_ - b.ptr_; -+ } -+ -+ T* ptr_; -+}; -+ -+// -+// gmem_ptr -+// -+ -+template -+struct gmem_ptr : device_ptr> { -+ using device_ptr>::device_ptr; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+gmem_ptr -+make_gmem_ptr(T* ptr) { -+ return {ptr}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+gmem_ptr -+make_gmem_ptr(void* ptr) { -+ return {reinterpret_cast(ptr)}; -+} -+ -+template -+struct is_gmem> : true_type {}; -+ -+// -+// smem_ptr -+// -+ -+template -+struct smem_ptr : device_ptr> { -+ using device_ptr>::device_ptr; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+smem_ptr -+make_smem_ptr(T* ptr) { -+ return {ptr}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+smem_ptr -+make_smem_ptr(void* ptr) { -+ return {reinterpret_cast(ptr)}; -+} -+ -+template -+struct is_smem> : true_type {}; -+ -+// -+// rmem_ptr -+// -+ -+template -+struct rmem_ptr : device_ptr> { -+ using device_ptr>::device_ptr; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+rmem_ptr -+make_rmem_ptr(T* ptr) { -+ return {ptr}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+rmem_ptr -+make_rmem_ptr(void* ptr) { -+ return {reinterpret_cast(ptr)}; -+} -+ -+template -+struct is_rmem> : true_type {}; -+ -+// -+// counting iterator -- quick and dirty -+// -+ -+struct counting -+{ -+ using index_type = int; -+ using value_type = index_type; -+ -+ CUTE_HOST_DEVICE constexpr -+ counting() : n_(0) {} -+ CUTE_HOST_DEVICE constexpr -+ counting(index_type const& n) : n_(n) {} -+ -+ CUTE_HOST_DEVICE constexpr -+ index_type operator[](index_type const& i) const { return n_ + i; } -+ -+ CUTE_HOST_DEVICE constexpr -+ index_type const& operator*() const { return n_; } -+ -+ CUTE_HOST_DEVICE constexpr -+ counting operator+(index_type const& i) const { return {n_ + i}; } -+ CUTE_HOST_DEVICE constexpr -+ counting& operator++() { ++n_; return *this; } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool operator==(counting const& other) const { return n_ == other.n_; } -+ CUTE_HOST_DEVICE constexpr -+ bool operator!=(counting const& other) const { return n_ != other.n_; } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool operator< (counting const& other) const { return n_ < other.n_; } -+ -+ index_type n_; -+}; -+ -+// -+// recast -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(T* ptr) { -+ return reinterpret_cast(ptr); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(T const* ptr) { -+ return reinterpret_cast(ptr); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(gmem_ptr const& ptr) { -+ return make_gmem_ptr(recast(ptr.ptr_)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(gmem_ptr const& ptr) { -+ return make_gmem_ptr(recast(ptr.ptr_)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(smem_ptr const& ptr) { -+ return make_smem_ptr(recast(ptr.ptr_)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(smem_ptr const& ptr) { -+ return make_smem_ptr(recast(ptr.ptr_)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(rmem_ptr const& ptr) { -+ return make_rmem_ptr(recast(ptr.ptr_)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(rmem_ptr const& ptr) { -+ return make_rmem_ptr(recast(ptr.ptr_)); -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print(T const* const ptr) -+{ -+ printf("raw_ptr_%db(%p)", int(8*sizeof(T)), ptr); -+} -+ -+template -+CUTE_HOST_DEVICE void print(gmem_ptr const& ptr) -+{ -+ printf("gmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); -+} -+ -+template -+CUTE_HOST_DEVICE void print(smem_ptr const& ptr) -+{ -+ printf("smem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); -+} -+ -+template -+CUTE_HOST_DEVICE void print(rmem_ptr const& ptr) -+{ -+ printf("rmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr const& ptr) -+{ -+ return os << "gmem_ptr_" << int(8*sizeof(T)) << "b"; -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr const& ptr) -+{ -+ return os << "smem_ptr_" << int(8*sizeof(T)) << "b"; -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr const& ptr) -+{ -+ return os << "rmem_ptr_" << int(8*sizeof(T)) << "b"; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/stride.hpp b/3rdparty/cutlass/include/cute/stride.hpp -new file mode 100644 -index 0000000..5fb0da8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/stride.hpp -@@ -0,0 +1,411 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+/** crd2idx maps a coordinate within to an index -+ * This is computed as follows: -+ * [coord, shape, and stride are all integers => step forward by stride] -+ * op(c, s, d) => c * d -+ * [coord is integer, shape and stride are tuple => divmod coord for each mode] -+ * op(c, (s,S), (d,D)) => op(c % prod(s), s, d) + op(c / prod(s), (S), (D)) -+ * [coord, shape, and stride are all tuples => consider each mode independently] -+ * op((c,C), (s,S), (d,D)) => op(c, s, d) + op((C), (S), (D)) -+ */ -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx(Coord const& coord, -+ Shape const& shape, -+ Stride const& stride); -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx_ttt(Coord const& coord, -+ Shape const& shape, -+ Stride const& stride, seq) -+{ -+ return (... + crd2idx(get(coord), get(shape), get(stride))); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx_itt(CInt const& coord, -+ STuple const& shape, -+ DTuple const& stride, seq) -+{ -+ if constexpr (sizeof...(Is) == 0) { // Avoid recursion and mod on single/last iter -+ return crd2idx(coord, get(shape), get(stride)); -+ } else { // General case -+ return crd2idx(coord % product(get(shape)), get(shape), get(stride)) -+ + crd2idx_itt(coord / product(get(shape)), shape, stride, seq{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx(Coord const& coord, -+ Shape const& shape, -+ Stride const& stride) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple::value) { // tuple tuple tuple -+ static_assert(tuple_size::value == tuple_size< Shape>::value, "Mismatched Ranks"); -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return detail::crd2idx_ttt(coord, shape, stride, tuple_seq{}); -+ } else { // tuple "int" "int" -+ static_assert(sizeof(Coord) == 0, "Invalid parameters"); -+ } -+ } else { -+ if constexpr (is_tuple::value) { // "int" tuple tuple -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return detail::crd2idx_itt(coord, shape, stride, tuple_seq{}); -+ } else { // "int" "int" "int" -+ return coord * stride; -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// If we know Stride is default [CompactColMajor], then we can take shortcuts -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx_horner(CTuple const& coord, -+ STuple const& shape, seq) -+{ -+ if constexpr (sizeof...(Is) == 0) { // No recursion on single/last iter -+ return get(coord); -+ } else { // General case -+ return get(coord) + get(shape) * crd2idx_horner(coord, shape, seq{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx(Coord const& coord, -+ Shape const& shape) -+{ -+ static_assert(decltype(congruent(coord,shape))::value, "Mismatched Ranks"); -+ if constexpr (is_tuple::value) { -+ // Flatten and apply Horner's method -+ auto flat_coord = flatten(coord); -+ auto flat_shape = flatten(shape); -+ return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq{}); -+ } else { -+ return coord; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+/** idx2crd splits an index to a coordinate within . -+ * -+ * This is computed as follows: -+ * [index, shape, and stride are all integers => determine 1D coord] -+ * op(i, s, d) => (i / d) % s -+ * [index is integer, shape and stride are tuple => determine component for each mode] -+ * op(i, (s,S), (d,D)) => (op(i, s, d), op(i, S, D)...) -+ * [index, shape, and stride are all tuples => consider each mode independently] -+ * op((i,I), (s,S), (d,D)) => (op(i, s, d), op((I), (S), (D))) -+ * -+ * NOTE: This only works for compact shape+stride layouts. A more general version would -+ * apply to all surjective layouts -+ */ -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+idx2crd(Index const& idx, -+ Shape const& shape, -+ Stride const& stride) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple::value) { // tuple tuple tuple -+ static_assert(tuple_size::value == tuple_size< Shape>::value, "Mismatched Ranks"); -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return transform(idx, shape, stride, [](auto const& i, auto const& s, auto const& d){ return idx2crd(i,s,d); }); -+ } else { // tuple "int" "int" -+ static_assert(sizeof(Index) == 0, "Invalid parameters"); -+ } -+ } else { -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple::value) { // "int" tuple tuple -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return transform(shape, stride, [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); -+ } else { // "int" tuple "int" -+ return transform(shape, compact_col_major(shape, stride), [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); -+ } -+ } else { // "int" "int" "int" -+ return (idx / stride) % shape; -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// If we know Stride is default [CompactColMajor], then we can take shortcuts -+// -+ -+//(idx / 1) % s0 -+//(idx / s0) % s1 -+//(idx / (s0 * s1)) % s2 -+//... -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+idx2crd(Index const& idx, -+ Shape const& shape) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple::value) { // tuple tuple -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return transform(idx, shape, [](auto const& i, auto const& s) { return idx2crd(i,s); }); -+ } else { // tuple "int" -+ static_assert(sizeof(Index) == 0, "Invalid parameters"); -+ } -+ } else { -+ if constexpr (is_tuple::value) { // "int" tuple -+ return idx2crd(idx, shape, compact_col_major(shape)); -+ } else { // "int" "int" -+ return idx; -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// crd2crd -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2crd(Coord const& coord, -+ SShape const& src_shape, -+ DShape const& dst_shape) -+{ -+ if constexpr (is_tuple::value && is_tuple::value && is_tuple::value) { -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return transform(coord, src_shape, dst_shape, [](auto const& c, auto const& s, auto const& d) { return crd2crd(c,s,d); }); -+ } else { -+ // assert(size(src_shape) == size(dst_shape)) -+ return idx2crd(crd2idx(coord, src_shape), dst_shape); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Compact Major -+// -+ -+// General tag for common layouts and dispatching -+struct GenColMajor {}; -+struct GenRowMajor {}; -+ -+template , class Major = GenColMajor> -+CUTE_HOST_DEVICE constexpr -+auto -+compact_major(Shape const& shape, -+ Current const& current = {}, -+ Major const& major = {}); -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_major_ti(Shape const& shape, -+ Current const& current, -+ GenColMajor const& major, seq) -+{ -+ return cute::make_tuple(compact_major(get(shape), current * product<0,Is>(shape), major)...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_major_ti(Shape const& shape, -+ Current const& current, -+ GenRowMajor const& major, seq) -+{ -+ constexpr int E = tuple_size::value; -+ return cute::make_tuple(compact_major(get(shape), current * product(shape), major)...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_major(Shape const& shape, -+ Current const& current, -+ Major const& major) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple::value) { // tuple tuple -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return transform(shape, current, [&](auto const& s, auto const& c){ return compact_major(s,c,major); }); -+ } else { // tuple int -+ return detail::compact_major_ti(shape, current, major, tuple_seq{}); -+ } -+ } else { -+ if constexpr (is_tuple::value) { // int tuple -+ static_assert(sizeof(Shape) == 0, "Invalid parameters"); -+ } else { // int int -+ if constexpr (is_constant<1, Shape>::value) { -+ return Int<0>{}; // If current is dynamic, this could save a reg -+ } else { -+ return current; -+ } -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Compact Col Major -+// -+ -+template > -+CUTE_HOST_DEVICE constexpr -+auto -+compact_col_major(Shape const& shape, -+ Current const& current = {}) -+{ -+ return compact_major(shape, current, GenColMajor{}); -+} -+ -+template -+using ColMajor = decltype(compact_col_major(std::declval())); -+ -+// -+// Compact Row Major -+// -+ -+template > -+CUTE_HOST_DEVICE constexpr -+auto -+compact_row_major(Shape const& shape, -+ Current const& current = {}) -+{ -+ return compact_major(shape, current, GenRowMajor{}); -+} -+ -+template -+using RowMajor = decltype(compact_row_major(std::declval())); -+ -+// -+// Compact Order -- compute a compact stride based on an ordering of the modes -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_order(Shape const& shape, Order const& order, -+ OrigShape const& orig_shape, OrigOrder const& orig_order) -+{ -+ if constexpr (is_tuple::value) { -+ return transform(shape, order, [&](auto const& x, auto const& y) { return compact_order(x, y, orig_shape, orig_order); }); -+ } else { -+ auto d = product(transform(orig_shape, orig_order, -+ [&](auto const& s, auto const& o) { -+ return conditional_return(o < order, product(s), Int<1>{}); -+ })); -+ return compact_col_major(shape, d); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_order(Shape const& shape, Order const& order) -+{ -+ static_assert(is_congruent::value, "Need congruence of shape and order."); -+ return detail::compact_order(shape, order, flatten_to_tuple(shape), flatten_to_tuple(order)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_order(Shape const& shape, GenColMajor const& major) -+{ -+ return compact_major(shape, Int<1>{}, major); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_order(Shape const& shape, GenRowMajor const& major) -+{ -+ return compact_major(shape, Int<1>{}, major); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/swizzle.hpp b/3rdparty/cutlass/include/cute/swizzle.hpp -new file mode 100644 -index 0000000..0a13e55 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/swizzle.hpp -@@ -0,0 +1,497 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+// A generic Swizzle functor -+/* 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx -+ * ^--^ MBase is the number of least-sig bits to keep constant -+ * ^-^ ^-^ BBits is the number of bits in the mask -+ * ^---------^ SShift is the distance to shift the YYY mask -+ * (pos shifts YYY to the right, neg shifts YYY to the left) -+ * -+ * e.g. Given -+ * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx -+ * the result is -+ * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY -+ */ -+template -+struct Swizzle -+{ -+ static constexpr int num_bits = BBits; -+ static constexpr int num_base = MBase; -+ static constexpr int num_shft = SShift; -+ -+ static_assert(num_base >= 0, "MBase must be positive."); -+ static_assert(num_bits >= 0, "BBits must be positive."); -+ static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits."); -+ -+ // using 'int' type here to avoid unintentially casting to unsigned... unsure. -+ using bit_msk = cute::constant; -+ using yyy_msk = cute::constant; -+ using zzz_msk = cute::constant; -+ using msk_sft = cute::constant; -+ -+ static constexpr uint32_t swizzle_code = uint32_t(yyy_msk{} | zzz_msk{}); -+ -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ apply(Offset const& offset) -+ { -+ return offset ^ shiftr(offset & yyy_msk{}, msk_sft{}); // ZZZ ^= YYY -+ } -+ -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator()(Offset const& offset) const -+ { -+ return apply(offset); -+ } -+}; -+ -+// Translation for legacy SwizzleXor -+// TODO: Deprecate -+template -+using SwizzleXor = Swizzle; -+ -+// -+// make_swizzle<0b1000, 0b0100>() -> Swizzle<1,2,1> -+// make_swizzle<0b11000000, 0b00000110>() -> Swizzle<2,1,5> -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_swizzle() -+{ -+ constexpr uint32_t BZ = popcount(Y); // Number of swizzle bits -+ constexpr uint32_t BY = popcount(Z); // Number of swizzle bits -+ static_assert(BZ == BY, "Number of bits in Y and Z don't match"); -+ constexpr uint32_t TZ_Y = countr_zero(Y); // Number of trailing zeros in Y -+ constexpr uint32_t TZ_Z = countr_zero(Z); // Number of trailing zeros in Z -+ constexpr uint32_t M = cute::min(TZ_Y, TZ_Z) % 32; -+ constexpr int32_t S = int32_t(TZ_Y) - int32_t(TZ_Z); // Difference in trailing zeros -+ static_assert((Y | Z) == Swizzle::swizzle_code, "Something went wrong."); -+ return Swizzle{}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Swizzle, Swizzle) -+{ -+ static_assert(S0 == S1, "Can only merge swizzles of the same shift."); -+ constexpr uint32_t Y = Swizzle::yyy_msk::value ^ Swizzle::yyy_msk::value; -+ constexpr uint32_t Z = Swizzle::zzz_msk::value ^ Swizzle::zzz_msk::value; -+ return make_swizzle(); -+ -+ //return ComposedFn, Swizzle>{}; -+} -+ -+// -+// Upcast and Downcast -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(Swizzle const& swizzle) -+{ -+ static_assert(has_single_bit(N), "N must be a power of two"); -+ constexpr int log2_n = bit_width(uint32_t(N)) - 1; -+ constexpr int NewM = M - log2_n; -+ if constexpr (NewM >= 0) { -+ return Swizzle{}; -+ } else { -+ return Swizzle{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(Swizzle const& swizzle) -+{ -+ static_assert(has_single_bit(N), "N must be a power of two"); -+ constexpr int log2_n = bit_width(uint32_t(N)) - 1; -+ return Swizzle{}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(Swizzle const& swizzle) -+{ -+ if constexpr (sizeof_bits::value == sizeof_bits::value) { -+ return swizzle; -+ } else if constexpr (sizeof_bits::value > sizeof_bits::value) { -+ static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a multiple of OldType"); -+ return upcast::value/sizeof_bits::value>(swizzle); -+ } else if constexpr (sizeof_bits::value < sizeof_bits::value) { -+ static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a divisor of OldType"); -+ return downcast::value/sizeof_bits::value>(swizzle); -+ } -+} -+ -+// -+// Utility for slicing and swizzle "offsets" -+// -+ -+// For swizzle functions, it is often needed to keep track of which bits are -+// consumed and which bits are free. Furthermore, it is useful to know whether -+// each of these bits is known statically or dynamically. -+ -+// MixedBits is an integer class where some bits are known statically and some -+// bits are known dynamically. These sets of bits are disjoint and it is known -+// statically which bits are known dynamically. -+ -+// MixedBits can only be manipulated through bitwise operations -+ -+// Abstract value: StaticInt | (dynamic_int_ & StaticFlags) -+template // 0: static, 1: dynamic -+struct MixedBits -+{ -+ // Representation invariants -+ static_assert(StaticFlags != 0, "Should be at least one dynamic bit in MixedBits."); -+ static_assert((StaticInt & StaticFlags) == 0, "No static/dynamic overlap allowed in MixedBits."); -+ // assert((dynamic_int_ & ~F) == 0); -+ -+ DynamicType dynamic_int_; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_mixed_bits(constant const&, DynamicType const& d, constant const&) -+{ -+ static_assert(is_integral::value); -+ if constexpr (is_static::value) { -+ static_assert((s & DynamicType::value & f) == 0, "No static/dynamic overlap allowed."); -+ return constant{} | (d & constant{}); // Just return a static int -+ } else if constexpr (f == 0) { -+ return constant{}; // Just return a static int -+ } else { -+ return MixedBits{d & f}; // MixedBits -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Explicit conversion for now -- consider casting on plus or minus -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+to_integral(MixedBits const& m) -+{ -+ //return S | (m.dynamic_int_ & F); -+ return S | m.dynamic_int_; -+} -+ -+// Any cute::is_integral -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+to_integral(I const& i) -+{ -+ return i; -+} -+ -+// -+// Operators -+// -+ -+// Equality -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(MixedBits const& m, constant const&) -+{ -+ return (S0 == (S1 & ~F0)) && (m.dynamic_int_ == (S1 & F0)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(constant const& s, MixedBits const& m) -+{ -+ return m == s; -+} -+ -+// Bitwise AND -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator&(MixedBits const& m0, MixedBits const& m1) -+{ -+ // Truth table for (S0,D0,F0) & (S1,D1,F1) -> (S,D,F) -+ // S0D0F0 | 0X0 | 001 | 011 | 1X0 | -+ // S1D1F1 -+ // 0X0 | 0X0 | 0X0 | 0X0 | 0X0 | -+ // 001 | 0X0 | 001 | 001 | 001 | -+ // 011 | 0X0 | 001 | 011 | 011 | -+ // 1X0 | 0X0 | 001 | 011 | 1X0 | -+ -+ return make_mixed_bits(constant{}, -+ //(S0 | m0.dynamic_int_) & (S1 | m1.dynamic_int_), -+ ((S1 & F0) & m0.dynamic_int_) | ((S0 & F1) & m1.dynamic_int_) | (m0.dynamic_int_ & m1.dynamic_int_), -+ constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator&(MixedBits const& m, constant const&) -+{ -+ return make_mixed_bits(constant{}, -+ m.dynamic_int_, -+ constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator&(constant const& s, MixedBits const& m) -+{ -+ return m & s; -+} -+ -+// Bitwise OR -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator|(MixedBits const& m0, MixedBits const& m1) -+{ -+ // Truth table for (S0,D0,F0) | (S1,D1,F1) -> (S,D,F) -+ // S0D0F0 | 0X0 | 001 | 011 | 1X0 | -+ // S1D1F1 -+ // 0X0 | 0X0 | 001 | 011 | 1X0 | -+ // 001 | 001 | 001 | 011 | 1X0 | -+ // 011 | 011 | 011 | 011 | 1X0 | -+ // 1X0 | 1X0 | 1X0 | 1X0 | 1X0 | -+ -+ return make_mixed_bits(constant{}, -+ ((~S1 & F0) & m0.dynamic_int_) | ((~S0 & F1) & m1.dynamic_int_), -+ constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator|(MixedBits const& m, constant const&) -+{ -+ return make_mixed_bits(constant{}, -+ m.dynamic_int_, -+ constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator|(constant const& s, MixedBits const& m) -+{ -+ return m | s; -+} -+ -+// Bitwise XOR -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator^(MixedBits const& m0, MixedBits const& m1) -+{ -+ // Truth table for (S0,D0,F0) ^ (S1,D1,F1) -> (S,D,F) -+ // S0D0F0 | 0X0 | 001 | 011 | 1X0 | -+ // S1D1F1 -+ // 0X0 | 0X0 | 001 | 011 | 1X0 | -+ // 001 | 001 | 001 | 011 | 011 | -+ // 011 | 011 | 011 | 001 | 001 | -+ // 1X0 | 1X0 | 011 | 001 | 0X0 | -+ -+ return make_mixed_bits(constant{}, -+ (S0 | m0.dynamic_int_) ^ (S1 | m1.dynamic_int_), -+ constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator^(MixedBits const& m, constant const&) -+{ -+ return make_mixed_bits(constant{}, -+ (S0 | m.dynamic_int_) ^ S1, -+ constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator^(constant const& s, MixedBits const& m) -+{ -+ return m ^ s; -+} -+ -+// -+// upcast and downcast -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+safe_div(MixedBits const& m, constant const& s) -+{ -+ static_assert(has_single_bit(S1), "Only divide MixedBits by powers of two."); -+ return make_mixed_bits(safe_div(constant{}, s), -+ safe_div(m.dynamic_int_, s), -+ safe_div(constant{}, s)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(MixedBits const& m) -+{ -+ static_assert(has_single_bit(N), "Only divide MixedBits by powers of two."); -+ return safe_div(m, constant{}); -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(T const& m) -+{ -+ return safe_div(m, constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(MixedBits const& m) -+{ -+ static_assert(has_single_bit(N), "Only scale MixedBits by powers of two."); -+ return make_mixed_bits(constant{}, -+ m.dynamic_int_ * N, -+ constant{}); -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(T const& m) -+{ -+ return m * constant{}; -+} -+ -+// -+// Convert a Pow2Layout+Coord to a MixedBits -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+to_mixed_bits(Shape const& shape, Stride const& stride, Coord const& coord) -+{ -+ if constexpr (is_tuple::value && is_tuple::value && is_tuple::value) { -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); -+ return transform_apply(shape, stride, coord, [](auto const& s, auto const& d, auto const& c) { return to_mixed_bits(s,d,c); }, -+ [](auto const&... a) { return (a ^ ...); }); -+ } else if constexpr (is_integral::value && is_integral::value && is_integral::value) { -+ static_assert(decltype(shape*stride)::value == 0 || has_single_bit(decltype(shape*stride)::value), "Requires pow2 shape*stride."); -+ return make_mixed_bits(Int<0>{}, coord * stride, (shape - Int<1>{}) * stride); -+ } else { -+ static_assert(is_integral::value && is_integral::value && is_integral::value, "Either Shape, Stride, and Coord must be all tuples, or they must be all integral (in the sense of cute::is_integral)."); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+to_mixed_bits(Layout const& layout, Coord const& coord) -+{ -+ return to_mixed_bits(layout.shape(), layout.stride(), idx2crd(coord, layout.shape())); -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print(MixedBits const& m) -+{ -+ printf("M_%u|(%u&%u)=%u", S, uint32_t(m.dynamic_int_), F, to_integral(m)); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) -+{ -+ return os << "M_" << S << "|(" << uint32_t(m.dynamic_int_) << "&" << F << ")=" << to_integral(m); -+} -+ -+template -+CUTE_HOST_DEVICE void print(Swizzle const&) -+{ -+ print("S<%d,%d,%d>", B, M, S); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, Swizzle const&) -+{ -+ return os << "S<" << B << "," << M << "," << S << ">"; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/swizzle_layout.hpp b/3rdparty/cutlass/include/cute/swizzle_layout.hpp -new file mode 100644 -index 0000000..1376a47 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/swizzle_layout.hpp -@@ -0,0 +1,1010 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#include -+ -+/* This implements a ComposedLayout of the form -+ * InvolutionFn o OffsetPlus o Layout -+ * where the InvolutionFn need not be linear (hence the need for the Offset). -+ * -+ * This ComposedLayout provides similar coordinate-to-index mapping and layout manipulations, -+ * but is not considered a "normal" layout. -+ * For example, this layout provides size() functions, but does not provide stride() functions. -+ * -+ * Furthermore, for known InvolutionFns, this layout attempts to decay itself -+ * to a normal-layout with dynamic or static strides. -+ * This is possible by determining the subdomain of the Involution function -+ * that is identity and testing if the right Layout's codomain is contained -+ * within it. -+ */ -+ -+namespace cute -+{ -+ -+// A Layout of non-trivially composable functions: F o I o L -+template -+struct ComposedLayout -+ : private cute::tuple // EBO for static layouts -+{ -+ CUTE_HOST_DEVICE constexpr -+ ComposedLayout(InvolutionFn const& fn = {}, -+ IntermediateOffset const& offset = {}, -+ Layout const& layout = {}) -+ : cute::tuple(fn, offset, layout) -+ {} -+ -+ // -+ // Accessors -+ // -+ -+ static constexpr int rank = Layout::rank; -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ swizzle_fn() const { -+ return get<0>(static_cast const&>(*this)); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ offset_fn() const { -+ return get<1>(static_cast const&>(*this)); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ layout_fn() const { -+ return get<2>(static_cast const&>(*this)); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ layout() const { -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ shape() const { -+ return layout_fn().shape(); -+ } -+ -+ // Doesn't really make sense to ask for the strides of this "layout" -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ stride() const = delete; -+ -+ // -+ // Mappings -+ // -+ -+ // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) -+ // OR -+ // Slice the layout and return the sublayout (Coord has an Underscore slice op) -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator()(Coord const& coord) const { -+ if constexpr (has_underscore::value) { -+ return slice(coord, *this); -+ } else { -+ return swizzle_fn()(to_integral(offset_fn()) + layout_fn()(coord)); // (F o L)(c) -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ } -+ -+ // Map a 1D linear coordinate to a flat ND logical coordinate -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator[](Int const& linear_idx) const { -+ return get_flat_coord(linear_idx); -+ } -+ -+ // Convenience function for multi-dimensional coordinates -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { -+ return operator()(make_coord(c0,c1,cs...)); -+ } -+ -+ // -+ // Compose -+ // -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ compose(OtherLayout const& other) const { -+ return composition(*this, other); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ compose(Layouts const&... layouts) const { -+ return composition(*this, make_tile(layouts...)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ with_shape(OtherShape const& shape) const { -+ return composition(*this, make_layout(shape)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ with_shape(Shapes const&... shapes) const { -+ return composition(*this, make_layout(make_shape(shapes...))); -+ } -+ -+ // -+ // Tile -+ // -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ tile(OtherLayout const& other) const { -+ return tiled_divide(*this, other); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ tile(Layouts const&... layouts) const { -+ return tiled_divide(*this, make_tile(layouts...)); -+ } -+ -+ // -+ // Utility -+ // -+ -+ // -+ // Index to Coordinate -+ // -+ -+ // NOTE Only valid for compact layouts -+ -+ // Return the (hierarchical) ND logical coordinate corresponding to the linear index -+ // @post this->crd2idx(@a result) == idx -+ // @post congruent(@a result, shape()) -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_hier_coord(IInt const& idx) const { -+ return layout_fn().get_hier_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) -+ } -+ -+ // Return the (flat) ND logical coordinate corresponding to the linear index -+ // @post this->crd2idx(@a result) == idx -+ // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_flat_coord(IInt const& idx) const { -+ return layout_fn().get_flat_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) -+ } -+ -+ // Return the generalized column-major 1D logical coordinate corresponding to the linear index -+ // @post this->crd2idx(@a result) == idx -+ // @post is_integral::value -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_1d_coord(IInt const& idx) const { -+ return layout_fn().get_1d_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) -+ } -+}; -+ -+template -+struct is_layout> : true_type {}; -+ -+template -+struct is_composed_layout : false_type {}; -+template -+struct is_composed_layout> : true_type {}; -+ -+// -+// Constructors -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Swizzle const& sxor) -+{ -+ return composition(sxor, Layout,Int<1>>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(ComposedLayout const& a, Layout const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), make_layout(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Layout const& a, ComposedLayout const& b) -+{ -+ return composition(b.swizzle_fn(), b.offset_fn(), make_layout(a, b.layout_fn())); -+} -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transfer_swizzle(Layout const& old_layout, -+ Layout const& new_layout) -+{ -+ // Our goal is to determine a new swizzle for the strides in new_layout for consistent vectorizations -+ -+ // This is accomplished by identifying -+ // S o L :=: S? o L* -+ // We identify the "active" portion of S by computing (P o L)(c*) where P is a projection generated by S -+ // Then that active identifier is transformed through the layouts: -+ // L*(L[(P o L)(c*)]) -+ // which is a new swizzle identifier for S?, the new swizzle -+ -+ // Projections of the swizzle layout for composition, P -+ auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), -+ make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); -+ -+ // Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L] -+ auto layout_only_zy = composition(swizzle_only_zy, old_layout); -+ // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) -+ auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); -+ -+ // Get the Z bit and the Y bits -- keep only those that are active in Z *and* Y -+ auto zzz_msk = typename Swizzle::zzz_msk{}; -+ auto yyy_msk = typename Swizzle::yyy_msk{}; -+ auto msk_sft = typename Swizzle::msk_sft{}; -+ auto active_Z = swizzle_active_bits & shiftr(swizzle_active_bits, msk_sft) & zzz_msk; -+ auto active_Y = swizzle_active_bits & shiftr(swizzle_active_bits, -msk_sft) & yyy_msk; -+ -+ // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) -+ auto new_active_Z = new_layout(old_layout.get_1d_coord(active_Z)); -+ auto new_active_Y = new_layout(old_layout.get_1d_coord(active_Y)); -+ -+ // Use this new swizzle identifier to construct the new swizzle for new_layout -+ // (this also makes sure it's a "valid" swizzle that Swizzle can represent) -+ return composition(make_swizzle(), new_layout); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_fragment_like(ComposedLayout,Offset,Layout> const& layout) -+{ -+ return detail::transfer_swizzle(layout.layout_fn(), make_fragment_like(layout.layout_fn())); -+} -+ -+// -+// Utilities -+// -+ -+// Return the layout of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+layout(ComposedLayout const& clayout) -+{ -+ return composition(clayout.swizzle_fn(), clayout.offset_fn(), layout(clayout.layout_fn())); -+} -+ -+// Return the shape of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+shape(ComposedLayout const& layout) -+{ -+ return shape(layout.layout_fn()); -+} -+ -+// Doesn't make sense to directly ask for the strides of this "layout" -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+stride(ComposedLayout const& layout) = delete; -+ -+// Return the number of elements in a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+size(ComposedLayout const& layout) -+{ -+ return size(layout.layout_fn()); -+} -+ -+// Return the number of modes -+template -+CUTE_HOST_DEVICE constexpr -+auto -+rank(ComposedLayout const& layout) -+{ -+ return rank(layout.layout_fn()); -+} -+ -+// Return the depth of the layout -+template -+CUTE_HOST_DEVICE constexpr -+auto -+depth(ComposedLayout const& layout) -+{ -+ return depth(layout.layout_fn()); -+} -+ -+// Return the codomain size of a mode -+template -+CUTE_HOST_DEVICE constexpr -+auto -+cosize(ComposedLayout const& layout) -+{ -+ return cosize(layout.layout_fn()); -+} -+ -+// -+// Operations to manipulate Layouts like a tuple of pairs -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+get(ComposedLayout const& a) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), get(a.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+take(ComposedLayout const& a) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), take(a.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+flatten(ComposedLayout const& a) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), flatten(a.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+append(ComposedLayout const& a, X const& x) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), append(a.layout_fn(), x)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+group(ComposedLayout const& a) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), group(a.layout_fn())); -+} -+ -+// -+// Slice a ComposedLayout -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_swizzle_strides(true_type, -+ IntZ const& Z, -+ IntY const& Y, -+ Offset const& offset, -+ int_sequence) -+{ -+ // Below is an optimized/compressed version of: -+ //return make_tuple((swizzle(offset + Z*Int<(1 << I)>{}) - swizzle(offset))...); -+ // with knowledge of Swizzle, I... ranges for each B bits, -+ // and the layout won't slice along z-bits that are already set -+ -+ // y\z 0 1 -+ // 0 Z DC -+ // 1 -Z DC -+ -+ return cute::make_tuple(conditional_return((offset & (Y << Int{})) == Int<0>{}, Z << Int{}, -(Z << Int{}))...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_swizzle_strides(false_type, -+ IntZ const& Z, -+ IntY const& Y, -+ Offset const& offset, -+ int_sequence) -+{ -+ // Below is an optimized/compressed version of: -+ //return make_tuple((swizzle(offset + Y*Int<(1 << I)>{}) - swizzle(offset))...); -+ // with knowledge of Swizzle, I... ranges for each B bits, -+ // and the layout won't slice along y-bits that are already set -+ -+ // y\z 0 1 -+ // 0 Y+Z Y-Z -+ // 1 DC DC -+ -+ return cute::make_tuple(conditional_return((offset & (Z << Int{})) == Int<0>{}, (Y+Z) << Int{}, (Y-Z) << Int{})...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout> const& layout) -+{ -+ if constexpr (all_underscore::value) { -+ // Skip the expensive/complicated attempt to decay to a normal layout and just reshape -+ return cute::make_tuple(composition(layout.swizzle_fn(), layout.offset_fn(), slice(coord, layout.layout_fn())), Int<0>{}); -+ } else { -+ -+ // Projections of the swizzle layout for composition -+ auto sw = make_layout(make_shape(Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B)>{}, Int<1>{})); -+ -+ auto swizzle_anti_zy = make_layout(shape(sw), -+ make_stride(stride<0>(sw), Int<0>{}, stride<2>(sw), Int<0>{}, size(sw))); -+ auto swizzle_only_zy = make_layout(shape(sw), -+ make_stride( Int<0>{}, stride<1>(sw), Int<0>{}, stride<3>(sw), Int<0>{})); -+ -+ // The portion of the layout that is not yet consumed -+ auto sliced_layout = slice(coord, layout.layout_fn()); -+ -+ // If the sliced_layout hits two bits that are swizzled together, then don't attempt to decay -+ -+ // Compose with the layout to get the swizzle projection, P o L [The Z and Y contributing portions of L] -+ // (this also tests that shape/stride of layout compose with swizzle) -+ auto sliced_layout_only_zy = composition(swizzle_only_zy, sliced_layout); -+ // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) -+ auto swizzle_active_bits = sliced_layout_only_zy(size(sliced_layout_only_zy)-Int<1>{}); -+ // Determine if any active bits collide under the swizzle -+ auto hit_ZandY = !(swizzle_active_bits & ~layout.swizzle_fn()(swizzle_active_bits)); -+ -+ // The portion of the layout that we are consuming now -+ auto diced_layout = dice(coord, layout.layout_fn()); -+ auto diced_coord = dice(coord, coord); -+ -+ auto diced_layout_anti_zy = composition(swizzle_anti_zy, diced_layout); -+ auto diced_layout_only_zy = composition(swizzle_only_zy, diced_layout); -+ -+ // New swizzle and offset -+ auto swizzle = layout.swizzle_fn(); -+ // offset_only_zy interacts with swizzle and gets accumulated with layout.offset_fn() -+ // being careful about the static/dynamic contributions from diced_layout and diced_coord -+ auto offset_only_zy = layout.offset_fn() ^ to_mixed_bits(diced_layout_only_zy, diced_coord); -+ // offset_anti_zy always gets passed through, no interaction with swizzle -+ auto offset_anti_zy = diced_layout_anti_zy(diced_coord); -+ -+ // If Layout's codomain hits on Y AND Z, then it's not reducible -+ // If Layout's codomain hits on Y XOR Z, then it's dynamic-normal -+ // If Layout's codomain hits on neither Y NOR Z, then it's static-normal -+ -+ // Test the sliced layout for hit_X & hit_Y for potential decay -+ if constexpr (is_constant::value) -+ { // Hits on Y AND Z, so it's not reducible -+ return cute::make_tuple(composition(swizzle, offset_only_zy, sliced_layout), offset_anti_zy); -+ } else -+ { // Misses on Y or Z, so it's static-normal or dynamic-normal -+ -+ // Lowest bit of the Z and Y masks -+ auto Z = typename Swizzle::zzz_msk{} & -typename Swizzle::zzz_msk{}; -+ auto Y = typename Swizzle::yyy_msk{} & -typename Swizzle::yyy_msk{}; -+ auto stride_lo = detail::make_swizzle_strides(Z < Y, Z, Y, offset_only_zy, make_int_sequence{}); -+ auto stride_hi = detail::make_swizzle_strides(Z > Y, Z, Y, offset_only_zy, make_int_sequence{}); -+ -+ // Construct a (dynamic) layout that we can perform the composition with -+ auto swizzle_layout = make_layout(make_shape (Int<(1 << M)>{}, repeat(Int<2>{}), Int<(1 << (abs(S)-B))>{}, repeat(Int<2>{}), Int< 1>{}), -+ make_stride(Int< 1>{}, stride_lo, Int<(1 << (M+B))>{}, stride_hi , Int<(1 << (M+B+abs(S)))>{})); -+ -+ // Decay to a normal layout with offset -+ return cute::make_tuple(composition(swizzle_layout, sliced_layout), -+ swizzle(to_integral(offset_only_zy)) + offset_anti_zy); -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+slice(Coord const& coord, ComposedLayout const& layout) -+{ -+ return get<0>(slice_and_offset(coord, layout)); -+} -+ -+// -+// composition -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Swizzle const& sxor, -+ Offset const& offset, -+ Layout const& layout) -+{ -+ return ComposedLayout>{sxor, offset, layout}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Swizzle const& sxor, -+ Offset const& offset, -+ ComposedLayout const& layout) -+{ -+ // Assume disjoint swizzles and offsets for commutivity -+ return composition(composition(sxor,layout.swizzle_fn()), offset ^ layout.offset_fn(), layout.layout_fn()); -+} -+ -+// Ignore identity case -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Swizzle<0,M,S> const&, -+ Int<0> const&, -+ Layout const& layout) -+{ -+ return layout; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Swizzle const& sxor, -+ Layout const& layout) -+{ -+ return composition(sxor, Int<0>{}, layout); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(ComposedLayout const& a, -+ LayoutOrTile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), composition(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Layout const& a, -+ Swizzle const& b) -+{ -+ // Get the Z bits and the Y bits -+ auto active_Y = a(typename Swizzle::yyy_msk{}); -+ auto active_Z = a(typename Swizzle::zzz_msk{}); -+ -+ // Works in simple cases... but could be greatly generalized -+ -+ return composition(make_swizzle(), a); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Layout const& a, -+ ComposedLayout const& b) -+{ -+ CUTE_STATIC_ASSERT_V(b.offset_fn() == Int<0>{}, "Require Swizzle offset == 0."); -+ -+ return composition(composition(a, b.swizzle_fn()), b.layout_fn()); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(ComposedLayout const& a, -+ ComposedLayout const& b) -+{ -+ auto asb = composition(a.layout_fn(), b); -+ -+ return composition(composition(a.swizzle_fn(),asb.swizzle_fn()), asb.offset_fn(), asb.layout_fn()); -+} -+ -+// -+// complement -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+complement(ComposedLayout const& layout, CoSizeHi const& cosize_hi) -+{ -+ // Assume there is no swizzle component in the complement -+ return complement(layout.layout_fn(), cosize_hi); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+complement(ComposedLayout const& layout) -+{ -+ return complement(layout, cosize(layout)); -+} -+ -+// -+// inverse -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+right_inverse(ComposedLayout const& layout) -+{ -+ CUTE_STATIC_ASSERT_V(layout.offset_fn() == Int<0>{}, "Requires 0-offset."); -+ return composition(right_inverse(layout.layout_fn()), layout.swizzle_fn()); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+left_inverse(ComposedLayout const& layout) -+{ -+ CUTE_STATIC_ASSERT_V(layout.offset_fn() == Int<0>{}, "Requires 0-offset."); -+ return composition(left_inverse(layout.layout_fn()), layout.swizzle_fn()); -+} -+ -+// -+// Other operations -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+max_common_vector(ComposedLayout,Offset,SLayout> const& a, -+ Layout const& b) -+{ -+ // This assumes that Offset is in the YZ domain of the Swizzle... -+ return cute::min(Int<(1 << M)>{}, max_common_vector(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+max_common_vector(Layout const& a, -+ ComposedLayout,Offset,SLayout> const& b) -+{ -+ return max_common_vector(b, a); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+max_common_vector(ComposedLayout,Offset0,SLayout0> const& a, -+ ComposedLayout,Offset1,SLayout1> const& b) -+{ -+ auto result = coalesce(composition(a, right_inverse(b))); -+ -+ if constexpr (is_constant<1, decltype(stride<0>(result.layout_fn()))>::value) { -+ return shape<0>(result); -+ } else { -+ return Int<1>{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip(ComposedLayout const& a) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), zip(a.layout_fn())); -+} -+ -+// Partitions -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_divide(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), logical_divide(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tile_unzip(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), tile_unzip(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tiled_divide(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), tiled_divide(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zipped_divide(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), zipped_divide(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_product(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), logical_product(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tiled_product(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), tiled_product(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+blocked_product(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), blocked_product(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+raked_product(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), raked_product(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tile_to_shape(ComposedLayout const& layout, -+ Shape const& trg_shape, -+ ModeOrder const& ord_shape = {}) -+{ -+ return composition(layout.swizzle_fn(), layout.offset_fn(), tile_to_shape(layout.layout_fn(), trg_shape, ord_shape)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter(ComposedLayout const& layout, Shape const& trg_profile) -+{ -+ return composition(layout.swizzle_fn(), layout.offset_fn(), filter(layout.layout_fn(), trg_profile)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+coalesce(ComposedLayout const& layout) -+{ -+ return composition(layout.swizzle_fn(), layout.offset_fn(), coalesce(layout.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+coalesce(ComposedLayout const& layout, Shape const& trg_profile) -+{ -+ return composition(layout.swizzle_fn(), layout.offset_fn(), coalesce(layout.layout_fn(), trg_profile)); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+// ComposedLayout as second argument is often more difficult... -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_product(Layout const& block, -+ ComposedLayout,Offset,LayoutT> const& tile) -+{ -+ CUTE_STATIC_ASSERT_V(tile.offset_fn() == Int<0>{}, "Require Swizzle offset == 0."); -+ // The new layout -- if swizzle wasn't an issue, this is the result -+ // our goal is to determine a new swizzle for these strides -+ auto new_layout = logical_product(block, tile.layout_fn()); -+ -+ // This is accomplished by identifying -+ // S o L :=: S? o L* -+ // We identify the "active" portion of S by computing (P o L)(c*) where P is a projection generated by S -+ // Then that active identifier is transformed through the layouts: -+ // L*(L[(P o L)(c*)]) -+ // which is a new swizzle identifier for S?, the new swizzle -+ -+ // Projections of the swizzle layout for composition, P -+ auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), -+ make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); -+ -+ // Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L] -+ auto layout_only_zy = composition(swizzle_only_zy, tile.layout_fn()); -+ // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) -+ auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); -+ // Get the Z bit and the Y bits -+ auto active_Z = swizzle_active_bits & typename Swizzle::zzz_msk{}; -+ auto active_Y = swizzle_active_bits & typename Swizzle::yyy_msk{}; -+ -+ // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) -+ auto new_active_Z = new_layout(Int<0>{}, tile.layout_fn()[active_Z]); -+ auto new_active_Y = new_layout(Int<0>{}, tile.layout_fn()[active_Y]); -+ -+ // Use this new swizzle identifier to construxt the new swizzle for new_layout -+ // (this also makes sure it's a "valid" swizzle that Swizzle can represent) -+ return composition(make_swizzle(), new_layout); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tiled_product(Layout const& block, -+ ComposedLayout const& tile) -+{ -+ /// Avoid swizzle slice -+ auto result = logical_product(block, tile); -+ return composition(result.swizzle_fn(), result.offset_fn(), result.layout_fn()(_, repeat>(_))); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+blocked_product(Layout const& block, -+ ComposedLayout const& layout) -+{ -+ constexpr int R = cute::max(rank_v, rank_v); -+ auto padded_block = append(block, Layout<_1,_0>{}); -+ auto padded_layout = append(layout, Layout<_1,_0>{}); -+ -+ auto result = logical_product(padded_block, padded_layout); -+ -+ return composition(result.swizzle_fn(), -+ result.offset_fn(), -+ coalesce(zip(get<0>(result.layout_fn()), get<1>(result.layout_fn())), repeat(Int<1>{}))); -+} -+ -+// -+// Upcast and Downcast -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(ComposedLayout const& layout) -+{ -+ return composition(upcast(layout.swizzle_fn()), upcast(layout.offset_fn()), upcast(layout.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(ComposedLayout const& layout) -+{ -+ return composition(downcast(layout.swizzle_fn()), downcast(layout.offset_fn()), downcast(layout.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(ComposedLayout const& layout) -+{ -+ if constexpr (sizeof(NewType) == sizeof(OldType)) { -+ return layout; -+ } else if constexpr (sizeof(NewType) > sizeof(OldType)) { -+ static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType"); -+ return upcast(layout); -+ } else if constexpr (sizeof(NewType) < sizeof(OldType)) { -+ static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType"); -+ return downcast(layout); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print(ComposedLayout const& layout) -+{ -+ print(layout.swizzle_fn()); print(" o "); print(layout.offset_fn()); print(" o "); print(layout.layout_fn()); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout const& layout) -+{ -+ return os << layout.swizzle_fn() << " o " << layout.offset_fn() << " o " << layout.layout_fn(); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/swizzle_ptr.hpp b/3rdparty/cutlass/include/cute/swizzle_ptr.hpp -new file mode 100644 -index 0000000..ed77acb ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/swizzle_ptr.hpp -@@ -0,0 +1,282 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#include -+#include -+#include -+ -+#include -+#include -+#include -+ -+/* This implements a swizzle pointer of the form -+ * InvolutionFn o PtrAdd -+ * where the InvolutionFn need not be linear. -+ * -+ * This differs subtly from swizzle_layout because the smem pointer is used -+ * as the offset. That means that swizzle_layout will implement position-independent -+ * swizzle layouts, while swizzle_ptr implements position-dependent swizzle tensors. -+ * Arch chose to design hardware with position-dependent swizzles. -+ * -+ * For clarity: -+ * NormalLayout : DeRef <- PtrAdd <- [Layout] -+ * ComposedLayout: DeRef <- PtrAdd <- [Swizzle <- OffsetAdd <- Layout] -+ * SwizzlePtr : [DeRef <- Swizzle <- PtrAdd] <- Layout -+ * -+ * Furthermore, for known swizzles, this pointer attempts to decay itself -+ * to a normal-pointer with a new layout containing dynamic or static strides. -+ * This is possible by determining the subdomain of the InvolutionFn -+ * that is identity and testing if the Layout's codomain is contained -+ * within it. -+ */ -+ -+namespace cute -+{ -+ -+template -+struct smem_ptr_swizzle -+{ -+ static_assert(std::is_empty::value, "Swizzle can't have state."); -+ -+ CUTE_HOST_DEVICE constexpr -+ T* get() const -+ { -+ return ptr_; -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ Swizzle get_swizzle() -+ { -+ return {}; -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ T* apply_swizzle(T* ptr) -+ { -+ return reinterpret_cast(Swizzle::apply(reinterpret_cast(ptr))); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ T& operator*() const -+ { -+ return *apply_swizzle(get()); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ T& operator[](Int const& i) const -+ { -+ return *apply_swizzle(get() + i); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ smem_ptr_swizzle operator+(Int const& i) const -+ { -+ return {ptr_ + i}; -+ } -+ -+ T* ptr_; -+}; -+ -+template -+struct is_smem> : true_type {}; -+ -+// Make a swizzle pointer -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_smem_ptr(T* ptr, Swizzle const& swizzle) -+{ -+ return smem_ptr_swizzle{ptr}; -+} -+ -+// A model of a nullptr smem_ptr with B == sizeof_bits::value -+// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr -+template -+struct smem_ptr_flag_bits : Int<0> {}; -+ -+using smem_ptr_flag = smem_ptr_flag_bits<1>; -+ -+// A flagged construction method to transform ComposedLayout -+// Make a swizzle pointer tensor and check that the intended type size matches -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(smem_ptr const& ptr, -+ ComposedLayout,Layout> const& layout) -+{ -+ static_assert(B == sizeof_bits::value, "Expected a B-bit pointer type."); -+ return make_tensor(make_smem_ptr(ptr.get(), layout.swizzle_fn()), -+ layout.layout_fn()); -+} -+ -+// Specialization for immediate decay -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(smem_ptr_swizzle>& p, Layout const& layout) -+{ -+ return make_tensor(make_smem_ptr(p.ptr_), layout); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(smem_ptr_swizzle> const& p, Layout const& layout) -+{ -+ return make_tensor(make_smem_ptr(p.ptr_), layout); -+} -+ -+// NOTE: To preserve smem_ptr_flag_bits under recast ops -+template -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(ComposedLayout,Layout> const& layout) -+{ -+ return composition(layout.swizzle_fn(), smem_ptr_flag_bits{}, upcast(layout.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(ComposedLayout,Layout> const& layout) -+{ -+ return composition(layout.swizzle_fn(), smem_ptr_flag_bits{}, downcast(layout.layout_fn())); -+} -+ -+// -+// Recast -+// Swizzle operates on the pointer address, so it doesn't care about the type -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(smem_ptr_swizzle const& ptr) -+{ -+ return smem_ptr_swizzle{recast(ptr.ptr_)}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(smem_ptr_swizzle const& ptr) -+{ -+ return smem_ptr_swizzle{recast(ptr.ptr_)}; -+} -+ -+// -+// Conversion with swizzle_layout -+// -+ -+template -+CUTE_HOST_DEVICE -+auto -+as_position_independent_swizzle_layout(ComposedLayout,Layout> const& layout) -+{ -+ return composition(recast,uint_bit_t>(layout.swizzle_fn()), Int<0>{}, layout.layout_fn()); -+} -+ -+template -+CUTE_HOST_DEVICE -+auto -+as_position_independent_swizzle_tensor(Tensor>, Layout> const& tensor) -+{ -+ { -+ uint32_t address = cast_smem_ptr_to_uint(tensor.data().get()); -+ uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code); -+ assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle -+ } -+ auto new_swizzle = recast,uint_bit_t>>(tensor.data().get_swizzle()); -+ return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout())); -+} -+ -+template -+CUTE_HOST_DEVICE -+auto -+as_position_independent_swizzle_tensor(Tensor>, Layout>& tensor) -+{ -+ { -+ uint32_t address = cast_smem_ptr_to_uint(tensor.data().get()); -+ uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code); -+ assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle -+ } -+ auto new_swizzle = recast,uint_bit_t>>(tensor.data().get_swizzle()); -+ return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout())); -+} -+ -+template -+CUTE_HOST_DEVICE -+auto -+as_position_independent_swizzle_tensor(Tensor>, Layout>&& tensor) -+{ -+ return as_position_independent_swizzle_tensor(tensor); -+} -+ -+// -+// Print -+// -+ -+// Capture and cast smem_ptr_flag Layouts to offset-0 layouts -+template -+CUTE_HOST_DEVICE -+void -+print_latex(ComposedLayout,Layout> const& layout) -+{ -+ auto new_swizzle = recast,uint_bit_t>(layout.swizzle_fn()); -+ print_latex(composition(new_swizzle, Int<0>{}, layout.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE void print(smem_ptr_flag_bits const& ptr) -+{ -+ printf("smem_ptr_%db(unset)", B); -+} -+ -+template -+CUTE_HOST_DEVICE void print(smem_ptr_swizzle> const& ptr) -+{ -+ printf("smem_ptr_S<%d,%d,%d>_%db(%p)", B, M, S, int(8*sizeof(T)), ptr.get()); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr_swizzle> const&) -+{ -+ return os << "smem_ptr_S<" << B << "," << M << "," << S << ">_" << int(8*sizeof(T)) << "b"; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/tensor.hpp b/3rdparty/cutlass/include/cute/tensor.hpp -new file mode 100644 -index 0000000..e88c22b ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/tensor.hpp -@@ -0,0 +1,900 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+// -+// Engine -- owning or non-owning data store -+// -+ -+// concept Engine { -+// using value_type = ; -+// iterator begin(); -+// }; -+ -+template -+using ArrayEngine = typename std::conditional<(sizeof_bits::value % 8 == 0), -+ array_aligned, -+ array_subbyte>::type; -+ -+template -+struct ViewEngine -+{ -+ using value_type = typename cute::remove_cvref())>::type; -+ -+ using iterator = Iterator; -+ iterator storage_; -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator const& -+ begin() const { -+ return storage_; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator& -+ begin() { -+ return storage_; -+ } -+}; -+ -+template -+struct is_rmem> : is_rmem {}; -+template -+struct is_smem> : is_smem {}; -+template -+struct is_gmem> : is_gmem {}; -+template -+struct ConstViewEngine -+{ -+ using value_type = typename cute::remove_cvref())>::type; -+ -+ using iterator = Iterator; -+ iterator storage_; -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator const& -+ begin() const { -+ return storage_; -+ } -+}; -+ -+template -+struct is_rmem> : is_rmem {}; -+template -+struct is_smem> : is_smem {}; -+template -+struct is_gmem> : is_gmem {}; -+// -+// Tensor -+// -+ -+template -+struct Tensor -+{ -+ using value_type = typename Engine::value_type; -+ //using pointer = typename engine_traits::pointer; -+ //using const_pointer = typename engine_traits::const_pointer; -+ //using reference = typename engine_traits::reference; -+ //using const_reference = typename engine_traits::const_reference; -+ -+ using engine_type = Engine; -+ using layout_type = Layout; -+ -+ CUTE_HOST_DEVICE constexpr -+ Tensor() {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ Tensor(Ptr const& ptr, Layout const& layout) -+ : rep_(layout, ptr) { -+ } -+ -+ // -+ // Accessors -+ // -+ -+ static constexpr int rank = Layout::rank; -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ tensor() const { -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ layout() const { -+ return get<0>(rep_); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ engine() const { -+ return get<1>(rep_); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ engine() { -+ return get<1>(rep_); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ data() const { -+ return engine().begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ data() { -+ return engine().begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ shape() const { -+ return layout().shape(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ auto -+ size() const { -+ return cute::size(shape()); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ stride() const { -+ return layout().stride(); -+ } -+ -+ // -+ // Indexing op() and op[] -+ // -+ -+ // Index into this tensor like an array by computing the offset via layout() -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator[](Coord const& coord) { -+ return data()[layout()(coord)]; -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator[](Coord const& coord) const { -+ return data()[layout()(coord)]; -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator()(Coord const& coord) { -+ if constexpr (has_underscore::value) { -+ auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); -+ return make_tensor(data() + offset, sliced_layout); -+ } else { -+ return data()[layout()(coord)]; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator()(Coord const& coord) const { -+ if constexpr (has_underscore::value) { -+ auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); -+ return make_tensor(data() + offset, sliced_layout); -+ } else { -+ return data()[layout()(coord)]; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ } -+ -+ // op() convenience function for multi-dimensional coordinates -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) { -+ return operator()(make_coord(c0,c1,cs...)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { -+ return operator()(make_coord(c0,c1,cs...)); -+ } -+ -+ // -+ // Compose -+ // -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ compose(Layouts const&... layouts) { -+ return make_tensor(data(), layout().compose(layouts...)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ compose(Layouts const&... layouts) const { -+ return make_tensor(data(), layout().compose(layouts...)); -+ } -+ -+ // -+ // Tile -+ // -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ tile(Layouts const&... layouts) { -+ return make_tensor(data(), layout().tile(layouts...)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ tile(Layouts const&... layouts) const { -+ return make_tensor(data(), layout().tile(layouts...)); -+ } -+ -+ // -+ // Utility -+ // -+ -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_1d_coord(Int const& linear_idx) const { -+ return layout().get_1d_coord(linear_idx); -+ } -+ -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_hier_coord(Int const& linear_idx) const { -+ return layout().get_hier_coord(linear_idx); -+ } -+ -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_flat_coord(Int const& linear_idx) const { -+ return layout().get_flat_coord(linear_idx); -+ } -+ -+ cute::tuple rep_; -+}; -+ -+ -+template -+struct is_tensor : false_type {}; -+template -+struct is_tensor> : true_type {}; -+ -+template -+struct is_rmem> : is_rmem {}; -+template -+struct is_smem> : is_smem {}; -+template -+struct is_gmem> : is_gmem {}; -+// -+// Make an owning Tensor that will allocate a static array -+// -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(Layout const& layout) -+{ -+ static_assert(is_static::value, "Dynamic owning tensors not supported"); -+ using Engine = ArrayEngine>; -+ return Tensor(); -+} -+ -+// e.g. make_tensor(12) -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(LayoutArg const& arg, LayoutArgs const&... args) -+{ -+ return make_tensor(make_layout(arg, args...)); -+} -+ -+// -+// Make a non-owning Tensor that will use a pointer (view) -+// -+ -+template ::value && -+ is_layout::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(Iterator const& iter, Layout const& layout) -+{ -+ using Engine = ViewEngine; -+ return Tensor(iter, layout); -+} -+ -+// e.g. make_tensor(vec.data(), 12) -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(Iterator const& iter, LayoutArg const& arg, LayoutArgs const&... args) -+{ -+ return make_tensor(iter, make_layout(arg, args...)); -+} -+ -+// -+// make_tensor_like -- make a register tensor the same type and shape as another -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor_like(Tensor const& tensor) -+{ -+ using value_type = typename Tensor::value_type; -+ return make_tensor(tensor.shape()); -+} -+ -+// -+// make_fragment_like -- make a register tensor the same type, shape, and (if possible) order as another tensor -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_fragment_like(Tensor const& tensor) -+{ -+ using value_type = typename Tensor::value_type; -+ return make_tensor(make_layout_like(tensor.layout())); -+} -+ -+// -+// make_identity_tensor -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_identity_tensor(Shape const& shape) -+{ -+ return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat_like(shape, Int<0>{}))), -+ make_identity_layout(shape)); -+} -+ -+// -+// Utilities -+// -+ -+// Return the subtensor of a mode -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+tensor(Tensor&& tensor) -+{ -+ return std::forward(tensor); -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+tensor(Tensor&& tensor) -+{ -+ return make_tensor(std::forward(tensor).data(), get(tensor.layout())); -+} -+ -+// Return the subtensor of a range of modes -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+take(Tensor&& tensor) -+{ -+ return make_tensor(std::forward(tensor).data(), take(tensor.layout())); -+} -+ -+// Return the layout of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+layout(Tensor const& tensor) -+{ -+ return layout(tensor.layout()); -+} -+ -+// Return the shape of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+shape(Tensor const& tensor) -+{ -+ return shape(tensor.layout()); -+} -+ -+// Return the stride of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+stride(Tensor const& tensor) -+{ -+ return stride(tensor.layout()); -+} -+ -+// Return the number of elements in a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+size(Tensor const& tensor) -+{ -+ return size(tensor.layout()); -+} -+ -+// Return the rank of a mode -+template -+CUTE_HOST_DEVICE constexpr -+auto -+rank(Tensor const& tensor) -+{ -+ return rank(tensor.layout()); -+} -+ -+// Return the depth of a mode -+template -+CUTE_HOST_DEVICE constexpr -+auto -+depth(Tensor const& tensor) -+{ -+ return depth(tensor.layout()); -+} -+ -+// -+// Operations to manipulate Tensors like a Layout -+// -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+flatten(Tensor&& tensor) -+{ -+ return make_tensor(std::forward(tensor).data(), flatten(tensor.layout())); -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+coalesce(Tensor&& tensor) -+{ -+ return make_tensor(std::forward(tensor).data(), coalesce(tensor.layout())); -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+coalesce(Tensor&& tensor, Profile const& profile) -+{ -+ return make_tensor(std::forward(tensor).data(), coalesce(tensor.layout(), profile)); -+} -+ -+// Group the modes [B,E) into a single mode -+// e.g. group<2,4>(make_tensor(Layout>{})) -+// => make_tensor(Layout,_5,_6>>{}) -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+group_modes(Tensor&& tensor) -+{ -+ return make_tensor(std::forward(tensor).data(), -+ group(tensor.layout())); -+} -+ -+// -+// Recast -+// -+ -+// NOTE: This is very dangerous to do -+// -- doesn't check dynamic integer divisibility -+// -- doesn't check alignment -+ -+// A tagged version for dispatching -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+recast(Tensor&& tensor, type_list) -+{ -+ using OldType = typename remove_cvref_t::value_type; -+ auto old_layout = tensor.layout(); -+ auto new_layout = recast(old_layout); -+ -+ // If this is an upcast of a normal Layout with static negative strides, then offset as well -+ if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { -+ auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{}); -+ auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{}); -+ auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); }); -+ -+ return make_tensor(recast(std::forward(tensor).data() + offset), new_layout); -+ } else { -+ return make_tensor(recast(std::forward(tensor).data() ), new_layout); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+recast(Tensor&& tensor) -+{ -+ return recast(std::forward(tensor), type_list{}); -+} -+ -+// -+// max_common_vector -+// -+ -+/* Return Int such that N is the maximum number of continguous elements -+ * that logically correspond in the tensors of @a a and @a b. This is, -+ * the number of elements that could reasonably be vectorized into a single load/store. -+ * -+ * @returns Int with N >= 0 -+ * -+ * A return value of Int<0> indicates that no such conclusion can be made and no -+ * vectorization should be attempted. -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+max_common_vector(Tensor const& a, -+ Tensor const& b) -+{ -+ using SrcType = typename Tensor::value_type; -+ using DstType = typename Tensor::value_type; -+ -+ using SrcRef = decltype(*(a.data())); -+ using DstRef = decltype(*(b.data())); -+ -+ // Determine if vectorization candidates at all -+ if constexpr (// Should be the same value_types, else the copy is also performing a cast -+ sizeof(SrcType) == sizeof(DstType) && -+ // The types should be trivially copyable so that vectorization is valid -+ std::is_trivially_copyable::value && -+ std::is_trivially_copyable::value && -+ // Should be load/storing real data, rather than implicit iterators or such -+ std::is_reference::value && -+ std::is_reference::value) -+ { -+ return max_common_vector(a.layout(), b.layout()); -+ } else { -+ return Int<0>{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Key algebraic operations -+// -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+logical_divide(Tensor && tensor, -+ Tile const& tile) -+{ -+ return make_tensor(std::forward(tensor).data(), -+ logical_divide(tensor.layout(), tile)); -+} -+ -+// zipped_divide is logical_divide with modes gathered into standard form ((BLK_A,BLK_B),(a,b)) -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+zipped_divide(Tensor && tensor, -+ Tile const& tile) // Layout or Tile -+{ -+ return make_tensor(std::forward(tensor).data(), -+ zipped_divide(tensor.layout(), tile)); -+} -+ -+// tiled_divide is logical_divide with the second output mode flattened ((BLK_A,BLK_B),a,b) -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+tiled_divide(Tensor && tensor, -+ Tile const& tile) // Layout or Tile -+{ -+ return make_tensor(std::forward(tensor).data(), -+ tiled_divide(tensor.layout(), tile)); -+} -+ -+// logical_product on a Tensor doesn't make sense since it often increases cosize -+ -+// -+// Logicial Divide utilities: local_partition and local_tile -+// -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+local_partition(Tensor && tensor, -+ Tile const& tile, -+ Coord const& coord) -+{ -+ constexpr int R1 = decltype(rank(tensor))::value; -+ -+ // Split the modes of tensor according to the modes of tile -+ // zipped_divide returns something like ((VEC_A,VEC_B,...),(a,b,...)) -+ -+ // The_coord is the coord into the first mode, flatten the rest -+ return zipped_divide(std::forward(tensor), tile)(coord, repeat(_)); -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+local_partition(Tensor && tensor, -+ Tile const& tile, -+ Coord const& coord, -+ Projection const& proj) -+{ -+ return local_partition(std::forward(tensor), -+ dice(proj, tile), -+ dice(proj, coord)); -+} -+ -+// Special case with Layout and Integral that extracts the coord first -+// e.g. local_partition(tensor, ThrLayout, threadIdx.x) -+template >::value && -+ is_integral::value)> -+CUTE_HOST_DEVICE -+auto -+local_partition(Tensor && tensor, -+ Layout const& tile, -+ Index const& index) -+{ -+ return local_partition(std::forward(tensor), -+ product_each(shape(tile)), -+ tile.get_flat_coord(index)); -+} -+ -+// Special case with Layout and Integral that extracts the coord first -+// e.g. local_partition(tensor, ThrLayout, threadIdx.x, Step<_1,X,_1>{}) -+template >::value && -+ is_integral::value)> -+CUTE_HOST_DEVICE -+auto -+local_partition(Tensor && tensor, -+ Layout const& tile, -+ Index const& index, -+ Projection const& proj) -+{ -+ return local_partition(std::forward(tensor), -+ dice(proj, product_each(shape(tile))), -+ dice(proj, tile).get_flat_coord(index)); -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+local_tile(Tensor && tensor, -+ Tile const& tile, -+ Coord const& coord) -+{ -+ constexpr int R0 = decltype(rank(tile))::value; -+ constexpr int R1 = decltype(rank(tensor))::value; -+ -+ // Split the modes of tensor according to the modes of tile -+ // zipped_divide returns something like ((VEC_A,VEC_B,...),(a,b,...)) -+ -+ // The padded_coord is the coord into the second mode, flatten the rest -+ return zipped_divide(std::forward(tensor), tile)(repeat(_), append(coord,_)); -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE -+auto -+local_tile(Tensor && tensor, -+ Tile const& tile, -+ Coord const& coord, -+ Proj const& proj) -+{ -+ return local_tile(std::forward(tensor), -+ dice(proj, tile), -+ dice(proj, coord)); -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor) -+{ -+ auto format = get_format(tensor(0)); -+ using type = typename decltype(format)::type; -+ -+ if constexpr (Layout::rank == 1) -+ { -+ for (int m = 0; m < size(tensor); ++m) { -+ printf(format.format, format.digits, type(tensor(m))); -+ printf("\n"); -+ } -+ } else -+ if constexpr (Layout::rank == 2) -+ { -+ for (int m = 0; m < size<0>(tensor); ++m) { -+ for (int n = 0; n < size<1>(tensor); ++n) { -+ printf(format.format, format.digits, type(tensor(m,n))); -+ } -+ printf("\n"); -+ } -+ } else -+ if constexpr (Layout::rank == 3) -+ { -+ print_tensor(tensor(_,_,0)); -+ for (int k = 1; k < size<2>(tensor); ++k) { -+ for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("-"); } print("\n"); -+ print_tensor(tensor(_,_,k)); -+ } -+ } else -+ if constexpr (Layout::rank == 4) -+ { -+ print_tensor(tensor(_,_,_,0)); -+ for (int p = 1; p < size<3>(tensor); ++p) { -+ for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("="); } print("\n"); -+ print_tensor(tensor(_,_,_,p)); -+ } -+ } -+} -+ -+template -+CUTE_HOST_DEVICE void print(Tensor const& tensor) -+{ -+ print(tensor.layout()); print("\n"); -+ print_tensor(tensor); -+} -+ -+template -+CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor const& tensor) -+{ -+ int digits = 9; -+ -+ if constexpr (Layout::rank == 1) -+ { -+ for (int m = 0; m < size(tensor); ++m) { -+ os << std::setw(digits) << tensor(m) << std::endl; -+ } -+ } else -+ if constexpr (Layout::rank == 2) -+ { -+ for (int m = 0; m < size<0>(tensor); ++m) { -+ for (int n = 0; n < size<1>(tensor); ++n) { -+ os << std::setw(digits) << tensor(m,n); -+ } -+ os << std::endl; -+ } -+ } else -+ if constexpr (Layout::rank == 3) -+ { -+ print_tensor_os(os, tensor(_,_,0)); -+ for (int k = 1; k < size<2>(tensor); ++k) { -+ for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl; -+ print_tensor_os(os, tensor(_,_,k)); -+ } -+ } else -+ if constexpr (Layout::rank == 4) -+ { -+ print_tensor_os(os, tensor(_,_,_,0)); -+ for (int p = 1; p < size<3>(tensor); ++p) { -+ for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl; -+ print_tensor_os(os, tensor(_,_,_,p)); -+ } -+ } -+ -+ return os; -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const& tensor) -+{ -+ os << tensor.layout() << std::endl; -+ return print_tensor_os(os, tensor); -+} -+ -+} // end namespace cute -+ -+// -+// Extended Engines -+// -+ -+#include -+ -+// -+// Tensor Algorithms -+// -+ -+#include -+#include -+#include -+#include -+#include -+#include -diff --git a/3rdparty/cutlass/include/cute/tensor_predicate.hpp b/3rdparty/cutlass/include/cute/tensor_predicate.hpp -new file mode 100644 -index 0000000..730f219 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/tensor_predicate.hpp -@@ -0,0 +1,63 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template -+struct ConstantTensor -+{ -+ template -+ CUTE_HOST_DEVICE constexpr -+ T const& -+ operator()(Coords const&...) const { -+ return val_; -+ } -+ -+ T val_; -+}; -+ -+struct TrivialPredTensor -+{ -+ template -+ CUTE_HOST_DEVICE constexpr -+ true_type -+ operator()(Coords const&...) const { -+ return {}; -+ } -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/tile.hpp b/3rdparty/cutlass/include/cute/tile.hpp -new file mode 100644 -index 0000000..b2fa2e8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/tile.hpp -@@ -0,0 +1,58 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// A Tile is not a Layout, it's a tuple of Layouts or Tiles or Underscores -+// -+ -+template -+using Tile = tuple; -+ -+template -+using is_tile = is_tuple; -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_tile(Layouts const&... layouts) -+{ -+ return Tile(layouts...); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/underscore.hpp b/3rdparty/cutlass/include/cute/underscore.hpp -new file mode 100644 -index 0000000..d79b4ee ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/underscore.hpp -@@ -0,0 +1,148 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+// For slicing -+struct Underscore : Int<0> {}; -+ -+CUTE_INLINE_CONSTANT Underscore _; -+ -+// Treat Underscore as an integral like integral_constant -+template <> -+struct is_integral : true_type {}; -+ -+template -+struct is_underscore : false_type {}; -+template <> -+struct is_underscore : true_type {}; -+ -+// Tuple trait for detecting static member element -+template -+struct has_elem : false_type {}; -+template -+struct has_elem : true_type {}; -+template -+struct has_elem::value> > -+ : has_elem > {}; -+template -+struct has_elem> -+ : disjunction, Elem>...> {}; -+ -+// Tuple trait for detecting static member element -+template -+struct all_elem : false_type {}; -+template -+struct all_elem : true_type {}; -+template -+struct all_elem::value> > -+ : all_elem > {}; -+template -+struct all_elem> -+ : conjunction, Elem>...> {}; -+ -+// Tuple trait for detecting Underscore member -+template -+using has_underscore = has_elem; -+ -+template -+using all_underscore = all_elem; -+ -+template -+using has_int1 = has_elem>; -+ -+template -+using has_int0 = has_elem>; -+ -+// -+// Slice keeps only the elements of Tuple B that are paired with an Underscore -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+slice(A const& a, B const& b) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return filter_tuple(a, b, [](auto const& x, auto const& y) { return slice(x,y); }); -+ } else if constexpr (is_underscore::value) { -+ return cute::tuple{b}; -+ } else { -+ return cute::tuple<>{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Dice keeps only the elements of Tuple B that are paired with an Int -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+dice(A const& a, B const& b) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return filter_tuple(a, b, [](auto const& x, auto const& y) { return dice(x,y); }); -+ } else if constexpr (is_underscore::value) { -+ return cute::tuple<>{}; -+ } else { -+ return cute::tuple{b}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Display utilities -+// -+ -+CUTE_HOST_DEVICE void print(Underscore const&) { -+ printf("_"); -+} -+ -+CUTE_HOST std::ostream& operator<<(std::ostream& os, Underscore const&) { -+ return os << "_"; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/util/debug.hpp b/3rdparty/cutlass/include/cute/util/debug.hpp -new file mode 100644 -index 0000000..9a62143 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/util/debug.hpp -@@ -0,0 +1,153 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+/** -+ * \file -+ * \brief Debugging and logging functionality -+ */ -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+/****************************************************************************** -+ * Debug and logging macros -+ ******************************************************************************/ -+ -+/** -+ * Formats and prints the given message to stdout -+ */ -+#if !defined(CUTE_LOG) -+# if !defined(__CUDA_ARCH__) -+# define CUTE_LOG(format, ...) printf(format, __VA_ARGS__) -+# else -+# define CUTE_LOG(format, ...) \ -+ printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ -+ blockIdx.x, blockIdx.y, blockIdx.z, \ -+ threadIdx.x, threadIdx.y, threadIdx.z, \ -+ __VA_ARGS__); -+# endif -+#endif -+ -+/** -+ * Formats and prints the given message to stdout only if DEBUG is defined -+ */ -+#if !defined(CUTE_LOG_DEBUG) -+# ifdef DEBUG -+# define CUTE_LOG_DEBUG(format, ...) CUTE_LOG(format, __VA_ARGS__) -+# else -+# define CUTE_LOG_DEBUG(format, ...) -+# endif -+#endif -+ -+/** -+ * \brief Perror macro with exit -+ */ -+#if !defined(CUTE_ERROR_EXIT) -+# define CUTE_ERROR_EXIT(e) \ -+ do { \ -+ cudaError_t code = (e); \ -+ if (code != cudaSuccess) { \ -+ fprintf(stderr, "<%s:%d> %s:\n %s: %s\n", \ -+ __FILE__, __LINE__, #e, \ -+ cudaGetErrorName(code), cudaGetErrorString(code)); \ -+ fflush(stderr); \ -+ exit(0); \ -+ } \ -+ } while (0) -+#endif -+ -+#if !defined(CUTE_CHECK_LAST) -+# define CUTE_CHECK_LAST() CUTE_ERROR_EXIT(cudaPeekAtLastError()); CUTE_ERROR_EXIT(cudaDeviceSynchronize()) -+#endif -+ -+#if !defined(CUTE_CHECK_ERROR) -+# define CUTE_CHECK_ERROR(e) CUTE_ERROR_EXIT(e) -+#endif -+ -+// A dummy function that uses compilation failure to print a type -+template -+CUTE_HOST_DEVICE -+void -+print_type(T&&) { -+ static_assert(sizeof(T) < 0, "Printing type T."); -+} -+ -+// -+// Device-specific helpers -+// -+// e.g. -+// if (thread0()) print(...); -+// if (block0()) print(...); -+// if (thread(42)) print(...); -+ -+CUTE_HOST_DEVICE -+bool -+thread(int tid, int bid) -+{ -+#if defined(__CUDA_ARCH__) -+ return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) -+ && ( blockIdx.x + blockIdx.y* gridDim.x + blockIdx.z* gridDim.x* gridDim.y == bid); -+#else -+ return true; -+#endif -+} -+ -+CUTE_HOST_DEVICE -+bool -+thread(int tid) -+{ -+ return thread(tid, 0); -+} -+ -+CUTE_HOST_DEVICE -+bool -+thread0() -+{ -+ return thread(0,0); -+} -+ -+CUTE_HOST_DEVICE -+bool -+block0() -+{ -+#if defined(__CUDA_ARCH__) -+ return !(blockIdx.x | blockIdx.y | blockIdx.z); -+#else -+ return true; -+#endif -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/util/print.hpp b/3rdparty/cutlass/include/cute/util/print.hpp -new file mode 100644 -index 0000000..ec774b0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/util/print.hpp -@@ -0,0 +1,140 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+// -+// CUDA compatible print and printf -+// -+ -+namespace cute -+{ -+ -+CUTE_HOST_DEVICE -+int -+num_digits(int x) -+{ -+ return (x < 10 ? 1 : -+ (x < 100 ? 2 : -+ (x < 1000 ? 3 : -+ (x < 10000 ? 4 : -+ (x < 100000 ? 5 : -+ (x < 1000000 ? 6 : -+ (x < 10000000 ? 7 : -+ (x < 100000000 ? 8 : -+ (x < 1000000000 ? 9 : -+ 10))))))))); -+} -+ -+template -+struct format_and_size { -+ using type = T; -+ char const* format; -+ int digits; -+}; -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(bool) { -+ return {"%*d", 3}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(int32_t) { -+ return {"%*d", 5}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(uint32_t) { -+ return {"%*d", 5}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(int64_t) { -+ return {"%*d", 5}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(uint64_t) { -+ return {"%*d", 5}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(half_t) { -+ return {"%*.2f", 8}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(float) { -+ return {"%*.2e", 10}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(double) { -+ return {"%*.3e", 11}; -+} -+ -+// -+// print dispatcher -+// -+ -+CUTE_HOST_DEVICE -+void -+print(char const& c) { -+ printf("%c", c); -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE -+void -+print(T const& a) { -+ printf("%d", int(a)); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+print(char const* format, T const&... t) { -+ printf(format, t...); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/util/type_traits.hpp b/3rdparty/cutlass/include/cute/util/type_traits.hpp -new file mode 100644 -index 0000000..4d37eb9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/util/type_traits.hpp -@@ -0,0 +1,101 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#define __CUTE_REQUIRES(...) typename std::enable_if<(__VA_ARGS__)>::type* = nullptr -+#define __CUTE_REQUIRES_V(...) typename std::enable_if::type* = nullptr -+ -+namespace cute -+{ -+ -+using std::conjunction; -+using std::conjunction_v; -+ -+using std::disjunction; -+using std::disjunction_v; -+ -+using std::negation; -+using std::negation_v; -+ -+using std::void_t; -+ -+// C++20 -+// using std::remove_cvref; -+template -+struct remove_cvref { -+ using type = std::remove_cv_t>; -+}; -+ -+// C++20 -+// using std::remove_cvref_t; -+template -+using remove_cvref_t = typename remove_cvref::type; -+ -+// -+// is_valid -+// -+ -+namespace detail { -+ -+template ()(std::declval()...))> -+CUTE_HOST_DEVICE constexpr auto -+is_valid_impl(int) { return std::true_type{}; } -+ -+template -+CUTE_HOST_DEVICE constexpr auto -+is_valid_impl(...) { return std::false_type{}; } -+ -+template -+struct is_valid_fn { -+ template -+ CUTE_HOST_DEVICE constexpr auto -+ operator()(Args&&...) const { return is_valid_impl(int{}); } -+}; -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr auto -+is_valid(F&&) { -+ return detail::is_valid_fn{}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr auto -+is_valid(F&&, Args&&...) { -+ return detail::is_valid_impl(int{}); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cutlass/aligned_buffer.h b/3rdparty/cutlass/include/cutlass/aligned_buffer.h -new file mode 100644 -index 0000000..1b29277 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/aligned_buffer.h -@@ -0,0 +1,129 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief AlignedBuffer is a container for trivially copyable elements suitable for use in -+ unions and shared memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Modifies semantics of cutlass::Array<> to provide guaranteed alignment. -+template < -+ typename T, -+ int N, -+ int Align = 16 -+> -+struct AlignedBuffer { -+ -+ /// Internal storage type -+ using Storage = uint8_t; -+ -+ /// Number of logical elements held in buffer -+ static int const kCount = N; -+ -+ /// Alignment requirement in bytes -+ static int const kAlign = Align; -+ -+ /// Number of storage elements -+ static int const kBytes = -+ (sizeof_bits::value * N + 7) / 8; -+ -+private: -+ -+ /// Internal storage -+ alignas(Align) Storage storage[kBytes]; -+ -+public: -+ -+ // -+ // C++ standard members -+ // -+ -+ typedef T value_type; -+ typedef size_t size_type; -+ typedef ptrdiff_t difference_type; -+ typedef value_type *pointer; -+ typedef value_type const * const_pointer; -+ -+ using Array = Array; -+ using reference = typename Array::reference; -+ using const_reference = typename Array::const_reference; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ pointer data() { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_pointer data() const { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Storage * raw_data() { -+ return storage; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Storage const * raw_data() const { -+ return storage; -+ } -+ -+ -+ CUTLASS_HOST_DEVICE -+ constexpr bool empty() const { -+ return !kCount; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ constexpr size_type size() const { -+ return kCount; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ constexpr size_type max_size() const { -+ return kCount; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/arch/arch.h b/3rdparty/cutlass/include/cutlass/arch/arch.h -new file mode 100644 -index 0000000..043bfac ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/arch.h -@@ -0,0 +1,107 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines tags for architecture-specific configurations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -+ -+/// Computes laneId within a warp -+CUTLASS_DEVICE -+int LaneId() { -+ int ret; -+ asm ("mov.u32 %0, %%laneid;" : "=r"(ret) : ); -+ return ret; -+} -+ -+/// Computes SM number the thread is running on -+CUTLASS_DEVICE -+int SmId() { -+ int ret; -+ asm ("mov.u32 %0, %%smid;" : "=r"(ret) : ); -+ return ret; -+} -+ -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+struct Sm50 { -+ static int const kMinComputeCapability = 50; -+}; -+struct Sm60 { -+ static int const kMinComputeCapability = 60; -+}; -+struct Sm61 { -+ static int const kMinComputeCapability = 61; -+}; -+struct Sm70 { -+ static int const kMinComputeCapability = 70; -+}; -+struct Sm72 { -+ static int const kMinComputeCapability = 72; -+}; -+struct Sm75 { -+ static int const kMinComputeCapability = 75; -+}; -+struct Sm80 { -+ static int const kMinComputeCapability = 80; -+}; -+struct Sm86 { -+ static int const kMinComputeCapability = 86; -+}; -+ -+struct Sm90 { -+ static int const kMinComputeCapability = 90; -+}; -+ -+/// Triggers a breakpoint on the device -+CUTLASS_DEVICE -+void device_breakpoint() { -+#if defined(__CUDA_ARCH__) -+ asm volatile (" brkpt;\n"); -+#endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/barrier.h b/3rdparty/cutlass/include/cutlass/arch/barrier.h -new file mode 100644 -index 0000000..34f0b4e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/barrier.h -@@ -0,0 +1,404 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Redistribution and use in source and binary forms, with or without modification, are not permit- -+ * ted. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Barrier Operations on SM90+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+namespace cutlass { -+/// @brief -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12) -+#define CUDA_BARRIER_ENABLED 1 -+#else -+#define CUDA_BARRIER_ENABLED 0 -+#endif -+ -+class NamedBarrier { -+ -+ // Data Members: -+ -+ // Range = [1 , NUM_THREADS_PER_CTA] -+ // Range % warp-size (i.e 32) == 0 -+ uint32_t const num_threads_; -+ -+ // Range : [0, 15] -+ uint32_t const id_; -+ -+ public: -+ -+ CUTLASS_DEVICE -+ NamedBarrier(uint32_t num_threads, uint32_t id = 0) -+ : num_threads_(num_threads), id_(id) {} -+ -+ CUTLASS_DEVICE -+ void arrive_and_wait() const { -+ NamedBarrier::arrive_and_wait(num_threads_, id_); -+ } -+ -+ CUTLASS_DEVICE -+ void arrive() const { -+ NamedBarrier::arrive(num_threads_, id_); -+ } -+ -+ CUTLASS_DEVICE -+ void sync() const { -+ NamedBarrier::arrive_and_wait(); -+ } -+ -+ // Static variants -+ CUTLASS_DEVICE -+ static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) { -+#if CUDA_BARRIER_ENABLED -+ asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ static void arrive(uint32_t num_threads, uint32_t barrier_id) { -+#if CUDA_BARRIER_ENABLED -+ asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ static void sync(uint32_t num_threads, uint32_t barrier_id) { -+ NamedBarrier::arrive_and_wait(num_threads, barrier_id); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Hopper introduces a new cluster-wide barrier which handle with Cluster-wide AW behaviour. -+// This is an extension to the Ampere AW barriers -+// Note : Ampere AW Barriers have a larger max-arrive count (2^30) than Hopper AW Barriers (2^20). -+struct ClusterBarrier { -+ -+ using ValueType = uint64_t; -+ -+protected: -+ // Can never be initializated - can only be aliased to smem -+ ValueType barrier_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ ClusterBarrier() = delete; -+ -+ CUTLASS_DEVICE -+ void init(uint32_t arrive_count) const { -+ ClusterBarrier::init(&this->barrier_, arrive_count); -+ } -+ -+ CUTLASS_DEVICE -+ uint32_t test_wait(uint32_t phase, uint32_t pred=true) const { -+ return ClusterBarrier::test_wait(&this->barrier_, phase, pred); -+ } -+ -+ CUTLASS_DEVICE -+ void wait(uint32_t phase) const { -+ ClusterBarrier::wait(&this->barrier_, phase); -+ } -+ -+ // Barrier arrive on local smem -+ CUTLASS_DEVICE -+ void arrive() const { -+ ClusterBarrier::arrive(&this->barrier_); -+ } -+ -+ // Remote SMEM arrive with a perdicate (usually done to pick the thread doing the arrive) -+ CUTLASS_DEVICE -+ void arrive(uint32_t cta_id, uint32_t pred = true ) const { -+ ClusterBarrier::arrive(&this->barrier_, cta_id, pred); -+ } -+ -+ // -+ // Static Versions -+ // -+ CUTLASS_DEVICE -+ static void init(ValueType const* smem_ptr, uint32_t arrive_count) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile( -+ "{\n\t" -+ "mbarrier.init.shared.b64 [%1], %0; \n" -+ "}" -+ : -+ : "r"(arrive_count), "r"(smem_addr)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ // Static version of wait - in case we don't want to burn a register -+ CUTLASS_DEVICE -+ static void wait(ValueType const* smem_ptr, uint32_t phase) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ // Arbitrarily large timer value after which try-wait expires and re-tries. -+ uint32_t ticks = 0x989680; -+ asm volatile( -+ "{\n\t" -+ ".reg .pred P1; \n\t" -+ "LAB_WAIT: \n\t" -+ "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t" -+ "@P1 bra.uni DONE; \n\t" -+ "bra.uni LAB_WAIT; \n\t" -+ "DONE: \n\t" -+ "}" -+ : -+ : "r"(smem_addr), "r"(phase), "r"(ticks)); -+ -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ static uint32_t test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ uint32_t waitComplete; -+ -+ asm volatile( -+ "{\n\t" -+ ".reg .pred P1; \n\t" -+ ".reg .pred P2; \n\t" -+ "setp.eq.u32 P2, %3, 1;\n\t" -+ "@P2 mbarrier.test_wait.parity.shared.b64 P1, [%1], %2; \n\t" -+ "selp.b32 %0, 1, 0, P1; \n\t" -+ "}" -+ : "=r"(waitComplete) -+ : "r"(smem_addr), "r"(phase), "r"(pred)); -+ -+ return waitComplete; -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ return 0; -+ } -+ -+ // Static Predicated version of the above - in case we know the address. -+ CUTLASS_DEVICE -+ static void arrive(ValueType const* smem_ptr, uint32_t cta_id, uint32_t pred) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile( -+ "{\n\t" -+ ".reg .pred p;\n\t" -+ ".reg .b32 remAddr32;\n\t" -+ "setp.eq.u32 p, %2, 1;\n\t" -+ "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" -+ "@p mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" -+ "}" -+ : -+ : "r"(smem_addr), "r"(cta_id), "r"(pred)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ // Barrier arrive on local smem -+ CUTLASS_DEVICE -+ static void arrive(ValueType const* smem_ptr) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ uint64_t state = 0; -+ asm volatile( -+ "{\n\t" -+ "mbarrier.arrive.shared.b64 %1, [%0];\n\t" -+ "}" -+ : -+ : "r"(smem_addr), "l"(state)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ static void invalidate(ValueType const* smem_ptr) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile( -+ "{\n\t" -+ "mbarrier.ival.shared.b64 [%0]; \n\t" -+ "}" -+ : -+ : "r"(smem_addr)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// SM90 also introduces a new type of cluster-barrier which supports sync. -+// not just based on Arrive Count, but also transaction count (in bytes) -+struct ClusterTransactionBarrier : public ClusterBarrier { -+ -+ CUTLASS_DEVICE -+ ClusterTransactionBarrier() = delete; -+ -+ // Performs an arrive operation + bytes reset -+ CUTLASS_DEVICE -+ void arrive_and_reset_bytes(uint32_t transaction_bytes) const { -+ ClusterTransactionBarrier::arrive_and_reset_bytes(&this->barrier_, transaction_bytes); -+ } -+ -+ // Performs an arrive operation + bytes reset -+ CUTLASS_DEVICE -+ void arrive_and_reset_bytes(uint32_t transaction_bytes, uint32_t cta_id) const { -+ ClusterTransactionBarrier::arrive_and_reset_bytes(&this->barrier_, transaction_bytes , cta_id, true); -+ } -+ -+ CUTLASS_DEVICE -+ void commit(uint32_t transaction_bytes, uint32_t pred = 1) const { -+ uint32_t cta_rank = cute::block_rank_in_cluster(); -+ ClusterTransactionBarrier::commit(&this->barrier_, cta_rank, transaction_bytes, pred); -+ } -+ -+ CUTLASS_DEVICE -+ void commit(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { -+ ClusterTransactionBarrier::commit(&this->barrier_, dst_cta_id, transaction_bytes, pred); -+ } -+ -+ // -+ // Static Versions -+ // -+ -+ // Performs an arrive operation + bytes reset -+ CUTLASS_DEVICE -+ static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile( -+ "{\n\t" -+ "mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0; \n\t" -+ "}" -+ : -+ : "r"(transaction_bytes), "r"(smem_addr)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ // Performs an arrive operation + bytes reset for a remote cta_id in a Cluster -+ CUTLASS_DEVICE -+ static void arrive_and_reset_bytes( -+ ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile( -+ "{\n\t" -+ ".reg .pred p;\n\t" -+ ".reg .b32 remAddr32;\n\t" -+ "setp.eq.u32 p, %2, 1;\n\t" -+ "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" -+ "@p mbarrier.arrive.expect_tx.shared::cluster.b64 _, [remAddr32], %3;\n\t" -+ "}" -+ : -+ : "r"(smem_addr), "r"(cta_id), "r"(pred), "r"(transaction_bytes)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ // Performs an bytes reset without doing an arrive operation -+ CUTLASS_DEVICE -+ static void reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile( -+ "{\n\t" -+ "mbarrier.expect_tx.shared.b64 [%1], %0; \n\t" -+ "}" -+ : -+ : "r"(transaction_bytes), "r"(smem_addr)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ // Increments transaction bytes in the barrier -+ CUTLASS_DEVICE -+ static void commit( -+ ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ smem_addr = cute::set_block_rank(smem_addr, dst_cta_id); -+ asm volatile( -+ "{\n\t" -+ ".reg .pred p;\n\t" -+ "setp.eq.u32 p, %2, 1;\n\t" -+ "@p mbarrier.complete_tx.shared::cluster.relaxed.cluster.b64 [%1], %0;" -+ "}" -+ : -+ : "r"(transaction_bytes), "r"(smem_addr), "r"(pred)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+}; -+ -+// Helps with visibility of barrier init operations across warps / cta / cluster -+// Available as a separate function so as to batch inits across barriers and fence once -+// Note : It must be composed with an appropriate sync instruction with the right scope -+// to ensure visibility eg. __syncthreads() or a cluster_arrive() + cluster_wait() -+CUTLASS_DEVICE -+void fence_barrier_init() { -+#if CUDA_BARRIER_ENABLED -+ asm volatile( -+ "{\n\t" -+ "fence.mbarrier_init.release.cluster; \n" -+ "}" -+ ::); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+} -+ -+// Issue a shared memory fence for async operations -+CUTLASS_DEVICE -+void fence_view_async_shared() { -+#if CUDA_BARRIER_ENABLED -+ asm volatile ( -+ "{\n\t" -+ "fence.proxy.async.shared::cta; \n" -+ "}" -+ ::); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+} // end namespace arch -+} // end namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/cache_operation.h b/3rdparty/cutlass/include/cutlass/arch/cache_operation.h -new file mode 100644 -index 0000000..fa70c4c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/cache_operation.h -@@ -0,0 +1,66 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Directives related to cache operations -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Controls PTX cache operations -+struct CacheOperation { -+ enum Kind { -+ /// Cache at all levels - accessed again -+ Always, -+ /// Cache at global level -+ Global, -+ /// Streaming - likely to be accessed once -+ Streaming, -+ /// Indicates the line will not be used again -+ LastUse, -+ /// Don't cache, and fetch again -+ Volatile, -+ /// Write back at all coherent levels -+ WriteBack, -+ /// Write through to system memory -+ WriteThrough -+ }; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/memory.h b/3rdparty/cutlass/include/cutlass/arch/memory.h -new file mode 100644 -index 0000000..b2a9468 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/memory.h -@@ -0,0 +1,474 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Architecture-specific operators on memory -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Fragment type to store loaded data -+ typename AccessType, -+ /// The bytes of loading -+ int LoadBytes -+ > -+struct global_load; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Specializations -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ -+ (__CUDACC_VER_MAJOR__ > 11)) && \ -+ defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ -+ ! (defined(__clang__) && defined(__CUDA__)) -+ #define CUTLASS_ENABLE_L2_PREFETCH 1 -+#else -+ #define CUTLASS_ENABLE_L2_PREFETCH 0 -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The redundant mov PTX instruction is used to enforce the compiler to -+// keep the initializing code before ld.global -+template -+struct global_load { -+ CUTLASS_DEVICE -+ global_load(AccessType &D, void const *ptr, bool pred_guard) { -+ uint4 *data = reinterpret_cast(&D); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %9, 0;\n" -+ " mov.b32 %0, %10;\n" -+ " mov.b32 %1, %11;\n" -+ " mov.b32 %2, %12;\n" -+ " mov.b32 %3, %13;\n" -+ " mov.b32 %4, %14;\n" -+ " mov.b32 %5, %15;\n" -+ " mov.b32 %6, %16;\n" -+ " mov.b32 %7, %17;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%8];\n" -+ " @p ld.global.L2::128B.v4.u32 {%4, %5, %6, %7}, [%18];\n" -+#else -+ " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" -+ " @p ld.global.v4.u32 {%4, %5, %6, %7}, [%18];\n" -+#endif -+ "}\n" -+ : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w), -+ "=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w) -+ : "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y), -+ "r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y), -+ "r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16)); -+ } -+}; -+ -+template -+struct global_load { -+ CUTLASS_DEVICE -+ global_load(AccessType &D, void const *ptr, bool pred_guard) { -+ uint4 &data = reinterpret_cast(D); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %5, 0;\n" -+ " mov.b32 %0, %6;\n" -+ " mov.b32 %1, %7;\n" -+ " mov.b32 %2, %8;\n" -+ " mov.b32 %3, %9;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%4];\n" -+#else -+ " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" -+#endif -+ "}\n" -+ : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) -+ : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); -+ } -+}; -+ -+template -+struct global_load { -+ CUTLASS_DEVICE -+ global_load(AccessType &D, void const *ptr, bool pred_guard) { -+ uint2 &data = reinterpret_cast(D); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %3, 0;\n" -+ " mov.b32 %0, %4;\n" -+ " mov.b32 %1, %5;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p ld.global.L2::128B.v2.u32 {%0, %1}, [%2];\n" -+#else -+ " @p ld.global.v2.u32 {%0, %1}, [%2];\n" -+#endif -+ "}\n" -+ : "=r"(data.x), "=r"(data.y) -+ : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y)); -+ } -+}; -+ -+template -+struct global_load { -+ CUTLASS_DEVICE -+ global_load(AccessType &D, void const *ptr, bool pred_guard) { -+ unsigned &data = reinterpret_cast(D); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %2, 0;\n" -+ " mov.b32 %0, %3;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p ld.global.L2::128B.u32 %0, [%1];\n" -+#else -+ " @p ld.global.u32 %0, [%1];\n" -+#endif -+ "}\n" -+ : "=r"(data) -+ : "l"(ptr), "r"((int)pred_guard), "r"(data)); -+ } -+}; -+ -+template -+struct global_load { -+ CUTLASS_DEVICE -+ global_load(AccessType &D, void const *ptr, bool pred_guard) { -+ uint16_t &data = reinterpret_cast(D); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %2, 0;\n" -+ " mov.b16 %0, %3;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p ld.global.L2::128B.u16 %0, [%1];\n" -+#else -+ " @p ld.global.u16 %0, [%1];\n" -+#endif -+ "}\n" -+ : "=h"(data) -+ : "l"(ptr), "r"((int)pred_guard), "h"(data)); -+ } -+}; -+ -+template -+struct global_load { -+ CUTLASS_DEVICE -+ global_load(AccessType &D, void const *ptr, bool pred_guard) { -+ if (pred_guard) D = *(reinterpret_cast(ptr)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Fragment type to store data -+ typename AccessType, -+ /// The bytes of storing -+ int StoreBytes -+ > -+struct global_store; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Specializations -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ uint4 const *data = reinterpret_cast(&D); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %5, 0;\n" -+ " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" -+ " @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" -+ " @p st.global.v4.u32 [%11], {%12, %13, %14, %15};\n" -+ " @p st.global.v4.u32 [%16], {%17, %18, %19, %20};\n" -+ "}\n" -+ : -+ : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), -+ "r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16), -+ "r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w), -+ "l"(((uint8_t *)ptr) + 32), -+ "r"(data[2].x), "r"(data[2].y), "r"(data[2].z), "r"(data[2].w), -+ "l"(((uint8_t *)ptr) + 48), -+ "r"(data[3].x), "r"(data[3].y), "r"(data[3].z), "r"(data[3].w)); -+ } -+}; -+ -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ uint4 const *data = reinterpret_cast(&D); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %5, 0;\n" -+ " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" -+ " @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" -+ "}\n" -+ : -+ : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), -+ "r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16), -+ "r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); -+ } -+}; -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ uint4 const &data = reinterpret_cast(D); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %5, 0;\n" -+ " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" -+ "}\n" -+ : -+ : "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w), "r"((int)pred_guard)); -+ } -+}; -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ uint2 const &data = reinterpret_cast(D); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %3, 0;\n" -+ " @p st.global.v2.u32 [%0], {%1, %2};\n" -+ "}\n" -+ : -+ : "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard)); -+ } -+}; -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ uint32_t const &data = reinterpret_cast(D); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %2, 0;\n" -+ " @p st.global.u32 [%0], %1;\n" -+ "}\n" -+ : -+ : "l"(ptr), "r"(data), "r"((int)pred_guard)); -+ } -+}; -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ uint16_t const &data = reinterpret_cast(D); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %2, 0;\n" -+ " @p st.global.u16 [%0], %1;\n" -+ "}\n" -+ : -+ : "l"(ptr), "h"(data), "r"((int)pred_guard)); -+ } -+}; -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ if (pred_guard) *(reinterpret_cast(ptr)) = D; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// ld.shared -+template -+CUTLASS_DEVICE -+void shared_load(void *dst, uint32_t ptr); -+ -+/// ld.shared - 16b -+template <> -+CUTLASS_DEVICE -+void shared_load<2>(void *dst, uint32_t ptr) { -+ asm volatile("ld.shared.u16 %0, [%1];\n" -+ : "=h"(*reinterpret_cast(dst)) -+ : "r"(ptr)); -+} -+ -+/// ld.shared - 32b -+template <> -+CUTLASS_DEVICE -+void shared_load<4>(void *dst, uint32_t ptr) { -+ asm volatile("ld.shared.u32 %0, [%1];\n" -+ : "=r"(*reinterpret_cast(dst)) -+ : "r"(ptr)); -+} -+ -+/// ld.shared - 64b -+template <> -+CUTLASS_DEVICE -+void shared_load<8>(void *dst, uint32_t ptr) { -+ uint2 *dst_u64 = reinterpret_cast(dst); -+ asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" -+ : -+ "=r"(dst_u64->x), -+ "=r"(dst_u64->y) -+ : "r"(ptr)); -+} -+ -+/// ld.shared - 128b -+template <> -+CUTLASS_DEVICE -+void shared_load<16>(void *dst, uint32_t ptr) { -+ uint4 *dst_u128 = reinterpret_cast(dst); -+ asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" -+ : -+ "=r"(dst_u128->x), -+ "=r"(dst_u128->y), -+ "=r"(dst_u128->z), -+ "=r"(dst_u128->w) -+ : "r"(ptr)); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// st.shared -+template -+CUTLASS_DEVICE -+void shared_store(uint32_t ptr, void const *src); -+ -+/// st.shared - 16b -+template <> -+CUTLASS_DEVICE -+void shared_store<2>(uint32_t ptr, void const *src) { -+ asm volatile("st.shared.u16 [%0], %1;\n" -+ : : -+ "r"(ptr), -+ "h"(*reinterpret_cast(src)) -+ ); -+} -+ -+/// st.shared - 32b -+template <> -+CUTLASS_DEVICE -+void shared_store<4>(uint32_t ptr, void const *src) { -+ asm volatile("st.shared.u32 [%0], %1;\n" -+ : : -+ "r"(ptr), -+ "r"(*reinterpret_cast(src)) -+ ); -+} -+ -+/// st.shared - 64b -+template <> -+CUTLASS_DEVICE -+void shared_store<8>(uint32_t ptr, void const *src) { -+ uint2 const *dst_u64 = reinterpret_cast(src); -+ asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" -+ : : -+ "r"(ptr), -+ "r"(dst_u64->x), -+ "r"(dst_u64->y) -+ ); -+} -+ -+/// st.shared - 128b -+template <> -+CUTLASS_DEVICE -+void shared_store<16>(uint32_t ptr, void const *src) { -+ uint4 const *dst_u128 = reinterpret_cast(src); -+ asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" -+ : : -+ "r"(ptr), -+ "r"(dst_u128->x), -+ "r"(dst_u128->y), -+ "r"(dst_u128->z), -+ "r"(dst_u128->w) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "memory_sm75.h" -+#include "memory_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/memory_sm75.h b/3rdparty/cutlass/include/cutlass/arch/memory_sm75.h -new file mode 100644 -index 0000000..ba59364 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/memory_sm75.h -@@ -0,0 +1,279 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Architecture-specific operators on memory added for SM75 -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cute/arch/util.hpp" -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Layout of destination matrix (column-major implies transpose) -+ typename Layout, -+ /// .x1, .x2, or .x4 -+ int MatrixCount -+> -+inline __device__ void ldsm(Array & D, void const* ptr); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Determine the appropriate way to target PTX's "ldmatrix" instruction. -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || (__CUDACC_VER_MAJOR__ >= 11) -+ -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) -+#define CUDA_LDMATRIX_ACTIVATED 1 -+#endif -+ -+#define CUDA_LDMATRIX_SUPPORTED 1 -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// CUTLASS helper to get SMEM pointer -+inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) { -+ return cute::cast_smem_ptr_to_uint(ptr); -+} -+ -+/// CUTLASS helper to get SMEM pointer -+inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) { -+ return cutlass_get_smem_pointer(const_cast(ptr)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+inline __device__ void ldsm( -+ Array & D, -+ void const* ptr) { -+ -+ #if defined(CUDA_LDMATRIX_ACTIVATED) -+ -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ int x; -+ asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr)); -+ reinterpret_cast(D) = x; -+ -+ #else -+ -+ CUTLASS_UNUSED(D); -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+inline __device__ void ldsm( -+ Array & D, -+ void const* ptr) { -+ -+ #if defined(CUDA_LDMATRIX_ACTIVATED) -+ -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ int x, y; -+ asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr)); -+ reinterpret_cast(D) = make_int2(x, y); -+ -+ #else -+ -+ CUTLASS_UNUSED(D); -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+inline __device__ void ldsm( -+ Array & D, -+ void const* ptr) { -+ -+ #if defined(CUDA_LDMATRIX_ACTIVATED) -+ -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ int x, y, z, w; -+ asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr)); -+ reinterpret_cast(D) = make_int4(x, y, z, w); -+ -+ #else -+ -+ CUTLASS_UNUSED(D); -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Transpose on 16b granularity -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+inline __device__ void ldsm( -+ Array & D, -+ void const* ptr) { -+ -+ #if CUDA_LDMATRIX_ACTIVATED -+ -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ int x; -+ asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr)); -+ reinterpret_cast(D) = x; -+ -+ #else -+ -+ CUTLASS_UNUSED(D); -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+inline __device__ void ldsm( -+ Array & D, -+ void const* ptr) { -+ -+ #if defined(CUDA_LDMATRIX_ACTIVATED) -+ -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ int x, y; -+ asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr)); -+ reinterpret_cast(D) = make_int2(x, y); -+ -+ #else -+ -+ CUTLASS_UNUSED(D); -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+inline __device__ void ldsm( -+ Array & D, -+ void const* ptr) { -+ -+ #if defined(CUDA_LDMATRIX_ACTIVATED) -+ -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ int x, y, z, w; -+ asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr)); -+ reinterpret_cast(D) = make_int4(x, y, z, w); -+ -+ #else -+ -+ CUTLASS_UNUSED(D); -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct shared_load_op { -+ CUTLASS_DEVICE -+ shared_load_op(AccessType &D, void const *ptr) { -+ D = *reinterpret_cast(ptr); -+ } -+}; -+ -+template -+CUTLASS_DEVICE void shared_load(AccessType &D, void const *ptr) { -+ shared_load_op(D, ptr); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct shared_load_op { -+ CUTLASS_DEVICE -+ shared_load_op(AccessType &D, void const *ptr) { -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ uint4 v; -+ asm volatile ("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];" : -+ "=r"(v.x), "=r"(v.y), "=r"(v.z), "=r"(v.w) : "r"(addr)); -+ -+ D = reinterpret_cast(v); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct shared_load_op { -+ CUTLASS_DEVICE -+ shared_load_op(AccessType &D, void const *ptr) { -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ uint2 v; -+ asm volatile ("ld.shared.v2.b32 {%0, %1}, [%2];" : -+ "=r"(v.x), "=r"(v.y) : "r"(addr)); -+ -+ D = reinterpret_cast(v); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/memory_sm80.h b/3rdparty/cutlass/include/cutlass/arch/memory_sm80.h -new file mode 100644 -index 0000000..04bab1d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/memory_sm80.h -@@ -0,0 +1,466 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Architecture-specific operators on memory added for SM80 -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/cache_operation.h" -+ -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ #define CUDA_CP_ASYNC_ACTIVATED 1 -+#else -+ #define CUDA_CP_ASYNC_ACTIVATED 0 -+#endif -+ -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Initiates an asynchronous copy from global memory to shared memory. -+/// -+/// cp.async -+/// -+template < -+ /// Size of the access in bytes -+ int SizeInBytes, -+ /// Cache operation -+ CacheOperation::Kind cache_op = CacheOperation::Always> -+struct cp_async; -+ -+/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate -+/// the entire transfer, zeros are written to SMEM if the guard predicate is false. -+/// -+/// cp.async -+/// -+template < -+ /// Size of the access in bytes -+ int SizeInBytes, -+ /// Cache operation -+ CacheOperation::Kind cache_op = CacheOperation::Always> -+struct cp_async_zfill; -+ -+/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate -+/// the entire transfer, nans (0x7eff) are written to SMEM if the guard predicate is false. -+/// -+/// cp.async -+/// -+template < -+ /// Size of the access in bytes -+ int SizeInBytes, -+ /// Cache operation -+ CacheOperation::Kind cache_op = CacheOperation::Always> -+struct cp_async_nan; -+ -+/// Either 0 or 1 are written to SMEM based on input element type -+/// Used for diagonal elements of triangular matrix of BLAS3 functions -+/// -+/// st.shared -+/// -+template < -+ /// Type of Element -+ typename Element, -+ /// If the data is for a Hermitian matrix diagonal -+ bool IsHermitianData = false> -+struct cp_async_diag; -+ -+static const uint32_t OOB_NAN_F16 = 0x7eff; -+static const uint32_t OOB_NAN_F16x2 = ((OOB_NAN_F16 << 16) | OOB_NAN_F16); -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization -+template < -+ /// Size of the access in bytes -+ int SizeInBytes> -+struct cp_async { -+ -+ /// Copy -+ CUTLASS_DEVICE -+ cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ // Make sure the size is supported. -+ static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16), -+ "Size is not supported"); -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %0, 0;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n" -+#else -+ " @p cp.async.ca.shared.global [%1], [%2], %3;\n" -+#endif -+ "}\n" ::"r"((int)pred_guard), -+ "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); -+ -+ #else -+ using AccessType = Array; -+ -+ if (pred_guard) { -+ *static_cast(smem_ptr) = *static_cast(global_ptr); -+ } -+ #endif -+ } -+}; -+ -+/// Partial specialization -+template < -+ /// Size of the access in bytes -+ int SizeInBytes> -+struct cp_async_zfill { -+ -+ /// Copy with zero fill -+ CUTLASS_DEVICE -+ cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ // Make sure the size is supported. -+ static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16), -+ "Size is not supported"); -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ int src_in_bytes = (pred_guard ? SizeInBytes : 0); -+ -+ asm volatile( -+#if CUTLASS_ENABLE_L2_PREFETCH -+ "cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), -+#else -+ "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), -+#endif -+ "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes)); -+ -+ #else -+ using AccessType = Array; -+ -+ if (pred_guard) { -+ *static_cast(smem_ptr) = *static_cast(global_ptr); -+ } -+ else { -+ AccessType zeros; -+ zeros.clear(); -+ *static_cast(smem_ptr) = zeros; -+ } -+ #endif -+ } -+}; -+ -+/// Partial specialization -+template <> -+struct cp_async_nan<16, CacheOperation::Always> { -+ static int const kSizeInBytes = 16; -+ -+ /// Copy with nan fill -+ CUTLASS_DEVICE -+ cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2, -+ OOB_NAN_F16x2, OOB_NAN_F16x2}; -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %0, 0;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n" -+#else -+ " @p cp.async.ca.shared.global [%1], [%2], %3;\n" -+#endif -+ " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n" -+ "}\n" -+ : -+ : "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), -+ "n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z), -+ "r"(OOB_NAN_F16x8.w)); -+ -+ #else -+ -+ CUTLASS_UNUSED(smem_ptr); -+ CUTLASS_UNUSED(global_ptr); -+ CUTLASS_UNUSED(pred_guard); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+ } -+}; -+ -+/// Partial specialization to write one (1) -+template -+struct cp_async_diag { -+ using Element = Element_; -+ -+ CUTLASS_DEVICE -+ cp_async_diag(void *smem_ptr) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ /// Values for the diagonal elements of the triangular input matrix -+ static __constant__ uint2 DIAG_DATA_DOUBLE_ONE = {0x3ff00000, 0x00000000}; -+ static __constant__ uint1 DIAG_DATA_FLOAT_ONE = {0x3f800000}; -+ static __constant__ uint1 DIAG_DATA_ZERO = {0x00000000}; -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ -+ if (platform::is_same>::value) { -+ asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" -+ : : -+ "r"(smem_int_ptr), "r"(DIAG_DATA_DOUBLE_ONE.y), "r"(DIAG_DATA_DOUBLE_ONE.x), -+ "r"(DIAG_DATA_ZERO.x), "r"(DIAG_DATA_ZERO.x)); -+ } else if (platform::is_same>::value) { -+ asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" -+ : : -+ "r"(smem_int_ptr), "r"(DIAG_DATA_FLOAT_ONE.x), "r"(DIAG_DATA_ZERO.x)); -+ } else if (platform::is_same::value) { -+ asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" -+ : : -+ "r"(smem_int_ptr), "r"(DIAG_DATA_DOUBLE_ONE.y),"r"(DIAG_DATA_DOUBLE_ONE.x)); -+ } else if (platform::is_same::value) { -+ asm volatile("st.shared.u32 [%0], %1;\n" -+ : : -+ "r"(smem_int_ptr), "r"(DIAG_DATA_FLOAT_ONE.x)); -+ } else { -+ CUTLASS_UNUSED(smem_int_ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ } -+ -+ #else -+ -+ CUTLASS_UNUSED(smem_ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+ } -+}; -+ -+/// Partial specialization to write zero for the imaginary part of Hermitian data -+template -+struct cp_async_diag { -+ using Element = Element_; -+ -+ CUTLASS_DEVICE -+ cp_async_diag(void *smem_ptr) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ /// Values for the diagonal elements of the triangular input matrix -+ static __constant__ uint1 DIAG_DATA_ZERO = {0x00000000}; -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ -+ if (platform::is_same>::value) { -+ asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" -+ : : -+ "r"(smem_int_ptr), "r"(DIAG_DATA_ZERO.x), "r"(DIAG_DATA_ZERO.x)); -+ } else if (platform::is_same>::value) { -+ asm volatile("st.shared.u32 [%0], %1;\n" -+ : : -+ "r"(smem_int_ptr), "r"(DIAG_DATA_ZERO.x)); -+ } else { -+ CUTLASS_UNUSED(smem_int_ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ } -+ -+ #else -+ -+ CUTLASS_UNUSED(smem_ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization -+template < -+ /// Size of the access in bytes -+ int SizeInBytes> -+struct cp_async { -+ -+ /// Copy -+ CUTLASS_DEVICE -+ cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ static_assert(SizeInBytes == 16, -+ "cp.async only supports CacheOperation::Global when access size is 16B."); -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %0, 0;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" -+#else -+ " @p cp.async.cg.shared.global [%1], [%2], %3;\n" -+#endif -+ "}\n" ::"r"((int)pred_guard), -+ "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); -+ -+ #else -+ using AccessType = Array; -+ -+ if (pred_guard) { -+ *static_cast(smem_ptr) = *static_cast(global_ptr); -+ } -+ #endif -+ } -+}; -+ -+/// Partial specialization -+template < -+ /// Size of the access in bytes -+ int SizeInBytes> -+struct cp_async_zfill { -+ -+ /// Copy with zero fill -+ CUTLASS_DEVICE -+ cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ static_assert(SizeInBytes == 16, -+ "cp.async only supports CacheOperation::Global when access size is 16B."); -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ int src_in_bytes = (pred_guard ? SizeInBytes : 0); -+ -+ asm volatile( -+#if CUTLASS_ENABLE_L2_PREFETCH -+ "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), -+#else -+ "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), -+#endif -+ "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes)); -+ -+ #else -+ using AccessType = Array; -+ -+ if (pred_guard) { -+ *static_cast(smem_ptr) = *static_cast(global_ptr); -+ } -+ else { -+ AccessType zeros; -+ zeros.clear(); -+ *static_cast(smem_ptr) = zeros; -+ } -+ #endif -+ } -+}; -+ -+/// Partial specialization -+template <> -+struct cp_async_nan<16, CacheOperation::Global> { -+ static int const kSizeInBytes = 16; -+ -+ /// Copy with nan fill -+ CUTLASS_DEVICE -+ cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2, -+ OOB_NAN_F16x2, OOB_NAN_F16x2}; -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %0, 0;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" -+#else -+ " @p cp.async.cg.shared.global [%1], [%2], %3;\n" -+#endif -+ " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n" -+ "}\n" -+ : -+ : "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), -+ "n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z), -+ "r"(OOB_NAN_F16x8.w)); -+ -+ #else -+ -+ CUTLASS_UNUSED(smem_ptr); -+ CUTLASS_UNUSED(global_ptr); -+ CUTLASS_UNUSED(pred_guard); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+ } -+}; -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. -+CUTLASS_DEVICE -+void cp_async_fence() { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ asm volatile("cp.async.commit_group;\n" ::); -+ #endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Blocks until all but previous cp.async.commit_group operations have committed. -+template -+CUTLASS_DEVICE void cp_async_wait() { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); -+ #endif -+} -+ -+/// Blocks until all previous cp.async.commit_group operations have committed. -+template <> -+CUTLASS_DEVICE void cp_async_wait<0>() { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ asm volatile("cp.async.wait_all;\n" ::); -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma.h b/3rdparty/cutlass/include/cutlass/arch/mma.h -new file mode 100644 -index 0000000..7d4d693 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma.h -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for multiply-add operations -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/arch/arch.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the operation implied by MMA. -+struct OpMultiplyAdd {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the result is saturated to MAX_FLOAT|MIN_FLOAT or MAX_INT|MIN_INT -+struct OpMultiplyAddSaturate {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the input is converted to a narrower type (BF16) -+struct OpMultiplyAddFastBF16 {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the input is converted to a narrower type (F16) -+struct OpMultiplyAddFastF16 {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the input is converted to 2 (big and small) TF32 components -+// Perform 3xTF32 or 4xTF32 for every F32 output element -+struct OpMultiplyAddFastF32 {}; -+ -+/// Tag indicating the input is converted to 2 (big and small) TF32 components -+// Perform 3xTF32 or 4xTF32 for every complex output element -+struct OpMultiplyAddComplexFastF32 {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the complex multiply-add operation -+struct OpMultiplyAddComplex {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the gaussian complex multiply-add operation -+struct OpMultiplyAddGaussianComplex {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the inner product is defined by (XOR, POPC) -+struct OpXorPopc {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag classifying math operators as thread-level operations. -+struct OpClassSimt {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag classifing operators as Tensor Core operations. -+struct OpClassTensorOp {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Tag classifing operators as WMMA Tensor Core operations -+struct OpClassWmmaTensorOp {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Size of the matrix product (concept: GemmShape) -+ typename Shape_, -+ /// Number of threads participating -+ int kThreads_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Inner product operator -+ typename Operator -+> -+struct Mma; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation - specialized for 1x1x1x1 matrix multiply operation -+template < -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Inner product operator -+ typename Operator_ -+> -+struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, Operator_> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = Operator_; -+ using ElementC = ElementC_; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+ multiply_add op; -+ -+ d[0] = op(a[0], b[0], c[0]); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specifies internal data type for computation -+struct SPFormatType { -+ enum Kind { -+ Thread -+ }; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Size of the matrix product (concept: GemmShape) -+ typename Shape_, -+ /// Number of threads participating -+ int kThreads_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Inner product operator -+ typename Operator, -+ /// Specifies meta data format -+ SPFormatType::Kind SPFormat = SPFormatType::Thread -+> -+struct SparseMma; -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Specializations for each compute capability -+// -+ -+#include "cutlass/arch/mma_sm50.h" -+#include "cutlass/arch/mma_sm60.h" -+#include "cutlass/arch/mma_sm61.h" -+#include "cutlass/arch/mma_sm70.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+#include "cutlass/arch/mma_sparse_sm80.h" -+#include "cutlass/arch/mma_sm90.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm50.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm50.h -new file mode 100644 -index 0000000..8aca344 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm50.h -@@ -0,0 +1,432 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/mma.h" -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = float; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ d[0] = a[0] * b[0] + c[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = double; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+ d[0] = a[0] * b[0] + c[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = int; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+ d[0] = a[0] * b[0] + c[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma< -+ gemm::GemmShape<1, 1, 1>, -+ 1, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAddComplex; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array, 1> const &a, -+ Array, 1> const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0].real() * b[0].real() + c[0].real(); -+ d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); -+ d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); -+ d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma< -+ gemm::GemmShape<1, 1, 1>, -+ 1, -+ complex, -+ LayoutA, -+ float, -+ LayoutB, -+ complex, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAddComplex; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array, 1> const &a, -+ Array const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0].real() * b[0] + c[0].real(); -+ d[0].imag() = a[0].imag() * b[0] + c[0].imag(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma< -+ gemm::GemmShape<1, 1, 1>, -+ 1, -+ float, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAddComplex; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array const &a, -+ Array, 1> const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0] * b[0].real() + c[0].real(); -+ d[0].imag() = a[0] * b[0].imag() + d[0].imag(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma< -+ gemm::GemmShape<1, 1, 1>, -+ 1, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAddComplex; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array, 1> const &a, -+ Array, 1> const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0].real() * b[0].real() + c[0].real(); -+ d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); -+ d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); -+ d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); -+ } -+}; -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma< -+ gemm::GemmShape<1, 1, 1>, -+ 1, -+ complex, -+ LayoutA, -+ double, -+ LayoutB, -+ complex, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAddComplex; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array, 1> const &a, -+ Array const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0].real() * b[0] + c[0].real(); -+ d[0].imag() = a[0].imag() * b[0] + c[0].imag(); -+ } -+}; -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma< -+ gemm::GemmShape<1, 1, 1>, -+ 1, -+ double, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAddComplex; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array const &a, -+ Array, 1> const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0] * b[0].real() + c[0].real(); -+ d[0].imag() = a[0] * b[0].imag() + d[0].imag(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = float; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ d[0] = float(a[0]) * float(b[0]) + c[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation for Quaternions -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma, 1, Quaternion, LayoutA, Quaternion, LayoutB, Quaternion, LayoutC, OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAdd; -+ using Element = Quaternion; -+ using ElementC = Element; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ multiply_add op; -+ d[0] = op(a[0], b[0], c[0]); -+ } -+ -+}; -+ -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm60.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm60.h -new file mode 100644 -index 0000000..349c838 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm60.h -@@ -0,0 +1,252 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/arch/mma.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template -+struct Mma< -+ gemm::GemmShape<2,1,1>, -+ 1, -+ half_t, -+ LayoutA, -+ half_t, -+ LayoutB, -+ half_t, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<2, 1, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) -+ -+ __half2 const & A = reinterpret_cast<__half2 const &>(a); -+ __half2 B = __half2half2(reinterpret_cast<__half const &>(b)); -+ __half2 const & C = reinterpret_cast<__half2 const &>(c); -+ -+ __half2 D = __hfma2(A, B, C); -+ -+ d = reinterpret_cast &>(D); -+ -+#else -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ d[i] = a[i] * b[0] + c[i]; -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template -+struct Mma< -+ gemm::GemmShape<1,2,1>, -+ 1, -+ half_t, -+ LayoutA, -+ half_t, -+ LayoutB, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 2, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) -+ -+ __half2 const & A = __half2half2(reinterpret_cast<__half const &>(a)); -+ __half2 B = reinterpret_cast<__half2 const &>(b); -+ __half2 const & C = reinterpret_cast<__half2 const &>(c); -+ -+ __half2 D = __hfma2(A, B, C); -+ -+ d = reinterpret_cast &>(D); -+ -+#else -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ d[i] = a[0] * b[i] + c[i]; -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template <> -+struct Mma < -+ gemm::GemmShape<2, 2, 1>, -+ 1, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<2, 2, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) -+ -+ __half2 const & A = reinterpret_cast<__half2 const &>(a); -+ __half2 Blo = __low2half2(reinterpret_cast<__half2 const &>(b)); -+ __half2 Bhi = __high2half2(reinterpret_cast<__half2 const &>(b)); -+ -+ __half2 const *C = reinterpret_cast<__half2 const *>(&c); -+ -+ __half2 Dlo = __hfma2(A, Blo, C[0]); -+ __half2 Dhi = __hfma2(A, Bhi, C[1]); -+ -+ Array * D = reinterpret_cast *>(&d); -+ -+ D[0] = reinterpret_cast const &>(Dlo); -+ D[1] = reinterpret_cast const &>(Dhi); -+ -+#else -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < 2; ++j) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j]; -+ } -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template <> -+struct Mma< -+ gemm::GemmShape<2, 2, 1>, -+ 1, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<2, 2, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) -+ -+ __half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a)); -+ __half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a)); -+ __half2 const & B = reinterpret_cast<__half2 const &>(b); -+ -+ __half2 const *C = reinterpret_cast<__half2 const *>(&c); -+ -+ __half2 Dlo = __hfma2(Alo, B, C[0]); -+ __half2 Dhi = __hfma2(Ahi, B, C[0]); -+ -+ Array * D = reinterpret_cast *>(&d); -+ -+ D[0] = reinterpret_cast &>(Dlo); -+ D[1] = reinterpret_cast &>(Dhi); -+#else -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < 2; ++j) { -+ d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j]; -+ } -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm61.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm61.h -new file mode 100644 -index 0000000..a1af935 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm61.h -@@ -0,0 +1,142 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#include "cutlass/layout/matrix.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template -+struct Mma< -+ gemm::GemmShape<1,1,4>, -+ 1, -+ int8_t, -+ LayoutA, -+ int8_t, -+ LayoutB, -+ int, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 4>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = int; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) -+ -+ unsigned const &A = reinterpret_cast(a); -+ unsigned const &B = reinterpret_cast(b); -+ -+ asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" -+ : "=r"(d[0]) -+ : "r"(A), "r"(B), "r"(c[0])); -+ -+#else -+ -+ d[0] = c[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < 4; ++k) { -+ d[0] += a[k] * b[k]; -+ } -+ -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template -+struct Mma< -+ gemm::GemmShape<1, 1, 2>, -+ 1, -+ int16_t, -+ layout::RowMajor, -+ int16_t, -+ layout::ColumnMajor, -+ int, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 2>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = int; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) -+ -+ unsigned const &A = reinterpret_cast(a); -+ unsigned const &B = reinterpret_cast(b); -+ -+ asm volatile("dp2a.s32.s32 %0, %1, %2, %3;" -+ : "=r"(d[0]) -+ : "r"(A), "r"(B), "r"(c[0])); -+#else -+ d[0] = c[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < 2; ++k) { -+ d[0] += a[k] * b[k]; -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm70.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm70.h -new file mode 100644 -index 0000000..9f93714 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm70.h -@@ -0,0 +1,665 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "mma.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) -+#define CUTLASS_ARCH_MMA_SM70_SUPPORTED -+#endif -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) -+ -+#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 &&__CUDACC_VER_MINOR__ >= 1)) -+#define CUTLASS_ARCH_MMA_SM70_ENABLED -+#endif -+ -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix multiply accumulate 884 - FP16 accumulation -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,4>, -+ 8, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::ColumnMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ unsigned const *C = reinterpret_cast(&c); -+ unsigned *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::ColumnMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::RowMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ unsigned const *C = reinterpret_cast(&c); -+ unsigned *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ unsigned const *C = reinterpret_cast(&c); -+ unsigned *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::RowMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ unsigned const *C = reinterpret_cast(&c); -+ unsigned *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix multiply accumulate 884 - FP32 accumulation -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::ColumnMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ /// Multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " -+ "{%12,%13,%14,%15,%16,%17,%18,%19};\n" -+ : "=f"(D[0]), -+ "=f"(D[1]), -+ "=f"(D[2]), -+ "=f"(D[3]), -+ "=f"(D[4]), -+ "=f"(D[5]), -+ "=f"(D[6]), -+ "=f"(D[7]) -+ : "r"(A[0]), -+ "r"(A[1]), -+ "r"(B[0]), -+ "r"(B[1]), -+ "f"(C[0]), -+ "f"(C[1]), -+ "f"(C[2]), -+ "f"(C[3]), -+ "f"(C[4]), -+ "f"(C[5]), -+ "f"(C[6]), -+ "f"(C[7]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::ColumnMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::RowMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ /// Multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " -+ "{%12,%13,%14,%15,%16,%17,%18,%19};\n" -+ : "=f"(D[0]), -+ "=f"(D[1]), -+ "=f"(D[2]), -+ "=f"(D[3]), -+ "=f"(D[4]), -+ "=f"(D[5]), -+ "=f"(D[6]), -+ "=f"(D[7]) -+ : "r"(A[0]), -+ "r"(A[1]), -+ "r"(B[0]), -+ "r"(B[1]), -+ "f"(C[0]), -+ "f"(C[1]), -+ "f"(C[2]), -+ "f"(C[3]), -+ "f"(C[4]), -+ "f"(C[5]), -+ "f"(C[6]), -+ "f"(C[7]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ /// Multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " -+ "{%12,%13,%14,%15,%16,%17,%18,%19};\n" -+ : "=f"(D[0]), -+ "=f"(D[1]), -+ "=f"(D[2]), -+ "=f"(D[3]), -+ "=f"(D[4]), -+ "=f"(D[5]), -+ "=f"(D[6]), -+ "=f"(D[7]) -+ : "r"(A[0]), -+ "r"(A[1]), -+ "r"(B[0]), -+ "r"(B[1]), -+ "f"(C[0]), -+ "f"(C[1]), -+ "f"(C[2]), -+ "f"(C[3]), -+ "f"(C[4]), -+ "f"(C[5]), -+ "f"(C[6]), -+ "f"(C[7]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::RowMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ /// Multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " -+ "{%12,%13,%14,%15,%16,%17,%18,%19};\n" -+ : "=f"(D[0]), -+ "=f"(D[1]), -+ "=f"(D[2]), -+ "=f"(D[3]), -+ "=f"(D[4]), -+ "=f"(D[5]), -+ "=f"(D[6]), -+ "=f"(D[7]) -+ : "r"(A[0]), -+ "r"(A[1]), -+ "r"(B[0]), -+ "r"(B[1]), -+ "f"(C[0]), -+ "f"(C[1]), -+ "f"(C[2]), -+ "f"(C[3]), -+ "f"(C[4]), -+ "f"(C[5]), -+ "f"(C[6]), -+ "f"(C[7]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation specialized for the entire warp -+template < -+ typename LayoutA, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename Operator -+> -+struct Mma< -+ gemm::GemmShape<16, 16, 4>, -+ 32, -+ half_t, -+ LayoutA, -+ half_t, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator -+> : -+ public Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ LayoutA, -+ half_t, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator> { -+ -+ using Shape = gemm::GemmShape<16, 16, 4>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm75.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm75.h -new file mode 100644 -index 0000000..1402e76 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm75.h -@@ -0,0 +1,1301 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply for SM75 -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+// CUDA Toolkit includes for nvcuda::wmma needed for binarized matrix multiply. -+#include -+#include "cutlass/wmma_array.h" -+#endif -+ -+// CUTLASS includes -+#include "cutlass/arch/mma.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) -+ -+#define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1 -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) -+#define CUTLASS_ARCH_MMA_SM75_ENABLED -+#endif -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 1688 - FP16 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation - F16 = F16 * F16 + F16 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 8>, -+ 32, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 8>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ unsigned const *C = reinterpret_cast(&c); -+ unsigned *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 1688 - FP32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 8>, -+ 32, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 8>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : -+ "r"(A[0]), "r"(A[1]), -+ "r"(B[0]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) -+ ); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Integer matrix multiply .8816 (8b) -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.s8.u8 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Integer matrix multiply (8b) with SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Integer matrix multiply (4b) -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ int4b_t, -+ layout::RowMajor, -+ int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ uint4b_t, -+ layout::RowMajor, -+ int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ int4b_t, -+ layout::RowMajor, -+ uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ uint4b_t, -+ layout::RowMajor, -+ uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Integer matrix multiply (4b) - SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ int4b_t, -+ layout::RowMajor, -+ int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ uint4b_t, -+ layout::RowMajor, -+ int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ int4b_t, -+ layout::RowMajor, -+ uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ uint4b_t, -+ layout::RowMajor, -+ uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// b1 ^ b1 + s32 => s32 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,128>, -+ 32, -+ uint1b_t, -+ layout::RowMajor, -+ uint1b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpXorPopc> { -+ -+ using Shape = gemm::GemmShape<8,8,128>; -+ -+ using ElementA = uint1b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint1b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpXorPopc; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+#if (__CUDA_ARCH__ >= 900) || (defined(CUTLASS_ARCH_WMMA_ENABLED)) -+ using WmmaFragmentA = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_a, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ nvcuda::wmma::experimental::precision::b1, -+ nvcuda::wmma::row_major>; -+ -+ using WmmaFragmentB = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_b, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ nvcuda::wmma::experimental::precision::b1, -+ nvcuda::wmma::col_major>; -+ -+ using WmmaFragmentC = nvcuda::wmma::fragment< -+ nvcuda::wmma::accumulator, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ int>; -+ -+ WmmaFragmentA const & A = reinterpret_cast(a); -+ WmmaFragmentB const & B = reinterpret_cast(b); -+ -+ WmmaFragmentC const & C = reinterpret_cast(c); -+ WmmaFragmentC & D = reinterpret_cast(d); -+ -+ nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, -+ nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); // WMMA must be supported to issue binary matrix multiply-accumulate instructions. -+ -+#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm80.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm80.h -new file mode 100644 -index 0000000..8682ae1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm80.h -@@ -0,0 +1,2185 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "mma.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) -+ -+#define CUTLASS_ARCH_MMA_SM80_SUPPORTED 1 -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+#define CUTLASS_ARCH_MMA_SM80_ENABLED -+#endif -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 1688 - Float BF16, FP32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation - F32 = bf16 * bf16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 8>, -+ 32, -+ bfloat16_t, -+ layout::RowMajor, -+ bfloat16_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 8>; -+ -+ using ElementA = bfloat16_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = bfloat16_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm( -+ "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " -+ "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : -+ "r"(A[0]), "r"(A[1]), -+ "r"(B[0]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) -+ ); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 1684 - Float TF32 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 4>, -+ 32, -+ tfloat32_t, -+ layout::RowMajor, -+ tfloat32_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 4>; -+ -+ using ElementA = tfloat32_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = tfloat32_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : -+ "r"(A[0]), "r"(A[1]), -+ "r"(B[0]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) -+ ); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 1688 - Float TF32 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 -+template <> -+struct Mma, 32, tfloat32_t, layout::RowMajor, -+ tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, -+ OpMultiplyAdd> { -+ using Shape = gemm::GemmShape<16, 8, 8>; -+ -+ using ElementA = tfloat32_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = tfloat32_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16816 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 16>, -+ 32, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 16>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ uint32_t const *C = reinterpret_cast(&c); -+ uint32_t *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), -+ "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]) -+ ); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 16>, -+ 32, -+ bfloat16_t, -+ layout::RowMajor, -+ bfloat16_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 16>; -+ -+ using ElementA = bfloat16_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = bfloat16_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 16>, -+ 32, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 16>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " -+ "{%10,%11,%12,%13};\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 884 - F64 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F64 = F64 * F64 + F64 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,4>, -+ 32, -+ double, -+ layout::RowMajor, -+ double, -+ layout::ColumnMajor, -+ double, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8,8,4>; -+ -+ using ElementA = double; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = double; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = double; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ -+ using ArchTag = arch::Sm80; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ double const & A = reinterpret_cast(a); -+ double const & B = reinterpret_cast(b); -+ -+ double const *C = reinterpret_cast(&c); -+ double *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=d"(D[0]), "=d"(D[1]) -+ : "d"(A), "d"(B), "d"(C[0]), "d"(C[1])); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16816 - S8 input, S32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " -+ "{%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " -+ "{%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " -+ "{%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " -+ "{%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " -+ "{%6}, {%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " -+ "{%6}, {%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " -+ "{%6}, {%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " -+ "{%6}, {%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16832 - S8 input, S32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const * A = reinterpret_cast(&a); -+ uint32_t const * B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16864 - S4 input, S32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16864 - S4 input, S32 accumulation - SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const * A = reinterpret_cast(&a); -+ uint32_t const * B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = B1 & B1 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,256>, -+ 32, -+ cutlass::uint1b_t, -+ layout::RowMajor, -+ cutlass::uint1b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,256>; -+ -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int32_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 168256 - B1 input, S32 accumulation - XOR,POPC -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = B1 & B1 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,256>, -+ 32, -+ cutlass::uint1b_t, -+ layout::RowMajor, -+ cutlass::uint1b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpXorPopc> { -+ -+ using Shape = gemm::GemmShape<16,8,256>; -+ -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpXorPopc; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm90.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm90.h -new file mode 100644 -index 0000000..1d0745b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm90.h -@@ -0,0 +1,264 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "mma.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) -+ #define CUTLASS_ARCH_MMA_SM90_F64_MMA_SUPPORTED -+ #if (!defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)) -+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -+ #define CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED -+ #endif -+ #endif -+#endif -+ -+#if (__CUDACC_VER_MAJOR__ >= 12) -+ #define CUTLASS_ARCH_MMA_SM90_SUPPORTED -+ #if (!defined(CUTLASS_ARCH_MMA_SM90_ENABLED)) -+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -+ #define CUTLASS_ARCH_MMA_SM90_ENABLED -+ #endif -+ #endif -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Matrix Multiply-Add 16x8x4 fp64 -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F64 = F64 * F64 + F64 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,4>, -+ 32, -+ double, -+ layout::RowMajor, -+ double, -+ layout::ColumnMajor, -+ double, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,4>; -+ -+ using ElementA = double; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = double; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = double; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ -+ using ArchTag = arch::Sm90; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+ double const *A = reinterpret_cast(&a); -+ double const *B = reinterpret_cast(&b); -+ -+ double const *C = reinterpret_cast(&c); -+ double *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64.rn {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" -+ : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) -+ : "d"(A[0]), "d"(A[1]), -+ "d"(B[0]), -+ "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Matrix Multiply-Add 16x8x8 fp64 -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F64 = F64 * F64 + F64 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,8>, -+ 32, -+ double, -+ layout::RowMajor, -+ double, -+ layout::ColumnMajor, -+ double, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,8>; -+ -+ using ElementA = double; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = double; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = double; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ -+ using ArchTag = arch::Sm90; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+ double const *A = reinterpret_cast(&a); -+ double const *B = reinterpret_cast(&b); -+ -+ double const *C = reinterpret_cast(&c); -+ double *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" -+ : "=d"(D[0]), "=d"(d[1]), "=d"(d[2]), "=d"(d[3]) -+ : "d"(A[0]), "d"(A[1]), "d"(A[2]), "d"(A[3]), -+ "d"(B[0]), "d"(B[1]), -+ "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Matrix Multiply-Add 16x8x16 fp64 -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F64 = F64 * F64 + F64 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ double, -+ layout::RowMajor, -+ double, -+ layout::ColumnMajor, -+ double, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = double; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = double; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = double; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ -+ using ArchTag = arch::Sm90; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+ double const *A = reinterpret_cast(&a); -+ double const *B = reinterpret_cast(&b); -+ -+ double const *C = reinterpret_cast(&c); -+ double *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7, %8, %9, %10, %11}, {%12, %13, %14, %15}, {%16, %17, %18, %19};\n" -+ : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) -+ : "d"(A[0]), "d"(A[2]), "d"(A[2]), "d"(A[3]), "d"(A[4]), "d"(A[5]), "d"(A[6]), "d"(A[7]) -+ "d"(B[0]), "d"(B[1]), "d"(B[2]), "d"(B[3]), -+ "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); -+ -+#else -+ CUTLASS_NOT_IMPLEMENTED(); -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sparse_sm80.h b/3rdparty/cutlass/include/cutlass/arch/mma_sparse_sm80.h -new file mode 100644 -index 0000000..a1f5b1d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sparse_sm80.h -@@ -0,0 +1,1685 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Sparse matrix multiply accumulate for SM80 -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "mma.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1)) -+ -+#define CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED 1 -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+#define CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED -+#endif -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 16832 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16, 8, 32>, -+ 32, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread -+> { -+ -+ using Shape = gemm::GemmShape<16, 8, 32>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 2; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c, uint32_t const &E, int const id2) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ uint32_t const *C = reinterpret_cast(&c); -+ uint32_t *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " -+ "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); -+ } -+ else if (id2 == 1) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " -+ "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); -+ } -+ else { -+ assert(0); -+ } -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16, 8, 32>, -+ 32, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread -+ > { -+ -+ using Shape = gemm::GemmShape<16, 8, 32>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 2; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c, uint32_t const &E, int const id2) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), -+ "r"(E)); -+ } -+ else if (id2 == 1) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), -+ "r"(E)); -+ } -+ else { -+ assert(0); -+ } -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 16832 - Float BF16, FP32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 -+template <> -+struct SparseMma, 32, bfloat16_t, layout::RowMajor, -+ bfloat16_t, layout::ColumnMajor, float, layout::RowMajor, -+ OpMultiplyAdd, SPFormatType::Thread> { -+ using Shape = gemm::GemmShape<16, 8, 32>; -+ -+ using ElementA = bfloat16_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = bfloat16_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 2; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c, uint32_t const &E, int const id2) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); -+ } else if (id2 == 1) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); -+ } else { -+ assert(0); -+ } -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 16816 - Float TF32 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 -+template <> -+struct SparseMma, 32, tfloat32_t, layout::RowMajor, -+ tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, -+ OpMultiplyAdd, SPFormatType::Thread> { -+ using Shape = gemm::GemmShape<16, 8, 16>; -+ -+ using ElementA = tfloat32_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = tfloat32_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 4; -+ -+ static int const kMaxID2 = 2; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c, uint32_t const &E, int const id2) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); -+ } else if (id2 == 1) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); -+ } else { -+ assert(0); -+ } -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation - SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation - SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/reg_reconfig.h b/3rdparty/cutlass/include/cutlass/arch/reg_reconfig.h -new file mode 100644 -index 0000000..2b74a22 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/reg_reconfig.h -@@ -0,0 +1,68 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief PTX for CTA Reconfiguration -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) -+ #if (defined(__CUDA_ARCH_FEAT_SM90_ALL)) -+ #define CUDA_CTA_RECONFIG_ACTIVATED 1 -+ #endif -+#else -+ #define CUDA_CTA_RECONFIG_ACTIVATED 0 -+#endif -+ -+namespace cutlass { -+namespace arch { -+ -+template -+CUTLASS_DEVICE -+void warpgroup_reg_alloc(){ -+#if CUDA_CTA_RECONFIG_ACTIVATED -+ asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); -+#endif -+} -+ -+template -+CUTLASS_DEVICE -+void warpgroup_reg_dealloc(){ -+#if CUDA_CTA_RECONFIG_ACTIVATED -+ asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); -+#endif -+} -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/simd.h b/3rdparty/cutlass/include/cutlass/arch/simd.h -new file mode 100644 -index 0000000..71128c2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/simd.h -@@ -0,0 +1,125 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing SIMD operators -+*/ -+ -+#pragma once -+ -+#include "../array.h" -+#include "../numeric_types.h" -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Element-wise operators -+// -+ -+CUTLASS_HOST_DEVICE -+template -+Array operator*(Array const &a, Array const &b) { -+ Array d; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ d[i] = a[i] * b[i]; -+ } -+ return d; -+} -+ -+CUTLASS_HOST_DEVICE -+template -+Array operator+(Array const &a, Array const &b) { -+ Array d; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ d[i] = a[i] + b[i]; -+ } -+ return d; -+} -+ -+CUTLASS_HOST_DEVICE -+template -+Array operator-(Array const &a, Array const &b) { -+ Array d; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ d[i] = a[i] - b[i]; -+ } -+ return d; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Multiply-accumulate operators -+// -+ -+CUTLASS_HOST_DEVICE -+template -+Array mac(Array const &a, Array const &b, Array const &c) { -+ Array d; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ d[i] = a[i] * b[i] + c[i]; -+ } -+ return d; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Dot product operator -+// -+ -+CUTLASS_HOST_DEVICE -+template -+Accumulator dot(Array const &a, Array const &b, Accumulator accum) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ accum += a[i] * b[i]; -+ } -+ return accum; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "simd_sm60.h" -+#include "simd_sm61.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/simd_sm60.h b/3rdparty/cutlass/include/cutlass/arch/simd_sm60.h -new file mode 100644 -index 0000000..16d528b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/simd_sm60.h -@@ -0,0 +1,116 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing SIMD operators for SM60 -+*/ -+ -+#pragma once -+ -+#include "simd.h" -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Element-wise operators - specialized for half_t x 2 -+// -+ -+CUTLASS_HOST_DEVICE -+template <> -+Array operator*(Array const &a, Array const &b) { -+ Array d; -+ -+ // TODO -+ -+ return d; -+} -+ -+CUTLASS_HOST_DEVICE -+template <> -+Array operator+(AArray const &a, Array const &b) { -+ Array d; -+ -+ // TODO -+ -+ return d; -+} -+ -+CUTLASS_HOST_DEVICE -+template <> -+Array operator-(Array const &a, Array const &b) { -+ Array d; -+ -+ // TODO -+ -+ return d; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Multiply-accumulate operators - specialized for half_t x 2 -+CUTLASS_HOST_DEVICE -+template <> -+Array mac(Array const &a, Array const &b, Array const &c) { -+ Array d; -+ -+ // TODO -+ -+ return d; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dot product operator - specialized for half_t <- (half_t * half_t) x 2 + half_t -+CUTLASS_HOST_DEVICE -+template <> -+half_t dot(Array const &a, Array const &b, half_t accum) { -+ -+ // TODO -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for float <- (half_t * half_t) x 2 + float -+CUTLASS_HOST_DEVICE -+template <> -+float dot(Array const &a, Array const &b, float accum) { -+ -+ // TODO -+ -+ return accum; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/simd_sm61.h b/3rdparty/cutlass/include/cutlass/arch/simd_sm61.h -new file mode 100644 -index 0000000..ba9abb7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/simd_sm61.h -@@ -0,0 +1,147 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing SIMD operators for SM61 -+*/ -+ -+#pragma once -+ -+#include "simd.h" -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dot product operator - specialized for int32_t <- (int8_t * int8_t) x 4 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (uint8_t * int8_t) x 4 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (int8_t * uint8_t) x 4 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (uint8_t * uint8_t) x 4 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/wmma.h b/3rdparty/cutlass/include/cutlass/arch/wmma.h -new file mode 100644 -index 0000000..db54e45 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/wmma.h -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for warp matrix multiply-add (WMMA) operations -+*/ -+ -+#pragma once -+ -+// CUTLASS WMMA does not support clang at present. -+#if !(defined(__clang__) && defined(__CUDA__)) -+ -+#if (__CUDACC_VER_MAJOR__ >= 9) -+#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700)) -+#define CUTLASS_ARCH_WMMA_ENABLED -+#define CUTLASS_ARCH_WMMA_SM70_ENABLED -+#endif -+#endif -+ -+#if (__CUDACC_VER_MAJOR__ >= 10) -+#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 720)) -+#define CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED -+#define CUTLASS_ARCH_WMMA_SM72_ENABLED -+#endif -+#endif -+ -+#if (__CUDACC_VER_MAJOR__ >= 10) -+#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750)) -+#define CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+#define CUTLASS_ARCH_WMMA_SM75_ENABLED -+#endif -+#endif -+ -+#endif //!(defined(__clang__) && defined(__CUDA__)) -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#include -+#include "cutlass/arch/mma.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/gemm/gemm.h" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+/// Statically maps cutlass data types => nvcuda::wmma data types -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+struct CutlassToWmmaDataType{ -+ using Type = Type_; -+}; -+ -+/// Statically maps cutlass::half_t => __half -+template<> -+struct CutlassToWmmaDataType { -+ using Type = __half; -+}; -+ -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) -+template<> -+struct CutlassToWmmaDataType { -+ using Type = __nv_bfloat16; -+}; -+#endif -+ -+/// Statically maps int8_t => char -+template<> -+struct CutlassToWmmaDataType { -+ using Type = signed char; -+}; -+ -+/// Statically maps uint8_t => char -+template<> -+struct CutlassToWmmaDataType { -+ using Type = unsigned char; -+}; -+ -+/// Statically maps int32_t => int -+template<> -+struct CutlassToWmmaDataType { -+ using Type = int; -+}; -+ -+#if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED) -+/// Statically maps cutlass::int4b_t => experimental::precision::s4 -+template<> -+struct CutlassToWmmaDataType { -+ using Type = nvcuda::wmma::experimental::precision::s4; -+}; -+ -+/// Statically maps cutlass::uint4b_t => experimental::precision::s4 -+template<> -+struct CutlassToWmmaDataType { -+ using Type = nvcuda::wmma::experimental::precision::u4; -+}; -+ -+/// Statically maps cutlass::uint1b_t => experimental::precision::b1 -+template<> -+struct CutlassToWmmaDataType { -+ using Type = nvcuda::wmma::experimental::precision::b1; -+}; -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+/// Statically maps cutlass::layout => nvcuda::wmma layout tags -+//////////////////////////////////////////////////////////////////////////////////////////////// -+template -+struct CutlassToWmmaLayout { -+}; -+ -+/// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags -+template <> -+struct CutlassToWmmaLayout { -+ using Layout = nvcuda::wmma::row_major; -+ static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_row_major; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+/// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags -+//////////////////////////////////////////////////////////////////////////////////////////////// -+template <> -+struct CutlassToWmmaLayout { -+ using Layout = nvcuda::wmma::col_major; -+ static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_col_major; -+}; -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+/// Statically maps nvcuda::wmma data types => cutlass data types -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+struct WmmaToCutlassDataType{ -+ using Type = Type_; -+}; -+ -+/// Statically maps __half => cutlass::half_t -+template<> -+struct WmmaToCutlassDataType<__half> { -+ using Type = cutlass::half_t; -+}; -+ -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) -+template<> -+struct WmmaToCutlassDataType<__nv_bfloat16> { -+ using Type = cutlass::bfloat16_t; -+}; -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// WMMA template structure defines nvcuda::wmma::fragments and static assertion chaeks -+// for a specific template paramterized data type (Element[A|B|C]), layout (Layout[A|B|C]), -+// and native wmma size (Shape) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape_, ///< Size of the matrix product (concept: GemmShape) -+ typename ElementA_, ///< Data type of A elements -+ typename LayoutA_, ///< Layout of A matrix (concept: MatrixLayout) -+ typename ElementB_, ///< Data type of B elements -+ typename LayoutB_, ///< Layout of B matrix (concept: MatrixLayout) -+ typename ElementC_, ///< Element type of C matrix -+ typename LayoutC_, /// Layout of C matrix (concept: MatrixLayout) -+ typename Operator_ = cutlass::arch::OpMultiplyAdd ///< Inner product operator (multiply-add, xor.popc) -+> -+struct Wmma; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Specializations for each compute capability -+// -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include "cutlass/arch/wmma_sm70.h" -+#endif -+ -+#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -+#include "cutlass/arch/wmma_sm72.h" -+#endif -+ -+#ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED -+#include "cutlass/arch/wmma_sm75.h" -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif //CUTLASS_ARCH_WMMA_ENABLED -diff --git a/3rdparty/cutlass/include/cutlass/arch/wmma_sm70.h b/3rdparty/cutlass/include/cutlass/arch/wmma_sm70.h -new file mode 100644 -index 0000000..0658474 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/wmma_sm70.h -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+#include "cutlass/layout/matrix.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+namespace cutlass { -+namespace arch { -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// WMMA template structure defines nvcuda::wmma::fragments and static assert for -+// wmma native instruction sizes supported for half -+// -+//////////////////////////////////////////////////////////////////////////////// -+template < -+typename Shape_, -+typename LayoutA_, -+typename LayoutB_, -+typename ElementC_, -+typename LayoutC_> -+struct Wmma< -+ Shape_, ///< Size of the matrix product (concept: GemmShape) -+ cutlass::half_t, ///< ElementA -+ LayoutA_, ///< LayoutA -+ cutlass::half_t, ///< ElementB -+ LayoutB_, ///< LayoutB -+ ElementC_, ///< ElementC -+ LayoutC_, ///< LayoutC -+ cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) -+> { -+ -+#if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED) -+ using Shape = Shape_; -+ using ElementA = cutlass::half_t; -+ using LayoutA = LayoutA_; -+ using ElementB = cutlass::half_t; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ // check supported wmma shape for the given multiplicand data types -+ static_assert( -+ platform::is_same, Shape>::value || -+ platform::is_same, Shape>::value || -+ platform::is_same, Shape>::value, -+ "Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); -+ -+ // check supported wmma output data type for the given multiplicand data types -+ static_assert( -+ platform::is_same::value || platform::is_same::value, -+ "Supported of wmma output data type for f16 multiplicands are: f16 and f32"); -+ -+ // Wmma Fragment -+ using FragmentA = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_a, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentB = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_b, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentC = nvcuda::wmma::fragment< -+ nvcuda::wmma::accumulator, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type>; -+ -+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) const { -+ -+ nvcuda::wmma::mma_sync(D, A, B, C); -+ } -+#else -+ static_assert(false, "wmma.mma.sync for floating point multiplicands is avialable only for SM70 and beyond"); -+#endif -+ -+}; -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/wmma_sm72.h b/3rdparty/cutlass/include/cutlass/arch/wmma_sm72.h -new file mode 100644 -index 0000000..c20e1b3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/wmma_sm72.h -@@ -0,0 +1,210 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+#include "cutlass/layout/matrix.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// WMMA template structure defines nvcuda::wmma::fragments and static assert for -+// wmma native instruction sizes supported for int8_t -+// -+//////////////////////////////////////////////////////////////////////////////// -+template < -+typename Shape_, -+typename LayoutA_, -+typename LayoutB_, -+typename LayoutC_> -+struct Wmma< -+ Shape_, ///< Size of the matrix product (concept: GemmShape) -+ int8_t, ///< ElementA -+ LayoutA_, ///< LayoutA -+ int8_t, ///< ElementB -+ LayoutB_, ///< LayoutB -+ int32_t, ///< ElementC -+ LayoutC_, ///< LayoutC -+ cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) -+> { -+#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED) -+ using Shape = Shape_; -+ using ElementA = int8_t; -+ using LayoutA = LayoutA_; -+ using ElementB = int8_t; -+ using LayoutB = LayoutB_; -+ using ElementC = int32_t; -+ using LayoutC = LayoutC_; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+ using ArchTag = arch::Sm72; -+ -+ // check supported wmma shape for the given multiplicand data types -+ static_assert( -+ platform::is_same, Shape>::value || -+ platform::is_same, Shape>::value || -+ platform::is_same, Shape>::value, -+ "Supported list of wmma operator shape for s8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); -+ -+ -+ // Wmma Fragment -+ using FragmentA = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_a, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentB = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_b, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentC = nvcuda::wmma::fragment< -+ nvcuda::wmma::accumulator, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type>; -+ -+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) const { -+ -+ nvcuda::wmma::mma_sync(D, A, B, C); -+ } -+ -+#else -+ static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond"); -+#endif -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// WMMA template structure defines nvcuda::wmma::fragments and static assert for -+// wmma native instruction sizes supported for uint8_t -+// -+//////////////////////////////////////////////////////////////////////////////// -+template < -+typename Shape_, -+typename LayoutA_, -+typename LayoutB_, -+typename LayoutC_> -+struct Wmma< -+ Shape_, ///< Size of the matrix product (concept: GemmShape) -+ uint8_t, ///< ElementA -+ LayoutA_, ///< LayoutA -+ uint8_t, ///< ElementB -+ LayoutB_, ///< LayoutB -+ int32_t, ///< ElementC -+ LayoutC_, ///< LayoutC -+ cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) -+> { -+#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED) -+ using Shape = Shape_; -+ using ElementA = uint8_t; -+ using LayoutA = LayoutA_; -+ using ElementB = uint8_t; -+ using LayoutB = LayoutB_; -+ using ElementC = int32_t; -+ using LayoutC = LayoutC_; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+ using ArchTag = arch::Sm72; -+ -+ // check supported wmma shape for the given multiplicand data types -+ static_assert( -+ platform::is_same, Shape>::value || -+ platform::is_same, Shape>::value || -+ platform::is_same, Shape>::value, -+ "Supported list of wmma operator shape for u8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); -+ -+ // Wmma Fragment -+ using FragmentA = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_a, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentB = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_b, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentC = nvcuda::wmma::fragment< -+ nvcuda::wmma::accumulator, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type>; -+ -+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) const { -+ -+ nvcuda::wmma::mma_sync(D, A, B, C); -+ } -+ -+#else -+ static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond"); -+#endif -+ -+}; -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/wmma_sm75.h b/3rdparty/cutlass/include/cutlass/arch/wmma_sm75.h -new file mode 100644 -index 0000000..89d030f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/wmma_sm75.h -@@ -0,0 +1,207 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+#include "cutlass/layout/matrix.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// WMMA template structure defines nvcuda::wmma::fragments and static assert for -+// wmma native instruction sizes supported for cutlass::int4b_t (experimental::s4). -+// -+//////////////////////////////////////////////////////////////////////////////// -+template < -+typename Shape_, -+typename LayoutA_, -+typename LayoutB_, -+typename LayoutC_> -+struct Wmma< -+ Shape_, ///< Size of the matrix product (concept: GemmShape) -+ cutlass::int4b_t, ///< ElementA -+ LayoutA_, ///< LayoutA -+ cutlass::int4b_t, ///< ElementB -+ LayoutB_, ///< LayoutB -+ int32_t, ///< ElementC -+ LayoutC_, ///< LayoutC -+ cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) -+> { -+#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) -+ using Shape = Shape_; -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = LayoutA_; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = LayoutB_; -+ using ElementC = int32_t; -+ using LayoutC = LayoutC_; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ // check supported wmma shape for the given multiplicand data types -+ static_assert( -+ platform::is_same, Shape>::value, -+ "Supported list of wmma operator shape for s8 multiplicands is: 8x8x32"); -+ -+ -+ // Wmma Fragment -+ using FragmentA = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_a, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentB = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_b, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentC = nvcuda::wmma::fragment< -+ nvcuda::wmma::accumulator, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type>; -+ -+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) const { -+ nvcuda::wmma::mma_sync(D, A, B, C); -+ -+ } -+ -+#else -+ static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond"); -+#endif -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// WMMA template structure defines nvcuda::wmma::fragments and static assert for -+// wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1). -+// -+//////////////////////////////////////////////////////////////////////////////// -+template < -+typename Shape_, -+typename LayoutA_, -+typename LayoutB_, -+typename LayoutC_> -+struct Wmma< -+ Shape_, ///< Size of the matrix product (concept: GemmShape) -+ cutlass::uint1b_t, ///< ElementA -+ LayoutA_, ///< LayoutA -+ cutlass::uint1b_t, ///< ElementB -+ LayoutB_, ///< LayoutB -+ int32_t, ///< ElementC -+ LayoutC_, ///< LayoutC -+ cutlass::arch::OpXorPopc ///< Operator (multiply-add, xor.popc) -+> { -+#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) -+ using Shape = Shape_; -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = LayoutA_; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = LayoutB_; -+ using ElementC = int32_t; -+ using LayoutC = LayoutC_; -+ using Operator = cutlass::arch::OpXorPopc; -+ using ArchTag = arch::Sm75; -+ -+ // check supported wmma shape for the given multiplicand data types -+ static_assert( -+ platform::is_same, Shape>::value, -+ "Supported list of wmma operator shape for b1 multiplicands is: 8x8x128"); -+ -+ -+ // Wmma Fragment -+ using FragmentA = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_a, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentB = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_b, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentC = nvcuda::wmma::fragment< -+ nvcuda::wmma::accumulator, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type>; -+ -+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) const { -+ nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, -+ nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); -+ } -+ -+#else -+ static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond"); -+#endif -+ -+}; -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/array.h b/3rdparty/cutlass/include/cutlass/array.h -new file mode 100644 -index 0000000..9fe245b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/array.h -@@ -0,0 +1,2457 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/half.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Statically sized array for any data type -+template < -+ typename T, -+ int N, -+ bool RegisterSized = sizeof_bits::value >= 32 -+> -+class Array; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the size of an Array<> in bits -+template -+struct sizeof_bits > { -+ static int const value = -+ int(sizeof(typename Array::Storage)) * 8 * int(Array::kStorageElements); -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if the argument is a power of 2 -+CUTLASS_HOST_DEVICE -+constexpr bool ispow2(unsigned x) { -+ return x && (!(x & (x - 1))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns the largest power of two not greater than the argument. -+CUTLASS_HOST_DEVICE -+constexpr unsigned floor_pow_2(unsigned x) { -+ return (x == 0 || ispow2(x)) ? x : ((floor_pow_2(x >> 1)) << 1); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Statically sized array for any data type -+template < -+ typename T, -+ int N -+> -+class Array { -+public: -+ -+ /// Storage type -+ using Storage = T; -+ -+ /// Element type -+ using Element = T; -+ -+ /// Number of storage elements -+ //static std::size_t const kStorageElements = N; -+ static size_t const kStorageElements = N; -+ -+ /// Number of logical elements -+ static size_t const kElements = N; -+ -+ // -+ // C++ standard members -+ // -+ -+ typedef T value_type; -+ typedef size_t size_type; -+ typedef ptrdiff_t difference_type; -+ typedef value_type &reference; -+ typedef value_type const & const_reference; -+ typedef value_type *pointer; -+ typedef value_type const * const_pointer; -+ -+ // -+ // Iterators -+ // -+ -+ /// Bidirectional iterator over elements -+ class iterator { -+ -+ /// Pointer to object -+ T *ptr_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ iterator(): ptr_(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ iterator(T *_ptr): ptr_(_ptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ iterator &operator++() { -+ ++ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator &operator--() { -+ --ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator operator++(int) { -+ iterator ret(*this); -+ ++ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator operator--(int) { -+ iterator ret(*this); -+ --ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T &operator*() const { -+ return *ptr_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator==(iterator const &other) const { -+ return ptr_ == other.ptr_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator!=(iterator const &other) const { -+ return ptr_ != other.ptr_; -+ } -+ }; -+ -+ /// Bidirectional constant iterator over elements -+ class const_iterator { -+ -+ /// Pointer to object -+ const T *ptr_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator(): ptr_(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator(T const *_ptr): ptr_(_ptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator &operator++() { -+ ++ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator &operator--() { -+ --ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator operator++(int) { -+ const_iterator ret(*this); -+ ++ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator operator--(int) { -+ const_iterator ret(*this); -+ --ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T const &operator*() const { -+ return *ptr_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator==(const_iterator const &other) const { -+ return ptr_ == other.ptr_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator!=(const_iterator const &other) const { -+ return ptr_ != other.ptr_; -+ } -+ }; -+ -+ /// Bidirectional iterator over elements -+ class reverse_iterator { -+ -+ /// Pointer to object -+ T *ptr_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator(): ptr_(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator(T *_ptr): ptr_(_ptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator &operator++() { -+ --ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator &operator--() { -+ ++ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator operator++(int) { -+ iterator ret(*this); -+ --ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator operator--(int) { -+ iterator ret(*this); -+ ++ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T &operator*() const { -+ return *(ptr_ - 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator==(reverse_iterator const &other) const { -+ return ptr_ == other.ptr_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator!=(reverse_iterator const &other) const { -+ return ptr_ != other.ptr_; -+ } -+ }; -+ -+ /// Bidirectional constant iterator over elements -+ class const_reverse_iterator { -+ -+ /// Pointer to object -+ T const *ptr_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator(): ptr_(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator(T const *_ptr): ptr_(_ptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator &operator++() { -+ --ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator &operator--() { -+ ++ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator operator++(int) { -+ const_reverse_iterator ret(*this); -+ --ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator operator--(int) { -+ const_reverse_iterator ret(*this); -+ ++ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T const &operator*() const { -+ return *(ptr_ - 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator==(const_iterator const &other) const { -+ return ptr_ == other.ptr_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator!=(const_iterator const &other) const { -+ return ptr_ != other.ptr_; -+ } -+ }; -+ -+private: -+ -+ /// Internal storage -+ Storage storage[kElements]; -+ -+public: -+ -+ #if 0 -+ CUTLASS_HOST_DEVICE -+ Array() { } -+ -+ CUTLASS_HOST_DEVICE -+ Array(Array const &x) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElements; ++i) { -+ storage[i] = x.storage[i]; -+ } -+ } -+ #endif -+ -+ /// Efficient clear method -+ CUTLASS_HOST_DEVICE -+ void clear() { -+ fill(T(0)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference at(size_type pos) { -+ return reinterpret_cast(storage[pos]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference at(size_type pos) const { -+ return reinterpret_cast(storage[pos]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference operator[](size_type pos) { -+ return reinterpret_cast(storage[pos]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference operator[](size_type pos) const { -+ return reinterpret_cast(storage[pos]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference front() { -+ return reinterpret_cast(storage[0]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference front() const { -+ return reinterpret_cast(storage[0]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference back() { -+ return reinterpret_cast(storage[kStorageElements - 1]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference back() const { -+ return reinterpret_cast(storage[kStorageElements - 1]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ pointer data() { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_pointer data() const { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ pointer raw_data() { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_pointer raw_data() const { -+ return reinterpret_cast(storage); -+ } -+ -+ -+ CUTLASS_HOST_DEVICE -+ constexpr bool empty() const { -+ return !kElements; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ constexpr size_type size() const { -+ return kElements; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ constexpr size_type max_size() const { -+ return kElements; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void fill(T const &value) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElements; ++i) { -+ storage[i] = static_cast(value); -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator begin() { -+ return iterator(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator begin() const { -+ return cbegin(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator cbegin() const { -+ return const_iterator(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator end() { -+ return iterator(reinterpret_cast(storage + kStorageElements)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator end() const { -+ return cend(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator cend() const { -+ return const_iterator(reinterpret_cast(storage + kStorageElements)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator rbegin() { -+ return reverse_iterator(reinterpret_cast(storage + kStorageElements)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator rbegin() const { -+ return crbegin(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator crbegin() const { -+ return const_reverse_iterator(reinterpret_cast(storage + kStorageElements)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator rend() { -+ return reverse_iterator(reinterpret_cast(storage)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator rend() const { -+ return crend(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator crend() const { -+ return const_reverse_iterator(reinterpret_cast(storage)); -+ } -+ -+ // -+ // Comparison operators -+ // -+ -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// Factories -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+CUTLASS_HOST_DEVICE -+Array make_Array(Element x) { -+ Array m; -+ m[0] = x; -+ return m; -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array make_Array(Element x, Element y) { -+ Array m; -+ m[0] = x; -+ m[1] = y; -+ return m; -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array make_Array(Element x, Element y, Element z) { -+ Array m; -+ m[0] = x; -+ m[1] = y; -+ m[2] = z; -+ return m; -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array make_Array(Element x, Element y, Element z, Element w) { -+ Array m; -+ m[0] = x; -+ m[1] = y; -+ m[2] = z; -+ m[3] = w; -+ return m; -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// functional.h numeric specializations -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct absolute_value_op< Array > { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs) const { -+ -+ Array result; -+ absolute_value_op scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct plus> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ -+ Array result; -+ plus scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, T const &scalar) const { -+ -+ Array result; -+ plus scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], scalar); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( T const &scalar, Array const &rhs) const { -+ -+ Array result; -+ plus scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+template -+struct minus> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ -+ Array result; -+ minus scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, T const &scalar) const { -+ -+ Array result; -+ minus scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], scalar); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( T const &scalar, Array const &rhs) const { -+ -+ Array result; -+ minus scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct multiplies> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ -+ Array result; -+ multiplies scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, T const &scalar) const { -+ -+ Array result; -+ multiplies scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], scalar); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( T const &scalar, Array const &rhs) const { -+ -+ Array result; -+ multiplies scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct divides> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ -+ Array result; -+ divides scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, T const &scalar) const { -+ -+ Array result; -+ divides scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], scalar); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( T const &scalar, Array const &rhs) const { -+ -+ Array result; -+ divides scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct maximum> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ -+ Array result; -+ maximum scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, T const &scalar) const { -+ -+ Array result; -+ maximum scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], scalar); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( T const &scalar, Array const &rhs) const { -+ -+ Array result; -+ maximum scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct minimum> { -+ -+ CUTLASS_HOST_DEVICE -+ static T scalar_op(T const &lhs, T const &rhs) { -+ return (rhs < lhs ? rhs : lhs); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ -+ Array result; -+ minimum scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, T const &scalar) const { -+ -+ Array result; -+ minimum scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], scalar); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( T const &scalar, Array const &rhs) const { -+ -+ Array result; -+ minimum scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct negate> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs) const { -+ -+ Array result; -+ negate scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add, Array, Array> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, Array const &b, Array const &c) const { -+ -+ Array result; -+ multiply_add scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(a[i], b[i], c[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, T const &scalar, Array const &c) const { -+ -+ Array result; -+ multiply_add scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(a[i], scalar, c[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(T const &scalar, Array const &b, Array const &c) const { -+ -+ Array result; -+ multiply_add scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, b[i], c[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Fused multiply-add-relu0 -+template -+struct multiply_add_relu0, Array, Array> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, Array const &b, Array const &c) const { -+ -+ Array result; -+ multiply_add scalar_op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0)); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, T const &scalar, Array const &c) const { -+ -+ Array result; -+ multiply_add scalar_op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0)); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(T const &scalar, Array const &b, Array const &c) const { -+ -+ Array result; -+ multiply_add scalar_op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0)); -+ } -+ -+ return result; -+ } -+}; -+ -+ -+template -+struct conjugate > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a) const { -+ -+ conjugate conj_op; -+ -+ Array ca; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ ca[i] = conj_op(a[i]); -+ } -+ return ca; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// functional.h numeric specializations targeting SIMD instructions in device code. -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct plus> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] + rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(half_t const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs + rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, half_t const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] + rhs; -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+template -+struct minus> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] - rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(half_t const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs - rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, half_t const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] - rhs; -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+template -+struct multiplies> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] * rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(half_t const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hmul( -+ reinterpret_cast<__half const &>(lhs), -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs * rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, half_t const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ -+ __half d_residual = __hmul( -+ a_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(rhs)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] * rhs; -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+template -+struct divides> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hdiv( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] / rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(half_t const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hdiv( -+ reinterpret_cast<__half const &>(lhs), -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs / rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, half_t const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ -+ __half d_residual = __hdiv( -+ a_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(rhs)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] / rhs; -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+template -+struct negate> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hneg2(source_ptr[i]); -+ } -+ -+ if (N % 2) { -+ half_t x = lhs[N - 1]; -+ __half lhs_val = -reinterpret_cast<__half const &>(x); -+ result[N - 1] = reinterpret_cast(lhs_val); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = -lhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add, Array, Array> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ Array const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); -+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); -+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]); -+ } -+ -+ if (N % 2) { -+ -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); -+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); -+ -+ __half d_residual = __hfma( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1], -+ c_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a[i], b[i], c[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ half_t const &a, -+ Array const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); -+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); -+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]); -+ } -+ -+ if (N % 2) { -+ -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); -+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); -+ __half d_residual = __hfma( -+ reinterpret_cast<__half const &>(a), -+ b_residual_ptr[N - 1], -+ c_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a, b[i], c[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ half_t const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); -+ __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); -+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]); -+ } -+ -+ if (N % 2) { -+ -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); -+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); -+ -+ __half d_residual = __hfma( -+ a_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(b), -+ c_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a[i], b, c[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ Array const &b, -+ half_t const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); -+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); -+ __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair); -+ } -+ -+ if (N % 2) { -+ -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); -+ -+ __half d_residual = __hfma( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(c)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a[i], b[i], c); -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+/// Fused multiply-add-relu0 -+template -+struct multiply_add_relu0, Array, Array> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ Array const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); -+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); -+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]); -+ } -+ -+ if (N % 2) { -+ -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); -+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); -+ -+ __half d_residual = __hfma_relu( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1], -+ c_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(op(a[i], b[i], c[i]), (half_t)0); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ half_t const &a, -+ Array const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); -+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); -+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]); -+ } -+ -+ if (N % 2) { -+ -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); -+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); -+ __half d_residual = __hfma_relu( -+ reinterpret_cast<__half const &>(a), -+ b_residual_ptr[N - 1], -+ c_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(op(a, b[i], c[i]), half_t(0)); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ half_t const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); -+ __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); -+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]); -+ } -+ -+ if (N % 2) { -+ -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); -+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); -+ -+ __half d_residual = __hfma_relu( -+ a_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(b), -+ c_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(op(a[i], b, c[i]), half_t(0)); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ Array const &b, -+ half_t const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); -+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); -+ __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair); -+ } -+ -+ if (N % 2) { -+ -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); -+ -+ __half d_residual = __hfma_relu( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(c)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(op(a[i], b[i], c), half_t(0)); -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+template -+struct minimum> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmin2(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hmin( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = (rhs[i] < lhs[i] ? rhs[i] : lhs[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(half_t const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmin2(lhs_pair, rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hmin( -+ reinterpret_cast<__half const &>(lhs), -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = (rhs[i] < lhs ? rhs[i] : lhs); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, half_t const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmin2(lhs_ptr[i], rhs_pair); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ -+ __half d_residual = __hmin( -+ a_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(rhs)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = (rhs < lhs[i] ? rhs : lhs[i]); -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+template -+struct maximum> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hmax( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = (lhs[i] < rhs[i] ? rhs[i] : lhs[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(half_t const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmax2(lhs_pair, rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hmax( -+ reinterpret_cast<__half const &>(lhs), -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = (lhs < rhs[i] ? rhs[i] : lhs); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, half_t const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ -+ __half d_residual = __hmax( -+ a_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(rhs)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = (lhs[i] < rhs ? rhs : lhs[i]); -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add, Array, Array> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ Array const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ unsigned *result_ptr = reinterpret_cast(&result); -+ unsigned const *a_ptr = reinterpret_cast(&a); -+ unsigned const *b_ptr = reinterpret_cast(&b); -+ unsigned const *c_ptr = reinterpret_cast(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" -+ : "=r"(result_ptr[i]) -+ : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_ptr[i]) -+ ); -+ } -+ -+ if (N % 2) { -+ -+ uint16_t *result_ptr = reinterpret_cast(&result); -+ uint16_t const *a_residual_ptr = reinterpret_cast(&a); -+ uint16_t const *b_residual_ptr = reinterpret_cast(&b); -+ uint16_t const *c_residual_ptr = reinterpret_cast(&c); -+ -+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n" -+ : "=h"(result_ptr[N - 1]) -+ : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) -+ ); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a[i], b[i], c[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ bfloat16_t const &a, -+ Array const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ unsigned *result_ptr = reinterpret_cast(&result); -+ -+ unsigned const *b_ptr = reinterpret_cast(&b); -+ unsigned const *c_ptr = reinterpret_cast(&c); -+ -+ unsigned a_packed = static_cast(a.raw()); -+ a_packed = (a_packed | (a_packed << 16)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" -+ : "=r"(result_ptr[i]) -+ : "r"(a_packed), "r"(b_ptr[i]), "r"(c_ptr[i]) -+ ); -+ } -+ -+ if (N % 2) { -+ -+ uint16_t *result_ptr = reinterpret_cast(&result); -+ uint16_t const *a_residual_ptr = reinterpret_cast(&a); -+ uint16_t const *b_residual_ptr = reinterpret_cast(&b); -+ uint16_t const *c_residual_ptr = reinterpret_cast(&c); -+ -+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n" -+ : "=h"(result_ptr[N - 1]) -+ : "h"(a_residual_ptr[0]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) -+ ); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a, b[i], c[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ bfloat16_t const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ unsigned *result_ptr = reinterpret_cast(&result); -+ -+ unsigned const *a_ptr = reinterpret_cast(&a); -+ unsigned const *c_ptr = reinterpret_cast(&c); -+ -+ unsigned b_packed = static_cast(b.raw()); -+ b_packed = (b_packed | (b_packed << 16)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" -+ : "=r"(result_ptr[i]) -+ : "r"(a_ptr[i]), "r"(b_packed), "r"(c_ptr[i]) -+ ); -+ } -+ -+ if (N % 2) { -+ -+ uint16_t *result_ptr = reinterpret_cast(&result); -+ uint16_t const *a_residual_ptr = reinterpret_cast(&a); -+ uint16_t const *b_residual_ptr = reinterpret_cast(&b); -+ uint16_t const *c_residual_ptr = reinterpret_cast(&c); -+ -+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n" -+ : "=h"(result_ptr[N - 1]) -+ : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[N - 1]) -+ ); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a[i], b, c[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ Array const &b, -+ bfloat16_t const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ unsigned *result_ptr = reinterpret_cast(&result); -+ -+ unsigned const *a_ptr = reinterpret_cast(&a); -+ unsigned const *b_ptr = reinterpret_cast(&b); -+ -+ unsigned c_packed = static_cast(c.raw()); -+ c_packed = (c_packed | (c_packed << 16)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" -+ : "=r"(result_ptr[i]) -+ : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_packed) -+ ); -+ } -+ -+ if (N % 2) { -+ -+ uint16_t *result_ptr = reinterpret_cast(&result); -+ uint16_t const *a_residual_ptr = reinterpret_cast(&a); -+ uint16_t const *b_residual_ptr = reinterpret_cast(&b); -+ uint16_t const *c_residual_ptr = reinterpret_cast(&c); -+ -+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n" -+ : "=h"(result_ptr[N - 1]) -+ : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[0]) -+ ); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a[i], b[i], c); -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+ -+/// bit_and -+template -+struct bit_and> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, Array const &b) const { -+ using ArrayType = Array; -+ using Storage = typename ArrayType::Storage; -+ ArrayType result; -+ -+ Storage *result_data = result.raw_data(); -+ Storage const *a_data = a.raw_data(); -+ Storage const *b_data = b.raw_data(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ArrayType::kStorageElements; ++i) { -+ result_data[i] = (a_data[i] & b_data[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+ -+/// bit_or -+template -+struct bit_or> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, Array const &b) const { -+ using ArrayType = Array; -+ using Storage = typename ArrayType::Storage; -+ ArrayType result; -+ -+ Storage *result_data = result.raw_data(); -+ Storage const *a_data = a.raw_data(); -+ Storage const *b_data = b.raw_data(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ArrayType::kStorageElements; ++i) { -+ result_data[i] = (a_data[i] | b_data[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+ -+/// bit_not -+template -+struct bit_not> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a) const { -+ using ArrayType = Array; -+ using Storage = typename ArrayType::Storage; -+ ArrayType result; -+ -+ Storage *result_data = result.raw_data(); -+ Storage const *a_data = a.raw_data(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ArrayType::kStorageElements; ++i) { -+ result_data[i] = (~a_data[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+ -+/// bit_xor -+template -+struct bit_xor> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, Array const &b) const { -+ using ArrayType = Array; -+ using Storage = typename ArrayType::Storage; -+ ArrayType result; -+ -+ Storage *result_data = result.raw_data(); -+ Storage const *a_data = a.raw_data(); -+ Storage const *b_data = b.raw_data(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ArrayType::kStorageElements; ++i) { -+ result_data[i] = (a_data[i] ^ b_data[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Operator overloads -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator+(Array const &lhs, Array const &rhs) { -+ plus> op; -+ return op(lhs, rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator-(Array const &lhs, Array const &rhs) { -+ minus> op; -+ return op(lhs, rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator-(Array const &lhs) { -+ negate> op; -+ return op(lhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator*(Array const &lhs, Array const &rhs) { -+ multiplies> op; -+ return op(lhs, rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator*(T lhs, Array const &rhs) { -+ multiplies> op; -+ return op(lhs, rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator*(Array const &lhs, T rhs) { -+ multiplies> op; -+ return op(lhs, rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator/(Array const &lhs, Array const &rhs) { -+ divides> op; -+ return op(lhs, rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array fma(Array const &a, Array const &b, Array const &c) { -+ multiply_add> op; -+ return op(a, b, c); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array fma(T a, Array const &b, Array const &c) { -+ multiply_add> op; -+ return op(a, b, c); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array fma(Array const &a, T b, Array const &c) { -+ multiply_add> op; -+ return op(a, b, c); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array fma(Array const &a, Array const &b, T c) { -+ multiply_add> op; -+ return op(a, b, c); -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/array_subbyte.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// AlignedArray -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Aligned array type -+template < -+ /// Element type -+ typename T, -+ /// Number of elements in the array -+ int N, -+ /// Alignment requirement in bytes -+ int Alignment = sizeof_bits::value * N / 8 -+> -+class alignas(Alignment) AlignedArray: public Array { -+public: -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/array_planar_complex.h b/3rdparty/cutlass/include/cutlass/array_planar_complex.h -new file mode 100644 -index 0000000..4503b77 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/array_planar_complex.h -@@ -0,0 +1,103 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Array holding planar complex elements -+template -+struct ArrayPlanarComplex { -+ -+ /// Underlying real element -+ using Element = Element_; -+ -+ /// Number of logical elements -+ static size_t const kElements = N; -+ -+ /// Underlying Fragment of real-valued elemenets -+ using ArrayReal = Array; -+ -+public: -+ -+ /// Fragment of real-valued elements representing the real part -+ ArrayReal real; -+ -+ /// Fragment of real-valued elements representing the imaginary part -+ ArrayReal imag; -+ -+public: -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ArrayPlanarComplex() { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ArrayPlanarComplex( -+ ArrayReal const &real_, -+ ArrayReal const &imag_ -+ ): -+ real(real_), imag(imag_) { } -+ -+ /// Sets the array to zero efficiently -+ CUTLASS_HOST_DEVICE -+ void clear() { -+ real.clear(); -+ imag.clear(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to deduce template arguments -+template -+CUTLASS_HOST_DEVICE -+ArrayPlanarComplex -+make_ArrayPlanarComplex(Array const &real, Array const &imag) { -+ return ArrayPlanarComplex(real, imag); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/array_subbyte.h b/3rdparty/cutlass/include/cutlass/array_subbyte.h -new file mode 100644 -index 0000000..ac30422 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/array_subbyte.h -@@ -0,0 +1,564 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Statically sized array for any data type -+template < -+ typename T, -+ int N -+> -+class Array { -+public: -+ -+ static int const kSizeBits = sizeof_bits::value * N; -+ -+ /// Storage type -+ using Storage = typename platform::conditional< -+ ((kSizeBits % 32) != 0), -+ typename platform::conditional< -+ ((kSizeBits % 16) != 0), -+ uint8_t, -+ uint16_t -+ >::type, -+ uint32_t -+ >::type; -+ -+ /// Element type -+ using Element = T; -+ -+ /// Number of logical elements per stored object -+ static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; -+ -+ /// Number of storage elements -+ static size_t const kStorageElements = N / kElementsPerStoredItem; -+ -+ /// Number of logical elements -+ static size_t const kElements = N; -+ -+ /// Bitmask for covering one item -+ static Storage const kMask = ((Storage(1) << sizeof_bits::value) - 1); -+ -+ // -+ // C++ standard members with pointer types removed -+ // -+ -+ typedef T value_type; -+ typedef size_t size_type; -+ typedef ptrdiff_t difference_type; -+ typedef value_type *pointer; -+ typedef value_type const *const_pointer; -+ -+ // -+ // References -+ // -+ -+ /// Reference object inserts or extracts sub-byte items -+ class reference { -+ /// Pointer to storage element -+ Storage *ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ reference(): ptr_(nullptr), idx_(0) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ reference(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ /// Assignment -+ CUTLASS_HOST_DEVICE -+ reference &operator=(T x) { -+ Storage item = (reinterpret_cast(x) & kMask); -+ -+ Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits::value))); -+ *ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits::value))); -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T get() const { -+ Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & kMask); -+ return reinterpret_cast(item); -+ } -+ -+ /// Extract -+ CUTLASS_HOST_DEVICE -+ operator T() const { -+ return get(); -+ } -+ -+ /// Explicit cast to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(get()); -+ } -+ -+ /// Explicit cast to float -+ CUTLASS_HOST_DEVICE -+ explicit operator float() const { -+ return float(get()); -+ } -+ }; -+ -+ /// Reference object extracts sub-byte items -+ class const_reference { -+ -+ /// Pointer to storage element -+ Storage const *ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ const_reference(): ptr_(nullptr), idx_(0) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ const_reference(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ CUTLASS_HOST_DEVICE -+ const T get() const { -+ Storage item = (*ptr_ >> (idx_ * sizeof_bits::value)) & kMask; -+ return reinterpret_cast(item); -+ } -+ -+ /// Extract -+ CUTLASS_HOST_DEVICE -+ operator T() const { -+ Storage item = Storage(Storage(*ptr_ >> Storage(idx_ * sizeof_bits::value)) & kMask); -+ return reinterpret_cast(item); -+ } -+ -+ /// Explicit cast to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(get()); -+ } -+ -+ /// Explicit cast to float -+ CUTLASS_HOST_DEVICE -+ explicit operator float() const { -+ return float(get()); -+ } -+ }; -+ -+ // -+ // Iterators -+ // -+ -+ /// Bidirectional iterator over elements -+ class iterator { -+ -+ /// Pointer to storage element -+ Storage *ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ iterator(): ptr_(nullptr), idx_(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ CUTLASS_HOST_DEVICE -+ iterator &operator++() { -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator &operator--() { -+ if (!idx_) { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ else { -+ --idx_; -+ } -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator operator++(int) { -+ iterator ret(*this); -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator operator--(int) { -+ iterator ret(*this); -+ if (!idx_) { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ else { -+ --idx_; -+ } -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference operator*() const { -+ return reference(ptr_, idx_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator==(iterator const &other) const { -+ return ptr_ == other.ptr_ && idx_ == other.idx_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator!=(iterator const &other) const { -+ return !(*this == other); -+ } -+ }; -+ -+ /// Bidirectional constant iterator over elements -+ class const_iterator { -+ -+ /// Pointer to storage element -+ Storage const *ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator(): ptr_(nullptr), idx_(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ CUTLASS_HOST_DEVICE -+ iterator &operator++() { -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator &operator--() { -+ if (!idx_) { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ else { -+ --idx_; -+ } -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator operator++(int) { -+ iterator ret(*this); -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator operator--(int) { -+ iterator ret(*this); -+ if (!idx_) { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ else { -+ --idx_; -+ } -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference operator*() const { -+ return const_reference(ptr_, idx_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator==(iterator const &other) const { -+ return ptr_ == other.ptr_ && idx_ == other.idx_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator!=(iterator const &other) const { -+ return !(*this == other); -+ } -+ }; -+ -+ /// Bidirectional iterator over elements -+ class reverse_iterator { -+ -+ /// Pointer to storage element -+ Storage *ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator(): ptr_(nullptr), idx_(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ }; -+ -+ /// Bidirectional constant iterator over elements -+ class const_reverse_iterator { -+ -+ /// Pointer to storage element -+ Storage const *ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator(): ptr_(nullptr), idx_(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ }; -+ -+private: -+ -+ /// Internal storage -+ Storage storage[kStorageElements]; -+ -+public: -+ -+ #if 0 -+ CUTLASS_HOST_DEVICE -+ Array() { } -+ -+ CUTLASS_HOST_DEVICE -+ Array(Array const &x) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < int(kStorageElements); ++i) { -+ storage[i] = x.storage[i]; -+ } -+ } -+ #endif -+ -+ /// Efficient clear method -+ CUTLASS_HOST_DEVICE -+ void clear() { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < int(kStorageElements); ++i) { -+ storage[i] = Storage(0); -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference at(size_type pos) { -+ return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference at(size_type pos) const { -+ return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference operator[](size_type pos) { -+ return at(pos); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference operator[](size_type pos) const { -+ return at(pos); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference front() { -+ return at(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference front() const { -+ return at(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference back() { -+ return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference back() const { -+ return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ pointer data() { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_pointer data() const { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Storage * raw_data() { -+ return storage; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Storage const * raw_data() const { -+ return storage; -+ } -+ -+ -+ CUTLASS_HOST_DEVICE -+ constexpr bool empty() const { -+ return !kElements; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ constexpr size_type size() const { -+ return kElements; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ constexpr size_type max_size() const { -+ return kElements; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void fill(T const &value) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElementsPerStoredItem; ++i) { -+ reference ref(storage, i); -+ ref = value; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < kStorageElements; ++i) { -+ storage[i] = storage[0]; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator begin() { -+ return iterator(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator cbegin() const { -+ return const_iterator(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator end() { -+ return iterator(storage + kStorageElements); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator cend() const { -+ return const_iterator(storage + kStorageElements); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator rbegin() { -+ return reverse_iterator(storage + kStorageElements); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator crbegin() const { -+ return const_reverse_iterator(storage + kStorageElements); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator rend() { -+ return reverse_iterator(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator crend() const { -+ return const_reverse_iterator(storage); -+ } -+ -+ // -+ // Comparison operators -+ // -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/barrier.h b/3rdparty/cutlass/include/cutlass/barrier.h -new file mode 100644 -index 0000000..85a178b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/barrier.h -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implementation of a CTA-wide barrier for inter-CTA synchronization. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// CTA-wide semaphore for inter-CTA synchronization. -+struct Barrier -+{ -+ -+public: -+ -+ /// Flag type -+ using T = int; -+ -+ /// Initial flag value -+ static const T INIT = 0; -+ -+ -+protected: -+ -+ /// Load flag, as a strong acquire operation (int specialization) -+ CUTLASS_DEVICE -+ static int ld_acquire(int *ptr) -+ { -+ int state = 0; -+ -+#if (__CUDA_ARCH__ >= 700) -+ /// SM70 and newer use memory consistency qualifiers -+ -+ // Acquire pattern using acquire modifier -+ asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); -+ -+#else -+ asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); -+#endif // (__CUDA_ARCH__ >= 700) -+ -+ return state; -+ } -+ -+ -+ /// Reduce into flag, with release pattern (int specialization) -+ CUTLASS_DEVICE -+ static void red_release(int *ptr, int val) -+ { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+#if (__CUDA_ARCH__ >= 700) -+ /// SM70 and newer use memory consistency qualifiers -+ -+ // Release pattern using acq_rel fence + relaxed modifier. (The fence also releases data -+ // that was weakly-written by other threads prior to the last syncthreads) -+ asm volatile ("fence.acq_rel.gpu;\n"); -+ asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val)); -+ -+#else -+ __threadfence(); -+ atomicAdd(ptr, val); -+#endif // (__CUDA_ARCH__ >= 700) -+#endif -+ } -+ -+ -+public: -+ -+ /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter -+ CUTLASS_DEVICE -+ static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count) -+ { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; -+ -+ if (thread_idx == 0) -+ { -+ // Spin-loop -+ #pragma unroll 1 -+ while(ld_acquire(flag_ptr) < count) {} -+ } -+ -+ __syncthreads(); -+#endif -+ } -+ -+ /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter -+ CUTLASS_DEVICE -+ static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) -+ { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; -+ -+ if (thread_idx == 0) -+ { -+ // Spin-loop -+ #pragma unroll 1 -+ while(ld_acquire(flag_ptr) != val) {} -+ } -+ __syncthreads(); -+#endif -+ } -+ -+ /// Uses thread[0] to wait for the specified count of signals on the given flag counter -+ CUTLASS_DEVICE -+ static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; -+ -+ if (thread_idx == 0) -+ { -+ // Spin-loop -+ #pragma unroll 1 -+ while(atomicCAS(flag_ptr, val, 0) != val) {} -+ } -+ -+ __syncthreads(); -+#endif -+ } -+ -+ /// Increment the arrival count for a flag -+ CUTLASS_DEVICE -+ static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx) -+ { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; -+ -+ __syncthreads(); -+ -+ if (thread_idx == 0) -+ { -+ red_release(flag_ptr, 1); -+ } -+#endif -+ } -+ -+ -+ /// Increment the arrival counts for a range of flags -+ CUTLASS_DEVICE -+ static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1) -+ { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ int flag_idx = first_flag_idx + thread_idx; -+ T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; -+ -+ // Barrier to make sure all other threads in block have written their data -+ __syncthreads(); -+ -+ // Select threads increment their flags -+ if (thread_idx < count) { -+ red_release(flag_ptr, 1); -+ } -+#endif -+ } -+}; -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/bfloat16.h b/3rdparty/cutlass/include/cutlass/bfloat16.h -new file mode 100644 -index 0000000..b660cd4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/bfloat16.h -@@ -0,0 +1,500 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines a proxy class for storing non-standard 16-bit floating point values with -+ 8 bits of exponent and 7 bit of mantissa. -+*/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include "cutlass/floating_point_nvrtc.h" -+#else -+#include -+#include -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Floating-point type with 8 bits of exponent and 7 bits of mantissa. -+struct alignas(2) bfloat16_t { -+ -+ // -+ // Data members -+ // -+ -+ /// Storage type -+ uint16_t storage; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs from an unsigned short -+ CUTLASS_HOST_DEVICE -+ static bfloat16_t bitcast(uint16_t x) { -+ bfloat16_t h; -+ h.storage = x; -+ return h; -+ } -+ -+ /// Default constructor -+ bfloat16_t() = default; -+ -+ /// Floating-point conversion - round toward nearest -+ CUTLASS_HOST_DEVICE -+ explicit bfloat16_t(float x) { -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) -+ -+ asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x)); -+ -+ #else -+ uint32_t bits; -+ -+ #if defined(__CUDA_ARCH__) -+ bits = reinterpret_cast(x); -+ #else -+ std::memcpy(&bits, &x, sizeof(bits)); -+ #endif -+ -+ if ((bits & 0x7f800000) != 0x7f800000) { -+ -+ bool mantissa_bit = ((bits & (1 << 16)) != 0); -+ bool round_bit = ((bits & (1 << 15)) != 0); -+ bool sticky_bit = ((bits & ((1 << 15) - 1)) != 0); -+ -+ if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) { -+ bits += uint32_t(1 << 16); -+ } -+ } -+ else if (bits & ~0xff800000) { -+ bits = 0x7fffffff; -+ } -+ -+ storage = uint16_t((bits >> 16) & 0xffff); -+ #endif -+ } -+ -+ /// Floating-point conversion - round toward nearest -+ CUTLASS_HOST_DEVICE -+ explicit bfloat16_t(double x): bfloat16_t(float(x)) { -+ -+ } -+ -+ /// Integer conversion - round toward nearest -+ CUTLASS_HOST_DEVICE -+ explicit bfloat16_t(int x) { -+ float flt = static_cast(x); -+ uint32_t bits; -+ -+ #if defined(__CUDA_ARCH__) -+ bits = reinterpret_cast(flt); -+ #else -+ std::memcpy(&bits, &flt, sizeof(bits)); -+ #endif -+ -+ storage = uint16_t(bits >> 16); -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ operator float() const { -+ unsigned bits = (unsigned(storage) << 16); -+ #if defined(__CUDA_ARCH__) -+ return reinterpret_cast(bits); -+ #else -+ float flt; -+ std::memcpy(&flt, &bits, sizeof(flt)); -+ return flt; -+ #endif -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(float(*this)); -+ } -+ -+ /// Converts to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(float(*this)); -+ } -+ -+ /// Casts to bool -+ CUTLASS_HOST_DEVICE -+ explicit operator bool() const { -+ return (float(*this) != 0.0f); -+ } -+ -+ /// Obtains raw bits -+ CUTLASS_HOST_DEVICE -+ uint16_t raw() const { -+ return storage; -+ } -+ /// Returns the sign bit -+ CUTLASS_HOST_DEVICE -+ bool signbit() const { -+ return ((raw() & 0x8000) != 0); -+ } -+ -+ /// Returns the biased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent_biased() const { -+ return int((raw() >> 7) & 0x0ff); -+ } -+ -+ /// Returns the unbiased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent() const { -+ return exponent_biased() - 127; -+ } -+ -+ /// Returns the mantissa -+ CUTLASS_HOST_DEVICE -+ int mantissa() const { -+ return int(raw() & 0x7f); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool signbit(cutlass::bfloat16_t const& h) { -+ return h.signbit(); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::bfloat16_t abs(cutlass::bfloat16_t const& h) { -+ return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fffffff); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isnan(cutlass::bfloat16_t const& h) { -+ return (h.exponent_biased() == 0x0ff) && h.mantissa(); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isfinite(cutlass::bfloat16_t const& h) { -+ return (h.exponent_biased() != 0x0ff); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::bfloat16_t nan_bf16(const char*) { -+ // NVIDIA canonical NaN -+ return cutlass::bfloat16_t::bitcast(0x7fff); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isinf(cutlass::bfloat16_t const& h) { -+ return (h.exponent_biased() == 0x0ff) && !h.mantissa(); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isnormal(cutlass::bfloat16_t const& h) { -+ return h.exponent_biased() && h.exponent_biased() != 0x0ff; -+} -+ -+CUTLASS_HOST_DEVICE -+int fpclassify(cutlass::bfloat16_t const& h) { -+ int exp = h.exponent_biased(); -+ int mantissa = h.mantissa(); -+ if (exp == 0x0ff) { -+ if (mantissa) { -+ return FP_NAN; -+ } -+ else { -+ return FP_INFINITE; -+ } -+ } -+ else if (!exp) { -+ if (mantissa) { -+ return FP_SUBNORMAL; -+ } -+ else { -+ return FP_ZERO; -+ } -+ } -+ return FP_NORMAL; -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) { -+#if defined(__CUDACC_RTC__) -+ return cutlass::bfloat16_t(sqrtf(float(h))); -+#else -+ return cutlass::bfloat16_t(std::sqrt(float(h))); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) { -+ -+ uint16_t a_bits; -+ uint16_t b_bits; -+ -+ #if defined(__CUDA_ARCH__) -+ a_bits = reinterpret_cast(a); -+ b_bits = reinterpret_cast(b); -+ #else -+ std::memcpy(&a_bits, &a, sizeof(a_bits)); -+ std::memcpy(&b_bits, &b, sizeof(b_bits)); -+ #endif -+ -+ uint16_t a_mag = (a_bits & 0x7fff); -+ uint16_t b_sign = (b_bits & 0x8000); -+ uint16_t result = (a_mag | b_sign); -+ -+ return bfloat16_t::bitcast(result); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Standard Library operations and definitions -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace std { -+ -+#if !defined(__CUDACC_RTC__) -+/// Numeric limits -+template <> -+struct numeric_limits { -+ static bool const is_specialized = true; -+ static bool const is_signed = true; -+ static bool const is_integer = false; -+ static bool const is_exact = false; -+ static bool const has_infinity = true; -+ static bool const has_quiet_NaN = true; -+ static bool const has_signaling_NaN = false; -+ static std::float_denorm_style const has_denorm = std::denorm_present; -+ static bool const has_denorm_loss = true; -+ static std::float_round_style const round_style = std::round_to_nearest; -+ static bool const is_iec559 = false; -+ static bool const is_bounded = true; -+ static bool const is_modulo = false; -+ static int const digits = 7; -+ -+ /// Least positive value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); } -+ -+ /// Minimum finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); } -+ -+ /// Maximum finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } -+}; -+#endif -+ -+} // namespace std -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Arithmetic operators -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return float(lhs) == float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return float(lhs) != float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return float(lhs) < float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return float(lhs) <= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return float(lhs) > float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return float(lhs) >= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return bfloat16_t(float(lhs) + float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator-(bfloat16_t const& lhs) { -+ return bfloat16_t(-float(lhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return bfloat16_t(float(lhs) - float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return bfloat16_t(float(lhs) * float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return bfloat16_t(float(lhs) / float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) { -+ lhs = bfloat16_t(float(lhs) + float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) { -+ lhs = bfloat16_t(float(lhs) - float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) { -+ lhs = bfloat16_t(float(lhs) * float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) { -+ lhs = bfloat16_t(float(lhs) / float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t& operator++(bfloat16_t & lhs) { -+ float tmp(lhs); -+ ++tmp; -+ lhs = bfloat16_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t& operator--(bfloat16_t & lhs) { -+ float tmp(lhs); -+ --tmp; -+ lhs = bfloat16_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator++(bfloat16_t & lhs, int) { -+ bfloat16_t ret(lhs); -+ float tmp(lhs); -+ tmp++; -+ lhs = bfloat16_t(tmp); -+ return ret; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator--(bfloat16_t & lhs, int) { -+ bfloat16_t ret(lhs); -+ float tmp(lhs); -+ tmp--; -+ lhs = bfloat16_t(tmp); -+ return ret; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// User-defined literals -+// -+ -+CUTLASS_HOST_DEVICE -+cutlass::bfloat16_t operator "" _bf16(long double x) { -+ return cutlass::bfloat16_t(float(x)); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) { -+ return cutlass::bfloat16_t(int(x)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/blas3.h b/3rdparty/cutlass/include/cutlass/blas3.h -new file mode 100644 -index 0000000..f5f8a09 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/blas3.h -@@ -0,0 +1,176 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Basic include for CUTLASS BLAS3/HPC code. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Enumerated type describing the type of kernel (based on input or output matrices). -+enum class BlasMode { -+ kGemm, -+ kSymmetric, -+ kHermitian, -+ kTriangular, -+ kInvalid -+}; -+ -+/// Enumerated type describing the fill mode for matrices for BLAS functions. -+enum class FillMode { -+ kFull, /// The entire tensor is covered. -+ kLower, /// The 'lower' part of a tensor is covered including diagonal -+ kUpper, /// The 'upper' part of a tensor is covered including diaognal -+ kDiagonal, /// Only diagonal elements are covered. -+ kNone, /// No element is covered. -+ kInvalid -+}; -+ -+/// Enumerated type describing the diagonal property of matrices for BLAS functions. -+enum class DiagType { -+ kNonUnit, -+ kUnit, -+ kZero, // Only used internally for computing SYMM/HEMM -+ kInvalid -+}; -+ -+/// Enumerated type describing the side dense matrix is in matrix equation for BLAS functions. -+enum class SideMode { -+ kLeft, -+ kRight, -+ kInvalid -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines FillMode inversions -+template -+struct InvertFillMode; -+ -+/// Invert FillMode lower to upper -+template <> -+struct InvertFillMode { -+ static FillMode const mode = FillMode::kUpper; -+}; -+ -+/// Invert FillMode upper to lower -+template <> -+struct InvertFillMode { -+ static FillMode const mode = FillMode::kLower; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines SideMode inversions -+template -+struct InvertSideMode; -+ -+/// Invert SideMode left to right -+template <> -+struct InvertSideMode { -+ static SideMode const mode = SideMode::kRight; -+}; -+ -+/// Invert SideMode right to left -+template <> -+struct InvertSideMode { -+ static SideMode const mode = SideMode::kLeft; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines correct compare operation for Triangular matrix boundary -+template -+struct TrMatrixCompareOp { -+ using Index = int32_t; -+ using Type = typename platform::conditional< -+ (kFillMode == FillMode::kLower), -+ greater_equal, -+ less_equal>::type; -+}; -+ -+template -+struct TrMatrixCompareOp { -+ using Index = int32_t; -+ using Type = typename platform::conditional< -+ (kFillMode == FillMode::kLower), -+ greater_equal, -+ less_equal>::type; -+}; -+ -+template -+struct TrMatrixCompareOp { -+ using Index = int32_t; -+ using Type = typename platform::conditional< -+ (kFillMode == FillMode::kLower), -+ greater, -+ less>::type; -+}; -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// Returns precision in terms of bits (based on datatype) to fill tensors with. -+// Defaults to 5 bits of mantissa for TF32 and FP32 (with implicit round-offs). -+// Also defines acceptable mantissa result variance/error. -+template -+struct MantissaInBits { -+ static int constexpr bits = 5; -+ static double constexpr error = 1.0e-7; -+}; -+ -+// Full precision is supported for FP64 -+template <> -+struct MantissaInBits { -+ static int constexpr bits = 30; -+ static double constexpr error = 1.0e-15; -+}; -+ -+template <> -+struct MantissaInBits> { -+ static int constexpr bits = 30; -+ static double constexpr error = 1.0e-15; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/block_striped.h b/3rdparty/cutlass/include/cutlass/block_striped.h -new file mode 100644 -index 0000000..563e619 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/block_striped.h -@@ -0,0 +1,267 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Utilities for performing block-striped access (load, store, reduce) of trivially-copyable, -+ statically-sized array types to global memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/wmma_array.h" -+#include "cutlass/functional.h" -+#include "cutlass/complex.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// AccessWidth -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes the maximal power-of-two that evenly divides the size of T, capped at Limit -+template < -+ typename T, -+ int Limit> -+struct AccessWidth -+{ -+ // Inductive case -+ template < -+ int ObjectBytes, /// Size of T in bytes -+ int AlignBytes, /// Template induction variable -+ bool IsAligned = /// Whether ObjectBytes is an even multiple of AlignBytes -+ ((AlignBytes <= Limit) && (ObjectBytes % AlignBytes == 0))> -+ struct Detail -+ { -+ static const int value = Detail::value; -+ }; -+ -+ // Base case (ObjectBytes is not an even multiple of AlignBytes) -+ template < -+ int ObjectBytes, /// Size of T in bytes -+ int AlignBytes> /// Template induction variable -+ struct Detail -+ { -+ static const int value = AlignBytes / 2; -+ }; -+ -+ /// The maximal power-of-two that evenly divides the size of T -+ static const int value = Detail< -+ (int) sizeof(T), -+ 1>::value; -+}; -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// StripedAccessType -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// ReinterpretCast type for striping a trivially-copyable type in global memory -+/// (Default specialization. Striping granularity is type T.) -+template < -+ typename T, /// Data type -+ int TransferBytes = /// Data access width (16 byte max for global memory access on current architectures) -+ AccessWidth::value> -+struct alignas(TransferBytes) StripedAccessType : public T -+{}; -+ -+ -+/// ReinterpretCast type for striping a trivially-copyable type in global memory -+/// (Specialization for cutlass::Array. Striping granularity is a multiple of T.) -+template < -+ typename T, /// Array element type -+ int N, /// Number of elements in array -+ bool RegisterSized, /// T is register-sized -+ int TransferBytes> /// Data access width -+struct StripedAccessType< -+ Array, -+ TransferBytes> -+: public AlignedArray< -+ T, // Element type of StripedAccessType -+ __NV_STD_MAX(1, TransferBytes / (int) sizeof(T)), // Number of elements T in StripedAccessType -+ TransferBytes> // Alignment of StripedAccessType -+{}; -+ -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+/// ReinterpretCast type for striping a trivially-copyable type in global memory -+/// (Specialization for cutlass::WmmaFragmentArray. Striping granularity is a multiple of T.) -+template< -+ typename Use, -+ int m, -+ int n, -+ int k, -+ typename ElementT, -+ typename Layout, -+ int kFragments, -+ int TransferBytes> -+struct StripedAccessType< -+ WmmaFragmentArray, kFragments>, -+ TransferBytes> -+: public AlignedArray< -+ ElementT, -+ __NV_STD_MAX(1, TransferBytes / (int) sizeof(ElementT)), -+ TransferBytes> -+{}; -+ -+#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// BlockStriped -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Utility for performing block-striped access (load, store) of trivially-copyable, -+/// statically-sized array types to global memory -+template < -+ int BlockThreads, -+ typename ArrayT, -+ typename AccessT = StripedAccessType > -+struct BlockStriped -+{ -+ /// Number of striped accesses -+ static const int kStripes = int(sizeof(ArrayT) / sizeof(AccessT)); -+ static_assert(kStripes > 0, "AccessT type must be smaller than or equal to ArrayT type"); -+ -+ /// Load -+ CUTLASS_DEVICE -+ static void load(ArrayT &data, ArrayT *ptr, int thread_idx) -+ { -+ AccessT *access_input = reinterpret_cast(ptr); -+ AccessT *access_data = reinterpret_cast(&data); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kStripes; ++i) { -+ access_data[i] = access_input[(BlockThreads * i) + thread_idx]; -+ } -+ } -+ -+ /// Load & Add -+ CUTLASS_DEVICE -+ static void load_add(ArrayT &data, ArrayT *ptr, int thread_idx) -+ { -+ AccessT *access_input = reinterpret_cast(ptr); -+ AccessT *access_data = reinterpret_cast(&data); -+ -+ plus add; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kStripes; ++i) -+ { -+ access_data[i] = add(access_data[i], access_input[(BlockThreads * i) + thread_idx]); -+ } -+ } -+ -+ /// Store -+ CUTLASS_DEVICE -+ static void store(ArrayT *ptr, const ArrayT &data, int thread_idx) -+ { -+ AccessT *access_output = reinterpret_cast(ptr); -+ const AccessT *access_data = reinterpret_cast(&data); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kStripes; ++i) { -+ access_output[(BlockThreads * i) + thread_idx] = access_data[i]; -+ } -+ } -+ -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// BlockStripedReduce -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, -+/// statically-sized array types to global memory. -+/// (Default specialization) -+template < -+ int BlockThreads, -+ typename ArrayT, -+ typename ElementT = typename StripedAccessType::Element> -+struct BlockStripedReduce : -+ BlockStriped< -+ BlockThreads, -+ ArrayT, -+ ElementT> -+{ -+ /// Reduce -+ CUTLASS_DEVICE -+ static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) -+ { -+ cutlass::red reduce; -+ ElementT *access_output = reinterpret_cast(ptr); -+ const ElementT *access_data = reinterpret_cast(&data); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < BlockStripedReduce::kStripes; ++i) { -+ reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); -+ } -+ } -+}; -+ -+ -+/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, -+/// statically-sized array types to global memory. -+/// (Specialization for half_t. Uses half2 vectorized-reduction.) -+template < -+ int BlockThreads, -+ typename ArrayT> -+struct BlockStripedReduce : -+ BlockStriped< -+ BlockThreads, -+ ArrayT, -+ half2> -+{ -+ static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length"); -+ -+ /// Reduce -+ CUTLASS_DEVICE -+ static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) -+ { -+ cutlass::red reduce; -+ half2 *access_output = reinterpret_cast(ptr); -+ const half2 *access_data = reinterpret_cast(&data); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < BlockStripedReduce::kStripes; ++i) -+ { -+ reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); -+ } -+ } -+}; -+ -+ -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/cluster_launch.hpp b/3rdparty/cutlass/include/cutlass/cluster_launch.hpp -new file mode 100644 -index 0000000..4843540 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/cluster_launch.hpp -@@ -0,0 +1,156 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief PTX for TMA Tensor Memory Access operators on memory added for SM90 -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/trace.h" -+ -+#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) -+# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED -+#endif -+ -+namespace cutlass { -+ -+#ifndef NDEBUG -+#define Return_Status(cudaError_t_status) \ -+ if (cudaError_t_status != cudaSuccess) { \ -+ fprintf(stderr, \ -+ "[ ERROR: CUDA Runtime ] %s:%d: %s\n", \ -+ __FILE__, \ -+ __LINE__, \ -+ cudaGetErrorString(cudaError_t_status)); \ -+ return Status::kInvalid; \ -+ } else { \ -+ return Status::kSuccess; \ -+ } -+#else -+#define Return_Status(cudaError_t_status) \ -+ if (cudaError_t_status != cudaSuccess) { \ -+ return Status::kInvalid; \ -+ } else { \ -+ return Status::kSuccess; \ -+ } -+#endif -+ -+struct ClusterLauncher { -+ constexpr static int MaxClusterSize = 32; -+ -+ // Check for hardware compatibility -+ static inline __host__ -+ Status check_cluster_dims(dim3 const& grid, dim3 const& cluster) { -+ if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) && -+ (grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) { -+ return Status::kSuccess; -+ } -+ else { -+ CUTLASS_TRACE_HOST("ClusterLauncher: Invalid cluster configuration -- aborting launch."); -+ return Status::kInvalid; -+ } -+ } -+ -+ static inline __host__ -+ Status -+#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) -+ init(void const* kernel_function) -+#else -+ init(void const* /* kernel_function */) -+#endif -+ { -+#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) -+ // This attribute was added in CUDA 11.8. -+ cudaError_t status = -+ cudaFuncSetAttribute( -+ kernel_function, cudaFuncAttributeNonPortableClusterSizeAllowed, 1); -+ Return_Status(status); -+#else -+ return Status::kInvalid; -+#endif -+ } -+ -+ // This is the method we expect to use going forward -+ static inline __host__ -+ Status launch( -+ dim3 const& grid_dims, -+ dim3 const& cluster_dims, -+ dim3 const& block_dims, -+ size_t const& smem_size, -+ cudaStream_t& cuda_stream, -+ void const* kernel, -+ void** kernel_params) { -+#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) -+ if (check_cluster_dims(grid_dims, cluster_dims) != Status::kSuccess) { -+ CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting."); -+ return Status::kInvalid; -+ } -+ -+ auto init_status = init(kernel); -+ if (init_status != Status::kSuccess) { -+ CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting."); -+ return Status::kInvalid; -+ } -+ -+ cudaLaunchConfig_t launch_config; -+ launch_config.gridDim = {grid_dims.x, grid_dims.y, grid_dims.z}; -+ launch_config.blockDim = {block_dims.x, block_dims.y, block_dims.z}; -+ launch_config.dynamicSmemBytes = smem_size; -+ launch_config.stream = cuda_stream; -+ -+ cudaLaunchAttribute launch_attribute[1]; -+ launch_attribute[0].id = cudaLaunchAttributeClusterDimension; -+ launch_attribute[0].val.clusterDim.x = cluster_dims.x; -+ launch_attribute[0].val.clusterDim.y = cluster_dims.y; -+ launch_attribute[0].val.clusterDim.z = cluster_dims.z; -+ -+ launch_config.attrs = launch_attribute; -+ launch_config.numAttrs = 1; -+ -+ CUTLASS_TRACE_HOST("ClusterLauncher: Launching GPC_CLUSTER_GRID GridDims = " -+ "(" << grid_dims.x << ", " << grid_dims.y << ", " << grid_dims.z << "), " -+ "And ClusterDims = " -+ "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); -+ -+ cudaError_t status = cudaLaunchKernelExC(&launch_config, kernel, kernel_params); -+ Return_Status(status); -+#else -+ CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch."); -+ return Status::kInvalid; -+#endif -+ } -+}; -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/complex.h b/3rdparty/cutlass/include/cutlass/complex.h -new file mode 100644 -index 0000000..089f474 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/complex.h -@@ -0,0 +1,705 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/half.h" -+#include "cutlass/real.h" -+ -+#include "cutlass/bfloat16.h" -+#include "cutlass/tfloat32.h" -+ -+#include "cutlass/fast_math.h" -+ -+#if !defined(__CUDACC_RTC__) -+#include -+#endif -+ -+namespace cutlass { -+ -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Enumeraed type describing a transformation on a complex value. -+enum class ComplexTransform { -+ kNone, -+ kConjugate -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines ComplexTransform inversions -+template -+struct InvertComplexTransform; -+ -+/// Invert ComplexTransform from kNone to kConjugate -+template <> -+struct InvertComplexTransform { -+ static ComplexTransform const transform = ComplexTransform::kConjugate; -+}; -+ -+/// Invert ComplexTransform from kConjugate to kNone -+template <> -+struct InvertComplexTransform { -+ static ComplexTransform const transform = ComplexTransform::kNone; -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Accessors for CUDA complex types -+// -+ -+#if !defined(__CUDACC_RTC__) -+/// Returns the real part of the complex number -+CUTLASS_HOST_DEVICE -+float const &real(cuFloatComplex const &z) { return z.x; } -+ -+/// Returns the real part of the complex number -+CUTLASS_HOST_DEVICE -+float &real(cuFloatComplex &z) { return z.x; } -+ -+/// Returns the real part of the complex number -+CUTLASS_HOST_DEVICE -+double const &real(cuDoubleComplex const &z) { return z.x; } -+ -+/// Returns the real part of the complex number -+CUTLASS_HOST_DEVICE -+double &real(cuDoubleComplex &z) { return z.x; } -+ -+/// Returns the imaginary part of the complex number -+CUTLASS_HOST_DEVICE -+float const &imag(cuFloatComplex const &z) { return z.y; } -+ -+/// Returns the imaginary part of the complex number -+CUTLASS_HOST_DEVICE -+float &imag(cuFloatComplex &z) { return z.y; } -+ -+/// Returns the imaginary part of the complex number -+CUTLASS_HOST_DEVICE -+double const &imag(cuDoubleComplex const &z) { return z.y; } -+ -+/// Returns the imaginary part of the complex number -+CUTLASS_HOST_DEVICE -+double &imag(cuDoubleComplex &z) { return z.y; } -+#endif -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Class for representing and manipulating complex numbers with conversions from built-in CUDA -+/// complex types. -+ -+template -+class complex -+{ -+ public: -+ /// Type alias for scalar type -+ using value_type = T; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Real part -+ T _real; -+ -+ /// Imaginary part -+ T _imag; -+ -+ public: -+ -+// -+// Methods -+// -+ -+ /// Default constructor -+ complex() = default; -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ complex(T r) : _real(r), _imag(T(0)) {} -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ complex(T r, T i) : _real(r), _imag(i) {} -+ -+ /// Constructor -+ template -+ CUTLASS_HOST_DEVICE -+ complex(complex const &z) : _real(static_cast(z.real())), _imag(static_cast(z.imag())) {} -+ -+ -+ #if !defined(__CUDACC_RTC__) -+ /// Conversion from cuFloatComplex -+ CUTLASS_HOST_DEVICE -+ complex(cuFloatComplex const &z) : _real(static_cast(cuCrealf(z))), _imag(static_cast(cuCimagf(z))) {} -+ -+ /// Conversion from cuDoubleComplex -+ CUTLASS_HOST_DEVICE -+ complex(cuDoubleComplex const &z) : _real(static_cast(cuCreal(z))), _imag(static_cast(cuCimag(z))) {} -+ #endif -+ -+ /// Assignment -+ template -+ CUTLASS_HOST_DEVICE -+ complex& operator=(complex const &z) -+ { -+ _real = static_cast(z.real()); -+ _imag = static_cast(z.imag()); -+ return *this; -+ } -+ -+ /// Equality operator -+ CUTLASS_HOST_DEVICE bool operator==(complex const &rhs) const { -+ return this->real() == rhs.real() && this->imag() == rhs.imag(); -+ } -+ -+ /// Inequality operator -+ CUTLASS_HOST_DEVICE bool operator!=(complex const &rhs) const { -+ return !(*this == rhs); -+ } -+ -+ /// Addition -+ template -+ CUTLASS_HOST_DEVICE complex operator+(complex const &rhs) const { -+ return complex(this->real() + rhs.real(), this->imag() + rhs.imag()); -+ } -+ -+ /// Reduction into memory address. Components may update out of order. -+ template -+ CUTLASS_DEVICE void red(complex *ptr) const { -+ static_assert(platform::is_same::value, "Component type must match"); -+ cutlass::red reduce; -+ reduce(&ptr->_real, _real); -+ reduce(&ptr->_imag, _imag); -+ } -+ -+ /// Reduction into memory address. Components may update out of order. (Half specialization) -+ CUTLASS_DEVICE void red(complex *ptr) const { -+ static_assert(platform::is_same::value, "Component type must match"); -+ half2 *h2_ptr = reinterpret_cast(ptr); -+ half2 h2_data = reinterpret_cast(*this); -+ cutlass::red reduce; -+ reduce(h2_ptr, h2_data); -+ } -+ -+ /// Subtraction -+ template -+ CUTLASS_HOST_DEVICE complex operator-(complex const &rhs) const { -+ return complex(this->real() - rhs.real(), this->imag() - rhs.imag()); -+ } -+ -+ /// Multiplication -+ template -+ CUTLASS_HOST_DEVICE complex operator*(complex const &rhs) const { -+ return complex(this->real() * rhs.real() - this->imag() * rhs.imag(), -+ this->real() * rhs.imag() + this->imag() * rhs.real()); -+ } -+ -+ /// Scalar Multiplication -+ template -+ CUTLASS_HOST_DEVICE complex operator*(A const &s) const { -+ return complex(this->real() * s, this->imag() * s); -+ } -+ -+ /// Division -+ template -+ CUTLASS_HOST_DEVICE complex operator/(complex const &rhs) const { -+ T d = T(rhs.real() * rhs.real() + rhs.imag() * rhs.imag()); -+ -+ return complex( -+ (real() * rhs.real() + imag() * rhs.imag()) / d, -+ (imag() * rhs.real() - real() * rhs.imag()) / d -+ ); -+ } -+ -+ /// Scalar Division -+ template -+ CUTLASS_HOST_DEVICE complex operator/(A const &s) const { -+ return complex(this->real() / s, this->imag() / s); -+ } -+ -+ /// Addition -+ template -+ CUTLASS_HOST_DEVICE complex &operator+=(complex const &rhs) { -+ *this = *this + rhs; -+ return *this; -+ } -+ -+ /// Subtraction -+ template -+ CUTLASS_HOST_DEVICE complex &operator-=(complex const &rhs) { -+ *this = *this - rhs; -+ return *this; -+ } -+ -+ /// Multiplication -+ template -+ CUTLASS_HOST_DEVICE complex &operator*=(complex const &rhs) { -+ *this = *this * rhs; -+ return *this; -+ } -+ -+ /// Scalar multiplication -+ template -+ CUTLASS_HOST_DEVICE complex &operator*=(A s) { -+ *this = *this * s; -+ return *this; -+ } -+ -+ /// Division -+ template -+ CUTLASS_HOST_DEVICE complex &operator/=(complex const &rhs) { -+ *this = *this / rhs; -+ return *this; -+ } -+ -+ /// Accesses the real part of the complex number -+ CUTLASS_HOST_DEVICE -+ T const &real() const { return _real; } -+ -+ /// Accesses the real part of the complex number -+ CUTLASS_HOST_DEVICE -+ T &real() { return _real; } -+ -+ /// Accesses the imaginary part of the complex number -+ CUTLASS_HOST_DEVICE -+ T const &imag() const { return _imag; } -+ -+ /// Accesses the imaginary part of the complex number -+ CUTLASS_HOST_DEVICE -+ T &imag() { return _imag; } -+ -+ -+ #if !defined(__CUDACC_RTC__) -+ /// Converts to cuFloatComplex -+ CUTLASS_HOST_DEVICE -+ explicit operator cuFloatComplex() const { return make_cuFloatComplex(float(real()), float(imag())); } -+ -+ /// Converts to cuDoubleComplex -+ CUTLASS_HOST_DEVICE -+ explicit operator cuDoubleComplex() const { return make_cuDoubleComplex(real(), imag()); } -+ #endif -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Accessors for complex template -+// -+ -+/// Returns the real part of the complex number -+template -+CUTLASS_HOST_DEVICE T const &real(complex const &z) { -+ return z.real(); -+} -+ -+/// Returns the real part of the complex number -+template -+CUTLASS_HOST_DEVICE T &real(complex &z) { -+ return z.real(); -+} -+ -+/// Returns the imaginary part of the complex number -+template -+CUTLASS_HOST_DEVICE T const &imag(complex const &z) { -+ return z.imag(); -+} -+ -+/// Returns the imaginary part of the complex number -+template -+CUTLASS_HOST_DEVICE T &imag(complex &z) { -+ return z.imag(); -+} -+ -+/// Returns the real part of the real number -+template -+CUTLASS_HOST_DEVICE T const &real(T const &r) { -+ return r; -+} -+ -+/// Returns the real part of the real number -+template -+CUTLASS_HOST_DEVICE T &real(T &r) { -+ return r; -+} -+ -+/// Returns the imaginary part of the real number -+template -+CUTLASS_HOST_DEVICE T const &imag(T const &r) { -+ return T(); -+} -+ -+/// Returns the imaginary part of the complex number -+template -+CUTLASS_HOST_DEVICE T &imag(T &r) { -+ return T(); -+} -+ -+// -+// Output operators -+// -+ -+#if !defined(__CUDACC_RTC__) -+template -+std::ostream &operator<<(std::ostream &out, complex const &z) { -+ T _r = real(z); -+ T _i = imag(z); -+ -+ if (bool(_i)) { -+ return out << _r << "+i" << _i; -+ } -+ return out << _r; -+} -+#endif -+ -+// -+// Non-member operators defined for complex types -+// -+ -+ -+// -+// Non-member functions defined for complex numbers -+// -+ -+/// Returns the magnitude of the complex number -+template -+CUTLASS_HOST_DEVICE T abs(complex const &z) { -+ return sqrt(norm(z)); -+} -+ -+/// Returns the magnitude of the complex number -+template -+CUTLASS_HOST_DEVICE T arg(complex const &z) { -+ return atan2(imag(z), real(z)); -+} -+ -+/// Returns the squared magnitude of a real number -+template -+CUTLASS_HOST_DEVICE T norm(T const &z) { -+ return z * z; -+} -+ -+/// Returns the squared magnitude of a real number -+template <> -+CUTLASS_HOST_DEVICE int8_t norm(int8_t const &z) { -+ return static_cast(z * z); -+} -+ -+/// Returns the squared magnitude of a complex number -+template -+CUTLASS_HOST_DEVICE double norm(complex const &z) { -+ return real(z) * real(z) + imag(z) * imag(z); -+} -+ -+/// Norm-accumulate calculation -+template -+CUTLASS_HOST_DEVICE R norm_accumulate(T const &x, R const & accumulator) { -+ return accumulator + static_cast(x) * static_cast(x); -+} -+ -+/// Norm accumulate specialized for complex types -+template -+CUTLASS_HOST_DEVICE R norm_accumulate(complex const &z, R const &accumulator) { -+ return accumulator + static_cast(real(z)) * static_cast(real(z)) + -+ static_cast(imag(z)) * static_cast(imag(z)); -+} -+ -+/// Returns the complex conjugate -+CUTLASS_HOST_DEVICE float conj(float const &z) { -+ return z; -+} -+ -+/// Returns the complex conjugate -+CUTLASS_HOST_DEVICE double conj(double const &z) { -+ return z; -+} -+ -+/// Returns the complex conjugate -+template -+CUTLASS_HOST_DEVICE complex conj(complex const &z) { -+ return complex(real(z), -imag(z)); -+} -+/// Indentity transform for non-complex types -+template -+CUTLASS_HOST_DEVICE T conj(T const &z) { -+ static_assert( !platform::is_same::value && -+ !platform::is_same::value && -+ !platform::is_same>::value && -+ !platform::is_same>::value, "May not be a complex data type"); -+ return z; -+} -+ -+/// Projects the complex number z onto the Riemann sphere -+template -+CUTLASS_HOST_DEVICE complex proj(complex const &z) { -+ T d = real(z) * real(z) + imag(z) * imag(z) + T(1); -+ return complex((T(2) * real(z)) / d, (T(2) * imag(z)) / d); -+} -+ -+/// Returns a complex number with magnitude r and phase theta -+template -+CUTLASS_HOST_DEVICE complex polar(T const &r, T const &theta = T()) { -+ return complex(r * cos(theta), r * sin(theta)); -+} -+ -+/// Computes the complex exponential of z. -+template -+CUTLASS_HOST_DEVICE complex exp(complex const &z) { -+ return complex(fast_exp(real(z)) * fast_cos(imag(z)), fast_exp(real(z)) * fast_sin(imag(z))); -+} -+ -+/// Computes the log of z -+template -+CUTLASS_HOST_DEVICE complex log(complex const &z) { -+ return complex(log(abs(z)), arg(z)); -+} -+ -+/// Computes the log base 10 of z -+template -+CUTLASS_HOST_DEVICE complex log10(complex const &z) { -+ return log(z) / T(log(T(10))); -+} -+ -+/// Computes the square root of complex number z -+template -+CUTLASS_HOST_DEVICE complex sqrt(complex const &z) { -+ return sqrt(T(2)) / T(2) * -+ complex(sqrt(sqrt(norm(z)) + real(z)), -+ (imag(z) < 0 ? T(-1) : T(1)) * sqrt(sqrt(norm(z)) - real(z))); -+} -+ -+/// Computes the cosine of complex z. -+template -+CUTLASS_HOST_DEVICE complex cos(complex const &z) { -+ return (exp(z) + exp(-z)) / T(2); -+} -+ -+/// Computes the sin of complex z. -+template -+CUTLASS_HOST_DEVICE complex sin(complex const &z) { -+ return (exp(-z) - exp(z)) * complex(T(0), T(1) / T(2)); -+} -+ -+/// Comparison -+template -+CUTLASS_HOST_DEVICE bool operator<(complex const &lhs, complex const &rhs) { -+ //TODO -+ return true; -+} -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex-valued type. -+template -+struct RealType< complex > -+{ -+ using Type = T; -+ -+ /// Number of elements -+ static int const kExtent = 2; -+ -+ CUTLASS_HOST_DEVICE -+ static complex from_real(double x) { -+ return complex(static_cast(x)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+CUTLASS_HOST_DEVICE -+cutlass::complex from_real >(double r) { -+ return cutlass::complex(half_t(r)); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+cutlass::complex from_real >(double r) { -+ return cutlass::complex(float(r)); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+cutlass::complex from_real >(double r) { -+ return cutlass::complex(r); -+} -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct is_complex { -+ static bool const value = false; -+}; -+ -+template -+struct is_complex> { -+ static bool const value = true; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// functional.h numeric specializations -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Squares with optional conversion -+template -+struct magnitude_squared, Output> { -+ CUTLASS_HOST_DEVICE -+ Output operator()(complex lhs) const { -+ multiplies mul_op; -+ -+ Output y_r = Output(lhs.real()); -+ Output y_i = Output(lhs.imag()); -+ -+ return mul_op(y_r, y_r) + mul_op(y_i, y_i); -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add, complex, complex> { -+ CUTLASS_HOST_DEVICE -+ complex operator()( -+ complex const &a, -+ complex const &b, -+ complex const &c) const { -+ -+ T real = c.real(); -+ T imag = c.imag(); -+ -+ real += a.real() * b.real(); -+ real += -a.imag() * b.imag(); -+ imag += a.real() * b.imag(); -+ imag += a.imag () * b.real(); -+ -+ return complex{ -+ real, -+ imag -+ }; -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add, T, complex> { -+ CUTLASS_HOST_DEVICE -+ complex operator()( -+ complex const &a, -+ T const &b, -+ complex const &c) const { -+ -+ T real = c.real(); -+ T imag = c.imag(); -+ -+ real += a.real() * b; -+ imag += a.imag () * b; -+ -+ return complex{ -+ real, -+ imag -+ }; -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add, complex> { -+ CUTLASS_HOST_DEVICE -+ complex operator()( -+ T const &a, -+ complex const &b, -+ complex const &c) const { -+ -+ T real = c.real(); -+ T imag = c.imag(); -+ -+ real += a * b.real(); -+ imag += a * b.imag(); -+ -+ return complex{ -+ real, -+ imag -+ }; -+ } -+}; -+ -+/// Conjugate -+template -+struct conjugate> { -+ CUTLASS_HOST_DEVICE -+ complex operator()(complex const &a) const { -+ return conj(a); -+ } -+}; -+ -+/// Computes the square of a difference with optional conversion -+template -+struct magnitude_squared_difference, Output> { -+ CUTLASS_HOST_DEVICE -+ Output operator()(complex lhs, complex rhs) const { -+ multiplies mul_op; -+ -+ Output y_r = Output(lhs.real()) - Output(rhs.real()); -+ Output y_i = Output(lhs.imag()) - Output(rhs.imag()); -+ -+ return mul_op(y_r, y_r) + mul_op(y_i, y_i); -+ } -+}; -+ -+/// Reduces value into the data pointed to by ptr (complex specialization) -+template -+struct red> { -+ CUTLASS_DEVICE -+ void operator()(complex *ptr, const complex &data) -+ { -+ data.red(ptr); -+ } -+}; -+ -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/constants.h b/3rdparty/cutlass/include/cutlass/constants.h -new file mode 100644 -index 0000000..ca7ea89 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/constants.h -@@ -0,0 +1,1239 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* \file -+ \brief Boost-style constant definitions for floating-point types. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/complex.h" -+ -+/////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace constants { -+ -+/////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Primary templates -+// -+ -+/// Returns 1, the multiplicative identity element -+template CUTLASS_HOST_DEVICE T one(); -+ -+/// Returns 0, the additive identity element -+template CUTLASS_HOST_DEVICE T zero(); -+ -+/// Returns 2 -+template CUTLASS_HOST_DEVICE T two(); -+ -+/// Returns pi, approximately 3.141 -+template CUTLASS_HOST_DEVICE T pi(); -+ -+/// Returns 2 * pi -+template CUTLASS_HOST_DEVICE T two_pi(); -+ -+/// Returns pi / 2 -+template CUTLASS_HOST_DEVICE T half_pi(); -+ -+/// Returns sqrt(pi) -+template CUTLASS_HOST_DEVICE T root_pi(); -+ -+/// Returns sqrt(pi / 2) -+template CUTLASS_HOST_DEVICE T root_half_pi(); -+ -+/// Returns sqrt(2 * pi) -+template CUTLASS_HOST_DEVICE T root_two_pi(); -+ -+/// Returns sqrt(ln(4)) -+template CUTLASS_HOST_DEVICE T root_ln_four(); -+ -+/// Returns e, approximately 2.718... -+template CUTLASS_HOST_DEVICE T e(); -+ -+/// Returns (1/2) -+template CUTLASS_HOST_DEVICE T half(); -+ -+/// Returns sqrt(2), approximately 1.414... -+template CUTLASS_HOST_DEVICE T root_two(); -+ -+/// Returns sqrt(2)/2, approximately 0.707... -+template CUTLASS_HOST_DEVICE T half_root_two(); -+ -+/// Returns ln(2), approximately 0.693... -+template CUTLASS_HOST_DEVICE T ln_two(); -+ -+/// Returns ln(ln(2)), approximately -0.3665... -+template CUTLASS_HOST_DEVICE T ln_ln_two(); -+ -+/// Returns 1/3, approximately 0.333... -+template CUTLASS_HOST_DEVICE T third(); -+ -+/// Returns 2/3, approximately 0.666... -+template CUTLASS_HOST_DEVICE T twothirds(); -+ -+/// Returns pi - 3, approximately 0.1416... -+template CUTLASS_HOST_DEVICE T pi_minus_three(); -+ -+/// Returns 4 - pi, approximately 0.858... -+template CUTLASS_HOST_DEVICE T four_minus_pi(); -+ -+ -+///////////////////////////////////////////////////////////////////////////////////// -+ -+// Specialization for double -+ -+/// Returns 1, the multiplicative identity element (specialization for double) -+template <> CUTLASS_HOST_DEVICE double one() { -+ uint64_t bits = 0x3ff0000000000000ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1, the multiplicative identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex one< complex >() { -+ return complex(one(), double()); -+} -+ -+/// Returns 0, the additive identity element (specialization for double) -+template <> CUTLASS_HOST_DEVICE double zero() { -+ uint64_t bits = 0x0ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 0, the additive identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex zero< complex >() { -+ return complex(zero(), double()); -+} -+ -+/// Returns 2 (specialization for double) -+template <> CUTLASS_HOST_DEVICE double two() { -+ uint64_t bits = 0x4000000000000000ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two< complex >() { -+ return complex(two(), double()); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for double) -+template <> CUTLASS_HOST_DEVICE double pi() { -+ uint64_t bits = 0x400921fb54442d18ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi< complex >() { -+ return complex(pi(), double()); -+} -+ -+/// Returns 2 * pi (specialization for double) -+template <> CUTLASS_HOST_DEVICE double two_pi() { -+ uint64_t bits = 0x401921fb54442d18ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 * pi (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { -+ return complex(two_pi(), double()); -+} -+ -+/// Returns pi / 2 (specialization for double) -+template <> CUTLASS_HOST_DEVICE double half_pi() { -+ uint64_t bits = 0x3ff921fb54442d18ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi / 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { -+ return complex(half_pi(), double()); -+} -+ -+/// Returns sqrt(pi) (specialization for double) -+template <> CUTLASS_HOST_DEVICE double root_pi() { -+ uint64_t bits = 0x3ffc5bf891b4ef6aull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { -+ return complex(root_pi(), double()); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for double) -+template <> CUTLASS_HOST_DEVICE double root_half_pi() { -+ uint64_t bits = 0x3ff40d931ff62705ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { -+ return complex(root_half_pi(), double()); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for double) -+template <> CUTLASS_HOST_DEVICE double root_two_pi() { -+ uint64_t bits = 0x40040d931ff62705ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { -+ return complex(root_two_pi(), double()); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for double) -+template <> CUTLASS_HOST_DEVICE double root_ln_four() { -+ uint64_t bits = 0x3ff2d6abe44afc43ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { -+ return complex(root_ln_four(), double()); -+} -+ -+/// Returns e, approximately 2.718... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double e() { -+ uint64_t bits = 0x4005bf0a8b145769ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns e, approximately 2.718... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex e< complex >() { -+ return complex(e(), double()); -+} -+ -+/// Returns (1/2) (specialization for double) -+template <> CUTLASS_HOST_DEVICE double half() { -+ uint64_t bits = 0x3fe0000000000000ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns (1/2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half< complex >() { -+ return complex(half(), double()); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double root_two() { -+ uint64_t bits = 0x3ff6a09e667f3bcdull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { -+ return complex(root_two(), double()); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double half_root_two() { -+ uint64_t bits = 0x3fe6a09e667f3bcdull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { -+ return complex(half_root_two(), double()); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double ln_two() { -+ uint64_t bits = 0x3fe62e42fefa39efull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { -+ return complex(ln_two(), double()); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double ln_ln_two() { -+ uint64_t bits = 0xbfd774f29bdd6b9full; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { -+ return complex(ln_ln_two(), double()); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double third() { -+ uint64_t bits = 0x3fd5555555555555ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex third< complex >() { -+ return complex(third(), double()); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double twothirds() { -+ uint64_t bits = 0x3fe5555555555555ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { -+ return complex(twothirds(), double()); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double pi_minus_three() { -+ uint64_t bits = 0x3fc21fb54442d180ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { -+ return complex(pi_minus_three(), double()); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double four_minus_pi() { -+ uint64_t bits = 0x3feb7812aeef4ba0ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { -+ return complex(four_minus_pi(), double()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////// -+ -+// Specialization for float -+ -+/// Returns 1, the multiplicative identity element (specialization for float) -+template <> CUTLASS_HOST_DEVICE float one() { -+ uint32_t bits = 0x3f800000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1, the multiplicative identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex one< complex >() { -+ return complex(one(), float()); -+} -+ -+/// Returns 0, the additive identity element (specialization for float) -+template <> CUTLASS_HOST_DEVICE float zero() { -+ uint32_t bits = 0x0u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 0, the additive identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex zero< complex >() { -+ return complex(zero(), float()); -+} -+ -+/// Returns 2 (specialization for float) -+template <> CUTLASS_HOST_DEVICE float two() { -+ uint32_t bits = 0x40000000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two< complex >() { -+ return complex(two(), float()); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for float) -+template <> CUTLASS_HOST_DEVICE float pi() { -+ uint32_t bits = 0x40490fdbu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi< complex >() { -+ return complex(pi(), float()); -+} -+ -+/// Returns 2 * pi (specialization for float) -+template <> CUTLASS_HOST_DEVICE float two_pi() { -+ uint32_t bits = 0x40c90fdbu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 * pi (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { -+ return complex(two_pi(), float()); -+} -+ -+/// Returns pi / 2 (specialization for float) -+template <> CUTLASS_HOST_DEVICE float half_pi() { -+ uint32_t bits = 0x3fc90fdbu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi / 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { -+ return complex(half_pi(), float()); -+} -+ -+/// Returns sqrt(pi) (specialization for float) -+template <> CUTLASS_HOST_DEVICE float root_pi() { -+ uint32_t bits = 0x3fe2dfc5u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { -+ return complex(root_pi(), float()); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for float) -+template <> CUTLASS_HOST_DEVICE float root_half_pi() { -+ uint32_t bits = 0x3fa06c99u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { -+ return complex(root_half_pi(), float()); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for float) -+template <> CUTLASS_HOST_DEVICE float root_two_pi() { -+ uint32_t bits = 0x40206c99u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { -+ return complex(root_two_pi(), float()); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for float) -+template <> CUTLASS_HOST_DEVICE float root_ln_four() { -+ uint32_t bits = 0x3f96b55fu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { -+ return complex(root_ln_four(), float()); -+} -+ -+/// Returns e, approximately 2.718... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float e() { -+ uint32_t bits = 0x402df854u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns e, approximately 2.718... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex e< complex >() { -+ return complex(e(), float()); -+} -+ -+/// Returns (1/2) (specialization for float) -+template <> CUTLASS_HOST_DEVICE float half() { -+ uint32_t bits = 0x3f000000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns (1/2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half< complex >() { -+ return complex(half(), float()); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float root_two() { -+ uint32_t bits = 0x3fb504f3u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { -+ return complex(root_two(), float()); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float half_root_two() { -+ uint32_t bits = 0x3f3504f3u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { -+ return complex(half_root_two(), float()); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float ln_two() { -+ uint32_t bits = 0x3f317218u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { -+ return complex(ln_two(), float()); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float ln_ln_two() { -+ uint32_t bits = 0xbebba795u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { -+ return complex(ln_ln_two(), float()); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float third() { -+ uint32_t bits = 0x3eaaaaabu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex third< complex >() { -+ return complex(third(), float()); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float twothirds() { -+ uint32_t bits = 0x3f2aaaabu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { -+ return complex(twothirds(), float()); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float pi_minus_three() { -+ uint32_t bits = 0x3e10fdaau; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { -+ return complex(pi_minus_three(), float()); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float four_minus_pi() { -+ uint32_t bits = 0x3f5bc095u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { -+ return complex(four_minus_pi(), float()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////// -+ -+// Specialization for tfloat32_t -+ -+/// Returns 1, the multiplicative identity element (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t one() { -+ uint32_t bits = 0x3f801000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1, the multiplicative identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex one< complex >() { -+ return complex(one(), tfloat32_t()); -+} -+ -+/// Returns 0, the additive identity element (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t zero() { -+ uint32_t bits = 0x1000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 0, the additive identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex zero< complex >() { -+ return complex(zero(), tfloat32_t()); -+} -+ -+/// Returns 2 (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t two() { -+ uint32_t bits = 0x40001000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two< complex >() { -+ return complex(two(), tfloat32_t()); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t pi() { -+ uint32_t bits = 0x40491fdbu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi< complex >() { -+ return complex(pi(), tfloat32_t()); -+} -+ -+/// Returns 2 * pi (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t two_pi() { -+ uint32_t bits = 0x40c91fdbu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 * pi (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { -+ return complex(two_pi(), tfloat32_t()); -+} -+ -+/// Returns pi / 2 (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t half_pi() { -+ uint32_t bits = 0x3fc91fdbu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi / 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { -+ return complex(half_pi(), tfloat32_t()); -+} -+ -+/// Returns sqrt(pi) (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t root_pi() { -+ uint32_t bits = 0x3fe2efc5u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { -+ return complex(root_pi(), tfloat32_t()); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t root_half_pi() { -+ uint32_t bits = 0x3fa07c99u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { -+ return complex(root_half_pi(), tfloat32_t()); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t root_two_pi() { -+ uint32_t bits = 0x40207c99u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { -+ return complex(root_two_pi(), tfloat32_t()); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t root_ln_four() { -+ uint32_t bits = 0x3f96c55fu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { -+ return complex(root_ln_four(), tfloat32_t()); -+} -+ -+/// Returns e, approximately 2.718... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t e() { -+ uint32_t bits = 0x402e0854u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns e, approximately 2.718... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex e< complex >() { -+ return complex(e(), tfloat32_t()); -+} -+ -+/// Returns (1/2) (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t half() { -+ uint32_t bits = 0x3f001000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns (1/2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half< complex >() { -+ return complex(half(), tfloat32_t()); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t root_two() { -+ uint32_t bits = 0x3fb514f3u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { -+ return complex(root_two(), tfloat32_t()); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t half_root_two() { -+ uint32_t bits = 0x3f3514f3u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { -+ return complex(half_root_two(), tfloat32_t()); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t ln_two() { -+ uint32_t bits = 0x3f318218u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { -+ return complex(ln_two(), tfloat32_t()); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t ln_ln_two() { -+ uint32_t bits = 0xbebbb795u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { -+ return complex(ln_ln_two(), tfloat32_t()); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t third() { -+ uint32_t bits = 0x3eaabaabu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex third< complex >() { -+ return complex(third(), tfloat32_t()); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t twothirds() { -+ uint32_t bits = 0x3f2abaabu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { -+ return complex(twothirds(), tfloat32_t()); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t pi_minus_three() { -+ uint32_t bits = 0x3e110daau; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { -+ return complex(pi_minus_three(), tfloat32_t()); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t four_minus_pi() { -+ uint32_t bits = 0x3f5bd095u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { -+ return complex(four_minus_pi(), tfloat32_t()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////// -+ -+// Specialization for half_t -+ -+/// Returns 1, the multiplicative identity element (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t one() { -+ uint16_t bits = 0x3c00u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1, the multiplicative identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex one< complex >() { -+ return complex(one(), half_t()); -+} -+ -+/// Returns 0, the additive identity element (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t zero() { -+ uint16_t bits = 0x0u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 0, the additive identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex zero< complex >() { -+ return complex(zero(), half_t()); -+} -+ -+/// Returns 2 (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t two() { -+ uint16_t bits = 0x4000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two< complex >() { -+ return complex(two(), half_t()); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t pi() { -+ uint16_t bits = 0x4248u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi< complex >() { -+ return complex(pi(), half_t()); -+} -+ -+/// Returns 2 * pi (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t two_pi() { -+ uint16_t bits = 0x4648u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 * pi (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { -+ return complex(two_pi(), half_t()); -+} -+ -+/// Returns pi / 2 (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t half_pi() { -+ uint16_t bits = 0x3e48u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi / 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { -+ return complex(half_pi(), half_t()); -+} -+ -+/// Returns sqrt(pi) (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t root_pi() { -+ uint16_t bits = 0x3f17u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { -+ return complex(root_pi(), half_t()); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t root_half_pi() { -+ uint16_t bits = 0x3d03u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { -+ return complex(root_half_pi(), half_t()); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t root_two_pi() { -+ uint16_t bits = 0x4103u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { -+ return complex(root_two_pi(), half_t()); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t root_ln_four() { -+ uint16_t bits = 0x3cb6u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { -+ return complex(root_ln_four(), half_t()); -+} -+ -+/// Returns e, approximately 2.718... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t e() { -+ uint16_t bits = 0x4170u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns e, approximately 2.718... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex e< complex >() { -+ return complex(e(), half_t()); -+} -+ -+/// Returns (1/2) (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t half() { -+ uint16_t bits = 0x3800u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns (1/2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half< complex >() { -+ return complex(half(), half_t()); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t root_two() { -+ uint16_t bits = 0x3da8u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { -+ return complex(root_two(), half_t()); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t half_root_two() { -+ uint16_t bits = 0x39a8u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { -+ return complex(half_root_two(), half_t()); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t ln_two() { -+ uint16_t bits = 0x398cu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { -+ return complex(ln_two(), half_t()); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t ln_ln_two() { -+ uint16_t bits = 0xb5ddu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { -+ return complex(ln_ln_two(), half_t()); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t third() { -+ uint16_t bits = 0x3555u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex third< complex >() { -+ return complex(third(), half_t()); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t twothirds() { -+ uint16_t bits = 0x3955u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { -+ return complex(twothirds(), half_t()); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t pi_minus_three() { -+ uint16_t bits = 0x3088u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { -+ return complex(pi_minus_three(), half_t()); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t four_minus_pi() { -+ uint16_t bits = 0x3adeu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { -+ return complex(four_minus_pi(), half_t()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////// -+ -+// Specialization for bfloat16_t -+ -+/// Returns 1, the multiplicative identity element (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t one() { -+ uint16_t bits = 0x3f80u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1, the multiplicative identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex one< complex >() { -+ return complex(one(), bfloat16_t()); -+} -+ -+/// Returns 0, the additive identity element (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t zero() { -+ uint16_t bits = 0x0u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 0, the additive identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex zero< complex >() { -+ return complex(zero(), bfloat16_t()); -+} -+ -+/// Returns 2 (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t two() { -+ uint16_t bits = 0x4000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two< complex >() { -+ return complex(two(), bfloat16_t()); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t pi() { -+ uint16_t bits = 0x4049u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi< complex >() { -+ return complex(pi(), bfloat16_t()); -+} -+ -+/// Returns 2 * pi (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t two_pi() { -+ uint16_t bits = 0x40c9u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 * pi (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { -+ return complex(two_pi(), bfloat16_t()); -+} -+ -+/// Returns pi / 2 (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t half_pi() { -+ uint16_t bits = 0x3fc9u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi / 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { -+ return complex(half_pi(), bfloat16_t()); -+} -+ -+/// Returns sqrt(pi) (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t root_pi() { -+ uint16_t bits = 0x3fe3u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { -+ return complex(root_pi(), bfloat16_t()); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t root_half_pi() { -+ uint16_t bits = 0x3fa0u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { -+ return complex(root_half_pi(), bfloat16_t()); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t root_two_pi() { -+ uint16_t bits = 0x4020u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { -+ return complex(root_two_pi(), bfloat16_t()); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t root_ln_four() { -+ uint16_t bits = 0x3f97u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { -+ return complex(root_ln_four(), bfloat16_t()); -+} -+ -+/// Returns e, approximately 2.718... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t e() { -+ uint16_t bits = 0x402eu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns e, approximately 2.718... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex e< complex >() { -+ return complex(e(), bfloat16_t()); -+} -+ -+/// Returns (1/2) (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t half() { -+ uint16_t bits = 0x3f00u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns (1/2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half< complex >() { -+ return complex(half(), bfloat16_t()); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t root_two() { -+ uint16_t bits = 0x3fb5u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { -+ return complex(root_two(), bfloat16_t()); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t half_root_two() { -+ uint16_t bits = 0x3f35u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { -+ return complex(half_root_two(), bfloat16_t()); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t ln_two() { -+ uint16_t bits = 0x3f31u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { -+ return complex(ln_two(), bfloat16_t()); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t ln_ln_two() { -+ uint16_t bits = 0xbebcu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { -+ return complex(ln_ln_two(), bfloat16_t()); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t third() { -+ uint16_t bits = 0x3eabu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex third< complex >() { -+ return complex(third(), bfloat16_t()); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t twothirds() { -+ uint16_t bits = 0x3f2bu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { -+ return complex(twothirds(), bfloat16_t()); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t pi_minus_three() { -+ uint16_t bits = 0x3e11u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { -+ return complex(pi_minus_three(), bfloat16_t()); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t four_minus_pi() { -+ uint16_t bits = 0x3f5cu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { -+ return complex(four_minus_pi(), bfloat16_t()); -+} -+/////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace constants -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/conv2d_problem_size.h b/3rdparty/cutlass/include/cutlass/conv/conv2d_problem_size.h -new file mode 100644 -index 0000000..2bc4eb0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/conv2d_problem_size.h -@@ -0,0 +1,652 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This file contains definitions and utility functions for describing convolution problem sizes. -+ -+ Conv2dProblem desciption: -+ activation (NHWC), -+ filter (KRSC), -+ output (NPQK), -+ pading (pad_h, pad_w), -+ stride (stride_h, stride_w), -+ dilation (dilation_h, dilation_w). -+ -+ Free functions to map: -+ Map tensor extents (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator) -+ Map tensor sizes (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) -+ Map tensor problem sizes (Conv2d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) -+*/ -+ -+#pragma once -+ -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/functional.h" -+ -+namespace cutlass { -+namespace conv { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Problem size structure -+struct Conv2dProblemSize { -+ -+ // Conv2d strictly problem size parameters -+ int N, H, W, C, P, Q, K, R, S; -+ int pad_h, pad_w; -+ int stride_h, stride_w; -+ int dilation_h, dilation_w; -+ Mode mode; -+ -+ // Conv2d implementation-related parameters -+ int split_k_slices; -+ int groups; -+ -+ // -+ // Methods -+ // -+ -+public: -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize(): -+ N(0), H(0), W(0), C(0), P(0), Q(0), K(0), R(0), S(0), -+ pad_h(0), pad_w(0), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), -+ mode(Mode::kConvolution), split_k_slices(1), groups(1) { } -+ -+ /// Constructor for default padding, stride, dilation, and split-K -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize( -+ int N, -+ int H, -+ int W, -+ int C, -+ int P, -+ int Q, -+ int K, -+ int R, -+ int S, -+ Mode mode -+ ): -+ N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S), -+ pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), -+ mode(mode), split_k_slices(1), groups (1) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize( -+ int N, -+ int H, -+ int W, -+ int C, -+ int K, -+ int R, -+ int S, -+ int P, -+ int Q, -+ int pad_h, -+ int pad_w, -+ int stride_h, -+ int stride_w, -+ int dilation_h, -+ int dilation_w, -+ Mode mode, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ N(N), H(H), W(W), C(C), K(K), R(R), S(S), P(P), Q(Q), -+ pad_h(pad_h), pad_w(pad_w), stride_h(stride_h), stride_w(stride_w), -+ dilation_h(dilation_h), dilation_w(dilation_w), -+ mode(mode), split_k_slices(split_k_slices), groups (groups) { } -+ -+ /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord -+ // set user-defined output size and sets P and Q (include all data members in ctor) -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize( -+ cutlass::Tensor4DCoord input_size, // NHWC -+ cutlass::Tensor4DCoord filter_size, // KRSC -+ cutlass::Tensor4DCoord padding, // pad_h, _, pad_w, _ -+ cutlass::MatrixCoord stride, // stride_h, stride_w -+ cutlass::MatrixCoord dilation, // dilation_h, dilation_w -+ cutlass::Tensor4DCoord output_size, // NPQK -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), -+ K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), -+ pad_h(padding[0]), pad_w(padding[2]), -+ stride_h(stride.row()), stride_w(stride.column()), -+ dilation_h(dilation.row()), dilation_w(dilation.column()), -+ P(output_size.h()), Q(output_size.w()), -+ mode(mode), split_k_slices(split_k_slices), groups(groups) {} -+ -+ /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord -+ // computes output size and sets P and Q (skip output from ctor arguments) -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize( -+ cutlass::Tensor4DCoord input_size, // NHWC -+ cutlass::Tensor4DCoord filter_size, // KRSC -+ cutlass::Tensor4DCoord padding, // pad_h, _, pad_w, _ -+ cutlass::MatrixCoord stride, // stride_h, stride_w -+ cutlass::MatrixCoord dilation, // dilation_h, dilation_w -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), -+ K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), -+ pad_h(padding[0]), pad_w(padding[2]), -+ stride_h(stride.row()), stride_w(stride.column()), -+ dilation_h(dilation.row()), dilation_w(dilation.column()), -+ mode(mode), split_k_slices(split_k_slices), groups(groups) { -+ // set output P and Q -+ P = ((H + pad_h * 2 - R * dilation_h) / stride_h) + 1; -+ Q = ((W + pad_w * 2 - S * dilation_w) / stride_w) + 1; -+ } -+ -+ /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord -+ // set user-defined output size and sets P and Q (skip padding, striding, and dilation) -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize( -+ cutlass::Tensor4DCoord input_size, // NHWC -+ cutlass::Tensor4DCoord filter_size, // KRSC -+ cutlass::Tensor4DCoord output_size, // NPQK -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), -+ K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), -+ P(output_size.h()), Q(output_size.w()), -+ pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), -+ dilation_h(1), dilation_w(1), -+ mode(mode), split_k_slices(split_k_slices), groups(groups) {} -+ -+ // Reset covolution mode in the problem -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize reset_mode(cutlass::conv::Mode mode_) { -+ Conv2dProblemSize tmp(*this); -+ tmp.mode = mode_; -+ return tmp; -+ } -+ -+ // Reset covolution mode in the problem -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize reset_split_k_slices(int split_k_slices_) { -+ Conv2dProblemSize tmp(*this); -+ tmp.split_k_slices = split_k_slices_; -+ return tmp; -+ } -+ -+ /// Equality operator (ignores mode and split_k_slice) -+ CUTLASS_HOST_DEVICE -+ bool operator==(Conv2dProblemSize const &conv) const { -+ return ( -+ (N == conv.N) && (H == conv.H) && (W == conv.W) && (C == conv.C) && -+ (K == conv.K) && (R == conv.R) && (S == conv.S) && -+ (P == conv.P) && (Q == conv.Q) && -+ (pad_h == conv.pad_h) && (pad_w == conv.pad_w) && -+ (stride_h == conv.stride_h) && (stride_w == conv.stride_w) && -+ (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w) -+ ); -+ } -+ -+ /// Inequality operator -+ CUTLASS_HOST_DEVICE -+ bool operator!=(Conv2dProblemSize const &rhs) const { -+ return !(*this == rhs); -+ } -+ -+ /// Returns activation extent as Tensor4DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor4DCoord activation_extent() const { -+ -+ return cutlass::Tensor4DCoord ({N, H, W, C}); -+ } -+ -+ /// Returns filter extent as Tensor4DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor4DCoord filter_extent() const { -+ -+ return cutlass::Tensor4DCoord ({K, R, S, C / groups}); -+ } -+ -+ /// Returns output extent as Tensor4DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor4DCoord output_extent() const { -+ -+ return cutlass::Tensor4DCoord ({N, P, Q, K}); -+ } -+ -+ /// Returns activation size in number of elements -+ CUTLASS_HOST_DEVICE -+ int64_t activation_size() const { -+ -+ return (N * H * W * C); -+ } -+ -+ /// Returns filter size in number of elements -+ CUTLASS_HOST_DEVICE -+ int64_t filter_size() const { -+ -+ return (K * R * S * C / groups); -+ } -+ -+ /// Returns output size in number of elements -+ CUTLASS_HOST_DEVICE -+ int64_t output_size() const { -+ -+ return (N * P * Q * K); -+ } -+ -+ /// Returns padding as Tensor4DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor4DCoord padding() const { -+ -+ return cutlass::Tensor4DCoord ({pad_h, pad_h, pad_w, pad_w}); -+ } -+ -+ /// Returns stride as MatrixCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::MatrixCoord stride() const { -+ -+ return cutlass::MatrixCoord ({stride_h, stride_w}); -+ } -+ -+ /// Returns dilation as MatrixCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::MatrixCoord dilation() const { -+ -+ return cutlass::MatrixCoord ({dilation_h, dilation_w}); -+ } -+ -+ ///////////////////////////////////////////////////////////////// -+ // Methods used for strided dgrad implementation -+ ///////////////////////////////////////////////////////////////// -+ /// Number of filter r positions to accumulate in gemm-k dim -+ CUTLASS_HOST_DEVICE -+ int num_gemm_k_filter_r(int r) const { -+ return ((R - r + stride_h - 1) / stride_h); -+ } -+ -+ /// Number of filter s positions to accumulate in gemm-k dim -+ CUTLASS_HOST_DEVICE -+ int num_gemm_k_filter_s(int s) const { -+ return ((S - s + stride_w - 1) / stride_w); -+ } -+ -+ /// Number of filter positions to accumulate in gemm-k dim -+ CUTLASS_HOST_DEVICE -+ int num_gemm_k_filter_positions(int r, int s) const { -+ return num_gemm_k_filter_r(r) * num_gemm_k_filter_s(s); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// ImplicitGemm helper functions // -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Determine the problem size of the implicit GEMM operation -+CUTLASS_HOST_DEVICE -+cutlass::gemm::GemmCoord implicit_gemm_problem_size( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ // Compute problem size -+ switch (conv_operator) { -+ case Operator::kFprop: -+ return gemm::GemmCoord( -+ problem_size.N * problem_size.P * problem_size.Q, -+ problem_size.K, -+ problem_size.R * problem_size.S * problem_size.C / problem_size.groups -+ ); -+ case Operator::kDgrad: -+ return gemm::GemmCoord( -+ problem_size.N * problem_size.H * problem_size.W, -+ problem_size.C, -+ problem_size.R * problem_size.S * problem_size.K -+ ); -+ case Operator::kWgrad: -+ return gemm::GemmCoord( -+ problem_size.K, -+ problem_size.R * problem_size.S * problem_size.C, -+ problem_size.N * problem_size.P * problem_size.Q -+ ); -+ default: -+ break; -+ } -+ return gemm::GemmCoord(); -+} -+ -+// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm -+CUTLASS_HOST_DEVICE -+int implicit_gemm_k_iterations( -+ Operator conv_operator, -+ int threadblock_K, -+ Conv2dProblemSize const &problem_size, -+ IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, -+ GroupMode group_mode = GroupMode::kNone, -+ int threadblock_N = 0) { -+ -+ int iterations = 0; -+ -+ if (group_mode == GroupMode::kNone) { -+ -+ if (algorithm == IteratorAlgorithm::kFixedChannels) { -+ -+ int positions_per_iteration = threadblock_K / problem_size.C; -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration; -+ break; -+ -+ default: -+ break; -+ } -+ } -+ else if (algorithm == IteratorAlgorithm::kFewChannels) { -+ -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K; -+ break; -+ -+ default: -+ break; -+ } -+ } -+ else { -+ int elements_per_split_k_slice = 0; -+ -+ switch (conv_operator) { -+ case Operator::kFprop: -+ elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ case Operator::kDgrad: -+ elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ case Operator::kWgrad: -+ elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; -+ break; -+ -+ default: -+ break; -+ } -+ } -+ -+ } else if (group_mode == GroupMode::kDepthwise) { -+ int channels_per_cta = threadblock_N; -+ -+ if (algorithm == IteratorAlgorithm::kAnalytic) { -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = problem_size.R * problem_size.S * -+ ((channels_per_cta + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ default: -+ break; -+ } -+ } -+ } else { // Group conv -+ -+ int channels_per_group = problem_size.C / problem_size.groups; -+ int k_per_group = problem_size.K / problem_size.groups; -+ -+ if (algorithm == IteratorAlgorithm::kAnalytic) { -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); -+ // In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups -+ if (problem_size.groups != 1) { -+ if (k_per_group < threadblock_N) { -+ iterations *= threadblock_N / k_per_group; -+ } -+ } -+ break; -+ -+ default: -+ break; -+ } -+ } else if (algorithm == IteratorAlgorithm::kOptimized) { -+ // Current optimized iterator only support GroupMode::kSingleGroup -+ if (group_mode == GroupMode::kSingleGroup) { -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ default: -+ break; -+ } -+ } -+ } -+ -+ } -+ -+ return iterations; -+} -+ -+ -+template -+CUTLASS_HOST_DEVICE -+int depthwise_gemm_k_iterations( -+ Operator conv_operator, -+ int threadblock_K, -+ Conv2dProblemSize const &problem_size, -+ IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, -+ GroupMode group_mode = GroupMode::kNone, -+ int threadblock_N = 0) { -+ -+ int n = problem_size.N; -+ int p = (problem_size.P + Output_P - 1) / Output_P; -+ int q = (problem_size.Q + Output_Q - 1) / Output_Q; -+ -+ int iterations = (n * p * q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ return iterations; -+} -+ -+ -+CUTLASS_HOST_DEVICE -+int implicit_gemm_k_iterations_per_channel( -+ Operator conv_operator, -+ int threadblock_K, -+ Conv2dProblemSize const &problem_size, -+ IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) { -+ -+ int iterations = 0; //0 means not applicable -+ if (algorithm == IteratorAlgorithm::kAnalytic || algorithm == IteratorAlgorithm::kOptimized) { -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = problem_size.R * problem_size.S; -+ break; -+ -+ case Operator::kDgrad: -+ iterations = problem_size.R * problem_size.S; -+ break; -+ -+ default: -+ break; -+ } -+ } -+ return iterations; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output) -+//////////////////////////////////////////////////////////////////////////////// -+/// Returns ImplicitGemm tensor A extent as Tensor4DCoord -+CUTLASS_HOST_DEVICE -+cutlass::Tensor4DCoord implicit_gemm_tensor_a_extent( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); -+ default : break; -+ } -+ return cutlass::Tensor4DCoord(); -+} -+ -+/// Returns ImplicitGemm tensor B extent as Tensor4DCoord -+CUTLASS_HOST_DEVICE -+cutlass::Tensor4DCoord implicit_gemm_tensor_b_extent( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); -+ default : break; -+ } -+ return cutlass::Tensor4DCoord(); -+} -+ -+/// Returns ImplicitGemm tensor C extent as Tensor4DCoord -+CUTLASS_HOST_DEVICE -+cutlass::Tensor4DCoord implicit_gemm_tensor_c_extent( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); -+ default : break; -+ } -+ return cutlass::Tensor4DCoord(); -+} -+ -+/// Returns ImplicitGemm tensor A size in number of elements -+CUTLASS_HOST_DEVICE -+int64_t implicit_gemm_tensor_a_size( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); -+ default : break; -+ } -+ return 0; -+} -+ -+/// Returns ImplicitGemm tensor B size in number of elements -+CUTLASS_HOST_DEVICE -+int64_t implicit_gemm_tensor_b_size( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); -+ default : break; -+ } -+ return 0; -+} -+ -+/// Returns ImplicitGemm tensor C size in number of elements -+CUTLASS_HOST_DEVICE -+int64_t implicit_gemm_tensor_c_size( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.output_size(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); -+ default : break; -+ } -+ return 0; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// Strided dgrad helper functions // -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// Returns number of CTAs tile M to cover valid MMAs per starting filter postion -+CUTLASS_HOST_DEVICE -+int strided_dgrad_tile_m_per_filter( -+ Conv2dProblemSize const &problem_size, -+ int tile_size_m) { -+ -+ // Compute NHW rows in Dx output that needs MMA per starting filter position -+ int rows_h_per_filter = (problem_size.H + problem_size.stride_h - 1) / problem_size.stride_h; -+ int rows_w_per_filter = (problem_size.W + problem_size.stride_w - 1) / problem_size.stride_w; -+ int rows_nhw_per_filter = problem_size.N * rows_h_per_filter * rows_w_per_filter; -+ -+ // Number of CTAs tile M to cover valid MMAs per starting filter postion -+ int tile_m_per_filter = (rows_nhw_per_filter + tile_size_m - 1) / tile_size_m; -+ -+ return tile_m_per_filter; -+} -+ -+// Computes starting Dx coord (h, w) for given starting filter postion -+CUTLASS_HOST_DEVICE -+void strided_dgrad_starting_coords( -+ Conv2dProblemSize const &problem_size, -+ FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, -+ int r, int s, -+ int &start_h, int &start_w) { -+ -+ // function locals for remainder by fast divmod -+ int pad_h_rem_, pad_w_rem_; -+ -+ // start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h; -+ stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h); -+ int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r)); -+ stride_h_divmod.divmod(start_h, r_); -+ -+ //start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w; -+ stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w); -+ int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s)); -+ stride_w_divmod.divmod(start_w, s_); -+} -+ -+} // namespace conv -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/conv3d_problem_size.h b/3rdparty/cutlass/include/cutlass/conv/conv3d_problem_size.h -new file mode 100644 -index 0000000..5bef4ff ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/conv3d_problem_size.h -@@ -0,0 +1,477 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This file contains definitions and utility functions for describing convolution problem sizes. -+ -+ Conv3dProblem desciption: -+ activation (NDHWC), -+ filter (KTRSC), -+ output (NZPQK), -+ pading (pad_d, pad_h, pad_w), -+ stride (stride_d, stride_h, stride_w), -+ dilation (dilation_d, dilation_h, dilation_w). -+ -+ Free functions to map: -+ Map tensor extents (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator) -+ Map tensor sizes (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) -+ Map tensor problem sizes (Conv3d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) -+*/ -+ -+#pragma once -+ -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+namespace cutlass { -+namespace conv { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Problem size structure -+struct Conv3dProblemSize : public Conv2dProblemSize { -+ // -+ // Type definitions -+ // -+ -+ // 3D coordinate for padding, stride, and dilation in (d, h, w) dimensions -+ using Coord3D = Coord<3>; -+ -+ // -+ // Data members -+ // -+ -+ // Conv3d strictly problem size parameters -+ int D, T, Z; // input depth, filter depth, output depth -+ int pad_d; // padding in depth dimension -+ int stride_d; // stride in depth dimension -+ int dilation_d; // dilation in depth dimension -+ -+ // -+ // Methods -+ // -+public: -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize(): -+ D(0), T(0), Z(0), -+ pad_d(0), -+ stride_d(1), -+ dilation_d(1), -+ Conv2dProblemSize() { } -+ -+ /// Constructor for default padding, stride, dilation, and split-K -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize( -+ int N, -+ int D, -+ int H, -+ int W, -+ int C, -+ int Z, -+ int P, -+ int Q, -+ int K, -+ int T, -+ int R, -+ int S, -+ Mode mode -+ ): -+ D(D), T(T), Z(Z), -+ pad_d(T / 2), stride_d(1), dilation_d(1), -+ Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize( -+ int N, -+ int D, -+ int H, -+ int W, -+ int C, -+ int K, -+ int T, -+ int R, -+ int S, -+ int Z, -+ int P, -+ int Q, -+ int pad_d, -+ int pad_h, -+ int pad_w, -+ int stride_d, -+ int stride_h, -+ int stride_w, -+ int dilation_d, -+ int dilation_h, -+ int dilation_w, -+ Mode mode, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ D(D), T(T), Z(Z), -+ pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d), -+ Conv2dProblemSize( -+ N, H, W, C, K, R, S, P, Q, -+ pad_h, pad_w, -+ stride_h, stride_w, -+ dilation_h, dilation_w, -+ mode, split_k_slices, groups) { } -+ -+ /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D -+ // set *user-defined* output size and sets Z, P, and Q (include all data members in ctor) -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize( -+ cutlass::Tensor5DCoord input_size, // NDHWC -+ cutlass::Tensor5DCoord filter_size, // KTRSC -+ Coord3D padding, // pad_d, pad_h, pad_w -+ Coord3D stride, // stride_d, stride_h, stride_w -+ Coord3D dilation, // dilation_d, dilation_h, dilation_w -+ cutlass::Tensor5DCoord output_size, // NZPQK -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ D(input_size.d()), T(filter_size.d()), Z(output_size.d()), -+ pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]), -+ Conv2dProblemSize( -+ {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, -+ {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, -+ {padding[1], padding[1], padding[2], padding[2]}, -+ {stride[1], stride[2]}, -+ {dilation[1], dilation[2]}, -+ {output_size.n(), output_size.h(), output_size.w(), output_size.c()}, -+ mode, split_k_slices, groups -+ ) { } -+ -+ /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D -+ // *computes* output size and sets Z, P and Q (include all data members in ctor) -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize( -+ cutlass::Tensor5DCoord input_size, // NDHWC -+ cutlass::Tensor5DCoord filter_size, // KTRSC -+ Coord3D padding, // pad_d, pad_h, pad_w -+ Coord3D stride, // stride_d, stride_h, stride_w -+ Coord3D dilation, // dilation_d, dilation_h, dilation_w -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ D(input_size.d()), T(filter_size.d()), -+ pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]), -+ Conv2dProblemSize( -+ {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, -+ {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, -+ {padding[1], padding[1], padding[2], padding[2]}, -+ {stride[1], stride[2]}, -+ {dilation[1], dilation[2]}, -+ mode, split_k_slices, groups -+ ) { -+ // set output Z -+ Z = ((D + pad_d * 2 - T * dilation_d) / stride_d) + 1; -+ } -+ -+ /// Equality operator (ignores mode and split_k_slice) -+ CUTLASS_HOST_DEVICE -+ bool operator==(Conv3dProblemSize const &conv) const { -+ return ( -+ (N == conv.N) && (D == conv.D) && (H == conv.H) && (W == conv.W) && (C == conv.C) && -+ (K == conv.K) && (T == conv.T) && (R == conv.R) && (S == conv.S) && -+ (Z == conv.Z) &&(P == conv.P) && (Q == conv.Q) && -+ (pad_d == conv.pad_d) && (pad_h == conv.pad_h) && (pad_w == conv.pad_w) && -+ (stride_d == conv.stride_d) && (stride_h == conv.stride_h) && (stride_w == conv.stride_w) && -+ (dilation_d == conv.dilation_d) && (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w) -+ ); -+ } -+ -+ /// Inequality operator -+ CUTLASS_HOST_DEVICE -+ bool operator!=(Conv3dProblemSize const &rhs) const { -+ return !(*this == rhs); -+ } -+ -+ // Reset covolution mode in the problem -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize reset_mode(cutlass::conv::Mode mode_) { -+ Conv3dProblemSize tmp(*this); -+ tmp.mode = mode_; -+ return tmp; -+ } -+ -+ // Reset covolution mode in the problem -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize reset_split_k_slices(int split_k_slices_) { -+ Conv3dProblemSize tmp(*this); -+ tmp.split_k_slices = split_k_slices_; -+ return tmp; -+ } -+ -+ /// Returns activation extent as Tensor5DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor5DCoord activation_extent() const { -+ -+ return cutlass::Tensor5DCoord ({N, D, H, W, C}); -+ } -+ -+ /// Returns filter extent as Tensor5DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor5DCoord filter_extent() const { -+ -+ return cutlass::Tensor5DCoord ({K, T, R, S, C}); -+ } -+ -+ /// Returns output extent as Tensor5DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor5DCoord output_extent() const { -+ -+ return cutlass::Tensor5DCoord ({N, Z, P, Q, K}); -+ } -+ -+ /// Returns activation size in number of elements -+ CUTLASS_HOST_DEVICE -+ int64_t activation_size() const { -+ -+ return (N * D * H * W * C); -+ } -+ -+ /// Returns filter size in number of elements -+ CUTLASS_HOST_DEVICE -+ int64_t filter_size() const { -+ -+ return (K * T * R * S * C); -+ } -+ -+ /// Returns output size in number of elements -+ CUTLASS_HOST_DEVICE -+ int64_t output_size() const { -+ -+ return (N * Z * P * Q * K); -+ } -+ -+ /// Returns output extent as Tensor5DCoord -+ CUTLASS_HOST_DEVICE -+ Coord3D padding() const { -+ -+ return Coord3D ({pad_d, pad_h, pad_w}); -+ } -+ -+ /// Returns stride as MatrixCoord -+ CUTLASS_HOST_DEVICE -+ Coord3D stride() const { -+ -+ return Coord3D ({stride_d, stride_h, stride_w}); -+ } -+ -+ /// Returns dilation as MatrixCoord -+ CUTLASS_HOST_DEVICE -+ Coord3D dilation() const { -+ -+ return Coord3D ({dilation_d, dilation_h, dilation_w}); -+ } -+ -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// ImplicitGemm helper functions // -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Determine the problem size of the implicit GEMM operation -+CUTLASS_HOST_DEVICE -+cutlass::gemm::GemmCoord implicit_gemm_problem_size( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ // Compute problem size -+ switch (conv_operator) { -+ case Operator::kFprop: -+ return gemm::GemmCoord( -+ problem_size.N * problem_size.Z * problem_size.P * problem_size.Q, -+ problem_size.K, -+ problem_size.T * problem_size.R * problem_size.S * problem_size.C -+ ); -+ case Operator::kDgrad: -+ return gemm::GemmCoord( -+ problem_size.N * problem_size.D * problem_size.H * problem_size.W, -+ problem_size.C, -+ problem_size.T * problem_size.R * problem_size.S * problem_size.K -+ ); -+ case Operator::kWgrad: -+ return gemm::GemmCoord( -+ problem_size.K, -+ problem_size.T * problem_size.R * problem_size.S * problem_size.C, -+ problem_size.N * problem_size.Z * problem_size.P * problem_size.Q -+ ); -+ default: -+ break; -+ } -+ return gemm::GemmCoord(); -+} -+ -+// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm -+CUTLASS_HOST_DEVICE -+int implicit_gemm_k_iterations( -+ Operator conv_operator, -+ int threadblock_K, -+ Conv3dProblemSize const &problem_size, -+ IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, -+ GroupMode group_mode = GroupMode::kNone, -+ int threadblock_N = 0) { -+ -+ int iterations = 0; -+ int elements_per_split_k_slice = 0; -+ if (group_mode == GroupMode::kNone) { -+ switch (conv_operator) { -+ case Operator::kFprop: -+ elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ case Operator::kDgrad: -+ elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ case Operator::kWgrad: -+ elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; -+ break; -+ -+ default: -+ break; -+ } -+ } else if (group_mode == GroupMode::kDepthwise) { -+ int channels_per_cta = threadblock_N; -+ -+ if (algorithm == IteratorAlgorithm::kAnalytic) { -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = problem_size.T * problem_size.R * problem_size.S * -+ ((channels_per_cta + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ default: -+ break; -+ } -+ } -+ } -+ -+ return iterations; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output) -+//////////////////////////////////////////////////////////////////////////////// -+/// Returns ImplicitGemm tensor A extent as Tensor5DCoord -+CUTLASS_HOST_DEVICE -+cutlass::Tensor5DCoord implicit_gemm_tensor_a_extent( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); -+ default : break; -+ } -+ return cutlass::Tensor5DCoord(); -+} -+ -+/// Returns ImplicitGemm tensor B extent as Tensor5DCoord -+CUTLASS_HOST_DEVICE -+cutlass::Tensor5DCoord implicit_gemm_tensor_b_extent( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); -+ default : break; -+ } -+ return cutlass::Tensor5DCoord(); -+} -+ -+/// Returns ImplicitGemm tensor C extent as Tensor5DCoord -+CUTLASS_HOST_DEVICE -+cutlass::Tensor5DCoord implicit_gemm_tensor_c_extent( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); -+ default : break; -+ } -+ return cutlass::Tensor5DCoord(); -+} -+ -+/// Returns ImplicitGemm tensor A size in number of elements -+CUTLASS_HOST_DEVICE -+int64_t implicit_gemm_tensor_a_size( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); -+ default : break; -+ } -+ return 0; -+} -+ -+/// Returns ImplicitGemm tensor B size in number of elements -+CUTLASS_HOST_DEVICE -+int64_t implicit_gemm_tensor_b_size( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); -+ default : break; -+ } -+ return 0; -+} -+ -+/// Returns ImplicitGemm tensor C size in number of elements -+CUTLASS_HOST_DEVICE -+int64_t implicit_gemm_tensor_c_size( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.output_size(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); -+ default : break; -+ } -+ return 0; -+} -+ -+} // namespace conv -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/convolution.h b/3rdparty/cutlass/include/cutlass/conv/convolution.h -new file mode 100644 -index 0000000..0647edf ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/convolution.h -@@ -0,0 +1,167 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ -+This file contains definitions and utility functions for describing convolution problem sizes in terms of -+activation (NHWC), filter (KRSC), output (NPQK), pading (pad_h, pad_w), stride (stride_h, stride_w), -+dilation (dilation_h, dilation_w). Furthermore, it defines helper functions to map cutlass' implicit gemm -+tensor extents, sizes, data types to that of convolutions extents, sizes, and data types. -+ -+ * Mapping convolutions to Gemm computation * -+ -+Cutlass employs ImplicitGemm algorithm to implement convolutions. ImplicitGemm algorithm runs gemm operation -+on convolution tensors Activation, Filter, and Output . The underlying gemm operation follows the standard -+gemm definition: -+ -+ C = A * B + C -+ -+ A and B are input matrices -+ C is source and output matrix -+ -+ -+For the three convolutional operators (Fprop, Dgrad, Wgrad), ImplicitGemm matrices A, B, and C are mapped on -+to convolution tensors Activation, Filter and Output as per the below table: -+ -+ ___________________________________________________________________________ -+ ConvolutionalOperator | A | B | C -+ ___________________________________________________________________________ -+ | | | | | -+ | Fprop | Activation | Filter | Output | -+ | Dgrad | Output | Filter | Activation | -+ | Wgrad | Output | Activation | Filter | -+ ___________________________________________________________________________ -+ -+In convolution codebase, DO NOT mix using (A, B, C) with (Acvitation, Filter, Output). -+ -+For example, a convolution class/function with A, B, Output is confusing and error-prone. Instead use below -+mapping functions and adhere to using either A, B, C or Acvitation, Filter, Output. -+ -+Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap -+Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+ -+namespace cutlass { -+namespace conv { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Convolutional operator -+enum class Operator { -+ kFprop, -+ kDgrad, -+ kWgrad -+}; -+ -+/// Distinguishes convolution from cross correlation -+enum class Mode { -+ kCrossCorrelation, -+ kConvolution -+}; -+ -+/// Selects among several implementation variants trading off performance with simplicity -+enum class IteratorAlgorithm { -+ kAnalytic, ///< functionally correct in all cases but lower performance -+ kOptimized, ///< optimized for R <= 32, S <= 32 and unity-stride dgrad -+ kFixedChannels, ///< Analytic algorithm optimized for fixed channel count (C == AccessSize) -+ kFewChannels, ///< Analytic algorithm optimized for few channels (C divisible by AccessSize) -+ kFixedStrideDilation ///< Optimized for fixed stride and dilation -+}; -+ -+/// Distinguishes among partial specializations that accelerate certain problems where convolution -+/// stride is unit. -+enum class StrideSupport { -+ kStrided, ///< arbitrary convolution stride -+ kUnity, ///< unit convolution stride -+ kFixed ///< fixed convolution stride -+}; -+ -+/// Identifies split-K mode -+enum class SplitKMode { -+ kNone, -+ kSerial, -+ kParallel -+}; -+ -+/// Identifies group mode -+enum class GroupMode { -+ kNone, -+ kSingleGroup, ///< One CTA calculates one group or less -+ kMultipleGroup, ///< One CTA calculates multiple groups -+ kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups) -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Shape of a tensor -+template < -+ int N = 1, -+ int H = 1, -+ int W = 1, -+ int C = 1 -+> -+struct TensorNHWCShape { -+ static int const kN = N; -+ static int const kH = H; -+ static int const kW = W; -+ static int const kC = C; -+ -+ static int const kHW = H * W; -+ static int const kNHW = N * kHW; -+ static int const kNHWC = N * H * W * C; -+ -+ static int const kCount = kNHWC; -+ -+ // -+ // Static member functions -+ // -+ -+ /// Returns a Coord object -+ CUTLASS_HOST_DEVICE -+ static Coord<4> toCoord() { -+ return make_Coord(kN, kH, kW, kC); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace conv -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/device/direct_convolution.h b/3rdparty/cutlass/include/cutlass/conv/device/direct_convolution.h -new file mode 100644 -index 0000000..d7f28f1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/device/direct_convolution.h -@@ -0,0 +1,269 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Template for device-level Depthwise Convolution -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/device_kernel.h" -+#include "cutlass/conv/convolution.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class DirectConvolution { -+public: -+ -+ using UnderlyingKernel = DirectConvolutionKernel_; -+ -+ using ElementA = typename UnderlyingKernel::ElementA; -+ using LayoutA = typename UnderlyingKernel::LayoutA; -+ using ElementB = typename UnderlyingKernel::ElementB; -+ using LayoutB = typename UnderlyingKernel::LayoutB; -+ using ElementC = typename UnderlyingKernel::ElementC; -+ using LayoutC = typename UnderlyingKernel::LayoutC; -+ using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; -+ using ElementCompute = typename UnderlyingKernel::ElementCompute; -+ using OperatorClass = typename UnderlyingKernel::OperatorClass; -+ using ArchTag = typename UnderlyingKernel::ArchTag; -+ using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; -+ using WarpShape = typename UnderlyingKernel::WarpShape; -+ using InstructionShape = typename UnderlyingKernel::InstructionShape; -+ using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; -+ using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; -+ static int const kStages = UnderlyingKernel::kStages; -+ static int const kConvDim = UnderlyingKernel::kConvDim; -+ using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; -+ using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; -+ using MathOperator = typename UnderlyingKernel::MathOperator; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; -+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; -+ static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; -+ static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; -+ -+ static int const kWarpCount = -+ (ThreadblockShape::kM / WarpShape::kM) * -+ (ThreadblockShape::kN / WarpShape::kN) * -+ (ThreadblockShape::kK / WarpShape::kK); -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingKernel::Arguments; -+ -+ using ReorderKernel = typename UnderlyingKernel::ReorderKernel; -+ -+ private: -+ -+ /// Kernel parameters object -+ typename UnderlyingKernel::Params params_; -+ -+public: -+ -+ /// Constructs Implicit GEMM -+ DirectConvolution() { } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ // dispatch to iterators -+ Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ if (kGroupMode != conv::GroupMode::kDepthwise) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // C and K should be multiple of groups -+ if (args.problem_size.K != args.problem_size.groups && -+ args.problem_size.C != args.problem_size.groups) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ -+ static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; -+ if (kConvolutionalOperator == conv::Operator::kFprop) { -+ if (args.problem_size.K % kAlignmentC) -+ return Status::kErrorMisalignedOperand; -+ } else if (kConvolutionalOperator == conv::Operator::kDgrad) { -+ if (args.problem_size.C % kAlignmentC) -+ return Status::kErrorMisalignedOperand; -+ } else if (kConvolutionalOperator == conv::Operator::kWgrad) { -+ if (args.problem_size.C % kAlignmentC) -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape( -+ threadblock_swizzle.get_tiled_shape( -+ kConvolutionalOperator, -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices)); -+ -+ if (!(grid.y <= std::numeric_limits::max() && -+ grid.z <= std::numeric_limits::max())) { -+ -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ return 0; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ // initialize the params structure from the arguments -+ params_ = typename UnderlyingKernel::Params( -+ args, -+ static_cast(workspace) -+ ); -+ -+ int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ // update the params structure from the arguments -+ params_.ptr_A = args.ref_A.data(); -+ params_.ptr_B = args.ref_B.data(); -+ params_.ptr_C = args.ref_C.data(); -+ params_.ptr_D = args.ref_D.data(); -+ params_.output_op = args.output_op; -+ params_.ptr_reordered_B = args.ref_reordered_B.data();; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ // Launch reorder kernel -+ if (params_.ptr_reordered_B != nullptr) { -+ dim3 grid = ReorderKernel::get_grid_shape(params_); -+ dim3 block = ReorderKernel::get_block_shape(); -+ -+ cutlass::Kernel<<>>(params_); -+ } -+ -+ // Launch main kernel -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(32 * kWarpCount, 1, 1); -+ -+ // Dynamic SMEM size based on input params. -+ int smem_size = int(params_.get_smem_size()); -+ -+ // Make sure we can use that much shared memory. -+ cudaError_t status = -+ cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); -+ if (status != cudaSuccess) -+ return Status::kErrorInternal; -+ -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+ -+ int get_smem_size() { return int(params_.get_smem_size()); } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h b/3rdparty/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h -new file mode 100644 -index 0000000..50bdc47 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h -@@ -0,0 +1,328 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Template for device-level Implicit GEMM Convolution -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/device_kernel.h" -+#include "cutlass/conv/convolution.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class ImplicitGemmConvolution { -+public: -+ -+ using UnderlyingKernel = ImplicitGemmKernel_; -+ -+ using ElementA = typename UnderlyingKernel::ElementA; -+ using LayoutA = typename UnderlyingKernel::LayoutA; -+ using ElementB = typename UnderlyingKernel::ElementB; -+ using LayoutB = typename UnderlyingKernel::LayoutB; -+ using ElementC = typename UnderlyingKernel::ElementC; -+ using LayoutC = typename UnderlyingKernel::LayoutC; -+ using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; -+ using ElementCompute = typename UnderlyingKernel::ElementCompute; -+ using OperatorClass = typename UnderlyingKernel::OperatorClass; -+ using ArchTag = typename UnderlyingKernel::ArchTag; -+ using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; -+ using WarpShape = typename UnderlyingKernel::WarpShape; -+ using InstructionShape = typename UnderlyingKernel::InstructionShape; -+ using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; -+ using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; -+ static int const kStages = UnderlyingKernel::kStages; -+ static int const kConvDim = UnderlyingKernel::kConvDim; -+ using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; -+ using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; -+ using MathOperator = typename UnderlyingKernel::MathOperator; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; -+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; -+ static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; -+ static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; -+ -+ static int const kWarpCount = -+ (ThreadblockShape::kM / WarpShape::kM) * -+ (ThreadblockShape::kN / WarpShape::kN) * -+ (ThreadblockShape::kK / WarpShape::kK); -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingKernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename UnderlyingKernel::Params params_; -+ -+public: -+ -+ /// Constructs Implicit GEMM -+ ImplicitGemmConvolution() { } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ // dispatch to iterators -+ Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ // check group conv constraint -+ if (args.problem_size.groups != 1) { -+ if (kGroupMode == conv::GroupMode::kNone) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // C and K should be multiple of groups -+ if (args.problem_size.K % args.problem_size.groups || -+ args.problem_size.C % args.problem_size.groups) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // split-k is not supported -+ if (args.problem_size.split_k_slices != 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ int k_per_group = args.problem_size.K / args.problem_size.groups; -+ // k_per_group should be multiple of ThreadblockShape N, one CTA calculate one group -+ if (kGroupMode == conv::GroupMode::kSingleGroup && k_per_group % ThreadblockShape::kN) { -+ return Status::kErrorInvalidProblem; -+ } -+ // ThreadblockShape::kN should be divisible by k_per_group, one CTA calculate multiple groups -+ if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // current optimized iterator algo only supports SingleGroup mode -+ if (kIteratorAlgorithm == IteratorAlgorithm::kOptimized && -+ kGroupMode != conv::GroupMode::kSingleGroup) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; -+ if (kConvolutionalOperator == conv::Operator::kFprop) { -+ if (args.problem_size.K % kAlignmentC) -+ return Status::kErrorMisalignedOperand; -+ } else if (kConvolutionalOperator == conv::Operator::kDgrad) { -+ if (args.problem_size.C % kAlignmentC) -+ return Status::kErrorMisalignedOperand; -+ } else if (kConvolutionalOperator == conv::Operator::kWgrad) { -+ if (args.problem_size.C % kAlignmentC) -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ // check for unsupported problem sizes for strided dgrad implementation -+ if (kConvolutionalOperator == conv::Operator::kDgrad && -+ kStrideSupport == conv::StrideSupport::kStrided) { -+ -+ // split-k (serial or parallel) is not supported for strided dgrad -+ if(args.problem_size.split_k_slices > 1) { -+ return Status::kErrorNotSupported; -+ } -+ -+ // dilation > {1x1} is not supported for strided dgrad -+ if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape( -+ threadblock_swizzle.get_tiled_shape( -+ kConvolutionalOperator, -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices)); -+ -+ if (!(grid.y <= std::numeric_limits::max() && -+ grid.z <= std::numeric_limits::max())) { -+ -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t workspace_bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ kConvolutionalOperator, -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ -+ if(args.split_k_mode == SplitKMode::kParallel) { -+ -+ // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. -+ // The user needs to call a reduction operator to optain the final output tensor -+ workspace_bytes = -+ sizeof(ElementAccumulator) * -+ size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) * -+ size_t(grid_tiled_shape.k()); -+ } -+ -+ else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) { -+ -+ // Split-K serial: The user workspace is used to store semaphore and serialize writing the -+ // final reduced output to user's output tensor -+ workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); -+ } -+ -+ return workspace_bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ if (args.problem_size.split_k_slices > 1) { -+ -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); -+ -+ if (status != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // initialize the params structure from the arguments -+ params_ = typename UnderlyingKernel::Params( -+ args, -+ static_cast(workspace) -+ ); -+ -+ int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ // update the params structure from the arguments -+ params_.ptr_A = args.ref_A.data(); -+ params_.ptr_B = args.ref_B.data(); -+ params_.ptr_C = args.ref_C.data(); -+ params_.ptr_D = args.ref_D.data(); -+ params_.output_op = args.output_op; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(32 * kWarpCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h b/3rdparty/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h -new file mode 100644 -index 0000000..2f434bd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h -@@ -0,0 +1,268 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Template for device-level fused activation's scale+bias+relu and Implicit GEMM Convolution -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/device_kernel.h" -+#include "cutlass/conv/convolution.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class ImplicitGemmConvolutionFusion { -+public: -+ -+ using ImplicitGemmFusionKernel = ImplicitGemmFusionKernel_; -+ -+ using ElementA = typename ImplicitGemmFusionKernel::ElementA; -+ using LayoutA = typename ImplicitGemmFusionKernel::LayoutA; -+ using ElementB = typename ImplicitGemmFusionKernel::ElementB; -+ using LayoutB = typename ImplicitGemmFusionKernel::LayoutB; -+ -+// using ElementScaleBias = typename ImplicitGemmFusionKernel::ElementScaleBias; -+// using LayoutScaleBias = typename ImplicitGemmFusionKernel::LayoutScaleBias; -+ -+ using ElementC = typename ImplicitGemmFusionKernel::ElementC; -+ using LayoutC = typename ImplicitGemmFusionKernel::LayoutC; -+ using ElementAccumulator = typename ImplicitGemmFusionKernel::ElementAccumulator; -+ using ElementCompute = typename ImplicitGemmFusionKernel::ElementCompute; -+ using OperatorClass = typename ImplicitGemmFusionKernel::OperatorClass; -+ using ArchTag = typename ImplicitGemmFusionKernel::ArchTag; -+ using ThreadblockShape = typename ImplicitGemmFusionKernel::ThreadblockShape; -+ using WarpShape = typename ImplicitGemmFusionKernel::WarpShape; -+ using InstructionShape = typename ImplicitGemmFusionKernel::InstructionShape; -+ using ThreadblockSwizzle = typename ImplicitGemmFusionKernel::ThreadblockSwizzle; -+ using EpilogueOutputOp = typename ImplicitGemmFusionKernel::EpilogueOutputOp; -+ static int const kStages = ImplicitGemmFusionKernel::kStages; -+ static int const kConvDim = ImplicitGemmFusionKernel::kConvDim; -+ using WarpMmaOperator = typename ImplicitGemmFusionKernel::WarpMmaOperator; -+ using ArchMmaOperator = typename ImplicitGemmFusionKernel::ArchMmaOperator; -+ using MathOperator = typename ImplicitGemmFusionKernel::MathOperator; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmFusionKernel::kConvolutionalOperator; -+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmFusionKernel::kIteratorAlgorithm; -+ -+ static int const kWarpCount = -+ (ThreadblockShape::kM / WarpShape::kM) * -+ (ThreadblockShape::kN / WarpShape::kN) * -+ (ThreadblockShape::kK / WarpShape::kK); -+ -+ /// Argument structure -+ using Arguments = typename ImplicitGemmFusionKernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename ImplicitGemmFusionKernel::Params params_; -+ -+public: -+ -+ /// Constructs Implicit GEMM -+ ImplicitGemmConvolutionFusion() { } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ // dispatch to iterators -+ Status status = ImplicitGemmFusionKernel::Mma::IteratorA::can_implement(args.problem_size); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ status = ImplicitGemmFusionKernel::Mma::IteratorB::can_implement(args.problem_size); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape( -+ threadblock_swizzle.get_tiled_shape( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size), -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices)); -+ -+ if (!(grid.y <= std::numeric_limits::max() && -+ grid.z <= std::numeric_limits::max())) { -+ -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t workspace_bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size), -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ -+ if(args.split_k_mode == SplitKMode::kParallel) { -+ -+ // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. -+ // The user needs to call a reduction operator to optain the final output tensor -+ workspace_bytes = -+ sizeof(ElementAccumulator) * -+ size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) * -+ size_t(grid_tiled_shape.k()); -+ } -+ -+ else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) { -+ -+ // Split-K serial: The user workspace is used to store semaphore and serialize writing the -+ // final reduced output to user's output tensor -+ workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); -+ } -+ -+ return workspace_bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ if (args.problem_size.split_k_slices > 1) { -+ -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); -+ -+ if (status != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // initialize the params structure from the arguments -+ params_ = typename ImplicitGemmFusionKernel::Params( -+ args, -+ static_cast(workspace) -+ ); -+ -+ int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Initializes Impicit GEMM state from arguments. -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ // update the params structure from the arguments -+ params_.ptr_A = args.ref_A.data(); -+ params_.ptr_B = args.ref_B.data(); -+ params_.ptr_scale = args.ref_A_scale.data(); -+ params_.ptr_bias = args.ref_A_bias.data(); -+ params_.ptr_C = args.ref_C.data(); -+ params_.ptr_D = args.ref_D.data(); -+ params_.output_op = args.output_op; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(32 * kWarpCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d.h -new file mode 100644 -index 0000000..cb7980b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d.h -@@ -0,0 +1,272 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions for threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/conv/threadblock/threadblock_swizzle.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" -+#include "cutlass/conv/threadblock/implicit_gemm_pipelined.h" -+#include "cutlass/conv/threadblock/implicit_gemm_multistage.h" -+#include "cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h" -+#include "cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h" -+#include "cutlass/conv/kernel/implicit_gemm_convolution.h" -+#include "cutlass/conv/kernel/implicit_gemm_convolution_fusion.h" -+#include "cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename ArchTag, -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename OutputOp -+> -+struct DefaultConvEpilogue { -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ OutputOp::kCount -+ >::Epilogue; -+}; -+ -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename OutputOp -+> -+struct DefaultConvEpilogue< -+ arch::Sm70, -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp -+> { -+ -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ OutputOp::kCount -+ >::Epilogue; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ArchTag, -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename ElementTensor, -+ typename ElementVector, -+ typename OutputOp, -+ int ElementsPerAccess -+> -+struct DefaultConvEpilogueWithBroadcastTensorOp { -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ ElementOutput, -+ ElementTensor, -+ ElementVector, -+ OutputOp, -+ ElementsPerAccess -+ >::Epilogue; -+}; -+ -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename ElementTensor, -+ typename ElementVector, -+ typename OutputOp, -+ int ElementsPerAccess -+> -+struct DefaultConvEpilogueWithBroadcastTensorOp< -+ arch::Sm70, -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ ElementOutput, -+ ElementTensor, -+ ElementVector, -+ OutputOp, -+ ElementsPerAccess -+ > { -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ ElementOutput, -+ ElementTensor, -+ ElementVector, -+ OutputOp, -+ ElementsPerAccess -+ >::Epilogue; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ArchTag, -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename OutputOp, -+ typename ReductionOp, -+ int ElementsPerAccess -+> -+struct DefaultConvEpilogueWithReductionTensorOp { -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ ElementsPerAccess -+ >::Epilogue; -+}; -+ -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename OutputOp, -+ typename ReductionOp, -+ int ElementsPerAccess -+> -+struct DefaultConvEpilogueWithReductionTensorOp< -+ arch::Sm70, -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ ElementsPerAccess -+ > { -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ ElementsPerAccess -+ >::Epilogue; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Defaults for strided Dgrad -+template < -+ typename ArchTag, -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename OutputOp -+> -+struct DefaultConvEpilogueStridedDgrad { -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ OutputOp::kCount -+ >::Epilogue; -+}; -+ -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename OutputOp -+> -+struct DefaultConvEpilogueStridedDgrad< -+ arch::Sm70, -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp -+> { -+ -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOpStridedDgrad< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ OutputOp::kCount -+ >::Epilogue; -+}; -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h -new file mode 100644 -index 0000000..6a54120 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h -@@ -0,0 +1,1927 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dDgrad -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> struct DefaultConv2dDgrad; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided and -+// multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kStrided, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kStrided, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided -+// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kStrided, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kStrided, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity Strided -+// and multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kUnity, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity -+// 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kUnity, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dDgrad specialization for optimized IteratorAlgorithm Dgrad Unity Strided -+// and multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kUnity, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided and -+// multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kStrided, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kStrided, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided -+// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kStrided, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kStrided, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Unity -+// 2 stage pipeline -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kUnity, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassSimt convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm, -+/// multi-stage pipeline, and FFMA-based mainloop for SM80 -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ conv::StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ conv::StrideSupport::kUnity -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ conv::StrideSupport::kUnity -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ conv::StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ conv::StrideSupport::kStrided -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ conv::StrideSupport::kStrided -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm, -+/// multi-stage pipeline, and FFMA-based mainloop for SM80 -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kUnity -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ conv::StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ conv::StrideSupport::kStrided -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ conv::StrideSupport::kStrided -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm, -+/// 2 stage pipeline, and FFMA-based mainloop for SM50 -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ conv::StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ conv::StrideSupport::kUnity -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ conv::StrideSupport::kUnity -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ conv::StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ conv::StrideSupport::kStrided -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ conv::StrideSupport::kStrided -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm, -+/// 2 stage pipeline, and FFMA-based mainloop for SM50 -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kUnity -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ conv::StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ conv::StrideSupport::kStrided -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ conv::StrideSupport::kStrided -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+ -+}; -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h -new file mode 100644 -index 0000000..3e16d17 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h -@@ -0,0 +1,1989 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dFprop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> struct DefaultConv2dFprop; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kFixedChannels, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and two stage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kFixedChannels, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kFewChannels, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kFewChannels, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB, -+ int InterleavedK -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA = typename MmaCore::SmemThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB = typename MmaCore::SmemThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm -+/// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage -+/// pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB, -+ int InterleavedK -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA = typename MmaCore::SmemThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB = typename MmaCore::SmemThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and -+/// multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag -+ >; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and -+// multistage pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB, -+ int InterleavedK -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true -+ >; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::SmemThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ layout::TensorNCxHWx, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::SmemThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ layout::TensorCxRSKx, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm -+/// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage -+/// pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB, -+ int InterleavedK -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::SmemThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::SmemThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassSimt convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, -+/// multi-stage pipeline, and FFMA-based mainloop for SM80 -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm, -+/// multi-stage pipeline, and FFMA-based mainloop for SM80 -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, -+/// 2 stage pipeline, and FFMA-based mainloop for SM50 -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm, -+/// 2 stage pipeline, and FFMA-based mainloop for SM50 -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h -new file mode 100644 -index 0000000..da48878 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h -@@ -0,0 +1,357 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution -+ definitions that combine threadblock-scoped matrix multiply-add with the -+ appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" -+#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for fused batch norm and Conv2dFprop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided -+> struct DefaultConv2dFpropFusion; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv2dFpropFusion < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorScaleBias = -+ cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ using SmemIteratorScaleBias = -+ cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static int const kThreadCount = 32; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< -+ MatrixShape, ElementScaleBias, -+ LayoutScaleBias, MatrixShape, -+ typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, -+ MmaCore::WarpCount::kK>; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmFpropFusionMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ IteratorScaleBias, -+ SmemIteratorScaleBias, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ WarpIteratorScaleBias, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and -+/// multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv2dFpropFusion < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag -+ >; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorScaleBias = -+ cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ using SmemIteratorScaleBias = -+ cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static int const kThreadCount = 32; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< -+ MatrixShape, ElementScaleBias, -+ LayoutScaleBias, MatrixShape, -+ typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, -+ MmaCore::WarpCount::kK>; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmFpropFusionMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ IteratorScaleBias, -+ SmemIteratorScaleBias, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ WarpIteratorScaleBias, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h -new file mode 100644 -index 0000000..d744ae8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Defines a GEMM with Reduction based on an existing UniversalGemm kernel. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> -+struct DefaultConv2dFpropWithBroadcast { -+ -+ using ImplicitGemmBase = typename DefaultConv2dFprop< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+ >::Kernel; -+ -+ // Replace epilogue -+ using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< -+ ArchTag, -+ typename ImplicitGemmBase::Epilogue::Shape, -+ typename ImplicitGemmBase::Epilogue::WarpMmaOperator, -+ ImplicitGemmBase::Epilogue::kPartitionsK, -+ ElementC, -+ typename EpilogueOutputOp::ElementT, -+ ElementC, -+ EpilogueOutputOp, -+ ImplicitGemmBase::Epilogue::kElementsPerAccess -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< -+ typename ImplicitGemmBase::Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h -new file mode 100644 -index 0000000..00b8c90 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Defines a GEMM with Reduction based on an existing UniversalGemm kernel. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename EpilogueReductionOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> -+struct DefaultConv2dFpropWithReduction { -+ -+ using ImplicitGemmBase = typename DefaultConv2dFprop< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+ >::Kernel; -+ -+ // Replace epilogue -+ using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithReductionTensorOp< -+ ArchTag, -+ typename ImplicitGemmBase::Epilogue::Shape, -+ typename ImplicitGemmBase::Epilogue::WarpMmaOperator, -+ ImplicitGemmBase::Epilogue::kPartitionsK, -+ ElementC, -+ EpilogueOutputOp, -+ EpilogueReductionOp, -+ ImplicitGemmBase::Epilogue::kElementsPerAccess -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< -+ typename ImplicitGemmBase::Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h -new file mode 100644 -index 0000000..cdd89e0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h -@@ -0,0 +1,490 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dGroupFpro -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::GroupMode GroupMode, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> struct DefaultConv2dGroupFprop; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline that supports all GroupMode. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::GroupMode GroupMode, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dGroupFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ GroupMode, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA, -+ GroupMode -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB, -+ GroupMode -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv2dProblemSize, -+ GroupMode -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and multistage -+/// pipeline that supports GroupMode::kSingleGroup. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dGroupFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ GroupMode::kSingleGroup, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv2dProblemSize, -+ GroupMode::kSingleGroup -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and -+/// 2 stage pipeline that supports GroupMode::kSingleGroup. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dGroupFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ GroupMode::kSingleGroup, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv2dProblemSize, -+ GroupMode::kSingleGroup -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h -new file mode 100644 -index 0000000..099bb6c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h -@@ -0,0 +1,1011 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> struct DefaultConv2dWgrad; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and two -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and two -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassSimt convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm, -+/// multi-stage pipeline, and FFMA-based mainloop for SM80 -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AccessTypeA, -+ int AccessTypeB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AccessTypeA, -+ AccessTypeB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm, -+/// multi-stage pipeline, and FFMA-based mainloop for SM80 -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AccessTypeA, -+ int AccessTypeB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AccessTypeA, -+ AccessTypeB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm, -+/// 2 stage pipeline, and FFMA-based mainloop for SM50 -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AccessTypeA, -+ int AccessTypeB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AccessTypeA, -+ AccessTypeB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm, -+/// 2 stage pipeline, and FFMA-based mainloop for SM50 -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AccessTypeA, -+ int AccessTypeB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AccessTypeA, -+ AccessTypeB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h -new file mode 100644 -index 0000000..62bf177 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h -@@ -0,0 +1,325 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" -+#include "cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided -+> struct DefaultConv2dWgradFusion; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv2dWgradFusion < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorScaleBias = -+ cutlass::conv::threadblock::PredicatedScaleBiasVectorIterator< -+ cutlass::MatrixShape<1, WarpShape::kN>, -+ ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmWgradFusionMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ IteratorScaleBias, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv2dWgradFusion < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorScaleBias = -+ cutlass::conv::threadblock::PredicatedScaleBiasVectorIterator< -+ cutlass::MatrixShape<1, WarpShape::kN>, -+ ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmWgradFusionMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ IteratorScaleBias, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h -new file mode 100644 -index 0000000..01e895c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h -@@ -0,0 +1,303 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h" -+ -+#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv3dDgrad -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided -+> struct DefaultConv3dDgrad; -+ -+/// Defines a kernel for Conv3dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided -+// and multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport::kStrided -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kStrided -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad, -+ Conv3dProblemSize -+ >; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided -+// and multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kUnity -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad, -+ Conv3dProblemSize -+ >; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h -new file mode 100644 -index 0000000..9c8f8cf ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h -@@ -0,0 +1,515 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" -+ -+ -+#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dFprop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided -+> struct DefaultConv3dFprop; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dFprop specialization for Analytic Iterator Algorithm -+/// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultConv3dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dFprop specialization for Optimized Iterator Algorithm -+/// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultConv3dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h -new file mode 100644 -index 0000000..66cbbcd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h -@@ -0,0 +1,360 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution -+ definitions that combine threadblock-scoped matrix multiply-add with the -+ appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" -+#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for fused batch norm and Conv3dFprop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided -+> struct DefaultConv3dFpropFusion; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dFprop specialzation for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dFpropFusion < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorScaleBias = -+ cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ using SmemIteratorScaleBias = -+ cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static int const kThreadCount = 32; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< -+ MatrixShape, ElementScaleBias, -+ LayoutScaleBias, MatrixShape, -+ typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, -+ MmaCore::WarpCount::kK>; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmFpropFusionMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ IteratorScaleBias, -+ SmemIteratorScaleBias, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ WarpIteratorScaleBias, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dFprop specialzation for Optimzed IteratorAlgorithm and -+/// multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dFpropFusion < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag -+ >; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorScaleBias = -+ cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ using SmemIteratorScaleBias = -+ cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static int const kThreadCount = 32; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< -+ MatrixShape, ElementScaleBias, -+ LayoutScaleBias, MatrixShape, -+ typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, -+ MmaCore::WarpCount::kK>; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmFpropFusionMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ IteratorScaleBias, -+ SmemIteratorScaleBias, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ WarpIteratorScaleBias, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h -new file mode 100644 -index 0000000..3807911 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h -@@ -0,0 +1,509 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided -+> struct DefaultConv3dWgrad; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm and two -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultConv3dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm and two -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultConv3dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad, -+ Conv3dProblemSize -+ >; -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h -new file mode 100644 -index 0000000..df57e30 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h -@@ -0,0 +1,588 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level Depthwise implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+#include "cutlass/conv/kernel/direct_convolution.h" -+ -+#include "cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/depthwise_fprop_pipelined.h" -+ -+// Direct Conv Related Header files -+#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h" -+#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h" -+ -+#include "cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h" -+#include "cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for DepthwiseFprop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value -+> struct DefaultDepthwiseFprop; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for DepthwiseFprop with direct convolution algorithm -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename ThreadBlockOutputShape, -+ typename FilterShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ // MatrixShape -+ typename StrideShape = cutlass::MatrixShape<-1, -1>, -+ // MatrixShape< Height, Width> -+ typename DilationShape = cutlass::MatrixShape<-1, -1>, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> struct DefaultDepthwiseDirect2dConvFprop; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassSimt convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultDepthwiseFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, // cutlass::arch::OpMultiplyAdd -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::conv::threadblock::DepthwiseMmaCoreWithLaneAccessSize< -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ layout::RowMajor, -+ ElementB, -+ layout::ColumnMajor, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassSimt, -+ 128, -+ sizeof_bits::value, -+ 2, -+ MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB, -+ cutlass::conv::GroupMode::kDepthwise -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::DepthwiseFpropPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv2dProblemSize, -+ cutlass::conv::GroupMode::kDepthwise -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, -+/// multiple stage pipeline, and SIMT-based mainloop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename ThreadBlockOutputShape, -+ typename FilterShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ typename StrideShape, -+ typename DilationShape, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultDepthwiseDirect2dConvFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ StrideShape, -+ DilationShape, -+ AlignmentA, -+ AlignmentB -+> { -+ // One warp handles the entrie groups per cta. -+ static_assert(ThreadblockShape::kN == WarpShape::kN, -+ "ThreadblockShape::kN should be same as WarpShape::kN "); -+ static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, -+ "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); -+ static_assert(ThreadblockShape::kM % WarpShape::kM == 0, -+ "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); -+ static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ layout::RowMajor, -+ ElementB, -+ layout::ColumnMajor, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassSimt, -+ 128, -+ 128, -+ Stages, -+ MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized< -+ cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> -+ ThreadBlockOutputShape, -+ ElementA, LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ using ThreadOutputShape = typename MmaCore::ThreadOutputShape; -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * AlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< -+ ThreadblockShape, // < outputShape:KMNK, groups per cta> -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ ThreadOutputShape, -+ ThreadBlockOutputShape -+ >::Epilogue; -+ -+ // Define the Mma -+ using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ CacheOpA, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages, -+ Epilogue -+ >; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::DirectConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv2dProblemSize, -+ cutlass::conv::GroupMode::kDepthwise, -+ ThreadBlockOutputShape -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, -+/// multiple stage pipeline, and SIMT-based mainloop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename ThreadBlockOutputShape, -+ typename FilterShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ typename StrideShape, -+ typename DilationShape, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultDepthwiseDirect2dConvFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kFixedStrideDilation, -+ StrideSupport, -+ StrideShape, -+ DilationShape, -+ AlignmentA, -+ AlignmentB, -+> { -+ -+ -+ -+ // One warp handles the entrie groups per cta. -+ static_assert(ThreadblockShape::kN == WarpShape::kN, -+ "ThreadblockShape::kN should be same as WarpShape::kN "); -+ static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, -+ "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); -+ static_assert(ThreadblockShape::kM % WarpShape::kM == 0, -+ "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); -+ static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); -+ -+ static_assert(StrideShape::kRow >= 0 && StrideShape::kColumn >= 0, "Stride should be fixed"); -+ static_assert(DilationShape::kRow >= 0 && DilationShape::kColumn >= 0, "Stride should be fixed"); -+ -+ // Activations loaded by threadblock -+ static int const ActivationShapeH = (ThreadBlockOutputShape::kH - 1) * StrideShape::kRow + -+ (FilterShape::kRow - 1) * DilationShape::kRow + 1; -+ -+ static int const ActivationShapeW = (ThreadBlockOutputShape::kW - 1) * StrideShape::kColumn + -+ (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; -+ -+ using ActivationShape = -+ cutlass::conv::TensorNHWCShape<1, ActivationShapeH, ActivationShapeW, ThreadblockShape::kN >; -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ layout::RowMajor, -+ ElementB, -+ layout::ColumnMajor, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassSimt, -+ 128, -+ 128, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kFixedStrideDilation, -+ StrideShape, -+ DilationShape, -+ ActivationShape>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation< -+ cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> -+ ThreadBlockOutputShape, -+ StrideShape, -+ DilationShape, -+ ActivationShape, -+ ElementA, LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ using ThreadOutputShape = typename MmaCore::ThreadOutputShape; -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * AlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< -+ ThreadblockShape, // < outputShape:KMNK, groups per cta> -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ ThreadOutputShape, -+ ThreadBlockOutputShape -+ >::Epilogue; -+ -+ // Define the Mma -+ using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ CacheOpA, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages, -+ Epilogue, -+ IteratorAlgorithm::kFixedStrideDilation -+ >; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::DirectConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv2dProblemSize, -+ cutlass::conv::GroupMode::kDepthwise, -+ ThreadBlockOutputShape -+ >; -+}; -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/direct_convolution.h b/3rdparty/cutlass/include/cutlass/conv/kernel/direct_convolution.h -new file mode 100644 -index 0000000..ef7a920 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/direct_convolution.h -@@ -0,0 +1,505 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multi-staged Depthwise Convolution kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure -+template > ///! OutputShape per ThreadBlock -+struct DirectConvolutionParams { -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ using ConvProblemSize = ConvProblemSize_; -+ using Arguments = Arguments_; -+ using ConvOutputIteratorParameter = ConvOutputIteratorParameter_; -+ -+ using ThreadblockShape = typename Mma::Shape; -+ static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; -+ static conv::GroupMode const kGroupMode = GroupMode_; -+ static int const kStages = Mma::kStages; -+ -+ ConvProblemSize problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ gemm::GemmCoord implicit_gemm_problem_size; -+ int swizzle_log_tile; -+ int smem_size_; -+ -+ int gemm_k_iterations; -+ int gemm_k_iterations_per_channel; -+ typename Mma::IteratorA::Params iterator_A; -+ typename Mma::IteratorA::Element const *ptr_A; -+ typename Mma::IteratorB::Params iterator_B; -+ typename Mma::IteratorB::Element const *ptr_B; -+ typename Mma::IteratorB::Element *ptr_reordered_B; -+ typename Epilogue::OutputTileIterator::Params iterator_C; -+ typename Epilogue::OutputTileIterator::Element *ptr_C; -+ typename Epilogue::OutputTileIterator::Params iterator_D; -+ typename Epilogue::OutputTileIterator::Element *ptr_D; -+ typename EpilogueOutputOp::Params output_op; -+ int *semaphore; -+ SplitKMode split_k_mode; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ DirectConvolutionParams() : swizzle_log_tile(0), gemm_k_iterations(0) {} -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ DirectConvolutionParams(Arguments const &args, int *semaphore = nullptr) -+ : problem_size(args.problem_size), -+ implicit_gemm_problem_size( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), -+ iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), -+ ptr_A(args.ref_A.data()), -+ iterator_B(Mma::IteratorB::getParams(args.problem_size, args.ref_B.layout())), -+ ptr_B(args.ref_B.data()), -+ ptr_reordered_B(args.ref_reordered_B.data()), -+ iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size), -+ ptr_C(args.ref_C.data()), -+ iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size), -+ ptr_D(args.ref_D.data()), -+ output_op(args.output_op), -+ semaphore(semaphore), -+ split_k_mode(args.split_k_mode), -+ split_k_slices(args.problem_size.split_k_slices) { -+ gemm_k_iterations = -+ depthwise_gemm_k_iterations(kConvolutionalOperator, -+ ThreadblockShape::kK, -+ args.problem_size, -+ kIteratorAlgorithm, -+ kGroupMode, -+ ThreadblockShape::kN); -+ -+ gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( -+ kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, kIteratorAlgorithm); -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ kConvolutionalOperator, -+ problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ -+ swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); -+ -+ // Dynamic SMEM usage because stride and dilation are runtime params. -+ smem_size_ = (iterator_A.activation_size * kStages + iterator_B.filter_size); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_smem_size() { -+ // Dynamic Smem Size -+ return smem_size_; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+struct ReorderKernel { -+ using Params = Params_; -+ using ElementB = ElementB_; -+ -+ union SharedStorage {}; -+ -+ static unsigned int const kReorderKernelThreadPerCTA = 128; -+ -+ CUTLASS_HOST_DEVICE -+ ReorderKernel() {} -+ -+ CUTLASS_HOST_DEVICE -+ static dim3 get_grid_shape(Params const ¶ms) { -+ return dim3{static_cast( -+ (params.problem_size.filter_size() + kReorderKernelThreadPerCTA - 1) / -+ kReorderKernelThreadPerCTA), -+ 1, -+ 1}; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static dim3 get_block_shape() { return dim3{kReorderKernelThreadPerCTA, 1, 1}; } -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ int64_t m = static_cast(params.problem_size.groups); -+ int64_t n = static_cast(params.problem_size.filter_size() / params.problem_size.K); -+ const ElementB *src_with_type = static_cast(params.ptr_B); -+ ElementB *dst_with_type = static_cast(params.ptr_reordered_B); -+ -+ int64_t linear_index = blockIdx.x * kReorderKernelThreadPerCTA + threadIdx.x; -+ int64_t index_m = linear_index / n; -+ int64_t index_n = linear_index % n; -+ int64_t new_linear_index = index_m + index_n * m; -+ -+ if (linear_index < m * n) { -+ dst_with_type[new_linear_index] = src_with_type[linear_index]; -+ } -+ return; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem -+ conv::GroupMode GroupMode_ = conv::GroupMode::kNone, ///! Group mode -+ typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> -+> -+struct DirectConvolution { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename EpilogueOutputOp::ElementOutput; -+ -+ /// Set output tensor C layout -+ using LayoutC = LayoutA; -+ -+ using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using WarpMmaOperator = typename Mma::Policy::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename WarpMmaOperator::Shape; -+ using InstructionShape = typename cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ static int const kStages = Mma::kStages; -+ static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; -+ static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using TensorRefA = typename Mma::IteratorA::TensorRef; -+ using TensorRefB = typename Mma::IteratorB::TensorRef; -+ using TensorRefC = cutlass::TensorRef; -+ -+ /// Check iterator A and B convolution dimension are the same and -+ // set device::ImplicitGemmConvolution::kConvDim -+ static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, -+ "Convolution on different different dimensions is not supported"); -+ static int const kConvDim = Mma::IteratorA::kConvDim; -+ -+ /// Conv dimension and problem size structure (Conv2d or Conv3d) -+ using ConvProblemSize = ConvProblemSize_; -+ -+ static conv::GroupMode const kGroupMode = GroupMode_; -+ -+ -+ // -+ // -+ // -+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< -+ LayoutC, -+ typename Epilogue::OutputTileIterator::Layout, -+ TensorRefC, -+ ConvOperator, -+ ConvProblemSize -+ >; -+ -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ ConvProblemSize problem_size; -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ TensorRefB ref_reordered_B; -+ TensorRefC ref_C; -+ TensorRefC ref_D; -+ typename EpilogueOutputOp::Params output_op; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size -+ ): -+ problem_size(problem_size) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size, -+ TensorRefA const & ref_A, -+ TensorRefB const & ref_B, -+ TensorRefC const & ref_C, -+ TensorRefC const & ref_D, -+ typename EpilogueOutputOp::Params const & output_op, -+ TensorRefB const & ref_reordered_B = nullptr, -+ SplitKMode const & split_k_mode = SplitKMode::kSerial -+ ): -+ problem_size(problem_size), -+ ref_A(ref_A), -+ ref_B(ref_B), -+ ref_C(ref_C), -+ ref_D(ref_D), -+ output_op(output_op), -+ ref_reordered_B(ref_reordered_B), -+ split_k_mode(split_k_mode) -+ { -+ -+ } -+ -+ }; -+ -+ using Params = -+ typename cutlass::conv::kernel::DirectConvolutionParams; -+ -+ using ReorderKernel = typename cutlass::conv::kernel::ReorderKernel; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ DirectConvolution() { } -+ -+ /// Executes one ImplicitGEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if threadblock is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { -+ -+ return; -+ } -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ int iterator_column_offset = 0; -+ int filter_row_offset = 0; -+ if (kGroupMode != GroupMode::kNone) { -+ if (kGroupMode == GroupMode::kDepthwise) { -+ iterator_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; -+ } -+ } -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.iterator_A, -+ params.problem_size, -+ params.ptr_A, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.m() + threadblock_tile_idx.k(), -+ iterator_column_offset -+ ) -+ ); -+ -+ typename Mma::IteratorB iterator_B( -+ params.iterator_B, -+ params.problem_size, -+ params.ptr_reordered_B, -+ thread_idx, -+ MatrixCoord( -+ filter_row_offset, -+ iterator_column_offset -+ ) -+ ); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // Compute logical position within grid -+ threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ -+ MatrixCoord threadblock_offset( -+ threadblock_tile_idx.m() + threadblock_tile_idx.k(), -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to destination tensor -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.iterator_D, -+ params.ptr_D, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator reading from source accumulator tensor -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.iterator_C, -+ params.ptr_C, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ -+ // Compute threadblock-scoped matrix multiply-add -+ // Epilogue is fused in the mainloop -+ mma(params.gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ params.iterator_A, -+ iterator_B, -+ params.iterator_B, -+ accumulators, -+ epilogue, -+ output_op, -+ iterator_D, -+ iterator_C, -+ params.split_k_slices); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h -new file mode 100644 -index 0000000..11ac967 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h -@@ -0,0 +1,456 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined Implicit GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem -+ conv::GroupMode GroupMode_ = conv::GroupMode::kNone ///! Group mode -+> -+struct ImplicitGemmConvolution { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename EpilogueOutputOp::ElementOutput; -+ -+ /// Set output tensor C layout -+ using LayoutC = LayoutA; -+ -+ using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using WarpMmaOperator = typename Mma::Policy::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename WarpMmaOperator::Shape; -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ static int const kStages = Mma::kStages; -+ static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; -+ static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using TensorRefA = typename Mma::IteratorA::TensorRef; -+ using TensorRefB = typename Mma::IteratorB::TensorRef; -+ using TensorRefC = cutlass::TensorRef; -+ -+ /// Check iterator A and B convolution dimension are the same and -+ // set device::ImplicitGemmConvolution::kConvDim -+ static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, -+ "Convolution on different different dimensions is not supported"); -+ static int const kConvDim = Mma::IteratorA::kConvDim; -+ -+ /// Conv dimension and problem size structure (Conv2d or Conv3d) -+ using ConvProblemSize = ConvProblemSize_; -+ -+ static conv::GroupMode const kGroupMode = GroupMode_; -+ -+ /// Wgrad C stride idx for implicit gemm algorithm -+ // Conv2d row-major matrix C (KxRSC) -+ // Conv3d row-major matrix C (KxTRSC) -+ static int const kWgradCStrideIdx = -+ platform::is_same::value ? 2 : 3; -+ -+ /// This chooses the appropriate stride element of the C tensor. -+ static int const kTensorCStrideIdx = -+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); -+ -+ // -+ // -+ // -+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< -+ LayoutC, -+ typename Epilogue::OutputTileIterator::Layout, -+ TensorRefC, -+ ConvOperator, -+ ConvProblemSize -+ >; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ ConvProblemSize problem_size; -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ TensorRefC ref_C; -+ TensorRefC ref_D; -+ typename EpilogueOutputOp::Params output_op; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size -+ ): -+ problem_size(problem_size) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size, -+ TensorRefA const & ref_A, -+ TensorRefB const & ref_B, -+ TensorRefC const & ref_C, -+ TensorRefC const & ref_D, -+ typename EpilogueOutputOp::Params const & output_op, -+ SplitKMode const & split_k_mode = SplitKMode::kSerial -+ ): -+ problem_size(problem_size), -+ ref_A(ref_A), -+ ref_B(ref_B), -+ ref_C(ref_C), -+ ref_D(ref_D), -+ output_op(output_op), -+ split_k_mode(split_k_mode) -+ { -+ -+ } -+ -+ }; -+ -+ /// Parameters structure -+ struct Params { -+ ConvProblemSize problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ gemm::GemmCoord implicit_gemm_problem_size; -+ int swizzle_log_tile; -+ -+ int gemm_k_iterations; -+ int gemm_k_iterations_per_channel; -+ typename Mma::IteratorA::Params iterator_A; -+ typename Mma::IteratorA::Element const *ptr_A; -+ typename Mma::IteratorB::Params iterator_B; -+ typename Mma::IteratorB::Element const *ptr_B; -+ typename Epilogue::OutputTileIterator::Params iterator_C; -+ typename Epilogue::OutputTileIterator::Element *ptr_C; -+ typename Epilogue::OutputTileIterator::Params iterator_D; -+ typename Epilogue::OutputTileIterator::Element *ptr_D; -+ typename EpilogueOutputOp::Params output_op; -+ int *semaphore; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), gemm_k_iterations(0) { } -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ int *semaphore = nullptr -+ ): -+ problem_size(args.problem_size), -+ implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), -+ iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), -+ ptr_A(args.ref_A.data()), -+ iterator_B(args.problem_size, args.ref_B.layout()), -+ ptr_B(args.ref_B.data()), -+ iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), -+ ptr_C(args.ref_C.data()), -+ iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), -+ ptr_D(args.ref_D.data()), -+ output_op(args.output_op), -+ semaphore(semaphore), -+ split_k_mode(args.split_k_mode) -+ { -+ gemm_k_iterations = implicit_gemm_k_iterations( -+ kConvolutionalOperator, -+ ThreadblockShape::kK, -+ args.problem_size, -+ kIteratorAlgorithm, -+ kGroupMode, -+ ThreadblockShape::kN); -+ -+ gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( -+ kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, kIteratorAlgorithm); -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ implicit_gemm_problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ -+ swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ ImplicitGemmConvolution() { } -+ -+ /// Executes one ImplicitGEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { -+ -+ return; -+ } -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ int iterator_A_column_offset = threadblock_tile_idx.k() * Mma::Shape::kK; -+ if (kGroupMode != GroupMode::kNone) { -+ if (kGroupMode != GroupMode::kDepthwise) { -+ int k_per_group = params.problem_size.K / params.problem_size.groups; -+ int group_idx = threadblock_tile_idx.n() * Mma::Shape::kN / k_per_group; -+ int channels_per_group = params.problem_size.C / params.problem_size.groups; -+ iterator_A_column_offset += group_idx * channels_per_group; -+ } else { -+ iterator_A_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; -+ } -+ } -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.iterator_A, -+ params.problem_size, -+ params.ptr_A, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ iterator_A_column_offset -+ ) -+ ); -+ -+ typename Mma::IteratorB iterator_B( -+ params.iterator_B, -+ params.problem_size, -+ params.ptr_B, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.k() * Mma::Shape::kK, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ) -+ ); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, params.gemm_k_iterations_per_channel); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // Construct the semaphore. -+ int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); -+ -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // Compute logical position within grid -+ threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); -+ } -+ -+ MatrixCoord threadblock_offset( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to destination tensor -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.iterator_D, -+ params.ptr_D, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator reading from source accumulator tensor -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.iterator_C, -+ params.ptr_C, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_idx.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_idx.k()); -+ -+ } -+ // Each split-k-slice writes to a unique tensor location -+ else if (params.split_k_mode == SplitKMode::kParallel) { -+ iterator_D.add_pointer_offset(threadblock_tile_idx.k() * -+ cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); -+ } -+ -+ // Run efficient epilogue -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_idx.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h -new file mode 100644 -index 0000000..b740c90 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h -@@ -0,0 +1,463 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined fused activation's scale+bias+relu and Implicit GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem -+> -+struct ImplicitGemmConvolutionFusion { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ -+ using ElementScaleBias = typename Mma::IteratorScaleBias::Element; -+ using LayoutScaleBias = typename Mma::IteratorScaleBias::Layout; -+ -+ using ElementC = typename EpilogueOutputOp::ElementOutput; -+ using LayoutC = LayoutA; -+ -+ using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using WarpMmaOperator = typename Mma::Policy::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename WarpMmaOperator::Shape; -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ static int const kStages = Mma::kStages; -+ static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using TensorRefA = typename Mma::IteratorA::TensorRef; -+ using TensorRefB = typename Mma::IteratorB::TensorRef; -+ using TensorRefScaleBias = typename Mma::IteratorScaleBias::TensorRef; -+ using TensorRefC = cutlass::TensorRef; -+ -+ /// Check iterator A and B convolution dimension are the same and -+ // set device::ImplicitGemmConvolution::kConvDim -+ static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, -+ "Convolution on different different dimensions is not supported"); -+ static int const kConvDim = Mma::IteratorA::kConvDim; -+ -+ /// Conv dimension and problem size structure (Conv2d or Conv3d) -+ using ConvProblemSize = ConvProblemSize_; -+ -+ static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; -+ -+ /// Wgrad C stride idx for implicit gemm algorithm -+ // Conv2d row-major matrix C (KxRSC) -+ // Conv3d row-major matrix C (KxTRSC) -+ static int const kWgradCStrideIdx = -+ platform::is_same::value ? 2 : 3; -+ -+ /// This chooses the appropriate stride element of the C tensor. -+ static int const kTensorCStrideIdx = -+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); -+ -+ // -+ // -+ // -+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< -+ LayoutC, -+ typename Epilogue::OutputTileIterator::Layout, -+ TensorRefC, -+ ConvOperator, -+ ConvProblemSize -+ >; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ ConvProblemSize problem_size; -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ TensorRefScaleBias ref_scale; -+ TensorRefScaleBias ref_bias; -+ TensorRefC ref_C; -+ TensorRefC ref_D; -+ typename EpilogueOutputOp::Params output_op; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size -+ ): -+ problem_size(problem_size) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size, -+ TensorRefA const & ref_A, -+ TensorRefB const & ref_B, -+ TensorRefScaleBias const & ref_scale, -+ TensorRefScaleBias const & ref_bias, -+ TensorRefC const & ref_C, -+ TensorRefC const & ref_D, -+ typename EpilogueOutputOp::Params const & output_op, -+ SplitKMode const & split_k_mode = SplitKMode::kSerial -+ ): -+ problem_size(problem_size), -+ ref_A(ref_A), -+ ref_B(ref_B), -+ ref_scale(ref_scale), -+ ref_bias(ref_bias), -+ ref_C(ref_C), -+ ref_D(ref_D), -+ output_op(output_op), -+ split_k_mode(split_k_mode) -+ { -+ -+ } -+ -+ }; -+ -+ /// Parameters structure -+ struct Params { -+ ConvProblemSize problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ gemm::GemmCoord implicit_gemm_problem_size; -+ int swizzle_log_tile; -+ int gemm_k_iterations; -+ typename Mma::IteratorA::Params iterator_A; -+ typename Mma::IteratorA::Element const *ptr_A; -+ typename Mma::IteratorB::Params iterator_B; -+ typename Mma::IteratorB::Element const *ptr_B; -+ typename Mma::IteratorScaleBias::Params iterator_scale_bias; -+ typename Mma::IteratorScaleBias::Element const *ptr_scale; -+ typename Mma::IteratorScaleBias::Element const *ptr_bias; -+ typename Epilogue::OutputTileIterator::Params iterator_C; -+ typename Epilogue::OutputTileIterator::Element *ptr_C; -+ typename Epilogue::OutputTileIterator::Params iterator_D; -+ typename Epilogue::OutputTileIterator::Element *ptr_D; -+ typename EpilogueOutputOp::Params output_op; -+ int *semaphore; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), gemm_k_iterations(0) { } -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ int *semaphore = nullptr -+ ): -+ problem_size(args.problem_size), -+ implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), -+ iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), -+ ptr_A(args.ref_A.data()), -+ iterator_B(args.problem_size, args.ref_B.layout()), -+ ptr_B(args.ref_B.data()), -+ iterator_scale_bias(args.problem_size, args.ref_scale.layout()), -+ ptr_scale(args.ref_scale.data()), -+ ptr_bias(args.ref_bias.data()), -+ iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), -+ ptr_C(args.ref_C.data()), -+ iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), -+ ptr_D(args.ref_D.data()), -+ output_op(args.output_op), -+ semaphore(semaphore), -+ split_k_mode(args.split_k_mode) -+ { -+ gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ implicit_gemm_problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ -+ swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ ImplicitGemmConvolutionFusion() { } -+ -+ /// Executes one ImplicitGEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { -+ -+ return; -+ } -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A operand -+ typename Mma::IteratorA iterator_A( -+ params.iterator_A, -+ params.problem_size, -+ params.ptr_A, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.k() * Mma::Shape::kK -+ ) -+ ); -+ -+ // Construct iterators to B operand -+ typename Mma::IteratorB iterator_B( -+ params.iterator_B, -+ params.problem_size, -+ params.ptr_B, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.k() * Mma::Shape::kK, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ) -+ ); -+ -+ // Construct iterators to A scale/bias vector -+ typename Mma::IteratorScaleBias iterator_scale_bias( -+ params.iterator_scale_bias, -+ params.problem_size, -+ params.ptr_scale, -+ params.ptr_bias, -+ thread_idx, -+ MatrixCoord( -+ 0, (kConvolutionalOperator == conv::Operator::kFprop) ? -+ (threadblock_tile_idx.k() * Mma::Shape::kK) : -+ // Wgrad -+ (threadblock_tile_idx.n() * Mma::Shape::kN) -+ ) -+ ); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(params.gemm_k_iterations, accumulators, iterator_A, -+ iterator_B, iterator_scale_bias, accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // Construct the semaphore. -+ int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); -+ -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // Compute logical position within grid -+ threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); -+ } -+ -+ MatrixCoord threadblock_offset( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to destination tensor -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.iterator_D, -+ params.ptr_D, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator reading from source accumulator tensor -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.iterator_C, -+ params.ptr_C, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_idx.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_idx.k()); -+ -+ } -+ // Each split-k-slice writes to a unique tensor location -+ else if (params.split_k_mode == SplitKMode::kParallel) { -+ iterator_D.add_pointer_offset(threadblock_tile_idx.k() * -+ cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); -+ } -+ -+ // Run efficient epilogue -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_idx.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h -new file mode 100644 -index 0000000..7304cbd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h -@@ -0,0 +1,492 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined Implicit GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem -+> -+struct ImplicitGemmConvolutionStridedDgrad { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename EpilogueOutputOp::ElementOutput; -+ -+ /// Set output tensor C layout -+ using LayoutC = LayoutA; -+ -+ using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using WarpMmaOperator = typename Mma::Policy::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename WarpMmaOperator::Shape; -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ static int const kStages = Mma::kStages; -+ static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; -+ static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using TensorRefA = typename Mma::IteratorA::TensorRef; -+ using TensorRefB = typename Mma::IteratorB::TensorRef; -+ using TensorRefC = cutlass::TensorRef; -+ -+ /// Check iterator A and B convolution dimension are the same and -+ // set device::ImplicitGemmConvolution::kConvDim -+ static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, -+ "Convolution on different different dimensions is not supported"); -+ static int const kConvDim = Mma::IteratorA::kConvDim; -+ -+ /// Conv dimension and problem size structure (Conv2d or Conv3d) -+ using ConvProblemSize = ConvProblemSize_; -+ -+ static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; -+ -+ /// Wgrad C stride idx for implicit gemm algorithm -+ // Conv2d row-major matrix C (KxRSC) -+ // Conv3d row-major matrix C (KxTRSC) -+ static int const kWgradCStrideIdx = -+ platform::is_same::value ? 2 : 3; -+ -+ /// This chooses the appropriate stride element of the C tensor. -+ static int const kTensorCStrideIdx = -+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); -+ -+ // Strided dgrad uses a specialized threadblock swizzle for functionality and performance -+ static_assert((platform::is_same::value) || -+ (platform::is_same>::value) || -+ (platform::is_same>::value) || -+ (platform::is_same>::value), -+ "Needs ThreadblockSwizzle type specialized for strided dgrad"); -+ -+ // -+ // -+ // -+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< -+ LayoutC, -+ typename Epilogue::OutputTileIterator::Layout, -+ TensorRefC, -+ ConvOperator, -+ ConvProblemSize -+ >; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ ConvProblemSize problem_size; -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ TensorRefC ref_C; -+ TensorRefC ref_D; -+ typename EpilogueOutputOp::Params output_op; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size -+ ): -+ problem_size(problem_size) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size, -+ TensorRefA const & ref_A, -+ TensorRefB const & ref_B, -+ TensorRefC const & ref_C, -+ TensorRefC const & ref_D, -+ typename EpilogueOutputOp::Params const & output_op, -+ SplitKMode const & split_k_mode = SplitKMode::kSerial -+ ): -+ problem_size(problem_size), -+ ref_A(ref_A), -+ ref_B(ref_B), -+ ref_C(ref_C), -+ ref_D(ref_D), -+ output_op(output_op), -+ split_k_mode(split_k_mode) -+ { -+ -+ } -+ -+ }; -+ -+ /// Parameters structure -+ struct Params { -+ ConvProblemSize problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ FastDivmod stride_h_divmod; -+ FastDivmod stride_w_divmod; -+ int gemm_k_iterations; -+ typename Mma::IteratorA::Params iterator_A; -+ typename Mma::IteratorA::Element const *ptr_A; -+ typename Mma::IteratorB::Params iterator_B; -+ typename Mma::IteratorB::Element const *ptr_B; -+ typename Epilogue::OutputTileIterator::Params iterator_C; -+ typename Epilogue::OutputTileIterator::Element *ptr_C; -+ typename Epilogue::OutputTileIterator::Params iterator_D; -+ typename Epilogue::OutputTileIterator::Element *ptr_D; -+ typename EpilogueOutputOp::Params output_op; -+ int *semaphore; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): gemm_k_iterations(0) { } -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ int *semaphore = nullptr -+ ): -+ problem_size(args.problem_size), -+ stride_h_divmod(args.problem_size.stride_h), -+ stride_w_divmod(args.problem_size.stride_w), -+ iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), -+ ptr_A(args.ref_A.data()), -+ iterator_B(args.problem_size, args.ref_B.layout()), -+ ptr_B(args.ref_B.data()), -+ iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size, ThreadblockShape::kM), -+ ptr_C(args.ref_C.data()), -+ iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size, ThreadblockShape::kM), -+ ptr_D(args.ref_D.data()), -+ output_op(args.output_op), -+ semaphore(semaphore), -+ split_k_mode(args.split_k_mode) -+ { -+ gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ kConvolutionalOperator, -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ ImplicitGemmConvolutionStridedDgrad() { } -+ -+ /// Executes one ImplicitGEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { -+ -+ return; -+ } -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Compute starting filter position for strided dgrad -+ int tile_m_per_filter = strided_dgrad_tile_m_per_filter(params.problem_size, -+ ThreadblockShape::kM); -+ int filter_tile_m = (threadblock_tile_idx.m() / tile_m_per_filter); -+ -+ -+ // The subsequent fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // int start_r = filter_tile_m / (params.problem_size.stride_w); -+ // int start_s = filter_tile_m % (params.problem_size.stride_w); -+ -+ int start_r, start_s; -+ params.stride_w_divmod(start_r, start_s, filter_tile_m); -+ -+ int filter_r = start_r; -+ int filter_s = start_s; -+ -+ if (params.problem_size.mode == Mode::kConvolution) { -+ filter_r = (params.problem_size.R - 1 - filter_r); -+ filter_s = (params.problem_size.S - 1 - filter_s); -+ } -+ -+ // Starting h, w positions for filter position in gemm_k=0 -+ int start_h, start_w; -+ strided_dgrad_starting_coords( -+ params.problem_size, -+ params.stride_h_divmod, params.stride_w_divmod, -+ filter_r, filter_s, -+ start_h, start_w); -+ -+ if (start_h >= params.problem_size.H || start_w >= params.problem_size.W) { -+ return; -+ } -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // Check if CTA contributes valid MMA (Dy * w) and accumulator will be non-zero after MMA -+ if (start_r < params.problem_size.R && start_s < params.problem_size.S) { -+ // Scale gemm_k_iterations for strided dgrad -+ int gemm_k_iterations = (params.gemm_k_iterations / (params.problem_size.R * params.problem_size.S) -+ ) * params.problem_size.num_gemm_k_filter_positions(start_r, start_s); -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.iterator_A, -+ params.problem_size, -+ params.ptr_A, -+ thread_idx, -+ params.stride_h_divmod, params.stride_w_divmod, -+ start_r, start_s, -+ MatrixCoord( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.k() * Mma::Shape::kK -+ ) -+ ); -+ -+ typename Mma::IteratorB iterator_B( -+ params.iterator_B, -+ params.problem_size, -+ params.ptr_B, -+ thread_idx, -+ start_r, start_s, -+ MatrixCoord( -+ threadblock_tile_idx.k() * Mma::Shape::kK, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ) -+ ); -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); -+ } -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // Construct the semaphore. -+ int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // Compute logical position within grid -+ threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); -+ } -+ -+ MatrixCoord threadblock_offset( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to destination tensor -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.iterator_D, -+ params.ptr_D, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ params.stride_h_divmod, params.stride_w_divmod, -+ start_r, start_s, -+ threadblock_offset -+ ); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ if (output_op.is_source_needed()) -+ { -+ // Tile iterator reading from source accumulator tensor -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.iterator_C, -+ params.ptr_C, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ params.stride_h_divmod, params.stride_w_divmod, -+ start_r, start_s, -+ threadblock_offset); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_idx.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_idx.k()); -+ } -+ -+ // Run epilogue with addend source iterator -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ } -+ else -+ { -+ // Run epilogue without addend source iterator -+ epilogue(output_op, iterator_D, accumulators); -+ } -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_idx.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h -new file mode 100644 -index 0000000..3fa7dac ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h -@@ -0,0 +1,499 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined Implicit GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem -+> -+struct ImplicitGemmConvolutionWithFusedEpilogue { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename EpilogueOutputOp::ElementOutput; -+ -+ /// Set output tensor C layout -+ using LayoutC = LayoutA; -+ -+ using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using WarpMmaOperator = typename Mma::Policy::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename WarpMmaOperator::Shape; -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ static int const kStages = Mma::kStages; -+ static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; -+ static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using TensorRefA = typename Mma::IteratorA::TensorRef; -+ using TensorRefB = typename Mma::IteratorB::TensorRef; -+ using TensorRefC = cutlass::TensorRef; -+ -+ /// Check iterator A and B convolution dimension are the same and -+ // set device::ImplicitGemmConvolution::kConvDim -+ static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, -+ "Convolution on different different dimensions is not supported"); -+ static int const kConvDim = Mma::IteratorA::kConvDim; -+ -+ /// Conv dimension and problem size structure (Conv2d or Conv3d) -+ using ConvProblemSize = ConvProblemSize_; -+ -+ static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; -+ -+ /// Wgrad C stride idx for implicit gemm algorithm -+ // Conv2d row-major matrix C (KxRSC) -+ // Conv3d row-major matrix C (KxTRSC) -+ static int const kWgradCStrideIdx = -+ platform::is_same::value ? 2 : 3; -+ -+ /// This chooses the appropriate stride element of the C tensor. -+ static int const kTensorCStrideIdx = -+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); -+ -+ // -+ // -+ // -+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< -+ LayoutC, -+ typename Epilogue::OutputTileIterator::Layout, -+ TensorRefC, -+ ConvOperator, -+ ConvProblemSize -+ >; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ ConvProblemSize problem_size; -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ TensorRefC ref_C; -+ TensorRefC ref_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ SplitKMode split_k_mode; -+ -+ void * ptr_Vector; -+ void * ptr_Tensor; -+ -+ typename LayoutC::Stride::Index ldr; -+ typename LayoutC::Stride::Index ldt; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size -+ ): -+ problem_size(problem_size) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size, -+ TensorRefA const & ref_A, -+ TensorRefB const & ref_B, -+ TensorRefC const & ref_C, -+ TensorRefC const & ref_D, -+ typename EpilogueOutputOp::Params const & output_op, -+ SplitKMode const & split_k_mode = SplitKMode::kSerial, -+ void * ptr_Vector = nullptr, -+ void * ptr_Tensor = nullptr, -+ typename LayoutC::Stride::Index ldr = 0, -+ typename LayoutC::Stride::Index ldt = 0 -+ ): -+ problem_size(problem_size), -+ ref_A(ref_A), -+ ref_B(ref_B), -+ ref_C(ref_C), -+ ref_D(ref_D), -+ output_op(output_op), -+ split_k_mode(split_k_mode), -+ ptr_Vector(ptr_Vector), -+ ptr_Tensor(ptr_Tensor), -+ ldr(ldr), -+ ldt(ldt) -+ { -+ -+ } -+ -+ }; -+ -+ /// Parameters structure -+ struct Params { -+ ConvProblemSize problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ gemm::GemmCoord implicit_gemm_problem_size; -+ int swizzle_log_tile; -+ -+ int gemm_k_iterations; -+ typename Mma::IteratorA::Params iterator_A; -+ typename Mma::IteratorA::Element const *ptr_A; -+ typename Mma::IteratorB::Params iterator_B; -+ typename Mma::IteratorB::Element const *ptr_B; -+ typename Epilogue::OutputTileIterator::Params iterator_C; -+ typename Epilogue::OutputTileIterator::Element *ptr_C; -+ typename Epilogue::OutputTileIterator::Params iterator_D; -+ typename Epilogue::OutputTileIterator::Element *ptr_D; -+ typename EpilogueOutputOp::Params output_op; -+ int *semaphore; -+ SplitKMode split_k_mode; -+ -+ typename Epilogue::TensorTileIterator::Params params_Tensor; -+ void * ptr_Vector; -+ typename LayoutC::Stride::Index ldr; -+ void * ptr_Tensor; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ gemm_k_iterations(0), -+ ptr_Vector(nullptr), -+ ldr(0), -+ ptr_Tensor(nullptr) -+ { } -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ int *semaphore = nullptr -+ ): -+ problem_size(args.problem_size), -+ implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), -+ iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), -+ ptr_A(args.ref_A.data()), -+ iterator_B(args.problem_size, args.ref_B.layout()), -+ ptr_B(args.ref_B.data()), -+ iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), -+ ptr_C(args.ref_C.data()), -+ iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), -+ ptr_D(args.ref_D.data()), -+ output_op(args.output_op), -+ semaphore(semaphore), -+ split_k_mode(args.split_k_mode), -+ params_Tensor(args.ldt), -+ ptr_Vector(args.ptr_Vector), -+ ldr(args.ldr), -+ ptr_Tensor(args.ptr_Tensor) -+ -+ { -+ gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ implicit_gemm_problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ -+ swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ ImplicitGemmConvolutionWithFusedEpilogue() { } -+ -+ /// Executes one ImplicitGEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { -+ -+ return; -+ } -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.iterator_A, -+ params.problem_size, -+ params.ptr_A, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.k() * Mma::Shape::kK -+ ) -+ ); -+ -+ typename Mma::IteratorB iterator_B( -+ params.iterator_B, -+ params.problem_size, -+ params.ptr_B, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.k() * Mma::Shape::kK, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ) -+ ); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // Construct the semaphore. -+ int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); -+ -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // Compute logical position within grid -+ threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); -+ } -+ -+ MatrixCoord threadblock_offset( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to destination tensor -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.iterator_D, -+ params.ptr_D, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator reading from source accumulator tensor -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.iterator_C, -+ params.ptr_C, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::ElementTensor *ptr_Tensor = -+ static_cast(params.ptr_Tensor); -+ -+ // Define the reduction output pointer and move to the appropriate place -+ typename Epilogue::ElementVector *ptr_Vector = -+ static_cast(params.ptr_Vector); -+ -+ // Additional tensor to load from -+ typename Epilogue::TensorTileIterator tensor_iterator( -+ params.params_Tensor, -+ // Only the final block outputs Tensor -+ ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && -+ (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) -+ ? nullptr -+ : ptr_Tensor, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Move to appropriate location for this output tile -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_offset.column() + threadblock_tile_idx.m() * params.ldr; -+ } -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_idx.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_idx.k()); -+ -+ } -+ // Each split-k-slice writes to a unique tensor location -+ else if (params.split_k_mode == SplitKMode::kParallel) { -+ iterator_D.add_pointer_offset(threadblock_tile_idx.k() * -+ cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, -+ // Only the final block uses Vector -+ ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && -+ (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) -+ ? nullptr -+ : ptr_Vector, -+ iterator_D, -+ accumulators, -+ iterator_C, -+ tensor_iterator, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ threadblock_offset); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_idx.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/thread/depthwise_mma.h b/3rdparty/cutlass/include/cutlass/conv/thread/depthwise_mma.h -new file mode 100644 -index 0000000..8f84563 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/thread/depthwise_mma.h -@@ -0,0 +1,325 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for depthwise convolution -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/thread/mma.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// MMA operation -+template < -+ /// Size of the matrix product (concept: GemmShape) -+ typename Shape_, -+ /// Number of threads participating -+ int kThreads_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Inner product operator -+ typename Operator -+> -+struct ElementwiseInnerProduct; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// General implementation -+template < -+ /// Size of the matrix product (concept: GemmShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Element type of C matrix -+ typename ElementC_> -+struct ElementwiseInnerProduct { -+ using Shape = Shape_; -+ using Operator = arch::OpMultiplyAdd; -+ using ElementC = ElementC_; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Shape::kN; ++i) { -+ d[i] = a[i] * b[i] + c[i]; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Specialization of half_t -+template <> -+struct ElementwiseInnerProduct< -+ gemm::GemmShape<2, 2, 1>, -+ 1, -+ half_t, -+ half_t, -+ half_t, -+ arch::OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<2, 2, 1>; -+ using Operator = arch::OpMultiplyAdd; -+ using ElementC = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) -+ -+ __half2 const & A = reinterpret_cast<__half2 const &>(a); -+ __half2 const & B = reinterpret_cast<__half2 const &>(b); -+ __half2 const & C = reinterpret_cast<__half2 const &>(c); -+ -+ __half2 tmp_D = __hfma2(A, B, C); -+ -+ d = reinterpret_cast const &>(tmp_D); -+ -+#else -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ d[i] = a[i] * b[i] + c[i]; -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape, -+ /// Data type of A elements -+ typename ElementA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Concept: arch::OpMultiplyAdd or arch::Mma<> -+ typename Operator = arch::OpMultiplyAdd, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+struct DepthwiseDirectConvElementwiseInnerProduct; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gemplate that handles all packed matrix layouts -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Operator used to compute GEMM -+ typename Operator_ -+> -+struct DepthwiseDirectConvElementwiseInnerProductGeneric { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = ElementA_; -+ -+ /// Data type of operand B -+ using ElementB = ElementB_; -+ -+ /// Element type of operand C -+ using ElementC = ElementC_; -+ -+ /// Underlying mathematical operator -+ using Operator = Operator_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Instruction -+ using MmaOp = cutlass::conv::thread::ElementwiseInnerProduct< -+ gemm::GemmShape, -+ 1, -+ ElementA, -+ ElementB, -+ ElementC, -+ Operator>; -+ -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = -+ reinterpret_cast const *>(&A); -+ Array const *ptr_B = -+ reinterpret_cast const *>(&B); -+ -+ MmaOp mma_op; -+ -+ // Copy accumulators -+ D = C; -+ -+ // Compute matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN / MmaOp::Shape::kN; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM; ++m) { -+ -+ Array tmpD = ptr_D[m * Shape::kN / MmaOp::Shape::kN + n]; -+ Array tmpA = ptr_A[m * Shape::kN / MmaOp::Shape::kN + n]; -+ Array tmpB = ptr_B[n]; -+ -+ mma_op(tmpD, tmpA, tmpB, tmpD); -+ -+ ptr_D[m * Shape::kN / MmaOp::Shape::kN + n] = tmpD; -+ -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Element type of C matrix -+ typename ElementC_ -+> -+struct DepthwiseDirectConvElementwiseInnerProduct< -+ Shape_, -+ ElementA_, -+ ElementB_, -+ ElementC_, -+ arch::OpMultiplyAdd -+ > { -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = ElementA_; -+ -+ /// Data type of operand B -+ using ElementB = ElementB_; -+ -+ /// Element type of operand C -+ using ElementC = ElementC_; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ /// A operand storage -+ using FragmentA = -+ Array; // output_tile_size per thread * groups_per_thread -+ -+ /// B operand storage -+ using FragmentB = Array; // 1 * groups_per_thread -+ -+ /// C operand storage -+ using FragmentC = -+ Array; // output_tile_size per thread * groups_per_thread -+ -+ static bool const use_optimized = 0; -+ -+ using ArchMmaOperator = DepthwiseDirectConvElementwiseInnerProductGeneric; -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ ArchMmaOperator mma; -+ -+ mma(D, A, B, C); -+ -+ } -+}; -+ -+} // namespace thread -+} // namespace conv -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..9464074 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h -@@ -0,0 +1,485 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dDgradFilterTileAccessIteratorAnalytic; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2dDgradFilterTileAccessIteratorAnalytic strided dgrad needs special handling to skip MMAs -+// on non-contributing w positions -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradFilterTileAccessIteratorAnalytic < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kStrided, -+ AccessType_ -+> { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or larger."); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension -+ int filter_r_; -+ int filter_s_; -+ int start_r_; -+ int start_s_; -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ int offset_c_[ThreadMap::Iterations::kContiguous]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ int start_r, int start_s, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_r_(start_r), -+ filter_s_(start_s), -+ start_r_(start_r), -+ start_s_(start_s) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = -+ threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // Moves filter_s -+ filter_s_ += problem_size_.stride_w; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ // Restore filter_s -+ filter_s_ = start_s_; -+ -+ // Move filter_r -+ filter_r_ += problem_size_.stride_h; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ // Restore filter_r -+ filter_r_ = start_r_; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the filter tensor w that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int k = offset_k_[iteration_strided_]; -+ int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord(k, filter_r_, filter_s_, c); -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor w -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.K && coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2dDgradFilterTileAccessIteratorAnalytic unity strided dgrad is more performant for dgrad -+// on problem sizes with stride = {1x1} -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradFilterTileAccessIteratorAnalytic < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kUnity, -+ AccessType_ -+>{ -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or larger."); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension -+ int filter_r_; -+ int filter_s_; -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ int offset_c_[ThreadMap::Iterations::kContiguous]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = -+ threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the filter tensor w that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int k = offset_k_[iteration_strided_]; -+ int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord(k, filter_r_, filter_s_, c); -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor w -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.K && coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..bd5aa70 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h -@@ -0,0 +1,619 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dDgradFilterTileAccessIteratorOptimized; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad -+// on problem sizes with stride = {1x1} -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradFilterTileAccessIteratorOptimized < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kStrided, -+ AccessType_ -+ > { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv2dStridedDgradFilterIteratorOptimizedParams { -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv2dStridedDgradFilterIteratorOptimizedParams const &base): -+ Conv2dStridedDgradFilterIteratorOptimizedParams(base) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout -+ ): -+ Conv2dStridedDgradFilterIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ) { } -+ -+ }; -+ -+private: -+ -+ Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ uint32_t predicates_[kAccessesPerVector]; -+ int filter_k_; -+ int filter_r_; -+ int filter_s_; -+ -+ int start_r_; -+ int start_s_; -+ -+ int64_t reset_bytes_s_; -+ int64_t reset_bytes_r_; -+ -+ // -+ // Assertions -+ // -+ -+ // We map predicates into bits packed in this uint32_t container -+ static_assert(ThreadMap::Iterations::kStrided * -+ ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, -+ "Currently, the number of loads per iteration is limited by the size of the predicates container."); -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorOptimized( -+ Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ int start_r, int start_s, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_{0}, -+ filter_r_(start_r), -+ filter_s_(start_s), -+ start_r_(start_r), -+ start_s_(start_s) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.row() + thread_coord.strided(); -+ Index column = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0]; -+ reset_bytes_r_ = reset_bytes_s_ + -+ (problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; -+ int filter_c = column + c * ThreadMap::Delta::kContiguous; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0); -+ -+ int pred_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ predicates_[v] |= (pred << pred_idx); -+ } -+ } -+ } -+ -+ TensorCoord coord{filter_k_, filter_r_, filter_s_, column}; -+ -+ pointer_ += params_.layout(coord) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_DEVICE -+ void advance() { -+ -+ int next_idx = 0; -+ LongIndex reset_bytes = params_.reset_bytes; -+ -+ // Move filter_s by stride_w -+ filter_s_ += problem_size_.stride_w; -+ if (filter_s_ >= problem_size_.S) { -+ -+ // Restore filter_s -+ filter_s_ = start_s_; -+ -+ // Move filter_r by stride_h -+ filter_r_ += problem_size_.stride_h; -+#if 0 -+ bool check = (filter_r_ < problem_size_.R); -+ -+ filter_r_ = check ? filter_r_ : start_r_; -+ next_idx = check ? 1 : 2; -+ reset_bytes += (check ? reset_bytes_s_ : reset_bytes_r_); -+#else -+ asm volatile( -+ "{\n\t" -+ " .reg .pred %%p;\n\t" -+ " .reg .s64 t1;\n\t" -+ " setp.lt.s32 %%p, %3, %4;\n\t" -+ " selp.s32 %0, %3, %5, %%p;\n\t" -+ " selp.s32 %1, 1, 2, %%p;\n\t" -+ " selp.s64 t1, %6, %7, %%p;\n\t" -+ " add.s64 %2, %8, t1;\n\t" -+ "}\n" -+ : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) -+ : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), -+ "l"(reset_bytes_s_), "l"(reset_bytes_r_), "l"(reset_bytes)); -+#endif -+ } -+ -+ // offset pointers by offset_bytes -+ pointer_ += (params_.inc_next[next_idx] - reset_bytes); -+ -+ if (next_idx == 2) { -+ filter_k_ += params_.filter_k_delta; -+ } -+ -+ // Clear predicates if needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { -+ uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ predicates_[v] = (predicates_[v] & (~kClearMask)); -+ } -+ } -+ } -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; -+ return (predicates_[iteration_vector_] & (1u << pred_idx)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ return reinterpret_cast(pointer_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ -+ // Move to the next K coordinate within the tile -+ pointer_ += params_.inc_next_strided; -+ -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad -+// on problem sizes with stride = {1x1} -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradFilterTileAccessIteratorOptimized < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kUnity, -+ AccessType_ -+ > { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv2dDgradFilterIteratorOptimizedParams { -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv2dDgradFilterIteratorOptimizedParams const &base): -+ Conv2dDgradFilterIteratorOptimizedParams(base) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout -+ ): -+ Conv2dDgradFilterIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ) { } -+ -+ }; -+ -+private: -+ -+ Conv2dDgradFilterIteratorOptimizedParams const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ uint32_t predicates_[kAccessesPerVector]; -+ int filter_rs_; -+ int filter_k_; -+ -+ // -+ // Assertions -+ // -+ -+ // We map predicates into bits packed in this uint32_t container -+ static_assert(ThreadMap::Iterations::kStrided * -+ ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, -+ "Currently, the number of loads per iteration is limited by the size of the predicates container."); -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorOptimized( -+ Conv2dDgradFilterIteratorOptimizedParams const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_{0}, -+ filter_rs_(0), -+ filter_k_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.row() + thread_coord.strided(); -+ Index column = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; -+ int filter_c = column + c * ThreadMap::Delta::kContiguous; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0); -+ -+ int pred_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ predicates_[v] |= (pred << pred_idx); -+ } -+ } -+ } -+ -+ pointer_ += ( -+ filter_k_ * params.layout.stride()[2] + column -+ ) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ LongIndex next = params_.inc_next_rs; -+ -+ // moves to the next tile -+ ++filter_rs_; -+ if (filter_rs_ == params_.RS) { -+ -+ filter_rs_ = 0; -+ next = params_.inc_next_k; -+ filter_k_ += params_.filter_k_delta; -+ } -+ -+ // Clear predicates if needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { -+ uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ predicates_[v] = (predicates_[v] & (~kClearMask)); -+ } -+ } -+ } -+ -+ pointer_ += next; -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; -+ return (predicates_[iteration_vector_] & (1u << pred_idx)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ return reinterpret_cast(pointer_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ -+ // Move to the next K coordinate within the tile -+ pointer_ += params_.inc_next_strided; -+ -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..08f5465 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h -@@ -0,0 +1,606 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/functional.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dDgradOutputGradientTileAccessIteratorAnalytic; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2dDgradOutputGradientTileAccessIteratorAnalytic strided dgrad needs special handling using -+// unscaled coordinations -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kStrided, -+ AccessType_ -+> { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Simpligying assertions -+ // -+ -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int filter_k_; -+ int filter_r_; -+ int filter_s_; -+ int start_r_; -+ int start_s_; -+ -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_p_[ThreadMap::Iterations::kStrided]; -+ int offset_q_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, -+ int start_r, int start_s, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_k_(0), -+ filter_r_(start_r), -+ filter_s_(start_s), -+ start_r_(start_r), -+ start_s_(start_s) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ int filter_r = filter_r_; -+ int filter_s = filter_s_; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ filter_r = (problem_size_.R - 1 - filter_r); -+ filter_s = (problem_size_.S - 1 - filter_s); -+ } -+ -+ // Starting h, w positions for filter position in gemm_k=0 -+ int start_h, start_w; -+ strided_dgrad_starting_coords( -+ problem_size_, -+ stride_h_divmod, stride_w_divmod, -+ filter_r, filter_s, -+ start_h, start_w); -+ -+ // Effective P and Q for filter position required for remapping NHW rows -+ int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h; -+ int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w; -+ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter; -+ -+ // (STEP 1) [reorder NHW rows to start with same filter positions] -+ offset_n_[s] = offset_npq / (P * Q); -+ int residual = offset_npq % (P * Q); -+ -+ int p = (residual / Q); -+ int q = (residual % Q); -+ -+ int mapped_h = (start_h + p * problem_size_.stride_h); -+ int mapped_w = (start_w + q * problem_size_.stride_w); -+ -+ // Access (p, q) coordinates for Dy tensor and a filter position in gemm_k=0 -+ // note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are divisible -+ // by stride_h and stride_w -+ offset_p_[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h; -+ offset_q_[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ // Move filter_s by stride_w -+ filter_s_ += problem_size_.stride_w; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ -+ // Restore filter_s -+ filter_s_ = start_s_; -+ -+ // Move filter_r by stride_h -+ filter_r_ += problem_size_.stride_h; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ -+ // Restore filter_r -+ filter_r_ = start_r_; -+ -+ // Move filter_k -+ filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the output tensor Dy that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int n = offset_n_[iteration_strided_]; -+ int p = offset_p_[iteration_strided_]; -+ int q = offset_q_[iteration_strided_]; -+ -+ int conv_sign = (problem_size_.mode == Mode::kConvolution ? 1 : -1); -+ -+ p += (conv_sign * (filter_r_ / problem_size_.stride_h)); -+ q += (conv_sign * (filter_s_ / problem_size_.stride_w)); -+ -+ int k = filter_k_ + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord( -+ n, -+ p, -+ q, -+ k); -+ } -+ -+ -+ /// Returns true if the current coordinate is within the output tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return -+ coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.P && -+ coord.w() >= 0 && coord.w() < problem_size_.Q && -+ coord.c() < problem_size_.K; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2dDgradOutputGradientTileAccessIteratorAnalytic for unity strides can be optimized by -+// eliminating modulo arithmetic to compute unscaled coordinates -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kUnity, -+ AccessType_ -+> { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Simpligying assertions -+ // -+ -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params { -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int filter_k_; -+ int filter_r_; -+ int filter_s_; -+ -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_w_[ThreadMap::Iterations::kStrided]; -+ int offset_h_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_k_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ offset_n_[s] = offset_nhw / (problem_size_.H * problem_size_.W); -+ int residual = offset_nhw % (problem_size_.H * problem_size_.W); -+ -+ offset_h_[s] = residual / problem_size_.W; -+ offset_w_[s] = residual % problem_size_.W; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // move to the next tile -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ -+ filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the output tensor Dy that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int n = offset_n_[iteration_strided_]; -+ int h = offset_h_[iteration_strided_]; -+ int w = offset_w_[iteration_strided_]; -+ -+ int r = filter_r_; -+ int s = filter_s_; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h) / problem_size_.stride_h; -+ int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w) / problem_size_.stride_w; -+ -+ int k = filter_k_ + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord(n, p, q, k); -+ } -+ -+ /// Returns true if the current coordinate is within the output tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.P && -+ coord.w() >= 0 && coord.w() < problem_size_.Q && -+ coord.c() < problem_size_.K; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // Conv2dDgradFilterTileAccessIteratorAnalytic unity stride specialization -+ // only supports (stride_h, stride_w) = (1, 1) -+ if (problem_size.stride() != MatrixCoord({1, 1})) { -+ return Status::kErrorNotSupported; -+ } -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..38d94ac ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h -@@ -0,0 +1,821 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dDgradOutputGradientTileAccessIteratorOptimized; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Conv2dDgradOutputGradientTileAccessIteratorOptimized strided dgrad needs special handling -+// to skip MMAs (Dx = Dy * w) on invalid filter positions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradOutputGradientTileAccessIteratorOptimized < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kStrided, -+ AccessType_ -+> { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ using Mask = uint64_t; -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Simpligying assertions -+ // -+ -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dStridedDgradOutputGradientIteratorOptimizedParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ -+ // One pointer per access -+ char const *pointer_[ThreadMap::Iterations::kStrided]; -+ -+ int filter_k_; -+ int filter_r_; -+ int filter_s_; -+ int start_r_; -+ int start_s_; -+ int64_t reset_bytes_s_; -+ int64_t reset_bytes_r_; -+ -+ Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorOptimized( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, -+ int start_r, int start_s, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ filter_k_(0), -+ filter_r_(start_r), -+ filter_s_(start_s), -+ start_r_(start_r), -+ start_s_(start_s) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0]; -+ -+ reset_bytes_r_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0] + -+ (problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1]; -+ -+ int offset_n[ThreadMap::Iterations::kStrided]; -+ int offset_p[ThreadMap::Iterations::kStrided]; -+ int offset_q[ThreadMap::Iterations::kStrided]; -+ -+ int filter_r = filter_r_; -+ int filter_s = filter_s_; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ filter_r = (problem_size_.R - 1 - filter_r); -+ filter_s = (problem_size_.S - 1 - filter_s); -+ } -+ -+ // Starting h, w positions for filter position in gemm_k=0 -+ int start_h, start_w; -+ strided_dgrad_starting_coords( -+ problem_size_, -+ stride_h_divmod, stride_w_divmod, -+ filter_r, filter_s, -+ start_h, start_w); -+ -+ -+ // Effective starting P and Q for filter position required for remapping NHW rows -+ int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h; -+ int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ pointer_[s] = reinterpret_cast(ptr); -+ -+ int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter; -+ -+ // (STEP 1) [reorder NHW rows to start with same filter positions] -+ offset_n[s] = offset_npq / (P * Q); -+ int residual = offset_npq % (P * Q); -+ -+ int p = (residual / Q); -+ int q = (residual % Q); -+ -+ int mapped_h = (start_h + p * problem_size_.stride_h); -+ int mapped_w = (start_w + q * problem_size_.stride_w); -+ -+ // Access (p, q) coordinates for Dy tensor for filter position in gemm_k=0 -+ // note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are ensured to be -+ // divisible by stride_h and stride_w -+ offset_p[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h; -+ offset_q[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w; -+ -+ // Intialize pointers for gemm_k=0 -+ TensorCoord coord{offset_n[s], offset_p[s], offset_q[s], filter_k_}; -+ -+ pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; -+ } -+ -+ // -+ // Precompute mask predicates -+ // -+ clear_mask(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int r = start_r; r < problem_size_.R; r += problem_size_.stride_h) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int p = offset_p[s_idx] ; -+ -+ p += (params_.conv_sign * (r / problem_size_.stride_h)); -+ -+ bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ masks_[s_idx][v_idx][0] |= (pred << r); -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for(int s = start_s; s < problem_size_.S; s += problem_size_.stride_w) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int q = offset_q[s_idx]; -+ q += (params_.conv_sign * (s / problem_size_.stride_w)); -+ -+ bool pred = (q >=0 && q < problem_size_.Q); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ masks_[s_idx][v_idx][1] |= (pred << s); -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size.K); -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}); -+ } -+ -+private: -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_byte_offset_(LongIndex byte_offset, LongIndex byte_reset = 0) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ pointer_[s] += byte_offset - byte_reset; -+ } -+ } -+ -+public: -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ add_byte_offset_(pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void advance() { -+ -+ int next_idx = 0; -+ int64_t reset_bytes = 0; -+ -+ // Move filter_s by stride_w -+ filter_s_ += problem_size_.stride_w; -+ if (filter_s_ >= problem_size_.S) { -+ -+ // Restore filter_s -+ filter_s_ = start_s_; -+ -+ // Move filter_r by stride_h -+ filter_r_ += problem_size_.stride_h; -+#if 0 -+ if (filter_r_ < problem_size_.R) { -+ -+ next_idx = 1; -+ -+ // Restore bytes in q coordinate (Mma in filter s dimenstion) -+ reset_bytes = reset_bytes_s_; -+ -+ } else { -+ -+ // Restore filter_r -+ filter_r_ = start_r_; -+ -+ next_idx = 2; -+ -+ // Restore bytes in p and q coordinate (Mma in filter s and r dimenstion) -+ reset_bytes = reset_bytes_r_; -+ } -+#else -+ asm volatile( -+ "{\n\t" -+ " .reg .pred %%p;\n\t" -+ " setp.lt.s32 %%p, %3, %4;\n\t" -+ " selp.s32 %0, %3, %5, %%p;\n\t" -+ " selp.s32 %1, 1, 2, %%p;\n\t" -+ " selp.s64 %2, %6, %7, %%p;\n\t" -+ "}\n" -+ : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) -+ : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), -+ "l"(reset_bytes_s_), "l"(reset_bytes_r_)); -+#endif -+ } -+ -+ // offset pointers by offset_bytes -+ add_byte_offset_(params_.inc_next[next_idx] - reset_bytes); -+ -+ if (next_idx == 2) { -+ filter_k_ += params_.filter_k_delta; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K); -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool clear = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; -+ masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; -+ } -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(int v, bool clear = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; -+ masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; -+ } -+ } -+ -+ /// Returns true if the current coordinate is within the output tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ return -+ (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && -+ (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // Limit on filter size -+ if (problem_size.R > 32 || problem_size.S > 32) { -+ return Status::kErrorNotSupported; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Conv2dDgradOutputGradientTileAccessIteratorOptimized unity stride dgrad is optimized for dgrad -+// with problem stride = {1x1} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradOutputGradientTileAccessIteratorOptimized < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kUnity, -+ AccessType_ -+> { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ using Mask = uint64_t; -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dDgradOutputGradientIteratorOptimizedParams; -+ -+private: -+ -+ Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ -+ // One pointer per access -+ char const *pointer_[ThreadMap::Iterations::kStrided]; -+ -+ // current filter position (r, s) -+ int filter_r_; -+ int filter_s_; -+ int filter_k_; -+ -+ Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorOptimized( -+ Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ filter_k_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ int offset_n[ThreadMap::Iterations::kStrided]; -+ int offset_h[ThreadMap::Iterations::kStrided]; -+ int offset_w[ThreadMap::Iterations::kStrided]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ pointer_[s] = reinterpret_cast(ptr); -+ -+ int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // offset_n[s] = offset_nhw / (problem_size_.H * problem_size_.W); -+ // int residual = offset_nhw % (problem_size_.H * problem_size_.W); -+ // -+ // offset_h[s] = residual / problem_size_.W; -+ // offset_w[s] = residual % problem_size_.W; -+ // -+ -+ int residual; -+ -+ params_.hw_divmod(offset_n[s], residual, offset_nhw); -+ params_.w_divmod(offset_h[s], offset_w[s], residual); -+ -+ TensorCoord coord = at_(offset_n[s], offset_h[s], offset_w[s], 0, 0); -+ -+ pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; -+ } -+ -+ clear_mask(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int r = 0; r < problem_size_.R; ++r) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int r_ = r; -+ if (problem_size_.mode == Mode::kConvolution) { -+ r_ = problem_size_.R - 1 - r; -+ } -+ -+ int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h; -+ -+ bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ masks_[s_idx][v_idx][0] |= (pred << r); -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int s = 0; s < problem_size_.S; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int s_ = s; -+ if (problem_size_.mode == Mode::kConvolution) { -+ s_ = problem_size_.S - 1 - s; -+ } -+ -+ int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w; -+ -+ bool pred = (q >= 0 && q < problem_size_.Q); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ masks_[s_idx][v_idx][1] |= (pred << s); -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, filter_k_ + v_idx * AccessType::kElements >= problem_size.K); -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); -+ } -+ -+private: -+ -+ /// Returns the coordinate in the output gradient tensor dy that is correspoinding to -+ // activation nhw and filter position k, r, s -+ CUTLASS_HOST_DEVICE -+ TensorCoord at_(int n, int h, int w, int r, int s) const { -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = problem_size_.R - 1 - r; -+ s = problem_size_.S - 1 - s; -+ } -+ -+ int p = h + problem_size_.pad_h - r * problem_size_.dilation_h; -+ int q = w + problem_size_.pad_w - s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, p, q, filter_k_); -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_byte_offset_(LongIndex byte_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ pointer_[s] += byte_offset; -+ } -+ } -+ -+public: -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ add_byte_offset_(pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ int next_idx = 0; -+ -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ == problem_size_.S) { -+ filter_s_ = 0; -+ ++filter_r_; -+ -+ if (filter_r_ < problem_size_.R) { -+ next_idx = 1; -+ } -+ else { -+ filter_r_ = 0; -+ next_idx = 2; -+ } -+ } -+ -+ add_byte_offset_(params_.inc_next[next_idx]); -+ -+ if (next_idx == 2) { -+ filter_k_ += params_.filter_k_delta; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K); -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool clear = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; -+ masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; -+ } -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(int v, bool clear = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; -+ masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ return -+ (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && -+ (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // This is specialized for unit stride -+ if (problem_size.stride() != MatrixCoord({1, 1})) { -+ return Status::kErrorNotSupported; -+ } -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorNotSupported; -+ } -+ -+ // Limit on filter size -+ if (problem_size.R > 32 || problem_size.S > 32) { -+ return Status::kErrorNotSupported; -+ } -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..e667ddd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h -@@ -0,0 +1,332 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray, -+ conv::GroupMode GroupMode_ = conv::GroupMode::kNone -+> -+class Conv2dFpropActivationTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ static conv::GroupMode const kGroupMode = GroupMode_; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int filter_c_; -+ int filter_r_; -+ int filter_s_; -+ int filter_c_init_; -+ int group_idx_offset_; -+ int channels_per_group_; -+ int crs_cnt_; -+ int crs_per_group_; -+ -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_p_[ThreadMap::Iterations::kStrided]; -+ int offset_q_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ crs_cnt_(0), -+ group_idx_offset_(0), -+ filter_c_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ if (kGroupMode != conv::GroupMode::kNone) { -+ filter_c_init_ = filter_c_; -+ channels_per_group_ = problem_size_.C / problem_size_.groups; -+ crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kColumn - 1) / Shape::kColumn); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); -+ int residual = offset_npq % (problem_size_.P * problem_size_.Q); -+ -+ offset_p_[s] = residual / problem_size_.Q; -+ offset_q_[s] = residual % problem_size_.Q; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ if (kGroupMode != conv::GroupMode::kNone) { -+ ++crs_cnt_; -+ } -+ -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ -+ if (kGroupMode == conv::GroupMode::kNone) { -+ filter_c_ += Shape::kColumn * problem_size_.split_k_slices; -+ } else { -+ if (crs_cnt_ == crs_per_group_) { -+ // moves to next group -+ crs_cnt_ = 0; -+ ++group_idx_offset_; -+ filter_c_ = group_idx_offset_ * channels_per_group_ + filter_c_init_; -+ } else { -+ filter_c_ += Shape::kColumn * problem_size_.split_k_slices; -+ } -+ } -+ } -+ -+ /// Returns the coordinate in the activations tensor X that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int n = offset_n_[iteration_strided_]; -+ int p = offset_p_[iteration_strided_]; -+ int q = offset_q_[iteration_strided_]; -+ -+ int r = filter_r_; -+ int s = filter_s_; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = (problem_size_.R - 1 - filter_r_); -+ s = (problem_size_.S - 1 - filter_s_); -+ } -+ -+ int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; -+ int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ int c = filter_c_ + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord(n, h, w, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor X -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W && -+ coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ return ptr; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h -new file mode 100644 -index 0000000..1b668ce ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h -@@ -0,0 +1,360 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dFpropActivationTileAccessIteratorFewChannels { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFewChannels; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kPositionsPerTile = Shape::kColumn; -+ -+ static int const kAccessesPerVector = kElementsPerAccess / AccessType::kElements; -+ -+ static bool const kUseFastDivmodPrologue = true; -+ static bool const kUseFastDivmodMainloop = true; -+ -+ static int const kStrideH = 0; -+ static int const kStrideW = 0; -+ static int const kDilationH = 0; -+ static int const kDilationW = 0; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dFewChannelsParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int rsc_index_; -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_p_[ThreadMap::Iterations::kStrided]; -+ int offset_q_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorFewChannels( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ rsc_index_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ rsc_index_ = (threadblock_offset.column() + thread_coord.contiguous()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ if (kUseFastDivmodPrologue) { -+ int residual = params_.divmod_Q.divmod(offset_q_[s], offset_npq); -+ offset_n_[s] = params_.divmod_P.divmod(offset_p_[s], residual); -+ } -+ else { -+ offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); -+ int residual = offset_npq % (problem_size_.P * problem_size_.Q); -+ -+ offset_p_[s] = residual / problem_size_.Q; -+ offset_q_[s] = residual % problem_size_.Q; -+ } -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ rsc_index_ += kPositionsPerTile * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the activations tensor X that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int n = offset_n_[iteration_strided_]; -+ int p = offset_p_[iteration_strided_]; -+ int q = offset_q_[iteration_strided_]; -+ -+ int rsc_index = rsc_index_ + iteration_vector_ * AccessType::kElements; -+ -+ int r = 0; -+ int s = 0; -+ int c = 0; -+ -+ if (kUseFastDivmodMainloop) { -+ int rs_index = params_.divmod_C.divmod(c, rsc_index); -+ r = params_.divmod_S.divmod(s, rs_index); -+ } -+ else { -+ c = (rsc_index % problem_size_.C); -+ -+ int rs_index = (rsc_index / problem_size_.C); -+ s = (rs_index % problem_size_.S); -+ r = (rs_index / problem_size_.S); -+ } -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ int stride_h = kStrideH; -+ if (!kStrideH) { -+ stride_h = problem_size_.stride_h; -+ } -+ -+ int stride_w = kStrideW; -+ if (!kStrideW) { -+ stride_w = problem_size_.stride_w; -+ } -+ -+ int dilation_h = kDilationH; -+ if (!kDilationH) { -+ dilation_h = problem_size_.dilation_h; -+ } -+ -+ int dilation_w = kDilationW; -+ if (!kDilationW) { -+ dilation_w = problem_size_.dilation_w; -+ } -+ -+ int h = p * stride_h - problem_size_.pad_h + r * dilation_h; -+ int w = q * stride_w - problem_size_.pad_w + s * dilation_w; -+ -+ return TensorCoord(n, h, w, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor X -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ bool in_bounds = -+ coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W && -+ coord.c() < problem_size_.C; -+ -+ return in_bounds; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ -+ int32_t offset = -+ coord.n() * params_.stride_n + -+ coord.h() * params_.stride_h + -+ coord.w() * params_.stride_w + -+ coord.c(); -+ -+ AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ return ptr; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorFewChannels &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kDilationH && problem_size.dilation_h != kDilationH) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kDilationW && problem_size.dilation_w != kDilationW) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kStrideH && problem_size.stride_h != kStrideH) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kStrideW && problem_size.stride_w != kStrideW) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h -new file mode 100644 -index 0000000..3e680f4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h -@@ -0,0 +1,353 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dFpropActivationTileAccessIteratorFixedChannels { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFixedChannels; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kFilterPositionsPerTile = Shape::kColumn / AccessType::kElements; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static bool const kUseFastDivmodPrologue = true; -+ static bool const kUseFastDivmodMainloop = true; -+ -+ static int const kStrideH = 0; -+ static int const kStrideW = 0; -+ static int const kDilationH = 0; -+ static int const kDilationW = 0; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dFewChannelsParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int rs_index_; -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_p_[ThreadMap::Iterations::kStrided]; -+ int offset_q_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorFixedChannels( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ rs_index_(0) { -+ -+ // -+ // This requires problem_size.C == AccessType::kElements -+ // -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ rs_index_ = (threadblock_offset.column() + thread_coord.contiguous()) / AccessType::kElements; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ if (kUseFastDivmodPrologue) { -+ int residual = params_.divmod_Q.divmod(offset_q_[s], offset_npq); -+ offset_n_[s] = params_.divmod_P.divmod(offset_p_[s], residual); -+ } -+ else { -+ offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); -+ int residual = offset_npq % (problem_size_.P * problem_size_.Q); -+ -+ offset_p_[s] = residual / problem_size_.Q; -+ offset_q_[s] = residual % problem_size_.Q; -+ } -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ rs_index_ += kFilterPositionsPerTile * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the activations tensor X that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int n = offset_n_[iteration_strided_]; -+ int p = offset_p_[iteration_strided_]; -+ int q = offset_q_[iteration_strided_]; -+ -+ int rs_index = rs_index_ + iteration_vector_; -+ -+ int r = 0; -+ int s = 0; -+ -+ if (kUseFastDivmodMainloop) { -+ r = params_.divmod_S.divmod(s, rs_index); -+ } -+ else { -+ s = (rs_index % problem_size_.S); -+ r = (rs_index / problem_size_.S); -+ } -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ int stride_h = kStrideH; -+ if (!kStrideH) { -+ stride_h = problem_size_.stride_h; -+ } -+ -+ int stride_w = kStrideW; -+ if (!kStrideW) { -+ stride_w = problem_size_.stride_w; -+ } -+ -+ int dilation_h = kDilationH; -+ if (!kDilationH) { -+ dilation_h = problem_size_.dilation_h; -+ } -+ -+ int dilation_w = kDilationW; -+ if (!kDilationW) { -+ dilation_w = problem_size_.dilation_w; -+ } -+ -+ int h = p * stride_h - problem_size_.pad_h + r * dilation_h; -+ int w = q * stride_w - problem_size_.pad_w + s * dilation_w; -+ -+ return TensorCoord(n, h, w, 0); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor X -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ -+ int32_t offset = -+ coord.n() * params_.stride_n + -+ coord.h() * params_.stride_h + -+ coord.w() * params_.stride_w + coord.c(); -+ -+ AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ return ptr; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorFixedChannels &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C != AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kDilationH && problem_size.dilation_h != kDilationH) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kDilationW && problem_size.dilation_w != kDilationW) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kStrideH && problem_size.stride_h != kStrideH) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kStrideW && problem_size.stride_w != kStrideW) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..fb1fcfc ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h -@@ -0,0 +1,422 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dFpropActivationTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ using Mask = uint64_t; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dFpropActivationIteratorOptimizedParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ -+ // One pointer per access -+ char const *pointer_[ThreadMap::Iterations::kStrided]; -+ -+ // current filter position (r, s) -+ int filter_r_; -+ int filter_s_; -+ int filter_c_; -+ -+ Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorOptimized( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ filter_c_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ int offset_n[ThreadMap::Iterations::kStrided]; -+ int offset_p[ThreadMap::Iterations::kStrided]; -+ int offset_q[ThreadMap::Iterations::kStrided]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ pointer_[s] = reinterpret_cast(ptr); -+ -+ int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // offset_n[s] = offset_npq / (problem_size_.P * problem_size_.Q); -+ // int residual = offset_npq % (problem_size_.P * problem_size_.Q); -+ // -+ // offset_p[s] = residual / problem_size_.Q; -+ // offset_q[s] = residual % problem_size_.Q; -+ // -+ -+ int residual; -+ -+ params.pq_divmod(offset_n[s], residual, offset_npq); -+ params.q_divmod(offset_p[s], offset_q[s], residual); -+ -+ TensorCoord coord = at_(offset_n[s], offset_p[s], offset_q[s], 0, 0); -+ -+ pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; -+ } -+ -+ clear_mask(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int r = 0; r < problem_size_.R; ++r) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int r_ = r; -+ if (problem_size_.mode == Mode::kConvolution) { -+ r_ = problem_size_.R - 1 - r; -+ } -+ -+ int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h; -+ -+ bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ masks_[s_idx][v_idx][0] |= (pred << r); -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int s = 0; s < problem_size_.S; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int s_ = s; -+ if (problem_size_.mode == Mode::kConvolution) { -+ s_ = problem_size_.S - 1 - s; -+ } -+ -+ int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w; -+ -+ bool pred = (w >= 0 && w < problem_size_.W); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ masks_[s_idx][v_idx][1] |= (pred << s); -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); -+ } -+ -+private: -+ -+ /// Returns the coordinate in the activations tensor X that is correspoinding to -+ // output npq and filter position r, s -+ CUTLASS_HOST_DEVICE -+ TensorCoord at_(int n, int p, int q, int r, int s) const { -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = problem_size_.R - 1 - r; -+ s = problem_size_.S - 1 - s; -+ } -+ -+ int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; -+ int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, h, w, filter_c_); -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_byte_offset_(LongIndex byte_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ pointer_[s] += byte_offset; -+ } -+ } -+ -+public: -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ add_byte_offset_(pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ int next_idx = 0; -+ -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ == problem_size_.S) { -+ filter_s_ = 0; -+ ++filter_r_; -+ -+ if (filter_r_ < problem_size_.R) { -+ next_idx = 1; -+ } -+ else { -+ filter_r_ = 0; -+ next_idx = 2; -+ } -+ } -+ -+ add_byte_offset_(params_.inc_next[next_idx]); -+ -+ if (next_idx == 2) { -+ filter_c_ += params_.filter_c_delta; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool clear = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ masks_[s][v][0] = clear ? 0 : masks_[s][v][0]; -+ masks_[s][v][1] = clear ? 0 : masks_[s][v][1]; -+ } -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(int v, bool clear = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ masks_[s][v][0] = clear ? 0 : masks_[s][v][0]; -+ masks_[s][v][1] = clear ? 0 : masks_[s][v][1]; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ return -+ (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && -+ (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorOptimized &operator++() { -+ -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Conv2dFpropActivationTileAccessIteratorOptimized has constraint on filter positions -+ // due to the number of mask bits. -+ if (problem_size.R > 32 || problem_size.S > 32) { -+ return Status::kErrorNotSupported; -+ } -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..5c7dbd7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h -@@ -0,0 +1,319 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray, -+ conv::GroupMode GroupMode_ = conv::GroupMode::kNone -+> -+class Conv2dFpropFilterTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ static conv::GroupMode const kGroupMode = GroupMode_; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int filter_r_; -+ int filter_s_; -+ int filter_c_; -+ int filter_c_init_; -+ int crs_cnt_; -+ int crs_per_group_; -+ int group_idx_offset_c_; -+ int channels_per_group_; -+ -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ int group_idx_offset_k_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ crs_cnt_(0), -+ group_idx_offset_c_(0), -+ filter_r_(0), -+ filter_s_(0), -+ filter_c_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); -+ -+ if (kGroupMode != conv::GroupMode::kNone) { -+ filter_c_init_ = filter_c_; -+ if (kGroupMode == conv::GroupMode::kDepthwise){ -+ channels_per_group_ = 1; -+ crs_per_group_ = problem_size_.S * problem_size_.R; -+ } else { -+ channels_per_group_ = problem_size_.C / problem_size_.groups; -+ crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kRow - 1) / Shape::kRow); -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ if (kGroupMode != conv::GroupMode::kNone && kGroupMode != conv::GroupMode::kDepthwise) { -+ group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (problem_size_.K / problem_size_.groups); -+ } -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * 8 / sizeof_bits::value; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ if (kGroupMode != conv::GroupMode::kNone) { -+ ++crs_cnt_; -+ } -+ -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ -+ if (kGroupMode == conv::GroupMode::kNone) { -+ filter_c_ += Shape::kRow * problem_size_.split_k_slices; -+ } else { -+ if (crs_cnt_ == crs_per_group_) { -+ crs_cnt_ = 0; -+ filter_c_ = filter_c_init_; -+ if (kGroupMode != conv::GroupMode::kDepthwise) { -+ // moves to next group -+ ++group_idx_offset_c_; -+ } -+ } else { -+ filter_c_ += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ } -+ -+ /// Returns the coordinate in the filter tensor W that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int k = offset_k_[iteration_strided_]; -+ int c = filter_c_ + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord(k, filter_r_, filter_s_, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ if (kGroupMode == conv::GroupMode::kNone) { -+ return coord.n() < problem_size_.K && coord.c() < problem_size_.C; -+ } else if (kGroupMode == conv::GroupMode::kDepthwise) { -+ return coord.n() < problem_size_.K && coord.c() < 1; // channels_per_group_ is always equal to ONE. -+ } else { -+ return coord.n() < problem_size_.K && coord.c() < channels_per_group_ && -+ group_idx_offset_c_ == group_idx_offset_k_[iteration_strided_]; -+ } -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h -new file mode 100644 -index 0000000..f0a3219 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dFpropFilterTileAccessIteratorFewChannels { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFewChannels; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kPositionsPerTile = Shape::kRow; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static bool const kUseFastDivmodPrologue = true; -+ static bool const kUseFastDivmodMainloop = true; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dFewChannelsParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int rsc_index_; -+ -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorFewChannels( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ rsc_index_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ rsc_index_ = (threadblock_offset.row() + thread_coord.contiguous()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * 8 / sizeof_bits::value; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ rsc_index_ += kPositionsPerTile * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the filter tensor W that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int rsc_index = rsc_index_ + iteration_vector_ * AccessType::kElements; -+ -+ int c = 0; -+ int s = 0; -+ int r = 0; -+ -+ if (kUseFastDivmodMainloop) { -+ int rs_index = params_.divmod_C.divmod(c, rsc_index); -+ r = params_.divmod_S.divmod(s, rs_index); -+ } -+ else { -+ c = (rsc_index % problem_size_.C); -+ int rs_index = (rsc_index / problem_size_.C); -+ -+ s = (rs_index % problem_size_.S); -+ r = (rs_index / problem_size_.S); -+ } -+ -+ int k = offset_k_[iteration_strided_]; -+ -+ return TensorCoord(k, r, s, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ bool in_bounds = -+ coord.n() < problem_size_.K && -+ coord.h() >= 0 && -+ coord.h() < problem_size_.R && -+ coord.c() < problem_size_.C; -+ -+ return in_bounds; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ -+ int32_t offset = -+ coord.n() * params_.stride_n + -+ coord.h() * params_.stride_h + -+ coord.w() * params_.stride_w + -+ coord.c(); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorFewChannels &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h -new file mode 100644 -index 0000000..6536f62 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h -@@ -0,0 +1,275 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dFpropFilterTileAccessIteratorFixedChannels { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFixedChannels; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kFilterPositionsPerTile = Shape::kRow / AccessType::kElements; -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static bool const kUseFastDivmodPrologue = true; -+ static bool const kUseFastDivmodMainloop = true; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dFewChannelsParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int rs_index_; -+ -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorFixedChannels( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ rs_index_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ rs_index_ = (threadblock_offset.row() + thread_coord.contiguous()) / AccessType::kElements; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * 8 / sizeof_bits::value; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ rs_index_ += kFilterPositionsPerTile * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the filter tensor W that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int rs_index = rs_index_ + iteration_vector_; -+ -+ int r = 0; -+ int s = 0; -+ -+ if (kUseFastDivmodMainloop) { -+ r = params_.divmod_S.divmod(s, rs_index); -+ } -+ else { -+ s = (rs_index % problem_size_.S); -+ r = (rs_index / problem_size_.S); -+ } -+ -+ int k = offset_k_[iteration_strided_]; -+ -+ return TensorCoord(k, r, s, 0); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.K && coord.h() >= 0 && coord.h() < problem_size_.R; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ -+ int32_t offset = -+ coord.n() * params_.stride_n + -+ coord.h() * params_.stride_h + -+ coord.w() * params_.stride_w + coord.c(); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorFixedChannels &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C != AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..a85c620 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h -@@ -0,0 +1,317 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dFpropFilterTileAccessIteratorOptimized{ -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv2dFpropFilterIteratorOptimizedParams { -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv2dFpropFilterIteratorOptimizedParams const &base): -+ Conv2dFpropFilterIteratorOptimizedParams(base) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout -+ ): -+ Conv2dFpropFilterIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ) { -+ -+ } -+ }; -+ -+private: -+ -+ Conv2dFpropFilterIteratorOptimizedParams const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ uint32_t predicates_[kAccessesPerVector]; -+ int filter_rs_; -+ int filter_c_; -+ int channels_per_group_; -+ -+ // -+ // Assertions -+ // -+ -+ // We map predicates into bits packed in this uint32_t container -+ static_assert(ThreadMap::Iterations::kStrided < sizeof(predicates_) * 8, -+ "Currently, the number of loads per iteration is limited by the size of the predicates container."); -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorOptimized( -+ Conv2dFpropFilterIteratorOptimizedParams const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_{0}, -+ filter_rs_(0), -+ filter_c_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); -+ Index column = threadblock_offset.column() + thread_coord.strided(); -+ channels_per_group_ = problem_size_.C / problem_size_.groups; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ predicates_[v_idx] |= (pred << s); -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); -+ } -+ -+ pointer_ += ( -+ params_.layout({filter_c_, column}) -+ ) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ LongIndex next = params_.inc_next_rs; -+ -+ // moves to the next tile -+ ++filter_rs_; -+ if (filter_rs_ == params_.RS) { -+ -+ filter_rs_ = 0; -+ next = params_.inc_next_c; -+ filter_c_ += params_.filter_c_delta; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); -+ } -+ -+ pointer_ += next; -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(int v, bool clear = true) { -+ predicates_[v] = clear ? 0u : predicates_[v]; -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return (predicates_[iteration_vector_] & (1u << iteration_strided_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ return reinterpret_cast(pointer_) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ -+ // Move to the next K coordinate within the tile -+ pointer_ += params_.inc_next_k; -+ -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_params.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_params.h -new file mode 100644 -index 0000000..d96dee8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_params.h -@@ -0,0 +1,893 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Extracts the host-params objects into non-template code. -+*/ -+ -+#pragma once -+ -+#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED -+#include -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Params structure used for all Conv2d analytic tile iterators -+template< typename Layout_ = layout::TensorNHWC > -+struct Conv2dAnalyticParams { -+ -+ using Layout = Layout_; -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dAnalyticParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dAnalyticParams( -+ Conv2dProblemSize const &, // unused; placeholder to match other Params interfaces. -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Params structure used for all Conv2d analytic tile iterators -+template< typename Layout_ = layout::TensorNHWC > -+struct Conv2dFewChannelsParams { -+ -+ using Layout = Layout_; -+ -+ -+ int32_t stride_w; -+ int32_t stride_h; -+ int32_t stride_n; -+ -+ FastDivmod divmod_P; -+ FastDivmod divmod_Q; -+ FastDivmod divmod_S; -+ FastDivmod divmod_C; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFewChannelsParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFewChannelsParams( -+ Conv2dProblemSize const &problem_size, // unused; placeholder to match other Params interfaces. -+ Layout const &layout -+ ): -+ stride_w(int32_t(layout.stride()[0])), -+ stride_h(int32_t(layout.stride()[1])), -+ stride_n(int32_t(layout.stride()[2])), -+ divmod_P(problem_size.P), -+ divmod_Q(problem_size.Q), -+ divmod_S(problem_size.S), -+ divmod_C(problem_size.C) -+ { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams -+struct Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ int tiled_rows_per_filter; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape -+ ): layout(layout) { -+ -+ int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row()); -+ -+ tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED -+ -+CUTLASS_HOST_DEVICE -+void TraceIteratorParams( -+ char const *conv_operator, -+ char const *operand, -+ int element_size_bits, -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+) { -+ -+#if !defined(__CUDA_ARCH__) -+ -+ char const *fname = "conv_iterator_params.csv"; -+ -+ std::ifstream test(fname); -+ bool file_exists = test.is_open(); -+ -+ if (file_exists) { -+ test.close(); -+ } -+ -+ std::ofstream trace("conv_iterator_params.csv", std::ofstream::app); -+ -+ if (!file_exists) { -+ trace -+ << "Operator,Operand,ElementSize,CtaRows,CtaColumns,ThreadCount,AccessSize," -+ << "IterationsContiguous,IterationsStrided,DeltaContiguous,DeltaStrided\n"; -+ } -+ -+ trace << conv_operator << "," << operand << "," << element_size_bits << "," -+ << threadblock_shape.row() << "," << threadblock_shape.column() -+ << "," << thread_count << "," << access_size -+ << "," << threadmap_iterations.contiguous() << "," << threadmap_iterations.strided() -+ << "," << threadmap_delta.contiguous() << "," << threadmap_delta.strided() << "\n"; -+#endif -+} -+ -+#define TRACE_CONV_INITIALIZERS(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta) \ -+ TraceIteratorParams(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta); -+ -+#else -+ -+#define TRACE_CONV_INITIALIZERS(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta) {} -+ -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized -+template< typename Layout_ = layout::TensorNHWC > -+struct Conv2dFpropActivationIteratorOptimizedParams; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized -+template<> -+struct Conv2dFpropActivationIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ int64_t inc_next[3]; // {next S, next R, next C} -+ int filter_c_delta; // number of logical elements to add to filter_c_ -+ int PQ; // product of P*Q -+ -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), -+ PQ(problem_size.P * problem_size.Q), -+ pq_divmod(PQ), -+ q_divmod(problem_size.Q) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_fprop", "activation", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); -+ -+ // next S -+ inc_next[0] = conv_sign * ( -+ int64_t(layout.stride()[0]) * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = conv_sign * ( -+ int64_t(layout.stride()[1]) * problem_size.dilation_h -+ - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next C -+ inc_next[2] = ( -+ threadblock_shape.column() * problem_size.split_k_slices -+ - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h -+ - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // logical offset added to internal channel counter - units are elements, not bytes -+ filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; -+ } -+ -+#if ENABLE_CONV2D_PARAMS_PRINT -+ /// Prints internal state. -+ CUTLASS_HOST_DEVICE -+ void print() { -+ auto stride = layout.stride(); -+ printf( -+ "Conv2dFpropActivationIteratorOptimizedParams:\n" -+ " layout(w: %d, h: %d, n: %d)\n" -+ " inc_next[%ld, %ld, %ld]\n" -+ " filter_c_delta(%d) - PQ(%d)\n" -+ " pq_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n" -+ " q_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n", -+ stride[0], stride[1], stride[2], -+ inc_next[0], inc_next[1], inc_next[2], -+ filter_c_delta, -+ PQ, -+ pq_divmod.divisor, -+ pq_divmod.multiplier, -+ pq_divmod.shift_right, -+ q_divmod.divisor, -+ q_divmod.multiplier, -+ q_divmod.shift_right -+ ); -+ } -+#endif -+}; -+ -+/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized -+template -+struct Conv2dFpropActivationIteratorOptimizedParams> { -+ static int const kInterleaved = Interleaved_; -+ -+ using Layout = layout::TensorNCxHWx; -+ -+ Layout layout; -+ -+ int64_t inc_next[3]; // {next S, next R, next C} -+ int filter_c_delta; // number of logical elements to add to filter_c_ -+ int PQ; // product of P*Q -+ -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), PQ(problem_size.P * problem_size.Q), pq_divmod(PQ), q_divmod(problem_size.Q) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_fprop", "activation", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); -+ -+ // next S -+ inc_next[0] = conv_sign * (kInterleaved * problem_size.dilation_w) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = conv_sign * ( -+ int64_t(layout.stride()[0]) * problem_size.dilation_h -+ - (problem_size.S - 1) * kInterleaved * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next C -+ inc_next[2] = ( -+ threadblock_shape.column() * problem_size.split_k_slices / kInterleaved * int64_t(layout.stride()[1]) -+ - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[0] * problem_size.dilation_h -+ - conv_sign * int64_t(problem_size.S - 1) * kInterleaved * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // logical offset added to internal channel counter - units are elements, not bytes -+ filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template< typename Layout_ = layout::TensorNHWC > -+struct Conv2dFpropFilterIteratorOptimizedParams; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template<> -+struct Conv2dFpropFilterIteratorOptimizedParams -+{ -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ int RS; -+ int filter_c_delta; -+ -+ int64_t inc_next_k; // offset in units of bytes to next K position -+ int64_t inc_next_rs; // offset in units of bytes to next RS position -+ int64_t inc_next_c; // offset in units of bytes to next C position -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_fprop", "filter", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ RS = problem_size.R * problem_size.S; -+ -+ inc_next_k = (int64_t(layout.stride()[2]) * threadmap_delta.strided() * element_size_bits) / 8; -+ -+ inc_next_rs = -+ ( int64_t(layout.stride()[0]) -+ - int64_t(layout.stride()[2]) * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() -+ ) * element_size_bits / 8; -+ -+ inc_next_c = -+ ( -+ threadblock_shape.row() * problem_size.split_k_slices -+ - int64_t(RS - 1) * layout.stride()[0] -+ - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] -+ ) * element_size_bits / 8; -+ -+ filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; -+ } -+ -+#if ENABLE_CONV2D_PARAMS_PRINT -+ /// Prints internal state. -+ CUTLASS_HOST_DEVICE -+ void print() { -+ auto stride = layout.stride(); -+ printf( -+ "Conv2dFpropFilterIteratorOptimizedParams:\n" -+ " layout[%d, %d, %d]\n" -+ " RS(%d), filter_c_delta(%d), inc_next(k: %ld, rs: %ld, c: %ld)\n", -+ stride[0], stride[1], stride[2], -+ RS, -+ filter_c_delta, -+ inc_next_k, inc_next_rs, inc_next_c -+ ); -+ } -+#endif -+}; -+ -+template -+struct Conv2dFpropFilterIteratorOptimizedParams> -+{ -+ static int const kInterleaved = Interleaved_; -+ using Layout = layout::TensorCxRSKx; -+ -+ Layout layout; -+ int RS; -+ int filter_c_delta; -+ -+ int64_t inc_next_k; // offset in units of bytes to next K position -+ int64_t inc_next_rs; // offset in units of bytes to next RS position -+ int64_t inc_next_c; // offset in units of bytes to next C position -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_fprop", "filter", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ RS = problem_size.R * problem_size.S; -+ -+ inc_next_k = (kInterleaved * threadmap_delta.strided() * element_size_bits) / 8; -+ -+ inc_next_rs = -+ ( int64_t(layout.stride()[0]) -+ - kInterleaved * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() -+ ) * element_size_bits / 8; -+ -+ inc_next_c = -+ ( -+ threadblock_shape.row() * problem_size.split_k_slices / kInterleaved * int64_t(layout.stride()[2]) -+ - int64_t(RS - 1) * layout.stride()[0] -+ - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * kInterleaved -+ ) * element_size_bits / 8; -+ -+ filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Dgrad Optimized Dy params (layout::TensorNHWC) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Parameters object for Conv2d DGRAD OutputGradient (dy) iterator -+struct Conv2dDgradOutputGradientIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ int64_t inc_next[3]; // {next S, next R, next K} -+ -+ int filter_k_delta; // number of logical elements to add to filter_k_ -+ -+ int HW; // product of H*W -+ -+ FastDivmod hw_divmod; -+ FastDivmod w_divmod; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), -+ HW(problem_size.H *problem_size.W), -+ hw_divmod(HW), -+ w_divmod(problem_size.W) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_dgrad", "output_gradient", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ int conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); -+ -+ // next S -+ inc_next[0] = conv_sign * ( -+ (int64_t)layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = conv_sign * ( -+ (int64_t)layout.stride()[1] * problem_size.dilation_h -+ - (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next K -+ inc_next[2] = ( -+ threadblock_shape.column() * problem_size.split_k_slices -+ - conv_sign * (problem_size.R - 1) * (int64_t)layout.stride()[1] * problem_size.dilation_h -+ - conv_sign * (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // logical offset added to internal channel counter - units are elements, not bytes -+ filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Strided Dgrad Optimized Dy params (layout::TensorNHWC) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+struct Conv2dStridedDgradOutputGradientIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ int64_t inc_next[3]; // {next S, next R, next K} -+ -+ int filter_k_delta; // number of logical elements to add to filter_k_ -+ -+ int tiled_rows_per_filter; -+ -+ int conv_sign; -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dStridedDgradOutputGradientIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dStridedDgradOutputGradientIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape -+ ): layout(layout) { -+ -+ int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row()); -+ -+ tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row(); -+ -+ conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); -+ -+ // next S -+ inc_next[0] = conv_sign * ( -+ (int64_t)layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = conv_sign * ( -+ (int64_t)layout.stride()[1] * problem_size.dilation_h -+ ) * element_size_bits / 8; -+ -+ // next K -+ inc_next[2] = ( -+ threadblock_shape.column() * problem_size.split_k_slices -+ ) * element_size_bits / 8; -+ -+ // logical offset added to internal channel counter - units are elements, not bytes -+ filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+// Dgrad Optimized w params (layout::TensorNHWC) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+struct Conv2dDgradFilterIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ int RS; -+ int filter_k_delta; -+ -+ int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile -+ int64_t inc_next_rs; // offset in units of bytes to next RS position -+ int64_t inc_next_k; // offset in units of bytes to next K position in subsequent tile -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), RS(problem_size.R * problem_size.S) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ inc_next_strided = ((int64_t)layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; -+ -+ inc_next_rs = -+ ( (int64_t)layout.stride()[0] -+ - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] -+ ) * element_size_bits / 8; -+ -+ inc_next_k = -+ ( -+ threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] -+ - (problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] -+ - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] -+ ) * element_size_bits / 8; -+ -+ filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+// StridedDgrad Optimized w params (layout::TensorNHWC) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+struct Conv2dStridedDgradFilterIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ int RS; -+ int filter_k_delta; -+ -+ int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile -+ int64_t inc_next[3]; // {next S, next R, next K} -+ int64_t reset_bytes; // offset in units of bytes to move back the pointer -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv2dStridedDgradFilterIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dStridedDgradFilterIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), RS(problem_size.R * problem_size.S) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; -+ -+ // next S -+ inc_next[0] = -+ ( (int64_t)layout.stride()[0] * problem_size.stride_w -+ //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] -+ ) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = -+ ( (int64_t)layout.stride()[1] * problem_size.stride_h -+ //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] -+ ) * element_size_bits / 8; -+ -+ // next K -+ inc_next[2] = -+ ( -+ threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] -+ //- (problem_size.R * problem_size.S - 1) * layout.stride()[0] -+ //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] -+ ) * element_size_bits / 8; -+ -+ // offset in units of bytes to move the pointer in backward direction -+ reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] -+ * element_size_bits / 8; -+ -+ filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters object for Conv2d WGRAD Output Gradient (dy) iterator -+struct Conv2dWgradOutputGradientIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ int NPQ; // precomputd product of N*P*Q for clearing predicates -+ -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ int64_t offset_next_strided; // offset in units of bytes to next npq coordinate within tile -+ int64_t offset_next_contiguous; // offset in units of bytes to next k coordinate within tile -+ int64_t inc_next_npq; // offset in units of bytes to next npq position in subsequent tile -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradOutputGradientIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradOutputGradientIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), -+ NPQ(problem_size.N * problem_size.P * problem_size.Q), -+ pq_divmod(problem_size.P * problem_size.Q), -+ q_divmod(problem_size.Q) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_wgrad", "output_gradient", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ // Incremental offsets in unites of bytes (number of elements) * sizeof_bits::value / 8 -+ offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) -+ * element_size_bits / 8; -+ -+ offset_next_contiguous = (threadmap_delta.contiguous()) -+ * element_size_bits / 8; -+ -+ inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) -+ * element_size_bits / 8; -+ } -+}; -+ -+struct Conv2dWgradActivationIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ FastDivmod sc_divmod; -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ FastDivmod c_divmod; -+ FastDivmod s_divmod; -+ int small_channel_conv_s_offset; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout -+ ): -+ layout(layout), -+ sc_divmod(problem_size.S * problem_size.C), -+ pq_divmod(problem_size.P * problem_size.Q), -+ q_divmod(problem_size.Q), -+ c_divmod(problem_size.C), -+ s_divmod(problem_size.S * problem_size.dilation_w), -+ small_channel_conv_s_offset((problem_size.S - 1) * problem_size.dilation_w - problem_size.pad_w) { -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ Conv2dWgradActivationIteratorOptimizedParams( -+ problem_size, -+ layout -+ ) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_wgrad", "activation", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ } -+}; -+ -+struct PredicatedScaleBiasVectorAccessIteratorParams { -+ public: -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIteratorParams() { } -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIteratorParams( -+ Conv2dProblemSize const &problem_size, -+ layout::PitchLinear const &layout) {} -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIteratorParams( -+ Conv2dProblemSize const &problem_size, -+ layout::RowMajor const &layout) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h -new file mode 100644 -index 0000000..9c1742d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h -@@ -0,0 +1,337 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template wraps the tile access iterator concept to load whole tiles from tensors in -+ memory used for implicit GEMM convolution. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TileIterator { -+public: -+ using TileAccessIterator = TileAccessIterator_; -+ -+ using Shape = typename TileAccessIterator::Shape; -+ using Element = typename TileAccessIterator::Element; -+ using Layout = typename TileAccessIterator::Layout; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = typename TileAccessIterator::ThreadMap; -+ using AccessType = typename TileAccessIterator::AccessType; -+ using TensorRef = typename TileAccessIterator::TensorRef; -+ using Index = typename TileAccessIterator::Index; -+ using LongIndex = typename TileAccessIterator::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm; -+ static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport; -+ using Params = typename TileAccessIterator::Params; -+ static int const kConvDim = TileAccessIterator::kConvDim; -+ using ConvProblemSize = typename TileAccessIterator::ConvProblemSize; -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+private: -+ -+ /// Internal state -+ TileAccessIterator tile_access_iterator_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TileIterator( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ tile_access_iterator_(params, problem_size, ptr, thread_idx, threadblock_offset) { } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) { -+ return TileAccessIterator::getParams(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ tile_access_iterator_.set_iteration_index(index); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ tile_access_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ TileIterator &operator++() { -+ tile_access_iterator_.advance(); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ TileIterator operator++(int) { -+ TileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ frag.clear(); -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[idx], -+ tile_access_iterator_.get() + pointer_offset, -+ tile_access_iterator_.valid() -+ ); -+ -+ ++tile_access_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ tile_access_iterator_.set_iteration_index(0); -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void advance() { -+ tile_access_iterator_.advance(); -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(ConvProblemSize const &problem_size) { -+ -+ // dispatch to iterator implementation -+ return TileAccessIterator::can_implement(problem_size); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Strided Dgrad Tile Iterator -+template -+class TileIteratorStridedDgrad { -+public: -+ using TileAccessIterator = TileAccessIterator_; -+ -+ using Shape = typename TileAccessIterator::Shape; -+ using Element = typename TileAccessIterator::Element; -+ using Layout = typename TileAccessIterator::Layout; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = typename TileAccessIterator::ThreadMap; -+ using AccessType = typename TileAccessIterator::AccessType; -+ using TensorRef = typename TileAccessIterator::TensorRef; -+ using Index = typename TileAccessIterator::Index; -+ using LongIndex = typename TileAccessIterator::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm; -+ static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport; -+ using Params = typename TileAccessIterator::Params; -+ static int const kConvDim = TileAccessIterator::kConvDim; -+ using ConvProblemSize = typename TileAccessIterator::ConvProblemSize; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+private: -+ -+ /// Internal state -+ TileAccessIterator tile_access_iterator_; -+ -+public: -+ -+ /// Constructor (output gradient (Dy) OperandA ctor) -+ CUTLASS_HOST_DEVICE -+ TileIteratorStridedDgrad( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, -+ int start_r, int start_s, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ tile_access_iterator_( -+ params, -+ problem_size, -+ ptr, -+ thread_idx, -+ stride_h_divmod, stride_w_divmod, -+ start_r, start_s, -+ threadblock_offset) { } -+ -+ /// Constructor (filter (w) OperandB ctor) -+ CUTLASS_HOST_DEVICE -+ TileIteratorStridedDgrad( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ int start_r, int start_s, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ tile_access_iterator_(params, -+ problem_size, -+ ptr, -+ thread_idx, -+ start_r, start_s, -+ threadblock_offset) { } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) { -+ return TileAccessIterator::getParams(problem_size, layout); -+ } -+ -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ tile_access_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ TileIteratorStridedDgrad &operator++() { -+ tile_access_iterator_.advance(); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ TileIteratorStridedDgrad operator++(int) { -+ TileIteratorStridedDgrad self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ frag.clear(); -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[c + s * ThreadMap::Iterations::kContiguous], -+ tile_access_iterator_.get() + pointer_offset, -+ tile_access_iterator_.valid() -+ ); -+ -+ ++tile_access_iterator_; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ tile_access_iterator_.set_iteration_index(0); -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void advance() { -+ tile_access_iterator_.advance(); -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(ConvProblemSize const &problem_size) { -+ -+ // dispatch to iterator implementation -+ return TileAccessIterator::can_implement(problem_size); -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..6e73115 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h -@@ -0,0 +1,285 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dWgradActivationTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ // Filter postion (r,s,c) in contiguous dimension stays constant for each gemm_iteration_k -+ int filter_r_[ThreadMap::Iterations::kContiguous]; -+ int filter_s_[ThreadMap::Iterations::kContiguous]; -+ int filter_c_[ThreadMap::Iterations::kContiguous]; -+ -+ int offset_npq_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)) -+ { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ // initialize r,s,c filter position for every contiguous iteration -+ CUTLASS_PRAGMA_UNROLL -+ for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int rsc_offset = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ -+ filter_r_[c] = rsc_offset / (problem_size_.S * problem_size_.C); -+ int residual = rsc_offset % (problem_size_.S * problem_size_.C); -+ -+ filter_s_[c] = residual / problem_size_.C; -+ filter_c_[c] = residual % problem_size_.C; -+ } -+ -+ // initialize n, p, q offset for every strided iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ offset_npq_[s] = threadblock_offset.row() + thread_coord.strided() -+ + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ // moves to the next GEMM-K offset (offset_npq_) in GEMM-B by a CTA-K tile -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_npq_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the activation tensor x that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int r, s, c; -+ -+ if (kAccessesPerVector == 1) { -+ /// One 128b aligned access fetching more than one element -+ c = filter_c_[iteration_contiguous_]; -+ r = filter_r_[iteration_contiguous_]; -+ s = filter_s_[iteration_contiguous_]; -+ } -+ else { -+ /// Multiple access to support non-128b alignment in contiguous dimenstion -+ c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) % problem_size_.C; -+ int wrap_c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) / problem_size_.C; -+ s = (filter_s_[iteration_contiguous_] + wrap_c) % problem_size_.S; -+ int wrap_s = (filter_s_[iteration_contiguous_] + wrap_c) / problem_size_.S; -+ r = filter_r_[iteration_contiguous_] + wrap_s; -+ } -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q); -+ int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q); -+ -+ int p = residual / problem_size_.Q; -+ int q = residual % problem_size_.Q; -+ -+ int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; -+ int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, h, w, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activation tensor x -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..8871735 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h -@@ -0,0 +1,321 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dWgradActivationTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dWgradActivationIteratorOptimizedParams; -+ -+private: -+ -+ Conv2dWgradActivationIteratorOptimizedParams const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ // Precomputed effective filter postion (r,s) in contiguous dimension stays constant for each gemm_iteration_k -+ // required for npq -> nhw translation -+ int precomputed_filter_r_[ThreadMap::Iterations::kContiguous]; -+ int precomputed_filter_s_[ThreadMap::Iterations::kContiguous]; -+ -+ // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k -+ int filter_c_[ThreadMap::Iterations::kContiguous]; -+ -+ int offset_npq_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationTileAccessIteratorOptimized( -+ Conv2dWgradActivationIteratorOptimizedParams const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)) -+ { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ // initialize r,s,c filter position for every contiguous iteration -+ CUTLASS_PRAGMA_UNROLL -+ for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int rsc_offset = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // filter_r_[c] = rsc_offset / (problem_size_.S * problem_size_.C); -+ // int residual = rsc_offset % (problem_size_.S * problem_size_.C); -+ // -+ // filter_s_[c] = residual / problem_size_.C; -+ // filter_c_[c] = residual % problem_size_.C; -+ -+ int residual; -+ params_.sc_divmod(precomputed_filter_r_[c], residual, rsc_offset); -+ params_.c_divmod(precomputed_filter_s_[c], filter_c_[c], residual); -+ -+ int r = precomputed_filter_r_[c]; -+ int s = precomputed_filter_s_[c]; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ precomputed_filter_r_[c] = -problem_size_.pad_h + r * problem_size_.dilation_h; -+ precomputed_filter_s_[c] = -problem_size_.pad_w + s * problem_size_.dilation_w; -+ } -+ -+ // initialize n, p, q offset for every strided iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ offset_npq_[s] = threadblock_offset.row() + thread_coord.strided() -+ + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ // moves to the next GEMM-K offset (offset_npq_) in GEMM-B by a CTA-K tile -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_npq_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the activation tensor x that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int r = precomputed_filter_r_[iteration_contiguous_]; -+ int s = precomputed_filter_s_[iteration_contiguous_]; -+ int c = filter_c_[iteration_contiguous_]; -+ -+ if (kAccessesPerVector > 1) { -+ // This code section is only to support non-128b alignment -+ // Multiple access to support non-128b alignment in contiguous dimenstion -+ int wrap_c; -+ params_.c_divmod(wrap_c, c, c + iteration_vector_ * AccessType::kElements); -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ s -= (problem_size_.dilation_w * wrap_c); -+ -+ int wrap_s; -+ params_.s_divmod(wrap_s, s, params_.small_channel_conv_s_offset - s); -+ s = params_.small_channel_conv_s_offset - s; -+ -+ r -= (problem_size_.dilation_h * wrap_s); -+ -+ } else { -+ s += (problem_size_.dilation_w * wrap_c); -+ -+ int wrap_s; -+ params_.s_divmod(wrap_s, s, s + problem_size_.pad_w); -+ s -= problem_size_.pad_w; -+ -+ r += (problem_size_.dilation_h * wrap_s); -+ } -+ } -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q); -+ // int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q); -+ // -+ // int p = residual / problem_size_.Q; -+ // int q = residual % problem_size_.Q; -+ -+ int residual, n, p, q; -+ -+ params_.pq_divmod(n, residual, offset_npq_[iteration_strided_]); -+ params_.q_divmod(p, q, residual); -+ -+ int h = p * problem_size_.stride_h + r; -+ int w = q * problem_size_.stride_w + s; -+ -+ return TensorCoord(n, h, w, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activation tensor x -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..97fd31e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h -@@ -0,0 +1,260 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dWgradOutputGradientTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int filter_k_[ThreadMap::Iterations::kContiguous]; -+ -+ int offset_npq_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradOutputGradientTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ // initialize filter_k for every contiguous iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ filter_k_[c] = threadblock_offset.row() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ } -+ -+ // initialize n, p, q offset for every strided iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_npq_[s] = threadblock_offset.column() + thread_coord.strided() -+ + s * ThreadMap::Delta::kStrided; -+ -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_npq_[s] += Shape::kColumn * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the output gradient tensor Dy that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int npq = offset_npq_[iteration_strided_]; -+ -+ int n = npq / (problem_size_.P * problem_size_.Q); -+ int residual = npq % (problem_size_.P * problem_size_.Q); -+ -+ int p = residual / problem_size_.Q; -+ int q = residual % problem_size_.Q; -+ -+ int k = filter_k_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord(n, p, q, k); -+ } -+ -+ -+ /// Returns true if the current coordinate is within the output gradient tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.h() < problem_size_.P && -+ coord.w() < problem_size_.Q && -+ coord.c() < problem_size_.K; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradOutputGradientTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..6725ed4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h -@@ -0,0 +1,310 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dWgradOutputGradientTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dWgradOutputGradientIteratorOptimizedParams; -+ -+private: -+ -+ Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ uint32_t predicates_[kAccessesPerVector]; -+ int filter_k_; -+ int offset_npq_; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradOutputGradientTileAccessIteratorOptimized( -+ Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_{0}, -+ filter_k_(0), -+ offset_npq_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.row() + thread_coord.contiguous(); -+ offset_npq_ = threadblock_offset.column() + thread_coord.strided(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous; -+ int offset_npq = offset_npq_ + s * ThreadMap::Delta::kStrided; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ bool predicate = valid_(at_(offset_npq, filter_k + v * AccessType::kElements)); -+ -+ uint32_t pred = (predicate ? 1u : 0); -+ -+ int pred_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ predicates_[v] |= (pred << pred_idx); -+ } -+ } -+ } -+ -+ // Offset pointer to (iteration_strided_, iteration_contiguous_) = (0, 0) -+ pointer_ += ( -+ offset_npq_ * params.layout.stride()[0] + filter_k_ -+ ) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile -+ offset_npq_ += Shape::kColumn * problem_size_.split_k_slices; -+ -+ // Clear predicates if needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ if (offset_npq_ + s * ThreadMap::Delta::kStrided >= params_.NPQ) { -+ uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ predicates_[v] = (predicates_[v] & (~kClearMask)); -+ } -+ } -+ } -+ -+ pointer_ += params_.inc_next_npq; -+ } -+ -+private: -+ /// Returns the coordinate in the output gradient tensor Dy that is pointed to -+ /// by offset_npq and k. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at_(int offset_npq, int k) const { -+ -+ // The subsequent fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // int npq = offset_npq; -+ // int n = npq / (problem_size_.P * problem_size_.Q); -+ // int residual = npq % (problem_size_.P * problem_size_.Q); -+ // -+ // int p = residual / problem_size_.Q; -+ // int q = residual % problem_size_.Q; -+ -+ int residual, n, p, q; -+ -+ params_.pq_divmod(n, residual, offset_npq); -+ params_.q_divmod(p, q, residual); -+ -+ return TensorCoord(n, p, q, k); -+ } -+ -+ /// Returns true if the coord is within the output gradient tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid_(TensorCoord coord) const { -+ -+ return coord.n() < problem_size_.N && -+ coord.c() < problem_size_.K; -+ } -+ -+public: -+ -+ /// Returns true if the current coordinate is within the output gradient tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; -+ return (predicates_[iteration_vector_] & (1u << pred_idx)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast( -+ pointer_ + -+ iteration_strided_ * params_.offset_next_strided + -+ iteration_contiguous_ * params_.offset_next_contiguous -+ ) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradOutputGradientTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..8566f07 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h -@@ -0,0 +1,268 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dDgradFilterTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or larger."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params { -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ // For a fixed filter position (t,r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension -+ int filter_t_; -+ int filter_r_; -+ int filter_s_; -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ int offset_c_[ThreadMap::Iterations::kContiguous]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradFilterTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_t_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = -+ threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ ++filter_t_; -+ if (filter_t_ < problem_size_.T) { -+ return; -+ } -+ filter_t_ = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the filter tensor w that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int c = offset_c_[iteration_contiguous_]; -+ int k = offset_k_[iteration_strided_]; -+ -+ return TensorCoord(k, filter_t_, filter_r_, filter_s_, c); -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor w -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.K && coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradFilterTileAccessIteratorAnalytic &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..b9876ff ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity -+> -+class Conv3dDgradFilterTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = StrideSupport_; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv3dDgradFilterIteratorOptimizedParams { -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv3dDgradFilterIteratorOptimizedParams const &base): -+ Conv3dDgradFilterIteratorOptimizedParams(base) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout -+ ): -+ Conv3dDgradFilterIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ) { } -+ -+ }; -+ -+private: -+ -+ Conv3dDgradFilterIteratorOptimizedParams const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ uint32_t predicates_; -+ int filter_trs_; -+ int filter_k_; -+ -+ // -+ // Assertions -+ // -+ -+ // We map predicates into bits packed in this uint32_t container -+ static_assert(ThreadMap::Iterations::kStrided * -+ ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, -+ "Currently, the number of loads per iteration is limited by the size of the predicates container."); -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradFilterTileAccessIteratorOptimized( -+ Conv3dDgradFilterIteratorOptimizedParams const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_(0), -+ filter_trs_(0), -+ filter_k_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.row() + thread_coord.strided(); -+ Index column = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; -+ int filter_c = column + c * ThreadMap::Delta::kContiguous; -+ -+ uint32_t pred = ((filter_k < problem_size_.K && filter_c < problem_size_.C) ? 1u : 0); -+ -+ int pred_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ predicates_ |= (pred << pred_idx); -+ } -+ } -+ -+ pointer_ += ( -+ filter_k_ * params.layout.stride()[3] + column -+ ) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ LongIndex next = params_.inc_next_trs; -+ -+ // moves to the next tile -+ ++filter_trs_; -+ if (filter_trs_ == params_.TRS) { -+ -+ filter_trs_ = 0; -+ next = params_.inc_next_k; -+ filter_k_ += params_.filter_k_delta; -+ } -+ -+ // Clear predicates if needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { -+ uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); -+ -+ predicates_ = (predicates_ & (~kClearMask)); -+ } -+ } -+ -+ pointer_ += next; -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; -+ return (predicates_ & (1u << pred_idx)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ return reinterpret_cast(pointer_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradFilterTileAccessIteratorOptimized &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ -+ // Move to the next K coordinate within the tile -+ pointer_ += params_.inc_next_strided; -+ -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..5c399e2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h -@@ -0,0 +1,343 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided -+> -+class Conv3dDgradOutputGradientTileAccessIteratorAnalytic; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv3dDgradOutputGradientTileAccessIteratorAnalytic strided dgrad needs special handling using -+// unscaled coordinations -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dDgradOutputGradientTileAccessIteratorAnalytic < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kStrided -+> { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Simpligying assertions -+ // -+ -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params { -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ConvProblemSize const &problem_size, -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ ConvProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ int filter_k_; -+ int filter_t_; -+ int filter_r_; -+ int filter_s_; -+ -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_d_[ThreadMap::Iterations::kStrided]; -+ int offset_w_[ThreadMap::Iterations::kStrided]; -+ int offset_h_[ThreadMap::Iterations::kStrided]; -+ -+private: -+ -+ /// Returns the coordinate in the output tensor Dy that is currently pointed to -+ /// by the iterator but DOES NOT scale by the convolution stride. This is needed -+ /// to compute predicates in the valid() method. The return value of the public at() -+ /// method is correctly scaled. -+ CUTLASS_HOST_DEVICE -+ TensorCoord unscaled_at_() const { -+ int n = offset_n_[iteration_strided_]; -+ int d = offset_d_[iteration_strided_]; -+ int h = offset_h_[iteration_strided_]; -+ int w = offset_w_[iteration_strided_]; -+ -+ int t = filter_t_; -+ int r = filter_r_; -+ int s = filter_s_; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ t = (problem_size_.T - 1 - t); -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ int z = (d + problem_size_.pad_d - t * problem_size_.dilation_d); -+ int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h); -+ int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w); -+ -+ return TensorCoord(n, z, p, q, filter_k_); -+ } -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradOutputGradientTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_k_(0), -+ filter_t_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_ndhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ offset_n_[s] = offset_ndhw / (problem_size_.D * problem_size_.H * problem_size_.W); -+ int residual = offset_ndhw % (problem_size_.D * problem_size_.H * problem_size_.W); -+ -+ offset_d_[s] = residual / (problem_size_.H * problem_size_.W); -+ residual = residual % (problem_size_.H * problem_size_.W); -+ -+ offset_h_[s] = residual / problem_size_.W; -+ offset_w_[s] = residual % problem_size_.W; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // move to the next tile -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ ++filter_t_; -+ if (filter_t_ < problem_size_.T) { -+ return; -+ } -+ filter_t_ = 0; -+ -+ filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the output tensor Dy that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ TensorCoord coord = unscaled_at_(); -+ -+ return TensorCoord( -+ coord.n(), -+ coord.d() / problem_size_.stride_d, -+ coord.h() / problem_size_.stride_h, -+ coord.w() / problem_size_.stride_w, -+ coord.c()); -+ } -+ -+ -+ /// Returns true if the current coordinate is within the output tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord unscaled_coord = unscaled_at_(); -+ TensorCoord coord = at(); -+ -+ return -+ !(unscaled_coord.d() % problem_size_.stride_d) && -+ !(unscaled_coord.h() % problem_size_.stride_h) && -+ !(unscaled_coord.w() % problem_size_.stride_w) && -+ coord.n() < problem_size_.N && -+ coord.d() >= 0 && coord.d() < problem_size_.Z && -+ coord.h() >= 0 && coord.h() < problem_size_.P && -+ coord.w() >= 0 && coord.w() < problem_size_.Q && -+ coord.c() < problem_size_.K; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(ConvProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..f834a34 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity -+> -+class Conv3dDgradOutputGradientTileAccessIteratorOptimized { -+public: -+ -+ static_assert(StrideSupport_ == conv::StrideSupport::kUnity, -+ "Only unit-stride dgrad is supported at this time."); -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ using Coord3D = Coord<3>; -+ static int const kAccessesPerVector = 1; -+ using Mask = uint64_t; -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv3dDgradOutputGradientIteratorOptimizedParams; -+ -+private: -+ -+ Params const ¶ms_; -+ ConvProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ -+ -+ // One pointer per access -+ char const *pointer_[ThreadMap::Iterations::kStrided]; -+ -+ // current filter position (t, r, s) -+ int filter_t_; -+ int filter_r_; -+ int filter_s_; -+ int filter_k_; -+ -+ Index masks_[ThreadMap::Iterations::kStrided][3]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradOutputGradientTileAccessIteratorOptimized( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ filter_k_(0), -+ filter_t_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ int offset_n[ThreadMap::Iterations::kStrided]; -+ int offset_d[ThreadMap::Iterations::kStrided]; -+ int offset_h[ThreadMap::Iterations::kStrided]; -+ int offset_w[ThreadMap::Iterations::kStrided]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ pointer_[s] = reinterpret_cast(ptr); -+ -+ int offset_ndhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // offset_n[s] = offset_ndhw / (problem_size_.D * problem_size_.H * problem_size_.W); -+ // int residual = offset_ndhw % (problem_size_.D * problem_size_.H * problem_size_.W); -+ // -+ // -+ // offset_d[s] = residual / (problem_size_.H * problem_size_.W); -+ // residual = residual % (problem_size_.H * problem_size_.W); -+ // -+ // offset_h[s] = residual / problem_size_.W; -+ // offset_w[s] = residual % problem_size_.W; -+ // -+ -+ int residual; -+ -+ // input: (ndhw offset) output: (n offset and resudial (dhw offset)) -+ params_.dhw_divmod(offset_n[s], residual, offset_ndhw); -+ // input: (dhw offset) output: (d offset and resudial (hw)) -+ params_.hw_divmod(offset_d[s], residual, residual); -+ // input: (hw offset) output: (h offset and resudial (w offset)) -+ params_.w_divmod(offset_h[s], offset_w[s], residual); -+ -+ TensorCoord coord = at_(offset_n[s], offset_d[s], offset_h[s], offset_w[s], 0, 0, 0); -+ -+ pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; -+ } -+ -+ clear_mask(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int t = 0; t < problem_size_.T; ++t) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int t_ = t; -+ if (problem_size_.mode == Mode::kConvolution) { -+ t_ = problem_size_.T - 1 - t; -+ } -+ -+ int z = offset_d[s_idx] + problem_size_.pad_d - t_ * problem_size_.dilation_d; -+ -+ bool pred = (offset_n[s_idx] < problem_size_.N && z >= 0 && z < problem_size_.Z); -+ masks_[s_idx][0] |= (pred << t); -+ } -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int r = 0; r < problem_size_.R; ++r) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int r_ = r; -+ if (problem_size_.mode == Mode::kConvolution) { -+ r_ = problem_size_.R - 1 - r; -+ } -+ -+ int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h; -+ -+ bool pred = (p >= 0 && p < problem_size_.P); -+ masks_[s_idx][1] |= (pred << r); -+ } -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int s = 0; s < problem_size_.S; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int s_ = s; -+ if (problem_size_.mode == Mode::kConvolution) { -+ s_ = problem_size_.S - 1 - s; -+ } -+ -+ int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w; -+ -+ bool pred = (q >= 0 && q < problem_size_.Q); -+ masks_[s_idx][2] |= (pred << s); -+ } -+ } -+ -+ if (filter_k_ >= problem_size.K) { -+ clear_mask(); -+ } -+ -+ set_iteration_index(0); -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); -+ } -+ -+private: -+ -+ -+ /// Returns the coordinate in the output gradient tensor dy that is correspoinding to -+ // activation ndhw and filter position k, t, r, s -+ CUTLASS_HOST_DEVICE -+ TensorCoord at_(int n, int d, int h, int w, int t, int r, int s) const { -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ t = problem_size_.T - 1 - t; -+ r = problem_size_.R - 1 - r; -+ s = problem_size_.S - 1 - s; -+ } -+ -+ int z = d + problem_size_.pad_d - t * problem_size_.dilation_d; -+ int p = h + problem_size_.pad_h - r * problem_size_.dilation_h; -+ int q = w + problem_size_.pad_w - s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, z, p, q, filter_k_); -+ } -+ -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_byte_offset_(LongIndex byte_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ pointer_[s] += byte_offset; -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask_(bool clear) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ // We are using inline PTX assembly here to avoid an CUDA C++ compilation -+ // artifact in which control flow instructions are generated. Instead, our -+ // intent is to predicate the mov instructions. -+ #if defined(__CUDA_ARCH__) -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " .reg .u32 m;" -+ " mov.u32 m, %2;" -+ " setp.ne.b32 p, %1, 0;\n" -+ " @p mov.u32 m, 0;\n" -+ " mov.u32 %0, m;\n" -+ "}\n" -+ : -+ "=r"(masks_[s][0]) -+ : -+ "r"((int)clear), -+ "r"(masks_[s][0]) -+ ); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " .reg .u32 m;" -+ " mov.u32 m, %2;" -+ " setp.ne.b32 p, %1, 0;\n" -+ " @p mov.u32 m, 0;\n" -+ " mov.u32 %0, m;\n" -+ "}\n" -+ : -+ "=r"(masks_[s][1]) -+ : -+ "r"((int)clear), -+ "r"(masks_[s][1]) -+ ); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " .reg .u32 m;" -+ " mov.u32 m, %2;" -+ " setp.ne.b32 p, %1, 0;\n" -+ " @p mov.u32 m, 0;\n" -+ " mov.u32 %0, m;\n" -+ "}\n" -+ : -+ "=r"(masks_[s][2]) -+ : -+ "r"((int)clear), -+ "r"(masks_[s][2]) -+ ); -+ #else -+ if (clear) { -+ masks_[s][0] = 0; -+ masks_[s][1] = 0; -+ masks_[s][2] = 0; -+ } -+ #endif -+ } -+ } -+ -+public: -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ add_byte_offset_(pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ int next_idx = 0; -+ -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ == problem_size_.S) { -+ -+ filter_s_ = 0; -+ ++filter_r_; -+ next_idx = 1; -+ -+ if (filter_r_ == problem_size_.R) { -+ filter_r_ = 0; -+ ++filter_t_; -+ -+ if (filter_t_ < problem_size_.T) { -+ next_idx = 2; -+ } -+ else { -+ filter_t_ = 0; -+ next_idx = 3; -+ } -+ } -+ } -+ -+ add_byte_offset_(params_.inc_next[next_idx]); -+ -+ if (next_idx == 3) { -+ filter_k_ += params_.filter_k_delta; -+ } -+ -+ clear_mask_(filter_k_ >= problem_size_.K); -+ } -+ -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ masks_[s][0] = Mask(0); -+ masks_[s][1] = Mask(0); -+ masks_[s][2] = Mask(0); -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ return -+ (masks_[iteration_strided_][0] & (Index(1) << filter_t_)) && -+ (masks_[iteration_strided_][1] & (Index(1) << filter_r_)) && -+ (masks_[iteration_strided_][2] & (Index(1) << filter_s_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast(pointer_[iteration_strided_]); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradOutputGradientTileAccessIteratorOptimized &operator++() { -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(ConvProblemSize const &problem_size) { -+ -+ // This is specialized for unit stride -+ if (problem_size.stride() != Coord3D({1, 1, 1})) { -+ return Status::kErrorNotSupported; -+ } -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % (128/sizeof_bits::value)) { -+ return Status::kErrorNotSupported; -+ } -+ -+ // Limit on filter size -+ if (problem_size.T > 32 || problem_size.R > 32 || problem_size.S > 32) { -+ return Status::kErrorNotSupported; -+ } -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..0519ebe ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h -@@ -0,0 +1,291 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dFpropActivationTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv3dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ ConvProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ int filter_t_; -+ int filter_r_; -+ int filter_s_; -+ int filter_c_; -+ -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_z_[ThreadMap::Iterations::kStrided]; -+ int offset_p_[ThreadMap::Iterations::kStrided]; -+ int offset_q_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropActivationTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_t_(0), -+ filter_r_(0), -+ filter_s_(0), -+ filter_c_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_nzpq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ offset_n_[s] = offset_nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ int residual = offset_nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ -+ offset_z_[s] = residual / (problem_size_.P * problem_size_.Q); -+ residual = residual % (problem_size_.P * problem_size_.Q); -+ -+ offset_p_[s] = residual / problem_size_.Q; -+ offset_q_[s] = residual % problem_size_.Q; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ ++filter_t_; -+ if (filter_t_ < problem_size_.T) { -+ return; -+ } -+ filter_t_ = 0; -+ -+ filter_c_ += Shape::kColumn * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the activations tensor X that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int n = offset_n_[iteration_strided_]; -+ int z = offset_z_[iteration_strided_]; -+ int p = offset_p_[iteration_strided_]; -+ int q = offset_q_[iteration_strided_]; -+ -+ int t = filter_t_; -+ int r = filter_r_; -+ int s = filter_s_; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ t = (problem_size_.T - 1 - filter_t_); -+ r = (problem_size_.R - 1 - filter_r_); -+ s = (problem_size_.S - 1 - filter_s_); -+ } -+ -+ int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; -+ int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; -+ int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, d, h, w, filter_c_); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor X -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.d() >= 0 && coord.d() < problem_size_.D && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W && -+ coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ return ptr; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropActivationTileAccessIteratorAnalytic &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(ConvProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..c51eb59 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h -@@ -0,0 +1,478 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_ -+> -+class Conv3dFpropActivationTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ using Mask = uint64_t; -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv3dFpropActivationIteratorOptimizedParams; -+ -+private: -+ -+ Conv3dFpropActivationIteratorOptimizedParams const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ -+ // One pointer per access -+ char const *pointer_[ThreadMap::Iterations::kStrided]; -+ -+ // current filter position (t, r, s) -+ int filter_t_; -+ int filter_r_; -+ int filter_s_; -+ int filter_c_; -+ -+ // mask for t, r, and s -+ Index masks_[ThreadMap::Iterations::kStrided][3]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropActivationTileAccessIteratorOptimized( -+ Conv3dFpropActivationIteratorOptimizedParams const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ) : -+ params_(params), -+ problem_size_(problem_size), -+ filter_t_(0), -+ filter_r_(0), -+ filter_s_(0), -+ filter_c_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ int offset_n[ThreadMap::Iterations::kStrided]; -+ int offset_z[ThreadMap::Iterations::kStrided]; -+ int offset_p[ThreadMap::Iterations::kStrided]; -+ int offset_q[ThreadMap::Iterations::kStrided]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ pointer_[s] = reinterpret_cast(ptr); -+ -+ int offset_nzpq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // offset_n[s] = offset_nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ // int residual = offset_nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ // -+ // offset_z[s] = residual / (problem_size_.P * problem_size_.Q); -+ // residual = residual % (problem_size_.P * problem_size_.Q); -+ // -+ // offset_p[s] = residual / problem_size_.Q; -+ // offset_q[s] = residual % problem_size_.Q; -+ // -+ -+ int residual; -+ -+ // input: (nzpq offset) output: (n offset and resudial (zpq offset)) -+ params.zpq_divmod(offset_n[s], residual, offset_nzpq); -+ // input: (zpq offset) output: (z offset and resudial (pq)) -+ params.pq_divmod(offset_z[s], residual, residual); -+ // input: (pq offset) output: (p offset and resudial (q offset)) -+ params.q_divmod(offset_p[s], offset_q[s], residual); -+ -+ TensorCoord coord = at_(offset_n[s], offset_z[s], offset_p[s], offset_q[s], 0, 0, 0); -+ -+ pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; -+ } -+ -+ clear_mask(); -+ -+ // mask predicates for filter position T -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int t = 0; t < problem_size_.T; ++t) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int t_ = t; -+ if (problem_size_.mode == Mode::kConvolution) { -+ t_ = problem_size_.T - 1 - t; -+ } -+ -+ int d = offset_z[s_idx] * problem_size_.stride_d - problem_size_.pad_d + t_ * problem_size_.dilation_d; -+ -+ bool pred = (offset_n[s_idx] < problem_size_.N && d >= 0 && d < problem_size_.D); -+ masks_[s_idx][0] |= (pred << t); -+ } -+ } -+ -+ // mask predicates for filter position R -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int r = 0; r < problem_size_.R; ++r) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int r_ = r; -+ if (problem_size_.mode == Mode::kConvolution) { -+ r_ = problem_size_.R - 1 - r; -+ } -+ -+ int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h; -+ -+ bool pred = (h >= 0 && h < problem_size_.H); -+ masks_[s_idx][1] |= (pred << r); -+ } -+ } -+ -+ // mask predicates for filter position S -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int s = 0; s < problem_size_.S; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int s_ = s; -+ if (problem_size_.mode == Mode::kConvolution) { -+ s_ = problem_size_.S - 1 - s; -+ } -+ -+ int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w; -+ -+ bool pred = (w >= 0 && w < problem_size_.W); -+ masks_[s_idx][2] |= (pred << s); -+ } -+ } -+ -+ if (filter_c_ >= problem_size.C) { -+ clear_mask(); -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); -+ } -+ -+private: -+ -+ /// Returns the coordinate in the activations tensor X that is correspoinding to -+ // output nzpq and filter position t, r, s -+ CUTLASS_HOST_DEVICE -+ TensorCoord at_(int n, int z, int p, int q, int t, int r, int s) const { -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ t = problem_size_.T - 1 - t; -+ r = problem_size_.R - 1 - r; -+ s = problem_size_.S - 1 - s; -+ } -+ -+ int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; -+ int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; -+ int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, d, h, w, filter_c_); -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_byte_offset_(LongIndex byte_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ pointer_[s] += byte_offset; -+ } -+ } -+ -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask_(bool clear) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ // We are using inline PTX assembly here to avoid an CUDA C++ compilation -+ // artifact in which control flow instructions are generated. Instead, our -+ // intent is to predicate the mov instructions. -+ #if defined(__CUDA_ARCH__) -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " .reg .u32 m;" -+ " mov.u32 m, %2;" -+ " setp.ne.b32 p, %1, 0;\n" -+ " @p mov.u32 m, 0;\n" -+ " mov.u32 %0, m;\n" -+ "}\n" -+ : -+ "=r"(masks_[s][0]) -+ : -+ "r"((int)clear), -+ "r"(masks_[s][0]) -+ ); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " .reg .u32 m;" -+ " mov.u32 m, %2;" -+ " setp.ne.b32 p, %1, 0;\n" -+ " @p mov.u32 m, 0;\n" -+ " mov.u32 %0, m;\n" -+ "}\n" -+ : -+ "=r"(masks_[s][1]) -+ : -+ "r"((int)clear), -+ "r"(masks_[s][1]) -+ ); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " .reg .u32 m;" -+ " mov.u32 m, %2;" -+ " setp.ne.b32 p, %1, 0;\n" -+ " @p mov.u32 m, 0;\n" -+ " mov.u32 %0, m;\n" -+ "}\n" -+ : -+ "=r"(masks_[s][2]) -+ : -+ "r"((int)clear), -+ "r"(masks_[s][2]) -+ ); -+ #else -+ if (clear) { -+ masks_[s][0] = 0; -+ masks_[s][1] = 0; -+ masks_[s][2] = 0; -+ } -+ #endif -+ } -+ } -+ -+public: -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ add_byte_offset_(pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ int next_idx = 0; -+ -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ == problem_size_.S) { -+ -+ filter_s_ = 0; -+ ++filter_r_; -+ next_idx = 1; -+ -+ if (filter_r_ == problem_size_.R) { -+ filter_r_ = 0; -+ ++filter_t_; -+ -+ if (filter_t_ < problem_size_.T) { -+ next_idx = 2; -+ } -+ else { -+ filter_t_ = 0; -+ next_idx = 3; -+ } -+ } -+ } -+ -+ add_byte_offset_(params_.inc_next[next_idx]); -+ -+ if (next_idx == 3) { -+ filter_c_ += params_.filter_c_delta; -+ } -+ -+ clear_mask_(filter_c_ >= problem_size_.C); -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ masks_[s][0] = Mask(0); -+ masks_[s][1] = Mask(0); -+ masks_[s][2] = Mask(0); -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ return -+ (masks_[iteration_strided_][0] & (Index(1) << filter_t_)) && -+ (masks_[iteration_strided_][1] & (Index(1) << filter_r_)) && -+ (masks_[iteration_strided_][2] & (Index(1) << filter_s_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast(pointer_[iteration_strided_]); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropActivationTileAccessIteratorOptimized &operator++() { -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // Conv3dFpropActivationTileAccessIteratorOptimized has constraint on filter positions -+ // due to the number of mask bits. -+ if (problem_size.T > 32 || problem_size.R > 32 || problem_size.S > 32) { -+ return Status::kErrorNotSupported; -+ } -+ return Status::kSuccess; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..41d87fe ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h -@@ -0,0 +1,253 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dFpropFilterTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv3dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ ConvProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ int filter_t_; -+ int filter_r_; -+ int filter_s_; -+ int filter_c_; -+ -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropFilterTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_t_(0), -+ filter_r_(0), -+ filter_s_(0), -+ filter_c_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * 8 / sizeof_bits::value; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ -+ ++filter_t_; -+ if (filter_t_ < problem_size_.T) { -+ return; -+ } -+ filter_t_ = 0; -+ -+ filter_c_ += Shape::kRow * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the filter tensor W that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int k = offset_k_[iteration_strided_]; -+ -+ return TensorCoord(k, filter_t_, filter_r_, filter_s_, filter_c_); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.K && -+ coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropFilterTileAccessIteratorAnalytic &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(ConvProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..c6c6f6f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h -@@ -0,0 +1,277 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_ -+> -+class Conv3dFpropFilterTileAccessIteratorOptimized{ -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv3dFpropFilterIteratorOptimizedParams { -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv3dFpropFilterIteratorOptimizedParams const &base): -+ Conv3dFpropFilterIteratorOptimizedParams(base) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout -+ ): -+ Conv3dFpropFilterIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ) { -+ -+ } -+ }; -+ -+private: -+ -+ Conv3dFpropFilterIteratorOptimizedParams const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ uint32_t predicates_; -+ int filter_trs_; -+ int filter_c_; -+ -+ // -+ // Assertions -+ // -+ -+ // We map predicates into bits packed in this uint32_t container -+ static_assert(ThreadMap::Iterations::kStrided < sizeof(predicates_) * 8, -+ "Currently, the number of loads per iteration is limited by the size of the predicates container."); -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropFilterTileAccessIteratorOptimized( -+ Conv3dFpropFilterIteratorOptimizedParams const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_{0}, -+ filter_trs_(0), -+ filter_c_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); -+ Index column = threadblock_offset.column() + thread_coord.strided(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); -+ predicates_ |= (pred << s); -+ } -+ -+ if (filter_c_ >= problem_size.C) { -+ predicates_ = 0u; -+ } -+ -+ pointer_ += ( -+ params_.layout({filter_c_, column}) -+ ) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ LongIndex next = params_.inc_next_trs; -+ -+ // moves to the next tile -+ ++filter_trs_; -+ if (filter_trs_ == params_.TRS) { -+ -+ filter_trs_ = 0; -+ next = params_.inc_next_c; -+ filter_c_ += params_.filter_c_delta; -+ } -+ -+ if (filter_c_ >= problem_size_.C) { -+ predicates_ = 0; -+ } -+ -+ pointer_ += next; -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return (predicates_ & (1u << iteration_strided_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ return reinterpret_cast(pointer_); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropFilterTileAccessIteratorOptimized &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ -+ // Move to the next K coordinate within the tile -+ pointer_ += params_.inc_next_k; -+ -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_params.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_params.h -new file mode 100644 -index 0000000..180dca5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_params.h -@@ -0,0 +1,508 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Extracts the host-params objects into non-template code. -+*/ -+ -+#pragma once -+ -+#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED -+#include -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Params structure used for all Conv3d analytic tile iterators -+template< typename Layout_ = layout::TensorNDHWC > -+struct Conv3dAnalyticParams { -+ -+ using Layout = Layout_; -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dAnalyticParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dAnalyticParams( -+ Conv3dProblemSize const &, // unused; placeholder to match other Params interfaces. -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for Conv3dFpropActivationTileIteratorOptimized -+template< typename Layout_ = layout::TensorNDHWC > -+struct Conv3dFpropActivationIteratorOptimizedParams; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for Conv3dFpropActivationTileIteratorOptimized -+template<> -+struct Conv3dFpropActivationIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNDHWC; -+ -+ Layout layout; -+ -+ int64_t inc_next[4]; // {next S, next R, next T, next C} -+ int filter_c_delta; // number of logical elements to add to filter_c_ -+ int ZPQ; // product of Z*P*Q -+ int PQ; // product of P*Q -+ -+ FastDivmod zpq_divmod; -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropActivationIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropActivationIteratorOptimizedParams( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), -+ PQ(problem_size.P * problem_size.Q), -+ ZPQ(problem_size.Z * problem_size.P * problem_size.Q), -+ zpq_divmod(ZPQ), -+ pq_divmod(PQ), -+ q_divmod(problem_size.Q) { -+ -+ TRACE_CONV_INITIALIZERS("conv3d_fprop", "activation", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ -+ int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); -+ -+ // next S -+ inc_next[0] = conv_sign * ( -+ int64_t(layout.stride()[0]) * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = conv_sign * ( -+ int64_t(layout.stride()[1]) * problem_size.dilation_h -+ - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next T -+ inc_next[2] = conv_sign * ( -+ int64_t(layout.stride()[2]) * problem_size.dilation_d -+ - (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h -+ - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next C -+ inc_next[3] = ( -+ threadblock_shape.column() * problem_size.split_k_slices -+ - conv_sign * int64_t(problem_size.T - 1) * layout.stride()[2] * problem_size.dilation_d -+ - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h -+ - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // logical offset added to internal channel counter - units are elements, not bytes -+ filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+template< typename Layout_ = layout::TensorNDHWC > -+struct Conv3dFpropFilterIteratorOptimizedParams; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template<> -+struct Conv3dFpropFilterIteratorOptimizedParams -+{ -+ -+ using Layout = layout::TensorNDHWC; -+ -+ Layout layout; -+ int TRS; -+ int filter_c_delta; -+ -+ int64_t inc_next_k; // offset in units of bytes to next K position -+ int64_t inc_next_trs; // offset in units of bytes to next TRS position -+ int64_t inc_next_c; // offset in units of bytes to next C position -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropFilterIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropFilterIteratorOptimizedParams( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout) { -+ -+ TRACE_CONV_INITIALIZERS("conv3d_fprop", "filter", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ TRS = problem_size.T * problem_size.R * problem_size.S; -+ -+ inc_next_k = (int64_t(layout.stride()[3]) * threadmap_delta.strided() * element_size_bits) / 8; -+ -+ inc_next_trs = -+ ( int64_t(layout.stride()[0]) -+ - int64_t(layout.stride()[3]) * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() -+ ) * element_size_bits / 8; -+ -+ inc_next_c = -+ ( -+ threadblock_shape.row() * problem_size.split_k_slices -+ - int64_t(TRS - 1) * layout.stride()[0] -+ - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3] -+ ) * element_size_bits / 8; -+ -+ filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters object for Conv3d DGRAD OutputGradient (dy) iterator -+struct Conv3dDgradOutputGradientIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNDHWC; -+ -+ Layout layout; -+ -+ int64_t inc_next[4]; // {next S, next R, next T, next K} -+ int filter_k_delta; // number of logical elements to add to filter_k_ -+ -+ FastDivmod dhw_divmod; -+ FastDivmod hw_divmod; -+ FastDivmod w_divmod; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradOutputGradientIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradOutputGradientIteratorOptimizedParams( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), -+ dhw_divmod(problem_size.D * problem_size.H * problem_size.W), -+ hw_divmod(problem_size.H * problem_size.W), -+ w_divmod(problem_size.W) { -+ -+ TRACE_CONV_INITIALIZERS("conv3d_dgrad", "output_gradient", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ int conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); -+ -+ // next S -+ inc_next[0] = conv_sign * ( -+ int64_t(layout.stride()[0]) * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = conv_sign * ( -+ int64_t(layout.stride()[1]) * problem_size.dilation_h -+ - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next T -+ inc_next[2] = conv_sign * ( -+ int64_t(layout.stride()[2]) * problem_size.dilation_d -+ - (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h -+ - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next K -+ inc_next[3] = ( -+ threadblock_shape.column() * problem_size.split_k_slices -+ - conv_sign * int64_t(problem_size.T - 1) * layout.stride()[2] * problem_size.dilation_d -+ - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h -+ - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // logical offset added to internal channel counter - units are elements, not bytes -+ filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters object for Conv2d DGRAD Filter (w) iterator -+struct Conv3dDgradFilterIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNDHWC; -+ -+ Layout layout; -+ int TRS; -+ int filter_k_delta; -+ -+ int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile -+ int64_t inc_next_trs; // offset in units of bytes to next TRS position -+ int64_t inc_next_k; // offset in units of bytes to next K position in subsequent tile -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradFilterIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradFilterIteratorOptimizedParams( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), TRS(problem_size.T * problem_size.R * problem_size.S) { -+ -+ TRACE_CONV_INITIALIZERS("conv3d_dgrad", "filter", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ inc_next_strided = ((int64_t)layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8; -+ -+ inc_next_trs = -+ ( (int64_t)layout.stride()[0] -+ - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] -+ ) * element_size_bits / 8; -+ -+ inc_next_k = -+ ( -+ threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[3] -+ - (problem_size.T * problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] -+ - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] -+ ) * element_size_bits / 8; -+ -+ filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; -+ } -+}; -+ -+/// Parameters object for Conv3d WGRAD OutputGradient iterator -+struct Conv3dWgradOutputGradientIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNDHWC; -+ using LongIndex = typename Layout::LongIndex; -+ -+ Layout layout; -+ -+ int NZPQ; // precomputd product of N*Z*P*Q for clearing predicates -+ int ZPQ; // product of Z*P*Q -+ unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ -+ unsigned zpq_shr; // in device code. -+ -+ int PQ; // product of P*Q -+ unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ -+ unsigned pq_shr; // in device code. -+ -+ unsigned q_mul; // precomputed quantities for fast computation of div/% by Q -+ unsigned q_shr; // in device code. -+ -+ LongIndex offset_next_strided; // offset in units of bytes to next nzpq coordinate within tile -+ LongIndex offset_next_contiguous; // offset in units of bytes to next k coordinate within tile -+ LongIndex inc_next_nzpq; // offset in units of bytes to next nzpq position in subsequent tile -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradOutputGradientIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradOutputGradientIteratorOptimizedParams( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): layout(layout) { -+ -+ TRACE_CONV_INITIALIZERS("conv3d_wgrad", "output_gradient", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ // Incremental offsets in unites of bytes (number of elements) * element_size_bits / 8 -+ offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) -+ * element_size_bits / 8; -+ -+ offset_next_contiguous = (threadmap_delta.contiguous()) -+ * element_size_bits / 8; -+ -+ inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) -+ * element_size_bits / 8; -+ -+ // Precompute several quantities for fast modulo arithmetic. -+ NZPQ = problem_size.N * problem_size.Z * problem_size.P * problem_size.Q; -+ ZPQ = problem_size.Z * problem_size.P * problem_size.Q; -+ find_divisor(zpq_mul, zpq_shr, ZPQ); -+ -+ PQ = problem_size.P * problem_size.Q; -+ find_divisor(pq_mul, pq_shr, PQ); -+ -+ find_divisor(q_mul, q_shr, problem_size.Q); -+ -+ } -+}; -+ -+/// Parameters object for Conv3d WGRAD Activation Tile Access Iterator -+struct Conv3dWgradActivationIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNDHWC; -+ -+ Layout layout; -+ -+ int RSC; // product of R*S*C -+ unsigned rsc_mul; // precomputed quantities for fast computation of div/% by RSC -+ unsigned rsc_shr; // in device code. -+ -+ int SC; // product of S*C -+ unsigned sc_mul; // precomputed quantities for fast computation of div/% by SC -+ unsigned sc_shr; // in device code. -+ -+ unsigned c_mul; // precomputed quantities for fast computation of div/% by C -+ unsigned c_shr; // in device code. -+ -+ int ZPQ; // product of Z*P*Q -+ unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ -+ unsigned zpq_shr; // in device code. -+ -+ int PQ; // product of P*Q -+ unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ -+ unsigned pq_shr; // in device code. -+ -+ unsigned q_mul; // precomputed quantities for fast computation of div/% by Q -+ unsigned q_shr; // in device code. -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradActivationIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradActivationIteratorOptimizedParams( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): layout(layout) { -+ -+ TRACE_CONV_INITIALIZERS("conv3d_wgrad", "activation", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ // Precompute several quantities for fast modulo arithmetic. -+ RSC = problem_size.R * problem_size.S * problem_size.C; -+ find_divisor(rsc_mul, rsc_shr, RSC); -+ -+ SC = problem_size.S * problem_size.C; -+ find_divisor(sc_mul, sc_shr, SC); -+ -+ find_divisor(c_mul, c_shr, problem_size.C); -+ -+ ZPQ = problem_size.Z * problem_size.P * problem_size.Q; -+ find_divisor(zpq_mul, zpq_shr, ZPQ); -+ -+ PQ = problem_size.P * problem_size.Q; -+ find_divisor(pq_mul, pq_shr, PQ); -+ -+ find_divisor(q_mul, q_shr, problem_size.Q); -+ -+ } -+}; -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..d9fe9ad ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dWgradActivationTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ -+ static int const kAccessesPerVector = 1; -+ -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params { -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ // Filter postion (t,r,s,c) in contiguous dimension stays constant for each gemm_iteration_k -+ int filter_t_[ThreadMap::Iterations::kContiguous]; -+ int filter_r_[ThreadMap::Iterations::kContiguous]; -+ int filter_s_[ThreadMap::Iterations::kContiguous]; -+ int filter_c_[ThreadMap::Iterations::kContiguous]; -+ -+ int offset_nzpq_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradActivationTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ // initialize t,r,s,c filter position for every contiguous iteration -+ CUTLASS_PRAGMA_UNROLL -+ for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int trsc_offset = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ -+ filter_t_[c] = trsc_offset / (problem_size_.R * problem_size_.S * problem_size_.C); -+ int residual = trsc_offset % (problem_size_.R * problem_size_.S * problem_size_.C); -+ -+ filter_r_[c] = residual / (problem_size_.S * problem_size_.C); -+ residual = residual % (problem_size_.S * problem_size_.C); -+ -+ filter_s_[c] = residual / problem_size_.C; -+ filter_c_[c] = residual % problem_size_.C; -+ -+ } -+ -+ // initialize n, z, p, q offset for every strided iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ offset_nzpq_[s] = threadblock_offset.row() + thread_coord.strided() -+ + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-B by a CTA-K tile -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_nzpq_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the activation tensor x that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int t = filter_t_[iteration_contiguous_]; -+ int r = filter_r_[iteration_contiguous_]; -+ int s = filter_s_[iteration_contiguous_]; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ t = (problem_size_.T - 1 - t); -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ int n = offset_nzpq_[iteration_strided_] / (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ int residual = offset_nzpq_[iteration_strided_] % (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ -+ int z = residual / (problem_size_.P * problem_size_.Q); -+ residual = residual % (problem_size_.P * problem_size_.Q); -+ -+ int p = residual / problem_size_.Q; -+ int q = residual % problem_size_.Q; -+ -+ int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; -+ int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; -+ int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]); -+ } -+ -+ /// Returns true if the current coordinate is within the activation tensor x -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.d() >= 0 && coord.d() < problem_size_.D && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W && -+ coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradActivationTileAccessIteratorAnalytic &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..2d56341 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h -@@ -0,0 +1,319 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dWgradActivationTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv3dWgradActivationIteratorOptimizedParams { -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv3dWgradActivationIteratorOptimizedParams const &base) -+ : Conv3dWgradActivationIteratorOptimizedParams(base) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv3dProblemSize const &problem_size, Layout const &layout) -+ : Conv3dWgradActivationIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {} -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ // Precomputed effective filter postion (t,r,s) in contiguous dimension stays constant for each gemm_iteration_k -+ // required for nzpq -> ndhw translation -+ int precomputed_filter_t_[ThreadMap::Iterations::kContiguous]; -+ int precomputed_filter_r_[ThreadMap::Iterations::kContiguous]; -+ int precomputed_filter_s_[ThreadMap::Iterations::kContiguous]; -+ -+ // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k -+ int filter_c_[ThreadMap::Iterations::kContiguous]; -+ -+ int offset_nzpq_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradActivationTileAccessIteratorOptimized( -+ Params const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ // initialize t,r,s,c filter position for every contiguous iteration -+ CUTLASS_PRAGMA_UNROLL -+ for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int trsc_offset = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // filter_t_[c] = trsc_offset / (problem_size_.R * problem_size_.S * problem_size_.C); -+ // int residual = trsc_offset % (problem_size_.R * problem_size_.S * problem_size_.C); -+ // -+ // filter_r_[c] = residual / (problem_size_.S * problem_size_.C); -+ // residual = residual % (problem_size_.S * problem_size_.C); -+ // -+ // filter_s_[c] = residual / problem_size_.C; -+ // filter_c_[c] = residual % problem_size_.C; -+ -+ int residual; -+ fast_divmod(precomputed_filter_t_[c], residual, trsc_offset, params_.RSC, params_.rsc_mul, params_.rsc_shr); -+ fast_divmod(precomputed_filter_r_[c], residual, residual, params_.SC, params_.sc_mul, params_.sc_shr); -+ fast_divmod(precomputed_filter_s_[c], filter_c_[c], residual, problem_size_.C, params_.c_mul, params_.c_shr); -+ -+ int t = precomputed_filter_t_[c]; -+ int r = precomputed_filter_r_[c]; -+ int s = precomputed_filter_s_[c]; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ t = (problem_size_.T - 1 - t); -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ // efective t,r,s for every contiguous dimension -+ precomputed_filter_t_[c] = - problem_size_.pad_d + t * problem_size_.dilation_d; -+ precomputed_filter_r_[c] = - problem_size_.pad_h + r * problem_size_.dilation_h; -+ precomputed_filter_s_[c] = - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ -+ } -+ -+ // initialize n, z, p, q offset for every strided iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ offset_nzpq_[s] = threadblock_offset.row() + thread_coord.strided() -+ + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-B by a CTA-K tile -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_nzpq_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the activation tensor x that is currently pointed to -+ /// by the iterator. -+ -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // int n = offset_nzpq_[iteration_strided_] / (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ // int residual = offset_nzpq_[iteration_strided_] % (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ // -+ // int z = residual / (problem_size_.P * problem_size_.Q); -+ // residual = residual % (problem_size_.P * problem_size_.Q); -+ // -+ // int p = residual / problem_size_.Q; -+ // int q = residual % problem_size_.Q; -+ -+ int residual, n, z, p, q; -+ fast_divmod(n, residual, offset_nzpq_[iteration_strided_], params_.ZPQ, params_.zpq_mul, params_.zpq_shr); -+ fast_divmod(z, residual, residual, params_.PQ, params_.pq_mul, params_.pq_shr); -+ fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr); -+ -+ int d = z * problem_size_.stride_d + precomputed_filter_t_[iteration_contiguous_]; -+ int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_];; -+ int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_]; -+ -+ return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]); -+ } -+ -+ /// Returns true if the current coordinate is within the activation tensor x -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.d() >= 0 && coord.d() < problem_size_.D && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W && -+ coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradActivationTileAccessIteratorOptimized &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..c21d3f9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h -@@ -0,0 +1,267 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dWgradOutputGradientTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params { -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ int filter_k_[ThreadMap::Iterations::kContiguous]; -+ -+ int offset_nzpq_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradOutputGradientTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)) { -+ -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ // initialize filter_k for every contiguous iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ filter_k_[c] = threadblock_offset.row() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ } -+ -+ // initialize n, p, q offset for every strided iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_nzpq_[s] = threadblock_offset.column() + thread_coord.strided() -+ + s * ThreadMap::Delta::kStrided; -+ -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-A by a CTA-K tile -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_nzpq_[s] += Shape::kColumn * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the output gradient tensor Dy that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int nzpq = offset_nzpq_[iteration_strided_]; -+ -+ int n = nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ int residual = nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ -+ int z = residual / (problem_size_.P * problem_size_.Q); -+ residual = residual % (problem_size_.P * problem_size_.Q); -+ -+ int p = residual / problem_size_.Q; -+ int q = residual % problem_size_.Q; -+ -+ return TensorCoord(n, z, p, q, filter_k_[iteration_contiguous_]); -+ } -+ -+ -+ /// Returns true if the current coordinate is within the output gradient tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.d() < problem_size_.Z && -+ coord.h() < problem_size_.P && -+ coord.w() < problem_size_.Q && -+ coord.c() < problem_size_.K; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradOutputGradientTileAccessIteratorAnalytic &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..7a79983 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h -@@ -0,0 +1,310 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dWgradOutputGradientTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv3dWgradOutputGradientIteratorOptimizedParams { -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv3dWgradOutputGradientIteratorOptimizedParams const &base) -+ : Conv3dWgradOutputGradientIteratorOptimizedParams(base) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv3dProblemSize const &problem_size, Layout const &layout) -+ : Conv3dWgradOutputGradientIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {} -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ uint32_t predicates_; -+ int filter_k_; -+ int offset_nzpq_; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradOutputGradientTileAccessIteratorOptimized( -+ Params const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_(0), -+ filter_k_(0), -+ offset_nzpq_(0) { -+ -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.row() + thread_coord.contiguous(); -+ offset_nzpq_ = threadblock_offset.column() + thread_coord.strided(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous; -+ int offset_nzpq = offset_nzpq_ + s * ThreadMap::Delta::kStrided; -+ -+ bool predicate = valid_(at_(offset_nzpq, filter_k)); -+ -+ uint32_t pred = (predicate ? 1u : 0); -+ -+ int pred_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ predicates_ |= (pred << pred_idx); -+ } -+ } -+ -+ // Offset pointer to (iteration_strided_, iteration_contiguous_) = (0, 0) -+ pointer_ += ( -+ offset_nzpq_ * params.layout.stride()[0] + filter_k_ -+ ) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile -+ offset_nzpq_ += Shape::kColumn * problem_size_.split_k_slices; -+ -+ // Clear predicates if needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ if (offset_nzpq_ + s * ThreadMap::Delta::kStrided >= params_.NZPQ) { -+ uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); -+ predicates_ = (predicates_ & (~kClearMask)); -+ } -+ } -+ pointer_ += params_.inc_next_nzpq; -+ } -+ -+private: -+ /// Returns the coordinate in the output gradient tensor Dy that is (offset_nzpq, k) pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at_(int offset_nzpq, int k) const { -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // int nzpq = offset_nzpq_; -+ // int n = nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ // int residual = nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ // -+ // int z = residual / (problem_size_.P * problem_size_.Q); -+ // residual = residual % (problem_size_.P * problem_size_.Q); -+ // -+ // int p = residual / problem_size_.Q; -+ // int q = residual % problem_size_.Q; -+ -+ int residual, n, z, p, q; -+ fast_divmod(n, residual, offset_nzpq, params_.ZPQ, params_.zpq_mul, params_.zpq_shr); -+ fast_divmod(z, residual, residual, params_.PQ, params_.pq_mul, params_.pq_shr); -+ fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr); -+ -+ return TensorCoord(n, z, p, q, k); -+ } -+ -+ /// Returns true if the coord is within the output gradient tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid_(TensorCoord coord) const { -+ -+ return coord.n() < problem_size_.N && -+ coord.c() < problem_size_.K; -+ } -+ -+public: -+ -+ /// Returns true if the current coordinate is within the output gradient tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; -+ return (predicates_ & (1u << pred_idx)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast( -+ pointer_ + -+ iteration_strided_ * params_.offset_next_strided + -+ iteration_contiguous_ * params_.offset_next_contiguous -+ ); -+ -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradOutputGradientTileAccessIteratorOptimized &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h -new file mode 100644 -index 0000000..86b5bc4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h -@@ -0,0 +1,230 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Extracts the host-params objects into non-template code. -+*/ -+ -+#pragma once -+ -+#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED -+#include -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized -+template -+struct Depthwise2dFpropDirectConvParams; -+ -+/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation -+template -+struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; -+ -+/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized -+template -+struct Depthwise2dFpropDirectConvFilterIteratorParams; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized -+template<> -+struct Depthwise2dFpropDirectConvParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ int32_t activation_tile_h; -+ int32_t activation_tile_w; -+ int32_t activation_tile_hw; -+ FastDivmod activation_tile_w_divmod; -+ -+ int filter[2]; -+ int stride[2]; -+ int dilation[2]; -+ int inc_next[2]; -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ int activation_load_count; -+ int activation_storage_elements; -+ int activation_size; -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Depthwise2dFpropDirectConvParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Depthwise2dFpropDirectConvParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ MatrixCoord threadblock_shape, ///< CTA threadblock Shape -+ Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock -+ const int element_size_bits, ///< bits of activation element -+ const int thread_count, ///< threads per threadblock -+ const int thread_count_contiguous, ///< number of threads for continuous dimension -+ const int element_per_load) ///< element per each load -+ : layout(layout) { -+ -+ filter[0] = problem_size.S; -+ filter[1] = problem_size.R; -+ -+ stride[0] = problem_size.stride_w; -+ stride[1] = problem_size.stride_h; -+ -+ dilation[0] = problem_size.dilation_w; -+ dilation[1] = problem_size.dilation_h; -+ -+ // Compute activation_tile size per threadblock because stride and dilation are runtime params. -+ activation_tile_h = (threadblock_output_shape.h() - 1) * problem_size.stride_h + -+ (problem_size.R - 1) * problem_size.dilation_h + 1; -+ activation_tile_w = (threadblock_output_shape.w() - 1) * problem_size.stride_w + -+ (problem_size.S - 1) * problem_size.dilation_w + 1; -+ activation_tile_hw = activation_tile_h * activation_tile_w; -+ -+ activation_tile_w_divmod = FastDivmod(activation_tile_w); -+ -+ /// Below two values could not be templatized because the stride and dilation are runtime params -+ activation_load_count = (thread_count_contiguous * activation_tile_hw + (thread_count - 1)) / thread_count; -+ activation_storage_elements = activation_load_count * element_per_load * thread_count; -+ activation_size = activation_storage_elements * element_size_bits / 8; -+ -+ // Fastdivmod for output P, Q -+ int tiles_p = -+ (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); -+ int tiles_q = (problem_size.Q + (threadblock_output_shape.w() - 1)) / -+ (threadblock_output_shape.w()); -+ -+ pq_divmod = FastDivmod(tiles_p * tiles_q); -+ q_divmod = FastDivmod(tiles_q); -+ -+ // next S -+ inc_next[0] = problem_size.dilation_w; -+ // next R -+ inc_next[1] = (activation_tile_w * problem_size.dilation_h - (problem_size.S - 1) * problem_size.dilation_w); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation -+template <> -+struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams { -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ int activation_size; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams() {} -+ -+ CUTLASS_HOST_DEVICE -+ Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< Layout object -+ MatrixCoord threadblock_shape, ///< Threadblock Shape -+ Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock -+ const int activation_size_ ///< Activation size loaded by iterator -+ ) -+ : layout(layout), -+ activation_size(activation_size_) { -+ // Fastdivmod for output P, Q -+ int tiles_p = -+ (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); -+ int tiles_q = -+ (problem_size.Q + (threadblock_output_shape.w() - 1)) / (threadblock_output_shape.w()); -+ -+ pq_divmod = FastDivmod(tiles_p * tiles_q); -+ q_divmod = FastDivmod(tiles_q); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized -+template <> -+struct Depthwise2dFpropDirectConvFilterIteratorParams { -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ int filter_size; -+ -+ bool is_convolution; -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Depthwise2dFpropDirectConvFilterIteratorParams() {} -+ -+ CUTLASS_HOST_DEVICE -+ Depthwise2dFpropDirectConvFilterIteratorParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< Layout object -+ MatrixCoord threadblock_shape, ///< Threadblock Shape -+ const int filter_size_) ///< Filter size loaded by iterator -+ : layout(layout), -+ filter_size(filter_size_), -+ is_convolution(problem_size.mode == Mode::kConvolution){} -+}; -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h -new file mode 100644 -index 0000000..80ec5d0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h -@@ -0,0 +1,314 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template > -+class DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation { -+ public: -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using OutputTileShape = OutputTileShape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ // Compilation value of stride , dialtion and activation shape -+ using StrideShape = StrideShape_; -+ using DilationShape = DilationShape_; -+ using ActivationShape = ActivationShape_; -+ -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ static int const kActivationSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * -+ sizeof_bits::value / 8; -+ -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); -+ -+ static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); -+ static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; -+ -+ private: -+ Conv2dProblemSize const &problem_size_; -+ Params const ¶ms_; -+ char const *pointer_; -+ -+ // Base channels for current threadblock -+ int base_c_; -+ // Base activation index for current threadblock -+ int offset_intial_npq_; -+ // Base activation coord for current threadblock -+ TensorCoord activatioin_base_; -+ // Intial thread positioin -+ int offset_initial_hwc_; -+ // Overall load instruction per thread. -+ int iterator_load_; -+ // thread loading position. -+ int iterator_hwc_; -+ // activation N is inside the Tensor or not -+ bool valid_n_; -+ -+ public: -+ -+ -+ CUTLASS_HOST_DEVICE -+ DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = -+ MatrixCoord() -+ ) -+ : params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ offset_intial_npq_(threadblock_offset.row()), -+ offset_initial_hwc_(thread_idx), -+ iterator_load_(0) { -+ -+ base_c_ = threadblock_offset.column(); -+ -+ set_iteration_index(0); -+ -+ set_activation_coord(offset_intial_npq_); -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_activation_coord(int offset_npq) { -+ int offset_inital_n, offset_inital_p, offset_inital_q; -+ int residual; -+ -+ params_.pq_divmod(offset_inital_n, residual, offset_npq); -+ params_.q_divmod(offset_inital_p, offset_inital_q, residual); -+ -+ int base_n = offset_inital_n; -+ -+ int base_h = -+ offset_inital_p * OutputTileShape::kH * StrideShape::kRow - problem_size_.pad_h; -+ -+ int base_w = -+ offset_inital_q * OutputTileShape::kW * StrideShape::kColumn - problem_size_.pad_w; -+ -+ activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); -+ -+ valid_n_ = activatioin_base_.n() < problem_size_.N; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params( -+ problem_size, -+ layout, -+ {Shape::kRow, Shape::kColumn}, -+ {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, -+ kActivationSize); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; -+ iterator_load_ = index; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // Go to next threadblock -+ offset_intial_npq_ += problem_size_.split_k_slices; -+ -+ set_iteration_index(0); -+ -+ set_activation_coord(offset_intial_npq_); -+ } -+ -+ /// Returns the coordinate in the activations tensor X that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; -+ int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; -+ int h = next / ActivationShape::kW; -+ int w = next % ActivationShape::kW; -+ -+ c = c * AccessType::kElements; -+ -+ return activatioin_base_ + TensorCoord(0, h, w, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor X -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ bool valid_c = coord.c() < problem_size_.C; -+ bool valid_h = coord.h() >= 0 && coord.h() < problem_size_.H; -+ bool valid_w = coord.w() >= 0 && coord.w() < problem_size_.W; -+ return valid_n_ ? valid_c & valid_h & valid_w : 0; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ AccessType const *ptr = -+ reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ return ptr; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation &operator++() { -+ -+ ++iterator_load_; -+ iterator_hwc_ += ThreadMap::kThreads; -+ -+ if (iterator_load_ < ThreadMap::Iterations::kCount) { -+ return *this; -+ } -+ -+ iterator_load_ = 0; -+ iterator_hwc_ = offset_initial_hwc_; -+ -+ return *this; -+ } -+ -+ /// Determines the activation size loaded by iterator -+ CUTLASS_HOST_DEVICE -+ int get_load_size() { -+ return kActivationSize; -+ } -+ -+ /// Determines the iterations needed -+ CUTLASS_HOST_DEVICE -+ int get_iteration_num() { -+ return ThreadMap::Iterations::kCount; -+ } -+ -+ /// Determines whether the Depthwise fprop can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check stride and dilation constraint -+ if (problem_size.stride_h != StrideShape::kRow || problem_size.stride_w != StrideShape::kColumn) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (problem_size.dilation_h != DilationShape::kRow || problem_size.dilation_w != DilationShape::kColumn) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h -new file mode 100644 -index 0000000..3439d46 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h -@@ -0,0 +1,291 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template > -+class DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized { -+ public: -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using OutputTileShape = OutputTileShape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); -+ -+ static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); -+ static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Depthwise2dFpropDirectConvParams; -+ -+ private: -+ Conv2dProblemSize const &problem_size_; -+ Params const ¶ms_; -+ char const *pointer_; -+ -+ // Base channels for current threadblock -+ int base_c_; -+ // Base activation index for current threadblock -+ int offset_intial_npq_; -+ // Base activation coord for current threadblock -+ TensorCoord activatioin_base_; -+ // Intial thread positioin -+ int offset_initial_hwc_; -+ // Overall load instruction per thread. -+ int iterator_load_; -+ // thread loading position. -+ int iterator_hwc_; -+ // Number of loads for activations tensor X. -+ const int number_of_loads_; -+ -+ public: -+ -+ -+ CUTLASS_HOST_DEVICE -+ DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = -+ MatrixCoord() -+ ) -+ : params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ offset_intial_npq_(threadblock_offset.row()), -+ offset_initial_hwc_(thread_idx), -+ iterator_load_(0), -+ number_of_loads_(params.activation_load_count) { -+ -+ base_c_ = threadblock_offset.column(); -+ -+ set_activation_coord(offset_intial_npq_); -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_activation_coord(int offset_npq) { -+ int offset_inital_n, offset_inital_p, offset_inital_q; -+ int residual; -+ -+ params_.pq_divmod(offset_inital_n, residual, offset_npq); -+ params_.q_divmod(offset_inital_p, offset_inital_q, residual); -+ -+ int base_n = offset_inital_n; -+ -+ int base_h = -+ offset_inital_p * OutputTileShape::kH * problem_size_.stride_h - problem_size_.pad_h; -+ -+ int base_w = -+ offset_inital_q * OutputTileShape::kW * problem_size_.stride_w - problem_size_.pad_w; -+ -+ activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params( -+ problem_size, -+ layout, -+ {Shape::kRow, Shape::kColumn}, -+ {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, -+ sizeof_bits::value, -+ ThreadMap::kThreads, -+ ThreadMap::Detail::ShapeVec::kContiguous, -+ ThreadMap::kElementsPerAccess); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; -+ iterator_load_ = index; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // Go to next threadblock -+ offset_intial_npq_ += problem_size_.split_k_slices; -+ -+ set_activation_coord(offset_intial_npq_); -+ } -+ -+ /// Returns the coordinate in the activations tensor X that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; -+ int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; -+ int h, w; -+ params_.activation_tile_w_divmod(h, w, next) ; -+ -+ c = c * AccessType::kElements; -+ -+ return activatioin_base_ + TensorCoord(0, h, w, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor X -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W && coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ AccessType const *ptr = -+ reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ return ptr; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized &operator++() { -+ -+ ++iterator_load_; -+ iterator_hwc_ += ThreadMap::kThreads; -+ -+ if (iterator_load_ < number_of_loads_) { -+ return *this; -+ } -+ -+ iterator_load_ = 0; -+ iterator_hwc_ = offset_initial_hwc_; -+ -+ return *this; -+ } -+ -+ /// Determines the activation size loaded by iterator -+ CUTLASS_HOST_DEVICE -+ int get_load_size() { -+ return params_.activation_size; -+ } -+ -+ /// Determines the iterations needed -+ CUTLASS_HOST_DEVICE -+ int get_iteration_num() { -+ return number_of_loads_; -+ } -+ -+ /// Determines whether the Depthwise fprop can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h -new file mode 100644 -index 0000000..26bbe57 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h -@@ -0,0 +1,551 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/conv/threadblock/depthwise_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Epilogue stores the data into global memory -+ typename Epilogue_, -+ /// iterator implementation variants -+ conv::IteratorAlgorithm IteratorAlgorithm_ = conv::IteratorAlgorithm::kOptimized, -+ /// Used for partial specialization -+ typename Enable = bool> -+class DepthwiseFpropDirectConvMultipleStage : -+ public DepthwiseDirectConvMmaBase { -+public: -+ ///< Base class -+ using Base = DepthwiseDirectConvMmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using Epilogue = Epilogue_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ static conv::IteratorAlgorithm const kItertorAlgorithm = IteratorAlgorithm_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ -+ using ElementC = typename Policy::Operator::ElementC; -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ DepthwiseFpropDirectConvMultipleStage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, -+ IteratorB &iterator_B, -+ int group_start_A = 0, -+ int group_start_B = 0) { -+ if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { -+ // Number of iterators is a static value. -+ iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ ++this->smem_iterator_A_; -+ } -+ } else { -+ // Number of iterators is a runtime value. -+ iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ ++this->smem_iterator_A_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA &iterator_A, -+ ///< Params of global memory iterator -+ typename IteratorA::Params const &iterator_a_params, -+ ///< iterator over B operand in global memory -+ IteratorB &iterator_B, -+ ///< Params of global memory iterator -+ typename IteratorB::Params const &iterator_b_params, -+ ///< initial value of accumulator -+ FragmentC const &src_accum, -+ /// Epilogue -+ Epilogue &epilogue, -+ ///< Output operator -+ typename Epilogue::OutputOp const &output_op, -+ ///< Tile iterator for destination -+ typename Epilogue::OutputTileIterator &destination_iterator, -+ ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ typename Epilogue::OutputTileIterator &source_iterator, -+ -+ int split_k_slices = 1 -+ ) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { -+ -+ if (stage == 0) { -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ } -+ -+ if(kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation){ -+ // Number of iterators is compilation static. -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ } else { -+ // Number of iterators is a runtime value. -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_num(iterator_A.get_iteration_num()); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ ///////////////////////////////////////////////////////////////////////////// -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); -+ -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ unsigned int iterations = 0; -+ constexpr int inner_loop_iterations = round_up(Base::kWarpGemmIterations, 2); -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { // Each iteration is a cta tile. -+ -+ accum.clear(); -+ -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < inner_loop_iterations; ++warp_mma_k) { -+ if (Base::kWarpGemmIterations % 2 == 0 || warp_mma_k + 1 != Base::kWarpGemmIterations) { -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ } -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ if (warp_mma_k == 0) { -+ group_start_iteration_A = 0; -+ group_start_iteration_B = 0; -+ copy_tiles_and_advance( -+ iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); -+ } -+ -+ if (warp_mma_k < Base::kWarpGemmIterations) { -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ } -+ -+ if (warp_mma_k + 1 == inner_loop_iterations) -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ if (warp_mma_k + 2 == inner_loop_iterations) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next cta -+ iterator_A.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({-Base::kStages, 0}); -+ -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.advance(- (Base::kStages-1) * iterator_A.get_load_size()); -+ smem_read_stage_idx = 0; -+ } else { -+ this->warp_tile_iterator_A_.advance(iterator_A.get_load_size()); -+ ++smem_read_stage_idx; -+ } -+ -+ if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { -+ this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); -+ } -+ -+ // goback to start position. B has no multiple stage -+ this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Shape::kK, 0}); -+ -+ --gemm_k_iterations; -+ } -+ } -+ -+ // -+ // Epilogue -+ // -+ int32_t smem_base_offset = iterator_B.get_load_size() + (iterations % Base::kStages) * iterator_A.get_load_size(); -+ -+ destination_iterator.set_tile_index(iterations * split_k_slices); -+ -+ source_iterator.set_tile_index(iterations * split_k_slices); -+ -+ epilogue(output_op, destination_iterator, accum, source_iterator, smem_base_offset); -+ -+ ++iterations; -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h -new file mode 100644 -index 0000000..e9153c9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h -@@ -0,0 +1,261 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+template > -+class DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized { -+public: -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static int const kFilterSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * -+ sizeof_bits::value / 8; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ using Params = Depthwise2dFpropDirectConvFilterIteratorParams; -+ -+ protected: -+ -+ Conv2dProblemSize const &problem_size_; -+ Params const ¶ms_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int filter_k_; -+ int offset_trs_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ -+ -+ CUTLASS_HOST_DEVICE -+ DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_k_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_trs_[s] = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout, {Shape::kRow, Shape::kColumn}, kFilterSize); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * 8 / sizeof_bits::value; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // Do nothing because the filter is persistent in the SMEM -+ } -+ -+ /// Returns the coordinate in the filter tensor W that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int k = filter_k_ + iteration_vector_ * AccessType::kElements; -+ int trs = offset_trs_[iteration_strided_]; -+ -+ return TensorCoord(k, trs, 0 , 0); // As a 2D-matrix -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.K && -+ coord.h() < Shape::kColumn; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ TensorCoord coord = at(); -+ int64_t offset = coord.n(); -+ if (params_.is_convolution) { -+ offset += (Shape::kColumn - coord.h() - 1)* problem_size_.K; -+ } else { -+ offset += coord.h() * problem_size_.K; -+ } -+ -+ return reinterpret_cast(pointer_ + -+ offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines the filter size loaded by iterator -+ CUTLASS_HOST_DEVICE -+ int get_load_size() { -+ return kFilterSize; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // check whether runtime filter size is same as templated filter size. -+ if ((problem_size.R * problem_size.S) != Shape::kColumn) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h -new file mode 100644 -index 0000000..fd43e40 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h -@@ -0,0 +1,336 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to A operand -+ typename TransformA_ = NumericArrayConverter< -+ typename SmemIteratorA_::Element, -+ typename IteratorA_::Element, -+ IteratorA_::Fragment::kElements>, -+ /// -+ /// Transformation applied to A operand -+ typename TransformB_ = NumericArrayConverter< -+ typename SmemIteratorB_::Element, -+ typename IteratorB_::Element, -+ IteratorB_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class DepthwiseFpropPipelined : public gemm::threadblock::MmaBase { -+public: -+ -+ ///< Base class -+ using Base = gemm::threadblock::MmaBase; -+ -+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+private: -+ -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ DepthwiseFpropPipelined( -+ typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC &accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const &src_accum, ///< source accumulator tile -+ int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel -+ TransformA transform_A = TransformA(), ///< transformation applied to A fragment -+ TransformB transform_B = TransformB()) { ///< transformation applied to B fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentA tb_frag_A; -+ FragmentB tb_frag_B; -+ -+ tb_frag_A.clear(); -+ tb_frag_B.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA warp_frag_A[2]; -+ WarpFragmentB warp_frag_B[2]; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = 1; -+ // Depthwise specific -+ int channel_start_index = 0; -+ int rs_plane_idx = 0; -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ -+ // Reset interation index. -+ iterator_B.set_iteration_index(0); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ __syncthreads(); -+ -+ if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ -+ // Move to next set of filter groups. -+ channel_start_index += Base::kWarpGemmIterations; -+ } -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ } -+ -+ warp_mma(accum, warp_frag_A[warp_mma_k % 2], -+ warp_frag_B[warp_mma_k % 2], accum); -+ } -+ -+ rs_plane_idx = (rs_plane_idx == gemm_k_iterations_per_channel - 1) ? 0: (rs_plane_idx + 1); -+ -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h -new file mode 100644 -index 0000000..e839b9a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h -@@ -0,0 +1,229 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a directconv threadblock-scoped Depthwise kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy object describing MmaTensorOp -+template < -+ /// Warp-level GEMM operator (concept: gemm::warp::Mma) -+ typename Operator_, -+ /// Padding used for A operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingA_, -+ /// Padding used for B operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingB_, -+ /// -+ typename ThreadMapA_, -+ /// -+ typename ThreadMapB_, -+ /// Number of partitions of K dimension of GEMM -+ int PartitionsK = 1> -+struct DepthwiseDirectConvMmaPolicy { -+ /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) -+ using Operator = Operator_; -+ -+ /// Padding used for A operand in shared memory -+ using SmemPaddingA = SmemPaddingA_; -+ -+ /// Padding used for B operand in shared memory -+ using SmemPaddingB = SmemPaddingB_; -+ -+ using ThreadMapA = ThreadMapA_; -+ using ThreadMapB = ThreadMapB_; -+ -+ /// Number of partitions of K dimension -+ static int const kPartitionsK = PartitionsK; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class DepthwiseDirectConvMmaBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = cutlass::gemm:: -+ GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ /// kWarpGemmIterations could be even and odd. -+ static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ static_assert(kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape<1, // Not determined at compile-time :( -+ Shape::kN + Policy::SmemPaddingA::kRow>; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = MatrixShape; // Tile N = 64? -+ -+ public: -+ // -+ // Data members -+ // -+ -+ // Let persistent B matrix in front of dynamic matrix A -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ /// Buffer for A operand -+ /// Not be determined at compile-time -- Just to get a Smem start address. -+ AlignedBuffer operand_A; -+ public: -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { return TensorRefA{operand_A.data(), LayoutA()}; } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } -+ }; -+ -+ protected: -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ DepthwiseDirectConvMmaBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h -new file mode 100644 -index 0000000..dadd2b4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h -@@ -0,0 +1,952 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting depthwise related simt instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/warp/mma_depthwise_simt.h" -+ -+#include "cutlass/gemm/threadblock/mma_pipelined.h" -+#include "cutlass/gemm/threadblock/mma_singlestage.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/conv/threadblock/depthwise_mma_base.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h" -+ -+#include "cutlass/arch/cache_operation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+namespace detail { -+// -+// Convert a WarpShapeM which is the whole tile of elements into the number of elements (2D) held by -+// each partitions within warp. -+// The goal is for each thread's tile of elements to be as square as -+// possible for performance (4x4 will be faster than 2x8). -+template // The number of partitions within the warp -+struct SimtWarpShape { -+ // kP * kQ * WarpNumThreadsM = WarpShapeM -+ // If needed, enable more specializations. -+}; -+template <> -+struct SimtWarpShape<4, 4> { -+ static constexpr int kP = 1; -+ static constexpr int kQ = 1; -+}; -+ -+template <> -+struct SimtWarpShape<4, 2> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 1; -+}; -+ -+template <> -+struct SimtWarpShape<4, 1> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 2; -+}; -+ -+template <> -+struct SimtWarpShape<8, 1> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 4; -+}; -+template <> -+struct SimtWarpShape<8, 2> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 2; -+}; -+template <> -+struct SimtWarpShape<8, 4> { -+ static constexpr int kP = 1; -+ static constexpr int kQ = 2; -+}; -+ -+template <> -+struct SimtWarpShape<16, 1> { -+ static constexpr int kP = 4; -+ static constexpr int kQ = 4; -+}; -+template <> -+struct SimtWarpShape<16, 2> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 4; -+}; -+template <> -+struct SimtWarpShape<16, 4> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 2; -+}; -+ -+template -+struct SimtWarpShape<25, WarpNumThreadsM> { -+ static_assert(WarpNumThreadsM == 1, "WarpShapeM could not be evenly splited by threads"); -+ static constexpr int kP = 5; -+ static constexpr int kQ = 5; -+}; -+ -+template <> -+struct SimtWarpShape<32, 1> { -+ static constexpr int kP = 4; -+ static constexpr int kQ = 8; -+}; -+ -+template <> -+struct SimtWarpShape<32, 2> { -+ static constexpr int kP = 4; -+ static constexpr int kQ = 4; -+}; -+ -+template <> -+struct SimtWarpShape<32, 4> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 4; -+}; -+ -+} // namespace detail -+ -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Size of a warp-scoped per thread access -+ int kLaneAccessSizeA_ = 0, -+ /// Size of a warp-scoped per thread access -+ int kLaneAccessSizeB_ = 0, -+ /// Number of stages -+ int Stages = 2, -+ /// Operation performed by MMA -+ typename Operator = typename platform::conditional< -+ (platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::arch::OpMultiplyAdd>::type, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ bool IsComplex = false // (is_complex::value || is_complex::value) -+> -+struct DepthwiseMmaCoreWithLaneAccessSize; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of threadblock-scoped output tile -+ typename ThreadBlockOutputShape, -+ /// Shape of filter shape per threadblock -+ typename FilterShape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Size of a warp-scoped per thread access -+ int kLaneAccessSizeA_ = 0, -+ /// Size of a warp-scoped per thread access -+ int kLaneAccessSizeB_ = 0, -+ /// Number of stages -+ int Stages = 2, -+ /// Operation performed by MMA -+ typename Operator = typename platform::conditional< -+ (platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::arch::OpMultiplyAdd>::type, -+ /// Iterator algo type -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, -+ /// Stride ( MatrixShape ) -+ typename StrideShape = cutlass::MatrixShape<-1, -1>, -+ /// Dilation ( MatrixShape ) -+ typename DilationShape = cutlass::MatrixShape<-1, -1>, -+ /// Activation Shape loaded by threadblock -+ typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ bool IsComplex = false // (is_complex::value || is_complex::value) -+> -+struct DepthwiseDirectConvMmaCoreWithLaneAccessSize; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB, -+ bool IsComplex -+> -+struct DepthwiseMmaCoreWithLaneAccessSize< -+ Shape, WarpShape, InstructionShape, -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ OperatorClass, -1, -1, Stages, Operator, AccumulatorsInRowMajor, -+ CacheOpA, CacheOpB, TransformA, TransformB, IsComplex -+> : cutlass::gemm::threadblock::DefaultMmaCore< -+ Shape, WarpShape, InstructionShape, -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ OperatorClass, Stages, Operator, AccumulatorsInRowMajor, -+ CacheOpA, CacheOpB, TransformA, TransformB, IsComplex -+> {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Size of a warp-scoped per thread access (a value of -1 indicates the default) -+ int kLaneAccessSizeA_, -+ /// Size of a warp-scoped per thread access (a value of -1 indicates the default) -+ int kLaneAccessSizeB_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DepthwiseMmaCoreWithLaneAccessSize, -+ ElementA_, -+ layout::RowMajor, -+ ElementB_, -+ layout::ColumnMajor, -+ ElementC_, -+ LayoutC_, -+ arch::OpClassSimt, -+ kLaneAccessSizeA_, -+ kLaneAccessSizeB_, -+ 2, -+ Operator_> : public cutlass::gemm::threadblock::DefaultMmaCore, -+ ElementA_, -+ layout::RowMajor, -+ ElementB_, -+ layout::ColumnMajor, -+ ElementC_, -+ LayoutC_, -+ arch::OpClassSimt, -+ 2, -+ Operator_> { -+ using Base = cutlass::gemm::threadblock::DefaultMmaCore, -+ ElementA_, -+ layout::RowMajor, -+ ElementB_, -+ layout::ColumnMajor, -+ ElementC_, -+ LayoutC_, -+ arch::OpClassSimt, -+ 2, -+ Operator_>; -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ -+ static int const kLaneAccessSizeA = kLaneAccessSizeA_; -+ static int const kLaneAccessSizeB = kLaneAccessSizeB_; -+ -+ // Divisility requirements -+ static_assert( kLaneAccessSizeA > 0 && kLaneAccessSizeB > 0, -+ "Size of a warp-scoped per thread access should be larger then ZERO" ); -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = typename Base::WarpCount; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; -+ -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory are same as base class -+ // -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = cutlass::gemm::threadblock::detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = kLaneAccessSizeA / sizeof_bits::value; -+ static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static int const kPaddingM = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ static int const kPaddingN = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, // skew for A matrix to avoid SMEM bank conflicts -+ MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) -+ typename ThreadBlockOutputShape_, -+ /// Shape of filter shape per threadblock -+ typename FilterShape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Size of a warp-scoped per thread access -+ int kLaneAccessSizeA_, -+ /// Number of stages -+ int Stages_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, -+ ElementA_, -+ layout::RowMajor, -+ ElementB_, -+ layout::ColumnMajor, -+ ElementC_, -+ LayoutC_, -+ arch::OpClassSimt, -+ kLaneAccessSizeA_, -+ 128, -+ Stages_, -+ Operator_> { -+ using Shape = Shape_; -+ using FilterShape = FilterShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ -+ static int const kLaneAccessSizeB = 128; -+ -+ // Divisility requirements -+ static_assert( kLaneAccessSizeB > 0, -+ "Size of a warp-scoped per thread access should be larger then ZERO" ); -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = cutlass::gemm::GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ 1 -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // For Gmem load -+ static int const kElementsPerAccessA = 128 / sizeof_bits::value; -+ static int const kElementsPerAccessB = 128 / sizeof_bits::value; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, // Set kStrided = 1 because activation shape is runtime value. -+ kThreads, -+ kElementsPerAccessA -+ >; -+ -+ /// ThreadMap of iterator A -+ using SmemThreadMapA = IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< -+ MatrixShape<1, Shape::kN>, // set kRow is 1 because it is a runtime value -+ ElementA, -+ SmemLayoutA, -+ 0, -+ SmemThreadMapA, // was IteratorThreadMapA -+ true // Dynamic iterations. -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccessB -+ >; -+ -+ /// Transpose the ThreadMap of iterator B -+ using SmemThreadMapB = IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ SmemThreadMapB, // was IteratorThreadMapB -+ false // static iterations. -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ // Groups per threads -+ // Fp32: 2 groups -+ // Fp16: 2 groups -+ static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; -+ // Define the warp-level op -+ static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); -+ static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; -+ -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ -+ // Get output P, Q per thread -+ static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; -+ static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; -+ -+ static const int LaneLayout = 1; -+ static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; -+ static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); -+ -+ // Define the output tile computed by each thread -+ using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; -+ -+ // Fetch the channel with same access size -+ static const int LaneM = LaneN; -+ -+ // No paddings -+ static int const kPaddingM = 0; -+ static int const kPaddingN = 0; -+ -+ static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape -+ ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> -+ ThreadBlockOutputShape_, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, // skew for A matrix to avoid SMEM bank conflicts -+ MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts -+ IteratorThreadMapA, -+ IteratorThreadMapB, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) -+ typename ThreadBlockOutputShape_, -+ /// Shape of filter shape per threadblock -+ typename FilterShape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Size of a warp-scoped per thread access -+ int kLaneAccessSizeA_, -+ /// Number of stages -+ int Stages_, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Stride ( MatrixShape ) -+ typename StrideShape_, -+ /// Dilation ( MatrixShape ) -+ typename DilationShape_, -+ /// Activation Shape loaded by threadblock -+ typename ActivationShape_> -+struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, -+ ElementA_, -+ layout::RowMajor, -+ ElementB_, -+ layout::ColumnMajor, -+ ElementC_, -+ LayoutC_, -+ arch::OpClassSimt, -+ kLaneAccessSizeA_, -+ 128, -+ Stages_, -+ Operator_, -+ IteratorAlgorithm::kFixedStrideDilation, -+ StrideShape_, -+ DilationShape_, -+ ActivationShape_> { -+ using Shape = Shape_; -+ using FilterShape = FilterShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ using StrideShape = StrideShape_; -+ using DilationShape = DilationShape_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ using ActivationShape = ActivationShape_; -+ -+ static int const kLaneAccessSizeB = 128; -+ -+ // Divisility requirements -+ static_assert( kLaneAccessSizeB > 0, -+ "Size of a warp-scoped per thread access should be larger then ZERO" ); -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = cutlass::gemm::GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ 1 -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // For Gmem load -+ static int const kElementsPerAccessA = 128 / sizeof_bits::value; -+ static int const kElementsPerAccessB = 128 / sizeof_bits::value; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccessA -+ >; -+ -+ /// ThreadMap of iterator A -+ using SmemThreadMapA = IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 0, -+ SmemThreadMapA, // was IteratorThreadMapA -+ false // static iterations. -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccessB -+ >; -+ -+ /// Transpose the ThreadMap of iterator B -+ using SmemThreadMapB = IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ SmemThreadMapB, // was IteratorThreadMapB -+ false // static iterations. -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ // Groups per threads -+ // Fp32: 2 groups -+ // Fp16: 2 groups -+ static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; -+ // Define the warp-level op -+ static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); -+ static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; -+ -+ static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; -+ static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; -+ -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ -+ static const int LaneLayout = 1; -+ static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; -+ static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); -+ -+ // Define the output tile computed by each thread -+ using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; -+ -+ // Fetch the channel with same access size -+ static const int LaneM = LaneN; -+ -+ // No paddings -+ static int const kPaddingM = 0; -+ static int const kPaddingN = 0; -+ -+ static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape -+ ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> -+ ThreadBlockOutputShape, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ IteratorAlgorithm::kFixedStrideDilation, /// Iterator algo type -+ StrideShape, /// Stride ( MatrixShape ) -+ DilationShape, /// Dilation ( MatrixShape ) -+ ActivationShape /// Activation Shape loaded by threadblock -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, // skew for A matrix to avoid SMEM bank conflicts -+ MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts -+ IteratorThreadMapA, -+ IteratorThreadMapB, -+ WarpCount::kK -+ >; -+}; -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h -new file mode 100644 -index 0000000..cc33c69 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h -@@ -0,0 +1,802 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped fused activation's -+ scale+bias+relu and Implicit GEMM Convolution kernel. -+ -+ The original implicit gemm will store out-of-bound data as zeroes in the -+ shared memory because zeros into the tensor core, zeroes out of the tensor -+ cores. The result is remained the same. When fusing scale+bias+relu -+ into the mainloop, it is no longer true because -+ -+ 0 x scale + bias = bias -+ -+ which is no longer always 0. So, instead of storing zeroes, this fused -+ kernel stores the out-of-bound data as a special NaN (0x7eff), when applying -+ scale+bias+relu, the code is like -+ -+ if (data == 0x7eff) -+ data = 0; -+ else -+ data = scale+bias+relu(data, scale, bias); -+ -+ See include/cutlass/conv/warp/scale_bias_relu_transformation.h for the -+ elementwise computation. See include/cutlass/arch/memory_sm80.h for nan fill. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -+#include "cutlass/conv/warp/scale_bias_relu_transform.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Element type of scale and bias vectors -+ typename ElementScaleBias_, -+ /// Layout of scale and bias vectors -+ typename LayoutScaleBias_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// WarpIterator to load Scale or Bias vector from the shared memory -+ typename WarpIteratorScaleBias_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaFpropFusionBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Element type of scale and bias vectors -+ using ElementScaleBias = ElementScaleBias_; -+ -+ /// Layout of scale and bias vectors -+ using LayoutScaleBias = LayoutScaleBias_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ ///< WarpIterator to load Scale or Bias vector from the shared memory -+ using WarpIteratorScaleBias = WarpIteratorScaleBias_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = cutlass::gemm::GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the scale and bias vectors -+ using TensorRefScaleBias = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ static_assert(kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ static_assert((kWarpGemmIterations % 2) == 0, -+ "Inner loop iteration must be an even number."); -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the A scale and bias vectors in shared memory -+ using ShapeScaleBias = -+ MatrixShape<1 + Policy::SmemPaddingA::kRow, -+ 2 * Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ /// Buffer for A operand Scale and Bias -+ AlignedBuffer operand_A_scale_bias; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a layout object for the A scale and bias vectors -+ CUTLASS_DEVICE -+ static LayoutScaleBias LayoutScaleBias() { -+ return LayoutScaleBias::packed( -+ {ShapeScaleBias::kRow, ShapeScaleBias::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ -+ /// Returns a TensorRef to the A operand Scale vector -+ CUTLASS_HOST_DEVICE -+ TensorRefScaleBias operand_A_scale_bias_ref() { -+ return TensorRefScaleBias{operand_A_scale_bias.data(), LayoutScaleBias()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of A operand scale and bias vector -+ /// from shared memory -+ WarpIteratorScaleBias warp_tile_iterator_A_scale_bias_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaFpropFusionBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_A_scale_bias_( -+ shared_storage.operand_A_scale_bias_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorScaleBias_, -+ /// Iterates over vectors of scale and bias vector in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorScaleBias_, -+ /// Cache operation for scale/bias operand -+ cutlass::arch::CacheOperation::Kind CacheOpScaleBias, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// WarpIterator to load Scale or Bias vector from the shared memory -+ typename WarpIteratorScaleBias_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class ImplicitGemmFpropFusionMultistage -+ : public MmaFpropFusionBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorScaleBias = IteratorScaleBias_; -+ ///< WarpIterator to load Scale or Bias vector from the shared memory -+ using WarpIteratorScaleBias = WarpIteratorScaleBias_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ ///< Base class -+ using Base = MmaFpropFusionBase; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ using SmemIteratorScaleBias = SmemIteratorScaleBias_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpScaleBias = -+ CacheOpScaleBias; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ -+ using ElementC = typename Policy::Operator::ElementC; -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpLoadedFragmentScaleBias = -+ typename WarpIteratorScaleBias::Fragment; -+ -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of A operand scale vector to shared memory -+ SmemIteratorScaleBias smem_iterator_A_scale_bias_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ ImplicitGemmFpropFusionMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_A_scale_bias_(shared_storage.operand_A_scale_bias_ref(), -+ thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_A_scale_bias_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, -+ IteratorScaleBias &iterator_A_scale_bias, -+ IteratorB &iterator_B, int group_start_A = 0, -+ int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / 8; -+ -+ // Uses nan fill for out of bound data -+ cutlass::arch::cp_async_nan( -+ dst_ptr, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ // Async Copy for operand A scale and bias vector. Scale and bias vectors -+ // are small. One iteration is enough. -+ if (group_start_A == 0) { -+ typename IteratorScaleBias::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_scale_bias_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorScaleBias::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async( -+ dst_ptr, iterator_A_scale_bias.get(), iterator_A_scale_bias.valid()); -+ } -+ -+ iterator_B.set_iteration_index(group_start_B); -+ -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< iterator over scale and bias vectors in global memory -+ IteratorScaleBias iterator_A_scale_bias, -+ ///< initial value of accumulator -+ FragmentC const &src_accum, -+ ///< number of iterations per channel -+ int gemm_k_iterations_per_channel = 0, -+ ///< Imaginary strides used for planar-complex only - ignored here -+ int64_t imag_stride_A = 0, -+ int64_t imag_stride_B = 0) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / 8; -+ -+ // Uses Nan fill for out of bound data -+ cutlass::arch::cp_async_nan( -+ dst_ptr, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ ++this->smem_iterator_A_; -+ } -+ -+ // Async Copy for operand A scale and bias vectors. Scale and bias -+ // vectors are small. One iteration is enough. -+ { -+ typename IteratorScaleBias::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_scale_bias_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorScaleBias::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async( -+ dst_ptr, iterator_A_scale_bias.get(), iterator_A_scale_bias.valid()); -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ iterator_A_scale_bias.advance(); -+ iterator_B.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_A_scale_bias_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpLoadedFragmentScaleBias warp_loaded_frag_A_scale_bias[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ cutlass::conv::warp::FpropScaleBiasReluTransform -+ elementwise_transform; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_A_scale_bias_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_A_scale_bias_.load( -+ warp_loaded_frag_A_scale_bias[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_A_scale_bias_; -+ ++this->warp_tile_iterator_B_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance(iterator_A, iterator_A_scale_bias, iterator_B); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ elementwise_transform(warp_transformed_frag_A[0], -+ warp_loaded_frag_A_scale_bias[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_A_scale_bias_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_A_scale_bias_.load( -+ warp_loaded_frag_A_scale_bias[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_A_scale_bias_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) { -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_A_scale_bias[warp_mma_k % 2]); -+ } -+ -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ group_start_iteration_A = 0; -+ group_start_iteration_B = 0; -+ } else { -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ } -+ -+ copy_tiles_and_advance(iterator_A, iterator_A_scale_bias, iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ elementwise_transform( -+ warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A_scale_bias[(warp_mma_k + 1) % 2]); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ iterator_A_scale_bias.advance(); -+ iterator_B.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_A_scale_bias_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_A_scale_bias_.add_tile_offset( -+ {0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_A_scale_bias_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ } -+ } -+ -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h -new file mode 100644 -index 0000000..80dc435 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h -@@ -0,0 +1,542 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class ImplicitGemmMultistage : -+ public gemm::threadblock::MmaBase { -+public: -+ ///< Base class -+ using Base = gemm::threadblock::MmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ -+ using ElementC = typename Policy::Operator::ElementC; -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ ImplicitGemmMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance( -+ IteratorA &iterator_A, IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< initial value of accumulator -+ FragmentC const &src_accum, -+ ///< number of iterations per channel -+ int gemm_k_iterations_per_channel = 0, -+ ///< Imaginary strides used for planar-complex only - ignored here -+ int64_t imag_stride_A = 0, -+ int64_t imag_stride_B = 0) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ iterator_B.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance(iterator_A, iterator_B); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary -+ // accumulator and this temporary accumulator is added to the final -+ // accumulator once in every mainloop iteration. -+ plus plus_accum; -+ -+ FragmentC tmp_accum; -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ tmp_accum.clear(); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ group_start_iteration_A = 0; -+ group_start_iteration_B = 0; -+ } else { -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ } -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, -+ group_start_iteration_B); -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ warp_mma( -+ tmp_accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ tmp_accum -+ ); -+ -+ if (warp_mma_k == 0) { -+ accum = plus_accum(accum, tmp_accum); -+ tmp_accum.clear(); -+ } -+ } else { -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ } -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ iterator_B.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ } -+ } -+ -+ } -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ accum = plus_accum(accum, tmp_accum); -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h -new file mode 100644 -index 0000000..4a36ef5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h -@@ -0,0 +1,320 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to A operand -+ typename TransformA_ = NumericArrayConverter< -+ typename SmemIteratorA_::Element, -+ typename IteratorA_::Element, -+ IteratorA_::Fragment::kElements>, -+ /// -+ /// Transformation applied to A operand -+ typename TransformB_ = NumericArrayConverter< -+ typename SmemIteratorB_::Element, -+ typename IteratorB_::Element, -+ IteratorB_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class ImplicitGemmPipelined : public gemm::threadblock::MmaBase { -+public: -+ -+ ///< Base class -+ using Base = gemm::threadblock::MmaBase; -+ -+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+private: -+ -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ ImplicitGemmPipelined( -+ typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC &accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const &src_accum, ///< source accumulator tile -+ int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel -+ TransformA transform_A = TransformA(), ///< transformation applied to A fragment -+ TransformB transform_B = TransformB()) { ///< transformation applied to B fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentA tb_frag_A; -+ FragmentB tb_frag_B; -+ -+ tb_frag_A.clear(); -+ tb_frag_B.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA warp_frag_A[2]; -+ WarpFragmentB warp_frag_B[2]; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ } -+ -+ warp_mma(accum, warp_frag_A[warp_mma_k % 2], -+ warp_frag_B[warp_mma_k % 2], accum); -+ } -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h -new file mode 100644 -index 0000000..13b5a34 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h -@@ -0,0 +1,729 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped fused activation's scale+bias+relu and -+ Implicit GEMM Convolution kernel. -+ -+ The original implicit gemm will store out-of-bound data as zeroes in the -+ shared memory because zeros into the tensor core, zeroes out of the tensor -+ cores. The result is remained the same. When fusing scale+bias+relu -+ into the mainloop, it is no longer true because -+ -+ 0 x scale + bias = bias -+ -+ which is no longer always 0. So, instead of storing zeroes, this fused -+ kernel stores the out-of-bound data as a special NaN (0x7eff), when applying -+ scale+bias+relu, the code is like -+ -+ if (data == 0x7eff) -+ data = 0; -+ else -+ data = scale+bias+relu(data, scale, bias); -+ -+ The biggest difference compared with the fused Fprop and scale+bias+relu is -+ that scale and bias are loop invariant in Wgrad so that they only needs to -+ be loaded once before the mainloop. -+ -+ See include/cutlass/conv/warp/scale_bias_relu_transformation.h for the -+ elementwise computation. See include/cutlass/arch/memory_sm80.h for nan fill. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -+#include "cutlass/conv/warp/scale_bias_relu_transform.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Element type of scale and bias vectors -+ typename ElementScaleBias_, -+ /// Layout of scale and bias vectors -+ typename LayoutScaleBias_, -+ /// Element type of scale and bias vectors -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaWgradFusionBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Element type of scale and bias vectors -+ using ElementScaleBias = ElementScaleBias_; -+ -+ /// Layout of scale and bias vectors -+ using LayoutScaleBias = LayoutScaleBias_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = cutlass::gemm::GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ static_assert(kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ static_assert((kWarpGemmIterations % 2) == 0, -+ "Inner loop iteration must be an even number."); -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaWgradFusionBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorScaleBias_, -+ /// Iterates over vectors of scale and bias vector i -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class ImplicitGemmWgradFusionMultistage -+ : public MmaWgradFusionBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorScaleBias = IteratorScaleBias_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ ///< Base class -+ using Base = MmaWgradFusionBase; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ -+ using ElementC = typename Policy::Operator::ElementC; -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ static int const kBBufferSize = -+ ((sizeof(typename Operator::ElementC) == 4) && -+ ((platform::is_same::value && -+ platform::is_same::value)) && -+ (Operator::Shape::kM >= 64 && Operator::Shape::kN >= 64)) -+ ? 1 -+ : 2; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpLoadedFragmentScaleBias = typename IteratorScaleBias::Fragment; -+ -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ int warp_idx_m_; -+ -+ int warp_idx_n_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ ImplicitGemmWgradFusionMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; -+ warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, -+ IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ -+ iterator_A.set_iteration_index(group_start_A); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B); -+ -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / 8; -+ -+ // Uses nan fill for out of bound data -+ cutlass::arch::cp_async_nan( -+ dst_ptr, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< iterator over scale and bias vectors in global memory -+ IteratorScaleBias iterator_B_scale_bias, -+ ///< initial value of accumulator -+ FragmentC const &src_accum, -+ ///< number of iterations per channel -+ int gemm_k_iterations_per_channel = 0, -+ ///< Imaginary strides used for planar-complex only - ignored here -+ int64_t imag_stride_A = 0, -+ int64_t imag_stride_B = 0) { -+ -+ // -+ // Prologue -+ // -+ -+ WarpLoadedFragmentScaleBias warp_loaded_frag_B_scale_bias; -+ iterator_B_scale_bias.add_tile_offset({0, warp_idx_n_}); -+ iterator_B_scale_bias.load(warp_loaded_frag_B_scale_bias); -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / 8; -+ -+ // Uses Nan fill for out of bound data -+ cutlass::arch::cp_async_nan( -+ dst_ptr, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ iterator_B.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[Detail::kBBufferSize]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[Detail::kBBufferSize]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ cutlass::conv::warp::WgradScaleBiasReluTransform -+ elementwise_transform; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance(iterator_A, iterator_B); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ elementwise_transform(warp_transformed_frag_B[0], -+ warp_loaded_frag_B_scale_bias); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ if (Detail::kBBufferSize == 2) { -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize]); -+ ++this->warp_tile_iterator_A_; -+ } -+ -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) { -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % Detail::kBBufferSize], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % Detail::kBBufferSize], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ elementwise_transform(warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_B_scale_bias); -+ } -+ -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % Detail::kBBufferSize], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ -+ if (Detail::kBBufferSize == 1) { -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ ++this->warp_tile_iterator_A_; -+ -+ } -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ elementwise_transform( -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B_scale_bias); -+ } -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ group_start_iteration_A = 0; -+ group_start_iteration_B = 0; -+ } else { -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ } -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ iterator_B.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ } -+ } -+ -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h -new file mode 100644 -index 0000000..8b5b111 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h -@@ -0,0 +1,470 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of scale and bias vectors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedScaleBiasVectorAccessIterator -+/// -+template -+class PredicatedScaleBiasVectorAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data. -+/// -+template -+class PredicatedScaleBiasVectorAccessIterator { -+ public: -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kElementsPerAccess = 128 / sizeof_bits::value; -+ static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess; -+ -+ using AccessType = AlignedArray; -+ -+ using Params = PredicatedScaleBiasVectorAccessIteratorParams; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params const ¶ms_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ int problem_size_trs; -+ int problem_size_c; -+ int filter_trs_; -+ -+ TensorCoord thread_offset_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Extent of tensor -+ Conv2dProblemSize const &problem_size, -+ /// Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : params_(params), -+ problem_size_trs(problem_size.R * problem_size.S), -+ problem_size_c(problem_size.C), -+ filter_trs_(0) { -+ pointer_ = (thread_id < kThreads) -+ ? reinterpret_cast( -+ const_cast(scale_pointer)) -+ : reinterpret_cast( -+ const_cast(bias_pointer)); -+ -+ // Per-thread offset in logical coordinates of tensor -+ int thread_base = (thread_id < kThreads) ? 0 : kThreads; -+ -+ thread_offset_ = -+ threadblock_offset + -+ TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Extent of tensor -+ Conv3dProblemSize const &problem_size, -+ /// Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : params_(params), -+ problem_size_trs(problem_size.T * problem_size.R * problem_size.S), -+ problem_size_c(problem_size.C), -+ filter_trs_(0) { -+ pointer_ = (thread_id < kThreads) -+ ? reinterpret_cast( -+ const_cast(scale_pointer)) -+ : reinterpret_cast( -+ const_cast(bias_pointer)); -+ -+ // Per-thread offset in logical coordinates of tensor -+ int thread_base = (thread_id < kThreads) ? 0 : kThreads; -+ -+ thread_offset_ = -+ threadblock_offset + -+ TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Extent of tensor -+ Conv2dProblemSize const &problem_size, -+ /// Pointer to start of scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to start of scale vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedScaleBiasVectorAccessIterator(params, problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Extent of tensor -+ Conv3dProblemSize const &problem_size, -+ /// Pointer to start of scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to start of scale vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedScaleBiasVectorAccessIterator(params, problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) {} -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ thread_offset_ = -+ thread_offset_ + -+ TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ return reinterpret_cast( -+ pointer_ + -+ (thread_offset_.contiguous() * sizeof_bits::value / 8)); -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator &operator++() { -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ ++filter_trs_; -+ if (filter_trs_ == problem_size_trs) { -+ filter_trs_ = 0; -+ add_tile_offset(TensorCoord(1, 0)); -+ } -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_DEVICE -+ PredicatedScaleBiasVectorAccessIterator operator++(int) { -+ PredicatedScaleBiasVectorAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ uint32_t enabled = 0; -+ -+#if defined(_MSC_VER) || (__CUDACC_VER_MAJOR__ < 11) -+ enabled = threadIdx.x < kThreads * 2; -+#else -+ asm volatile( -+ "{\n" -+ " .reg .u32 tid_reg;\n" -+ " .reg .pred p;\n" -+ " mov.u32 tid_reg, %%tid.x;\n" -+ " setp.lt.u32 p, tid_reg, %1;\n" -+ " selp.u32 %0, 1, 0, p;\n" -+ "}\n" : "+r"(enabled) :"n"(kThreads * 2)); -+#endif -+ -+ return ((thread_offset_.contiguous() < problem_size_c) && enabled); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedScaleBiasVectorAccessIterator { -+ public: -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; -+ -+ using Params = PredicatedScaleBiasVectorAccessIteratorParams; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Extent of tensor -+ Conv2dProblemSize const &problem_size, -+ ///< Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ ///< Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params, problem_size, scale_pointer, bias_pointer, -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Extent of tensor -+ Conv3dProblemSize const &problem_size, -+ ///< Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ ///< Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params, problem_size, scale_pointer, bias_pointer, -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Conv2dProblemSize const &problem_size, ///< Extent of tensor -+ ConstPointer scale_pointer, ///< Pointer to the start of the scale vector -+ ConstPointer bias_pointer, ///< Pointer to the start of the bias vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedScaleBiasVectorAccessIterator(params, problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Conv3dProblemSize const &problem_size, ///< Extent of tensor -+ ConstPointer scale_pointer, ///< Pointer to the start of the scale vector -+ ConstPointer bias_pointer, ///< Pointer to the start of the bias vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedScaleBiasVectorAccessIterator(params, problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// threadblock tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator operator++(int) { -+ PredicatedScaleBiasVectorAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ iterator_.advance(); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h -new file mode 100644 -index 0000000..98b4c82 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h -@@ -0,0 +1,371 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of scale and bias vectors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedScaleBiasVectorIterator -+/// -+template -+class PredicatedScaleBiasVectorIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for wgrad pitch-linear data. -+/// -+template -+class PredicatedScaleBiasVectorIterator { -+ public: -+ -+ using WarpShape = WarpShape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kElementsPerAccess = 1; -+ -+ using AccessType = AlignedArray; -+ -+ static int const kIterations = WarpShape::kContiguous / 8; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ using Params = Conv2dWgradActivationIteratorOptimizedParams; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params const ¶ms_; -+ -+ /// Internal pointer to first access of tile -+ ConstPointer scale_pointer_; -+ ConstPointer bias_pointer_; -+ -+ /// Size of tensor -+ Conv2dProblemSize problem_size_; -+ -+ int32_t thread_offset_; -+ -+ // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k -+ int32_t filter_c_[kIterations]; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Extent of tensor -+ Conv2dProblemSize const &problem_size, -+ /// Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : params_(params), -+ problem_size_(problem_size), -+ scale_pointer_(scale_pointer), -+ bias_pointer_(bias_pointer) { -+ -+ thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4; -+ } -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Extent of tensor -+ Conv2dProblemSize const &problem_size, -+ /// Pointer to start of scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to start of scale vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedScaleBiasVectorIterator(params, problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole warp tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ -+ thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int c = 0; c < kIterations; ++c) { -+ int rsc_offset = thread_offset_ + c * 8; -+ -+ int residual, tmp; -+ params_.sc_divmod(tmp, residual, rsc_offset); -+ params_.c_divmod(tmp, filter_c_[c], residual); -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ frag.fill(__float2half2_rn(0.0f)); -+ __half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag); -+ -+ // load scale -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ -+ cutlass::arch::global_load< -+ __half, -+ sizeof(AccessType) -+ >( -+ frag_ptr[c * 2].x, -+ scale_pointer_ + filter_c_[c], -+ true -+ ); -+ } -+ -+ // load bias -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ -+ cutlass::arch::global_load< -+ __half, -+ sizeof(AccessType) -+ >( -+ frag_ptr[c * 2 + 1].x, -+ bias_pointer_ + filter_c_[c], -+ true -+ ); -+ } -+ -+ // duplicate scale -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ frag_ptr[c * 2].y = frag_ptr[c * 2].x; -+ } -+ -+ // duplicate bias -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x; -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedScaleBiasVectorIterator { -+ public: -+ -+ using WarpShape = WarpShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedScaleBiasVectorIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; -+ using Fragment = typename UnderlyingIterator::Fragment; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedScaleBiasVectorIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Conv2dProblemSize const &problem_size, Layout const &layout) -+ : params_(problem_size, layout::TensorNHWC(0, 0, 0)){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Extent of tensor -+ Conv2dProblemSize const &problem_size, -+ ///< Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ ///< Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, problem_size, scale_pointer, bias_pointer, -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Conv2dProblemSize const &problem_size, ///< Extent of tensor -+ ConstPointer scale_pointer, ///< Pointer to the start of the scale vector -+ ConstPointer bias_pointer, ///< Pointer to the start of the bias vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedScaleBiasVectorIterator(params, problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// threadblock tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ iterator_.load(frag); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h -new file mode 100644 -index 0000000..0ed0b24 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h -@@ -0,0 +1,193 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implements several possible threadblock-swizzling functions mapping blockIdx to -+ Convolution problems. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/platform/platform.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+CUTLASS_HOST_DEVICE -+static int get_strided_dgrad_tile_m( -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ int tile_size_m) { -+ -+ // CTAs in M dimension per starting filter position -+ int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, tile_size_m); -+ -+ // Inflate number of CTAs in M dimension to cover every strating filter position even those that -+ // may fall out of valid MMA (Dy * w) but are needed to apply epilogue (beta * Dx_source) -+ // and point-wise fusion -+ int tile_m = tile_m_per_filter * int(problem_size.stride().product()); -+ -+ // There is a possible performance optimization here that leads up to 2x speeds than the current -+ // CUTLASS strided dgrad performance for stride > filter, i.e., stride={2x2} and filter={1x1}) -+ // -+ // * Optimization * -+ // Only launch CTAs in M dimenstion which contribute to a row in Dx output -+ // -+ // -+ // * Constraints * -+ // (A) stride <= filter, for example, stride={2x2} and filter={3x3}: -+ // - (A.1): There are no constraints for this case and the optimization does -+ // affect this case functionality or performance. -+ // (B) stride > filter, for example, stride={2x2} and filter={1x1}: -+ // - (B.1): Dx output tensor should be zero initialized -+ // - (B.2): The kernel epilogue cannot apply beta. Thus, beta should be zero -+ -+ return tile_m; -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Threadblock swizzling function for strided dgrad convolution -+struct StridedDgradHorizontalThreadblockSwizzle : -+ public gemm::threadblock::GemmHorizontalThreadblockSwizzle { -+ -+ using Base = gemm::threadblock::GemmHorizontalThreadblockSwizzle; -+ -+ CUTLASS_HOST_DEVICE -+ StridedDgradHorizontalThreadblockSwizzle() { } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) -+ CUTLASS_HOST_DEVICE -+ gemm::GemmCoord get_tiled_shape( -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ gemm::GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ gemm::GemmCoord implicit_gemm_problem_size = -+ cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); -+ -+ // compute number of tiles in m dimension -+ int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); -+ -+ // compute number of tiles in n dimenstion -+ int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); -+ -+ return gemm::GemmCoord( -+ tile_m, -+ tile_n, -+ split_k_slices); -+ } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) -+ private: -+ using Base::get_tiled_shape; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Threadblock swizzling function for strided dgrad convolution -+template -+struct StridedDgradIdentityThreadblockSwizzle : -+ public gemm::threadblock::GemmIdentityThreadblockSwizzle { -+ -+ using Base = gemm::threadblock::GemmIdentityThreadblockSwizzle; -+ -+ CUTLASS_HOST_DEVICE -+ StridedDgradIdentityThreadblockSwizzle() { } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) -+ CUTLASS_HOST_DEVICE -+ gemm::GemmCoord get_tiled_shape( -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ gemm::GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ gemm::GemmCoord implicit_gemm_problem_size = -+ cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); -+ -+ // compute number of tiles in m dimension -+ int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); -+ -+ // compute number of tiles in n dimenstion -+ int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); -+ -+ return gemm::GemmCoord( -+ tile_m, -+ tile_n, -+ split_k_slices); -+ } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) -+ private: -+ using Base::get_tiled_shape; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for GEMMs -+template -+struct DepthwiseDirect2dConvIdentityThreadblockSwizzle -+ : public gemm::threadblock::GemmIdentityThreadblockSwizzle { -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvIdentityThreadblockSwizzle() {} -+ -+ /// Returns the shape of the problem in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ gemm::GemmCoord get_tiled_shape(cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ gemm::GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ gemm::GemmCoord implicit_gemm_problem_size = -+ cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); -+ -+ return gemm::GemmCoord(1, -+ (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ split_k_slices); -+ } -+}; -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h b/3rdparty/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h -new file mode 100644 -index 0000000..ae49cc1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h -@@ -0,0 +1,381 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/thread/depthwise_mma.h" -+ -+ -+#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaDepthwiseSimt -+ : public cutlass::gemm::warp:: -+ MmaSimt { -+ using Base = cutlass::gemm::warp:: -+ MmaSimt; -+ -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Hard-coded for now -+ using ArchTag = arch::Sm50; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+public: -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = cutlass::conv::warp::DepthwiseMmaSimtTileIterator< -+ MatrixShape, -+ cutlass::gemm::Operand::kB, -+ ElementB, -+ LayoutB, -+ Policy, -+ PartitionsK, -+ Shape::kK -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentB = FragmentB; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaDepthwiseSimt():Base() {} -+}; -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Shape of filter shape per threadblock - concept: gemm::GemmShape -+ typename FilterShape_, -+ /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> -+ typename ThreadOutputShape_, -+ /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> -+ typename ThreadBlockOutputShape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Iterator algo type -+ conv::IteratorAlgorithm IteratorAlgorithm_ = IteratorAlgorithm::kAnalytic, -+ /// Stride ( MatrixShape ) -+ typename StrideShape_ = cutlass::MatrixShape<-1, -1>, -+ /// Dilation ( MatrixShape ) -+ typename DilationShape_ = cutlass::MatrixShape<-1, -1>, -+ /// Activation Shape loaded by threadblock -+ typename ActivationShape_ = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaDepthwiseDirectConvSimt { -+ public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Shape of filter shape per threadblock - concept: gemm::GemmShape -+ using FilterShape = FilterShape_; -+ -+ /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> -+ using ThreadOutputShape = ThreadOutputShape_; -+ -+ /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Iterator algo type -+ static conv::IteratorAlgorithm const IteratorAlgorithm = IteratorAlgorithm_; -+ -+ /// Stride ( MatrixShape ) -+ using StrideShape = StrideShape_; -+ -+ /// Dilation ( MatrixShape ) -+ using DilationShape = DilationShape_; -+ -+ /// Activation Shape loaded by threadblock -+ using ActivationShape = ActivationShape_; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Hard-coded for now -+ using ArchTag = arch::Sm50; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value || -+ platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) && -+ platform::is_same< ElementA, int8_t >::value && -+ platform::is_same< ElementB, int8_t >::value; -+ -+ using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type; -+ -+ /// Thread-level matrix multiply accumulate operator -+ using ThreadMma = cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct< -+ cutlass::gemm::GemmShape< -+ Shape::kM / Policy::WarpShape::kRow, // number of output pixels proccessed per thread -+ Shape::kN / Policy::WarpShape::kColumn, // number of channels proccessed per thread -+ 1>, -+ ElementA, -+ ElementB, -+ ElementC, -+ arch::OpMultiplyAdd, -+ dp4a_type -+ >; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename ThreadMma::ArchMmaOperator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Shape of the underlying instruction -+ using InstructionShape = cutlass::gemm::GemmShape<1,1,use_dp4a ? 4 : 1>; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = cutlass::conv::warp::DepthwiseDirect2dConvSimtTileIterator< -+ MatrixShape, // per warp -+ FilterShape, -+ ThreadOutputShape, -+ ThreadBlockOutputShape, -+ cutlass::gemm::Operand::kA, -+ ElementA, -+ Policy, -+ IteratorAlgorithm, -+ StrideShape, -+ DilationShape, -+ ActivationShape, -+ PartitionsK, -+ Shape::kK -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = FragmentA; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = cutlass::gemm::warp::MmaSimtTileIterator< -+ MatrixShape<1, Shape::kN>, -+ cutlass::gemm::Operand::kB, -+ ElementB, -+ LayoutB, -+ Policy, -+ PartitionsK, -+ Shape::kK -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentB = FragmentB; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< -+ MatrixShape, -+ cutlass::gemm::Operand::kC, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Storage for C tile -+ using FragmentC = typename ThreadMma::FragmentC; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaDepthwiseDirectConvSimt() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA a, -+ FragmentB b, -+ FragmentC const &c, int group_idx = 0) const { -+ -+ ThreadMma mma; -+ -+ mma(d, a, b, c); -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ //TODO: Implement this -+ dst_A = A; -+ dst_B = B; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace conv -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h b/3rdparty/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h -new file mode 100644 -index 0000000..b750a4b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h -@@ -0,0 +1,862 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT -+ instructions -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/conv/convolution.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Iterates over operands to warp-level matrix multiply operations targeting SIMT instructions -+/// -+/// concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ cutlass::gemm::Operand Operand, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK = 1, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize = 1 -+> -+class DepthwiseMmaSimtTileIterator; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for B operands of row-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize> -+class DepthwiseMmaSimtTileIterator -+ : public cutlass::gemm::warp::MmaSimtTileIterator { -+ -+ using Base = cutlass::gemm::warp::MmaSimtTileIterator; -+ public: -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kB; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::RowMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = typename Base::TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = typename Base::ThreadShape; -+ -+ /// Number of individual loads -+ using Iterations = typename Base::Iterations; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+ static_assert(Policy::LaneMmaShape::kN == 1, "Each thread should be 1 element per LDS along the k-dim"); -+ -+private: -+ -+ MatrixCoord lane_offset_; -+ int channel_idx_; -+ int base_channel_idx_; -+ int warps_n_; -+ -+ public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ DepthwiseMmaSimtTileIterator():Base() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ DepthwiseMmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) : Base(ref, lane_id) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ warps_n_ = -1; -+ channel_idx_ = 0; -+ base_channel_idx_ = 0; -+ lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ DepthwiseMmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ if(warps_n_ == -1){ -+ warps_n_ = coord.column(); -+ } -+ -+ Base::add_tile_offset(coord); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kRow; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ -+ void const *ptr = this->ref_.data() + -+ this->ref_.offset({-(channel_idx_ - base_channel_idx_), -+ n * Policy::WarpShape::kColumn}) + -+ pointer_offset / Policy::LaneMmaShape::kN; -+ -+ // Base_k of a warp + Base_k of current threads. -+ int thread_k_base_idx = -+ warps_n_ * Shape::kColumn / Policy::LaneMmaShape::kN + lane_offset_.column(); -+ -+ if (channel_idx_ + k == thread_k_base_idx + n * Policy::WarpShape::kColumn) { -+ // Depthwise kernel would only do computation when channel == k. -+ // Loads an element when the current computation channel == the k corresponding to this thread. -+ arch::shared_load(dst_ptr[n + k * Iterations::kColumn], ptr); -+ } else { -+ // Reduce SMEM load -+ dst_ptr[n + k * Iterations::kColumn].fill(Element(0)); -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ if(k_group % PartitionGroupSize == 0 && k_group != 0){ -+ base_channel_idx_ = k_group; -+ } -+ channel_idx_ = k_group; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Size of filter (concept: gemm::GemmShape) -+ typename FilterShape_, -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename ThreadOutputShape_, -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename ThreadBlockOutputShape_, -+ /// Operand identity -+ cutlass::gemm::Operand Operand, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Iterator algo type -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, -+ /// Stride ( MatrixShape ) -+ typename StrideShape = cutlass::MatrixShape<-1, -1>, -+ /// Dilation ( MatrixShape ) -+ typename DilationShape = cutlass::MatrixShape<-1, -1>, -+ /// Activation Shape loaded by threadblock -+ typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK = 1, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize = 1> -+class DepthwiseDirect2dConvSimtTileIterator; -+ -+ -+/// Specialization for A operands of row-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Size of filter (concept: gemm::GemmShape) -+ typename FilterShape_, -+ /// Size of the matrix to load (concept: TensorNHWC) -+ typename ThreadOutputShape_, -+ /// Size of the matrix to load (concept: TensorNHWC) -+ typename ThreadBlockOutputShape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Iterator algo type -+ conv::IteratorAlgorithm IteratorAlgorithm, -+ /// Stride ( MatrixShape ) -+ typename StrideShape, -+ /// Dilation ( MatrixShape ) -+ typename DilationShape, -+ /// Activation Shape loaded by threadblock -+ typename ActivationShape, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize> -+class DepthwiseDirect2dConvSimtTileIterator { -+ public: -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Shape of filter (concept: gemm::GemmShape) -+ using FilterShape = FilterShape_; -+ -+ /// Shape of tile to load (concept: TensorNHWC) -+ using ThreadOutputShape = ThreadOutputShape_; -+ -+ /// Shape of tile to load (concept: TensorNHWC) -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ -+ /// Operand tag -+ static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::RowMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kRow % Policy::WarpShape::kRow), -+ "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ -+// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ ThreadOutputShape::kNHW, // Output tile shape Computed by current threads -+ ThreadOutputShape::kC -+ >; -+ -+ static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kN -+ >; -+ -+ using ThreadTileCount = MatrixShape< -+ ThreadBlockOutputShape::kH / ThreadOutputShape::kH, -+ ThreadBlockOutputShape::kW / ThreadOutputShape::kW -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+protected: -+ -+ /// Internal reference -+ cutlass::TensorRef, layout::RowMajor> ref_; -+ -+ int activation_offset[ThreadOutputShape::kH][ThreadOutputShape::kW][Iterations::kColumn]; -+ int iterator_r_; -+ int iterator_s_; -+ int iterator_offset_; -+ -+ int inc_next_s_ ; -+ int inc_next_r_ ; -+ -+ MatrixCoord lane_offset_; -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ // Set channel offset -+ lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); -+ -+ ref.add_coord_offset(lane_offset_); -+ -+ ref_.reset(reinterpret_cast *>(ref.data()), -+ ref.stride(0) / Policy::LaneMmaShape::kN); -+ -+ iterator_r_ = 0; -+ iterator_s_ = 0; -+ iterator_offset_ = 0; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ template -+ CUTLASS_HOST_DEVICE -+ void setup_initial_status(Params const& params) { -+ -+ inc_next_s_ = params.inc_next[0]; -+ inc_next_r_ = params.inc_next[1]; -+ -+ // Get base HW offset of current threads -+ int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); -+ int base_p_ = -+ (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH; -+ int base_q_ = -+ (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < ThreadOutputShape::kH; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int q = 0; q < ThreadOutputShape::kW; ++q) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < Iterations::kColumn; ++col) { -+ int base_w = (base_q_ + q) * params.stride[0]; -+ int base_h = (base_p_ + p) * params.stride[1]; -+ -+ int offset = base_h * params.activation_tile_w + base_w; -+ activation_offset[p][q][col] = offset; -+ } -+ } -+ } -+ } -+ -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ // Set warp row and col start -+ lane_offset_ = MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ void advance(int32_t pointer_offset) { -+ ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); -+ iterator_s_ = 0; -+ iterator_r_ = 0; -+ iterator_offset_ = 0; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &operator++() { -+ ++iterator_s_; -+ if (iterator_s_ < FilterShape::kColumn) { -+ iterator_offset_ += inc_next_s_; -+ -+ return *this; -+ } -+ -+ iterator_s_ = 0; -+ -+ ++iterator_r_; -+ if (iterator_r_ < FilterShape::kRow) { -+ iterator_offset_ += inc_next_r_; -+ return *this; -+ } -+ -+ iterator_r_ = 0; -+ iterator_offset_ = 0; -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator & operator--() { -+ // Do nothing -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < ThreadOutputShape::kH; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int q = 0; q < ThreadOutputShape::kW; ++q) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ void const *ptr = ref_.data() + -+ ref_.offset({activation_offset[p][q][n] + (iterator_offset_), -+ n * Policy::WarpShape::kColumn}) + -+ pointer_offset / Policy::LaneMmaShape::kN; -+ arch::shared_load(dst_ptr[n + q + p * ThreadOutputShape::kW], ptr); -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ // Do nothing at present. -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, Index pointer_offset) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Specialization for A operands of row-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Size of filter (concept: gemm::GemmShape) -+ typename FilterShape_, -+ /// Size of the matrix to load (concept: TensorNHWC) -+ typename ThreadOutputShape_, -+ /// Size of the matrix to load (concept: TensorNHWC) -+ typename ThreadBlockOutputShape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Stride ( MatrixShape ) -+ typename StrideShape_, -+ /// Dilation ( MatrixShape ) -+ typename DilationShape_, -+ /// Activation Shape loaded by threadblock -+ typename ActivationShape_, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize> -+class DepthwiseDirect2dConvSimtTileIterator { -+ public: -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Shape of filter (concept: gemm::GemmShape) -+ using FilterShape = FilterShape_; -+ -+ /// Shape of tile to load (concept: TensorNHWC) -+ using ThreadOutputShape = ThreadOutputShape_; -+ -+ /// Shape of tile to load (concept: TensorNHWC) -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ -+ /// Stride ( MatrixShape ) -+ using StrideShape = StrideShape_; -+ -+ /// Dilation ( MatrixShape ) -+ using DilationShape = DilationShape_; -+ -+ /// Activation Shape loaded by threadblock -+ using ActivationShape = ActivationShape_; -+ -+ /// Operand tag -+ static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::RowMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kRow % Policy::WarpShape::kRow), -+ "The warp-level GEMM M size must be divisible by the number of threads arranged " -+ "along the M dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, -+ "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ -+ // Activations loaded by threadblock -+ static int const ThreadActivationShapeH = (ThreadOutputShape::kH - 1) * StrideShape::kRow + -+ (FilterShape::kRow - 1) * DilationShape::kRow + 1; -+ -+ static int const ThreadActivationShapeW = (ThreadOutputShape::kW - 1) * StrideShape::kColumn + -+ (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; -+ -+ using ThreadActivationShape = cutlass::conv:: -+ TensorNHWCShape<1, ThreadActivationShapeH, ThreadActivationShapeW, ThreadOutputShape::kC>; -+ -+ // Thread-level shape of a fragment -+ using ThreadShape = -+ MatrixShape; -+ -+ static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = -+ MatrixShape; -+ -+ using ThreadTileCount = MatrixShape; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+ protected: -+ /// Internal reference -+ cutlass::TensorRef, layout::RowMajor> ref_; -+ -+ Array -+ activation[ThreadActivationShape::kH][ThreadActivationShape::kW][Iterations::kColumn]; -+ int iterator_r_; -+ int iterator_s_; -+ -+ -+ MatrixCoord lane_offset_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator() {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator(TensorRef ref, int lane_id) { -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ // Set channel offset -+ lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); -+ -+ ref.add_coord_offset(lane_offset_); -+ -+ ref_.reset(reinterpret_cast *>(ref.data()), -+ ref.stride(0) / Policy::LaneMmaShape::kN); -+ -+ iterator_r_ = 0; -+ iterator_s_ = 0; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ template -+ CUTLASS_HOST_DEVICE void setup_initial_status( -+ Params const ¶ms) { -+ -+ // Get base HW offset of current threads -+ int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); -+ int base_h = -+ (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH * StrideShape::kRow; -+ int base_w = -+ (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW * StrideShape::kColumn; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int h = 0; h < ThreadActivationShape::kH; ++h) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int w = 0; w < ThreadActivationShape::kW; ++w) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < Iterations::kColumn; ++col) { -+ int offset = (base_h + h) * ActivationShape::kW + (base_w + w); -+ -+ void const *ptr = ref_.data() + ref_.offset({offset, col * Policy::WarpShape::kColumn}); -+ arch::shared_load(activation[h][w][col], ptr); -+ } -+ } -+ } -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ // Set warp row and col start -+ lane_offset_ = -+ MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ void advance(int32_t pointer_offset) { -+ ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); -+ iterator_s_ = 0; -+ iterator_r_ = 0; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &operator++() { -+ ++iterator_s_; -+ if (iterator_s_ < FilterShape::kColumn) { -+ return *this; -+ } -+ -+ iterator_s_ = 0; -+ -+ ++iterator_r_; -+ if (iterator_r_ < FilterShape::kRow) { -+ return *this; -+ } -+ -+ iterator_r_ = 0; -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &operator--() { -+ // Do nothing -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < ThreadOutputShape::kH; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int q = 0; q < ThreadOutputShape::kW; ++q) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ const int h = p * StrideShape::kRow + iterator_r_ * DilationShape::kRow; -+ const int w = q * StrideShape::kColumn + iterator_s_ * DilationShape::kColumn; -+ -+ dst_ptr[n + q + p * ThreadOutputShape::kW] = activation[h][w][n]; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ // Do nothing at present. -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, Index pointer_offset) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+} // namespace warp -+} // namespace conv -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h b/3rdparty/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h -new file mode 100644 -index 0000000..a1a4dff ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level per channel scale+bias+relu before -+ matrix multiply-accumulate operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct FpropScaleBiasReluTransform { -+ -+ using T = typename FragmentActivations::Element; -+ -+ static int const NumActivations = FragmentActivations::kElements; -+ static int const NumScaleBias = FragmentScaleBias::kElements; -+ static int const MmaElements = 2; -+ // One element has one scale and one bias -+ static int const MmaScaleBiasPair = 2; -+ // 16816 has 2 columns -+ static int const MmaCols = 2; -+ -+ using MmaOperand = Array; -+ using ScaleBiasOperand = Array; -+ -+ CUTLASS_DEVICE -+ void transform(MmaOperand &activations, ScaleBiasOperand const &scale_bias) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+ uint32_t *ptr_activations = reinterpret_cast(&activations); -+ uint32_t const *ptr_scale_bias = reinterpret_cast(&scale_bias); -+ -+ // Apply per channel scale+bias+relu if the data is not a special NaN -+ // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. -+ -+ // We assumes the pair of FP16 are either both inbound or both out-of-bound. -+ // It requires C to be an even number. -+ asm volatile( -+ "{\n\t" -+ " .reg .pred %%p;\n\t" -+ " .reg .b32 t1;\n\t" -+ " setp.eq.u32 %%p, %2, %4;\n\t" -+ " fma.rn.f16x2.relu t1, %1, %2, %3;\n" -+ " selp.u32 %0, 0, t1, %%p;\n\t" -+ "}\n" -+ : "=r"(ptr_activations[0]) -+ : "r"(ptr_scale_bias[0]), "r"(ptr_activations[0]), -+ "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16x2)); -+#else -+ // TODO: write emulation code -+ assert(0); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ void operator()(FragmentActivations &activations, -+ FragmentScaleBias const &scale_bias) { -+ MmaOperand *ptr_activations = reinterpret_cast(&activations); -+ ScaleBiasOperand const *ptr_scale_bias = -+ reinterpret_cast(&scale_bias); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < (NumActivations / MmaElements); ++i) { -+ transform(ptr_activations[i], ptr_scale_bias[(i / MmaScaleBiasPair) % MmaCols]); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct WgradScaleBiasReluTransform { -+ -+ using T = typename FragmentActivations::Element; -+ -+ static int const NumActivations = FragmentActivations::kElements; -+ static int const NumScaleBias = FragmentScaleBias::kElements; -+ static int const MmaElements = 2; -+ // One element has one scale and one bias -+ static int const MmaScaleBiasPair = 2; -+ // 16816 has 2 rows -+ static int const MmaRows = 2; -+ -+ using MmaOperand = Array; -+ using ScaleBiasOperand = Array<__half2, MmaScaleBiasPair>; -+ -+ CUTLASS_DEVICE -+ void transform(MmaOperand &activations, ScaleBiasOperand const &scale_bias) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+ -+ __half2 *ptr_activations = reinterpret_cast<__half2 *>(&activations); -+ uint32_t const *ptr_scale_bias = reinterpret_cast(&scale_bias); -+ -+#if 1 -+ // CUDA + PTX version -+ -+ bool h1_oob = (reinterpret_cast(ptr_activations[0].x) == cutlass::arch::OOB_NAN_F16); -+ bool h2_oob = (reinterpret_cast(ptr_activations[0].y) == cutlass::arch::OOB_NAN_F16); -+ -+ // Apply per channel scale+bias+relu if the data is not a special NaN -+ // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. -+ -+ // We cannot gurantee that the pair of F16 are both in bound or both -+ // out-of-bound because C x R x S can be an odd number. -+ asm volatile( -+ "{\n\t" -+ " fma.rn.f16x2.relu %0, %1, %2, %3;\n" -+ "}" -+ : "=r"(reinterpret_cast(ptr_activations[0])) -+ : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), -+ "r"(ptr_scale_bias[1])); -+ -+ reinterpret_cast(ptr_activations[0]) = h1_oob ? -+ (reinterpret_cast(ptr_activations[0]) & 0xffff0000) : -+ reinterpret_cast(ptr_activations[0]); -+ -+ reinterpret_cast(ptr_activations[0]) = h2_oob ? -+ (reinterpret_cast(ptr_activations[0]) & 0xffff) : -+ reinterpret_cast(ptr_activations[0]); -+#else -+ // pure PTX version -+ -+ // Apply per channel scale+bias+relu if the data is not a special NaN -+ // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. -+ asm volatile( -+ "{\n" -+ " .reg .b16 t1, t2;\n" -+ " .reg .b32 t3, t4, t5, t6;\n" -+ " .reg .pred p1, p2;\n" -+ " mov.b32 {t1, t2}, %2;\n" -+ " setp.eq.s16 p1, t1, %4;\n" -+ " setp.eq.s16 p2, t2, %4;\n" -+ " fma.rn.f16x2.relu t3, %1, %2, %3;\n" -+ " and.b32 t4, t3, %5;\n" -+ " selp.b32 t5, t4, t3, p1;\n" -+ " and.b32 t6, t5, %6;\n" -+ " selp.b32 %0, t6, t5, p2;\n" -+ "}\n" -+ : "=r"(reinterpret_cast(ptr_activations[0])) -+ : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), -+ "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16), "n"(0xffff0000), "n"(0x0000ffff)); -+#endif -+#else -+ // TODO: write emulation code -+ assert(0); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ void operator()(FragmentActivations &activations, -+ FragmentScaleBias const &scale_bias) { -+ MmaOperand *ptr_activations = reinterpret_cast(&activations); -+ ScaleBiasOperand const *ptr_scale_bias = -+ reinterpret_cast(&scale_bias); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < (NumActivations / MmaElements); ++i) { -+ transform(ptr_activations[i], ptr_scale_bias[(i / MmaRows)]); -+ } -+ } -+}; -+} // namespace warp -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/coord.h b/3rdparty/cutlass/include/cutlass/coord.h -new file mode 100644 -index 0000000..4558385 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/coord.h -@@ -0,0 +1,480 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief A Coord is a coordinate of arbitrary rank into a tensor or matrix -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Statically-sized array specifying Coords within a tensor -+template < -+ int Rank_, ///< Logical rank of coordinate -+ typename Index_ = int, ///< Index type used for each dimension -+ typename LongIndex_ = int64_t ///< Long index type used for linear offsets -+> -+struct Coord { -+ -+public: -+ -+ // -+ // Type and constant definitions -+ // -+ -+ /// Number of elements in Coord -+ static int const kRank = Rank_; -+ -+ /// Index type used to store elements -+ using Index = Index_; -+ -+ /// Type used to represent linear offsets -+ using LongIndex = LongIndex_; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Indices -+ Index idx[kRank]; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor initializes uniformly -+ CUTLASS_HOST_DEVICE -+ explicit Coord(Index value = Index(0)) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] = value; -+ } -+ } -+ -+ /// Constructs from an array of integers -+ CUTLASS_HOST_DEVICE -+ Coord(Index const (&_idx)[kRank]) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] = _idx[i]; -+ } -+ } -+ -+ /// Constructs from some other Coord -+ template -+ CUTLASS_HOST_DEVICE -+ Coord(Coord other) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] = other[i]; -+ } -+ } -+ -+ /// Returns a slice of the Coord which may be larger or smaller in rank -+ /// than this. -+ template -+ CUTLASS_HOST_DEVICE -+ Coord slice(int start = 0, Index identity = 0) const { -+ Coord result; -+ for (int i = 0; i < Slice; ++i) { -+ if (i + start < kRank) { -+ result[i] = idx[i + start]; -+ } -+ else { -+ result[i] = identity; -+ } -+ } -+ return result; -+ } -+ -+ /// Returns the index of the dimension with least value -+ CUTLASS_HOST_DEVICE -+ int min_dim_index() const { -+ int i = 0; -+ for (int j = 1; j < kRank; ++j) { -+ if (idx[j] < idx[i]) { -+ i = j; -+ } -+ } -+ return i; -+ } -+ -+ /// Returns the index of the dimension with greatest value -+ CUTLASS_HOST_DEVICE -+ int max_dim_index() const { -+ int i = 0; -+ for (int j = 1; j < kRank; ++j) { -+ if (idx[j] > idx[i]) { -+ i = j; -+ } -+ } -+ return i; -+ } -+ -+ /// Returns true if Coord is non-zero. -+ CUTLASS_HOST_DEVICE -+ explicit operator bool() const { -+ for (int i = 0; i < kRank; ++i) { -+ if (idx[i]) { -+ return true; -+ } -+ } -+ return false; -+ } -+ -+ /// Returns true if Coord is uniformly zero. -+ CUTLASS_HOST_DEVICE -+ bool operator!() const { -+ for (int i = 0; i < kRank; ++i) { -+ if (idx[i]) { -+ return false; -+ } -+ } -+ return true; -+ } -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ Coord operator+(Coord const& b) const { -+ Coord c; -+ for (int i = 0; i < kRank; ++i) { -+ c.idx[i] = idx[i] + b.idx[i]; -+ } -+ return c; -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ Coord operator-(Coord const& b) const { -+ Coord c; -+ for (int i = 0; i < kRank; ++i) { -+ c.idx[i] = idx[i] - b.idx[i]; -+ } -+ return c; -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ Coord operator*(Coord const& b) const { -+ Coord c; -+ for (int i = 0; i < kRank; ++i) { -+ c.idx[i] = idx[i] * b.idx[i]; -+ } -+ return c; -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ Coord operator/(Coord const& b) const { -+ Coord c; -+ for (int i = 0; i < kRank; ++i) { -+ c.idx[i] = idx[i] / b.idx[i]; -+ } -+ return c; -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ Coord& operator+=(Coord const& b) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] += b.idx[i]; -+ } -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ Coord& operator-=(Coord const& b) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] -= b.idx[i]; -+ } -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ Coord& operator*=(Coord const& b) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] *= b.idx[i]; -+ } -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ Coord& operator/=(Coord const& b) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] /= b.idx[i]; -+ } -+ return *this; -+ } -+ -+ /// Member access operator -+ CUTLASS_HOST_DEVICE Index& operator[](int dim) { return idx[dim]; } -+ -+ /// Member access operator -+ CUTLASS_HOST_DEVICE Index const& operator[](int dim) const { return idx[dim]; } -+ -+ /// Computes the dot product with anotherCoord object -+ CUTLASS_HOST_DEVICE -+ LongIndex dot(Coord const& b, LongIndex sum = LongIndex(0)) const { -+ for (int i = 0; i < kRank; ++i) { -+ sum += idx[i] * b.idx[i]; -+ } -+ return sum; -+ } -+ -+ /// Gets the index of a given Coord element -+ template -+ CUTLASS_HOST_DEVICE Index& at() { -+ return idx[Dim]; -+ } -+ -+ /// Access via index; may limit unrolling potential -+ CUTLASS_HOST_DEVICE -+ Index& at(int dim) { return idx[dim]; } -+ -+ /// Gets the index of a given Coord element -+ template -+ CUTLASS_HOST_DEVICE Index const& at() const { -+ return idx[Dim]; -+ } -+ -+ /// Access via index; may limit unrolling potential -+ CUTLASS_HOST_DEVICE -+ Index const& at(int dim) const { return idx[dim]; } -+ -+ /// Determines if two Coord<> objects are equal -+ CUTLASS_HOST_DEVICE -+ bool operator==(Coord const& b) const { -+ bool equal = true; -+ for (int i = 0; equal && i < kRank; ++i) { -+ equal = (idx[i] == b.idx[i]); -+ } -+ return equal; -+ } -+ -+ /// Not equal -+ CUTLASS_HOST_DEVICE -+ bool operator!=(Coord const& b) const { return !(*this == b); } -+ -+ /// Clamps a coordinate to a range specified by maximum and minimum values -+ CUTLASS_HOST_DEVICE -+ Coord& clamp(Coord const& max, Coord const& min = Coord()) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]); -+ } -+ return *this; -+ } -+ -+ /// Returns the sum of all elements -+ CUTLASS_HOST_DEVICE -+ Index sum() const { -+ Index sum_(idx[0]); -+ for (int i = 1; i < kRank; ++i) { -+ sum_ += idx[i]; -+ } -+ return sum_; -+ } -+ -+ /// Returns the product of all elements -+ CUTLASS_HOST_DEVICE -+ LongIndex product() const { -+ LongIndex product_(idx[0]); -+ for (int i = 1; i < kRank; ++i) { -+ product_ *= idx[i]; -+ } -+ return product_; -+ } -+ -+ /// Less than operator -+ CUTLASS_HOST_DEVICE -+ bool operator<(Coord const &b) const { -+ for (int i = 0; i < kRank; ++i) { -+ if (!(idx[i] < b[i])) { -+ return false; -+ } -+ } -+ return true; -+ } -+ -+ /// Less than or equals operator -+ CUTLASS_HOST_DEVICE -+ bool operator<=(Coord const &b) const { -+ for (int i = 0; i < kRank; ++i) { -+ if (!(idx[i] <= b[i])) { -+ return false; -+ } -+ } -+ return true; -+ } -+ -+ /// Greater than operator -+ CUTLASS_HOST_DEVICE -+ bool operator>(Coord const &b) const { -+ return !(*this <= b); -+ } -+ -+ /// Greater than or equals operator -+ CUTLASS_HOST_DEVICE -+ bool operator>=(Coord const &b) const { -+ return !(*this < b); -+ } -+}; -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+ -+/// Scalar multiplication -+template -+CUTLASS_HOST_DEVICE -+Coord operator*(Index s, Coord coord) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank; ++i) { -+ coord[i] *= s; -+ } -+ return coord; -+} -+ -+/// Scalar multiplication -+template -+CUTLASS_HOST_DEVICE -+Coord operator*(Coord coord, Index s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank; ++i) { -+ coord[i] *= s; -+ } -+ return coord; -+} -+ -+/// Scalar division -+template -+CUTLASS_HOST_DEVICE -+Coord operator/(Index s, Coord coord) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank; ++i) { -+ coord[i] = s / coord[i]; -+ } -+ return coord; -+} -+ -+/// Scalar division -+template -+CUTLASS_HOST_DEVICE -+Coord operator/(Coord coord, Index s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank; ++i) { -+ coord[i] /= s; -+ } -+ return coord; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Integer-valued make_Coord -+// -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to make a 2-element coordinate -+template -+CUTLASS_HOST_DEVICE -+Coord<1, T> make_Coord(T _0) { -+ T values[1] = {_0}; -+ return Coord<1, T>(values); -+} -+ -+/// Helper to make a 2-element coordinate -+template -+CUTLASS_HOST_DEVICE -+Coord<2, T> make_Coord(T _0, T _1) { -+ T values[2] = {_0, _1}; -+ return Coord<2, T>(values); -+} -+ -+/// Helper to make a 3-element coordinate -+template -+CUTLASS_HOST_DEVICE -+Coord<3, T> make_Coord(T _0, T _1, T _2) { -+ T values[3] = {_0, _1, _2}; -+ return Coord<3, T>(values); -+} -+ -+/// Helper to make a 4-element coordinate -+template -+CUTLASS_HOST_DEVICE -+Coord<4, T> make_Coord(T _0, T _1, T _2, T _3) { -+ T values[4] = {_0, _1, _2, _3}; -+ return Coord<4, T>(values); -+} -+ -+/// Helper to make a 5-element coordinate -+template -+CUTLASS_HOST_DEVICE -+Coord<5, T> make_Coord(T _0, T _1, T _2, T _3, T _4) { -+ T values[5] = {_0, _1, _2, _3, _4}; -+ return Coord<5, T>(values); -+} -+ -+/// Helper to make a 1-element coordinate -+template -+CUTLASS_HOST_DEVICE -+Coordmake_Coord_with_padding(T _0) { -+ Coord coord; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = N - 1; i > 0; --i) { -+ coord[i] = 0; -+ } -+ -+ coord[0] = _0; -+ -+ return coord; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/core_io.h b/3rdparty/cutlass/include/cutlass/core_io.h -new file mode 100644 -index 0000000..4d15432 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/core_io.h -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Helpers for printing cutlass/core objects -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Output operator for CUDA built-in dim3 type -+inline std::ostream &operator<<(std::ostream &out, dim3 d) { -+ return out << d.x << ", " << d.y << ", " << d.z; -+} -+ -+/// Output operator for CUDA built-in error type -+inline std::ostream &operator<<(std::ostream &out, cudaError_t error) { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ return out << cudaGetErrorString(error); -+#endif -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// stream operators for cutlass namespace // -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+inline -+std::ostream& operator<<(std::ostream& out, Array const& v) { -+ for (int i = 0; i < Rank; ++i) { -+ out << (i ? ", " : "") << v[i]; -+ } -+ return out; -+} -+ -+template -+inline -+std::ostream& operator<<(std::ostream& out, Coord const& coord) { -+ for (int i = 0; i < Rank; ++i) { -+ out << (i ? ", " : "") << coord[i]; -+ } -+ return out; -+} -+ -+inline -+std::istream & operator>>(std::istream &stream, half_t &x) { -+ float tmp; -+ stream >> tmp; -+ x = static_cast(tmp); -+ return stream; -+} -+ -+inline -+std::ostream & operator<<(std::ostream &out, half_t const &x) { -+ return out << float(x); -+} -+ -+inline -+std::ostream & operator<<(std::ostream &out, bfloat16_t const &x) { -+ return out << float(x); -+} -+ -+inline -+std::ostream & operator<<(std::ostream &out, tfloat32_t const &x) { -+ return out << float(x); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to enable formatted printing of CUTLASS scalar types to an ostream -+template -+struct ScalarIO { -+ -+ /// Value to print -+ T value; -+ -+ /// Default ctor -+ ScalarIO() { } -+ -+ /// Constructs from a value -+ ScalarIO(T value): value(value) {} -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Default printing to ostream -+template -+inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { -+ return out << scalar.value; -+} -+ -+/// Printing to ostream of int8_t as integer rather than character -+template <> -+inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { -+ return out << int(scalar.value); -+} -+ -+/// Printing to ostream of uint8_t as integer rather than character -+template <> -+inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { -+ return out << unsigned(scalar.value); -+} -+ -+ -+/// Default printing to ostream for MatrixShape -+template -+inline -+std::ostream & operator<<(std::ostream &out, MatrixShape const &matrix_shape) { -+ out << "cutlass::MatrixShape::(kRow, kColumn) {" -+ << cutlass::MatrixShape::kRow <<"," -+ << cutlass::MatrixShape::kColumn <<"}"; -+ return out; -+} -+ -+ -+/// Prints matrix to ostream -+template -+std::ostream & operator<<(std::ostream &out, Matrix const &rhs) { -+ -+ for (int i = 0; i < Rows; ++i) { -+ for (int j = 0; j < Columns; ++j) { -+ ScalarIO element(rhs.at(i, j)); -+ out << (j ? ", " : "") << element; -+ } -+ out << "\\n"; -+ } -+ -+ return out; -+} -+ -+template -+std::ostream &operator<<(std::ostream &out, Quaternion const &rhs) { -+ -+ out << ScalarIO(rhs.w()) << " "; -+ if (rhs.x() >= 0) { -+ out << "+"; -+ } -+ -+ out << ScalarIO(rhs.x()) << "*i "; -+ if (rhs.y() >= 0) { -+ out << "+"; -+ } -+ -+ out << ScalarIO(rhs.y()) << "*j "; -+ if (rhs.z() >= 0) { -+ out << "+"; -+ } -+ -+ out << ScalarIO(rhs.z()) << "*k"; -+ -+ return out; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// stream operators for cutlass::gemm namespace // -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+namespace gemm { -+ -+/// Default printing to ostream for GemmShape -+template -+inline -+std::ostream & operator<<(std::ostream &out, GemmShape const &gemm_shape) { -+ out << "cutlass::gemm::GemmShape::(kM, kN, kK) {" -+ << cutlass::gemm::GemmShape::kM <<"," -+ << cutlass::gemm::GemmShape::kN <<"," -+ << cutlass::gemm::GemmShape::kK << "}"; -+ return out; -+} -+ -+/// Default printing to ostream for GemmCoord -+inline -+std::ostream & operator<<(std::ostream &out, GemmCoord const &gemm_coord) { -+ out << "cutlass::gemm::GemmCoord {" -+ << gemm_coord.m() <<"," -+ << gemm_coord.n() <<"," -+ << gemm_coord.k() << "}"; -+ return out; -+} -+ -+} //namespace gemm -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// stream operators for cutlass namespace // -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Default printing to ostream for PitchLinearShape -+template < int Contiguous, int Strided> -+inline -+std::ostream & operator<<(std::ostream &out, PitchLinearShape const &pitch_linear_shape) { -+ out << "cutlass::PitchLinearShape:(kContiguous, kStrided) {" -+ << cutlass::layout::PitchLinearShape::kContiguous <<"," -+ << cutlass::layout::PitchLinearShape::kStrided <<"}"; -+ return out; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// stream operators for cutlass::conv namespace // -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+namespace conv { -+/// Default printing to ostream for Conv2dProblemSize -+inline -+std::ostream& operator<<(std::ostream& out, Conv2dProblemSize const& problem) { -+ out << "NHWC: (" << problem.N << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl -+ << "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C / problem.groups << ")" << std::endl -+ << "NPQK: (" << problem.N << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl -+ << "groups: (" << problem.groups << ")" << std::endl -+ << "Pad_h, Pad_w: (" << problem.pad_h << ", " << problem.pad_w << ")" << std::endl -+ << "Stride_h, Stride_w: (" << problem.stride_h << ", " << problem.stride_w << ")" << std::endl -+ << "Dilation_h, Dilation_w: (" << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl -+ << "split_k_slices: (" << problem.split_k_slices << ")" << std::endl -+ << "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")"; -+ -+ return out; -+} -+ -+ -+/// Default printing to ostream for Conv3dProblemSize -+inline -+std::ostream& operator<<(std::ostream& out, Conv3dProblemSize const& problem) { -+ out << "NDHWC: (" << problem.N << ", " << problem.D << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl -+ << "KTRSC: (" << problem.K << ", " << problem.T << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl -+ << "NZPQK: (" << problem.N << ", " << problem.Z << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl -+ << "pad_d, pad_h, pad_w: (" << problem.pad_d << ", " << problem.pad_h << ", " << problem.pad_w << ")" << std::endl -+ << "stride_d, stride_h, stride_w: (" << problem.stride_d << ", " << problem.stride_h << ", " << problem.stride_w << ")" << std::endl -+ << "dilation_d, dilation_h, dilation_w: (" << problem.dilation_d << ", " << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl -+ << "split_k_slices: (" << problem.split_k_slices << ") " << std::endl -+ << "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")"; -+ -+ return out; -+} -+ -+} // namespace conv -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/cutlass.h b/3rdparty/cutlass/include/cutlass/cutlass.h -new file mode 100644 -index 0000000..12bc3a3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/cutlass.h -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Basic include for CUTLASS. -+*/ -+ -+#pragma once -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#ifdef CUTLASS_NAMESPACE -+#define concat_tok(a, b) a ## b -+#define mkcutlassnamespace(pre, ns) concat_tok(pre, ns) -+#define cutlass mkcutlassnamespace(cutlass_, CUTLASS_NAMESPACE) -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -+#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ -+#define CUTLASS_DEVICE __forceinline__ __device__ -+#elif defined(__CUDACC_RTC__) -+#define CUTLASS_HOST_DEVICE __forceinline__ __device__ -+#define CUTLASS_DEVICE __forceinline__ __device__ -+#else -+#define CUTLASS_HOST_DEVICE inline -+#define CUTLASS_DEVICE inline -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) -+{ } -+ -+#if defined(__GNUC__) -+ #define CUTLASS_UNUSED(expr) __CUTLASS_UNUSED(expr) -+#else -+ #define CUTLASS_UNUSED(expr) do { ; } while (&expr != &expr) -+#endif -+ -+#if !defined(__CUDACC_RTC__) -+ -+#include -+ -+ #if defined(__CUDA_ARCH__) -+ #if defined(_MSC_VER) -+ #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __FUNCSIG__); asm volatile ("brkpt;\n"); } -+ #else -+ #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __PRETTY_FUNCTION__); asm volatile ("brkpt;\n"); } -+ #endif -+ -+ #else -+ #if defined(_MSC_VER) -+ #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) -+ #else -+ #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) -+ #endif -+ #endif -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+/// Status code returned by CUTLASS operations -+enum class Status { -+ kSuccess, ///< Operation was successful. -+ kErrorMisalignedOperand, ///< operands fail alignment requirements. -+ kErrorInvalidDataType, ///< DataType fails requirement. -+ kErrorInvalidLayout, ///< Layout fails alignment requirement. -+ kErrorInvalidProblem, ///< Specified problem size is not supported by operator. -+ kErrorNotSupported, ///< Operation is not supported on current device. -+ kErrorWorkspaceNull, ///< The given workspace is null when it is required to be non-null. -+ kErrorInternal, ///< An error within CUTLASS occurred. -+ kErrorArchMismatch, ///< CUTLASS runs on a device that it was not compiled for. -+ kErrorInsufficientDriver, ///< CUTLASS runs with a driver that is too old. -+ kErrorMemoryAllocation, ///< Kernel launch failed due to insufficient device memory. -+ kInvalid ///< Status is unspecified. -+}; -+ -+/// Convert cutlass status to status strings -+CUTLASS_HOST_DEVICE -+static char const* cutlassGetStatusString(cutlass::Status status) { -+ switch (status) { -+ case cutlass::Status::kSuccess: -+ return "Success"; -+ case cutlass::Status::kErrorMisalignedOperand: -+ return "Error Misaligned Operand"; -+ case cutlass::Status::kErrorInvalidDataType: -+ return "Error Invalid Data Type"; -+ case cutlass::Status::kErrorInvalidLayout: -+ return "Error Invalid Layout"; -+ case cutlass::Status::kErrorInvalidProblem: -+ return "Error Invalid Problem"; -+ case cutlass::Status::kErrorNotSupported: -+ return "Error Not Supported"; -+ case cutlass::Status::kErrorWorkspaceNull: -+ return "Error Workspace Null"; -+ case cutlass::Status::kErrorInternal: -+ return "Error Internal"; -+ case cutlass::Status::kErrorInsufficientDriver: -+ return "Error Insufficient Driver"; -+ case cutlass::Status::kErrorArchMismatch: -+ return "Error Architecture Mismatch"; -+ case cutlass::Status::kErrorMemoryAllocation: -+ return "Error Memory Allocation failed"; -+ case cutlass::Status::kInvalid: break; -+ } -+ -+ return "Invalid status"; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -+#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0 -+#endif -+ -+ -+// CUDA 10.1 introduces the mma instruction -+#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) -+#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0 -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#define CUTLASS_ASSERT(x) assert(x) -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler. -+#if defined(__CUDA_ARCH__) -+ #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) -+ #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") -+ #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1") -+ #else -+ #define CUTLASS_PRAGMA_UNROLL #pragma unroll -+ #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1 -+ #endif -+ -+ #define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL -+ -+#else -+ -+ #define CUTLASS_PRAGMA_UNROLL -+ #define CUTLASS_PRAGMA_NO_UNROLL -+ #define CUTLASS_GEMM_LOOP -+ -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static const int NumThreadsPerWarp = 32; -+static const int NumThreadsPerWarpGroup = 128; -+static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; -+static const int NumThreadsPerQuad = 4; -+static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper function to return true when called by thread 0 of threadblock 0. -+CUTLASS_HOST_DEVICE bool thread0() { -+ #if defined(__CUDA_ARCH__) -+ return (!threadIdx.x && !threadIdx.y && !threadIdx.z) && (!blockIdx.x && !blockIdx.y && !blockIdx.z); -+ #else -+ return false; -+ #endif -+} -+ -+/// Returns a warp-uniform value indicating the canonical warp index of the calling threads. -+/// Threads within the warp must be converged. -+CUTLASS_DEVICE -+int canonical_warp_idx() { -+ #if defined(__CUDA_ARCH__) -+ return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarp, 0); -+ #else -+ return 0; -+ #endif -+} -+ -+/// Returns a warp-uniform value indicating the canonical warp group index of the calling threads. -+/// Threads within the warp must be converged. -+CUTLASS_DEVICE -+int canonical_warp_group_idx() { -+ #if defined(__CUDA_ARCH__) -+ return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarpGroup, 0); -+ #else -+ return 0; -+ #endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/device_kernel.h b/3rdparty/cutlass/include/cutlass/device_kernel.h -new file mode 100644 -index 0000000..68042e3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/device_kernel.h -@@ -0,0 +1,113 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for generic CUTLASS kernel. -+*/ -+ -+#pragma once -+ -+// __grid_constant__ was introduced in CUDA 11.7. -+#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) -+# define CUTLASS_GRID_CONSTANT_SUPPORTED -+#endif -+ -+// __grid_constant__ can be enabled only on SM70+ -+#if defined(CUTLASS_GRID_CONSTANT_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) -+# define CUTLASS_GRID_CONSTANT_ENABLED -+#endif -+ -+#if ! defined(CUTLASS_GRID_CONSTANT) -+# if defined(CUTLASS_GRID_CONSTANT_ENABLED) -+# define CUTLASS_GRID_CONSTANT __grid_constant__ -+# else -+# define CUTLASS_GRID_CONSTANT -+# endif -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Generic CUTLASS kernel template. -+template -+__global__ -+void Kernel(typename Operator::Params params) { -+ // Dynamic shared memory base pointer -+ extern __shared__ int SharedStorageBase[]; -+ -+ // Declare pointer to dynamic shared memory. -+ typename Operator::SharedStorage *shared_storage = -+ reinterpret_cast(SharedStorageBase); -+ -+ Operator op; -+ -+ op(params, *shared_storage); -+} -+ -+ -+/// Generic CUTLASS kernel template. -+template -+__global__ -+void Kernel2(typename Operator::Params params) { -+ // Dynamic shared memory base pointer -+ extern __shared__ int SharedStorageBase[]; -+ -+ // Declare pointer to dynamic shared memory. -+ typename Operator::SharedStorage *shared_storage = -+ reinterpret_cast(SharedStorageBase); -+ -+ Operator::invoke(params, *shared_storage); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// 3.0 specific launch -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Generic CUTLASS kernel template. -+template -+__global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) -+void device_kernel(CUTLASS_GRID_CONSTANT typename Operator::Params const params) -+{ -+ // Dynamic shared memory base pointer -+ extern __shared__ char smem[]; -+ -+ Operator op; -+ op(params, smem); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+} /// namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp b/3rdparty/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp -new file mode 100644 -index 0000000..5b1b924 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp -@@ -0,0 +1,49 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Redistribution and use in source and binary forms, with or without modification, are permitted -+ * provided that the following conditions are met: -+ * * Redistributions of source code must retain the above copyright notice, this list of -+ * conditions and the following disclaimer. -+ * * Redistributions in binary form must reproduce the above copyright notice, this list of -+ * conditions and the following disclaimer in the documentation and/or other materials -+ * provided with the distribution. -+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used -+ * to endorse or promote products derived from this software without specific prior written -+ * permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::epilogue::collective { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class DispatchPolicy, -+ class... Args -+> -+struct CollectiveEpilogue { -+ static_assert(std::is_void_v, "Could not find an epilogue specialization."); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::epilogue::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "default_epilogue.hpp" -+#include "epilogue.hpp" -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp b/3rdparty/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp -new file mode 100644 -index 0000000..71499b5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp -@@ -0,0 +1,195 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing elementwise operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cute/tensor.hpp" -+#include "cute/numeric/int.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace collective { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies an element wise operation to all elements within the fragment -+/// and writes them out to destination storage. -+template < -+ class StrideC_, -+ class StrideD_, -+ class ThreadEpilogueOp_ -+> -+class DefaultEpilogue { -+public: -+ // -+ // Type Aliases -+ // -+ // derived types of output thread level operator -+ using ThreadEpilogueOp = ThreadEpilogueOp_; -+ using ElementOutput = typename ThreadEpilogueOp::ElementOutput; -+ using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; -+ using ElementCompute = typename ThreadEpilogueOp::ElementCompute; -+ using ElementScalar = ElementCompute; -+ using ElementC = typename ThreadEpilogueOp::ElementC; -+ using StrideC = StrideC_; -+ using ElementD = typename ThreadEpilogueOp::ElementD; -+ using StrideD = StrideD_; -+ -+ static const int kOutputAlignment = ThreadEpilogueOp::kCount; -+ using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; -+ -+ static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ -+ struct SharedStorage { }; -+ -+ // Params of epilogue::collective contain the epilogue::thread params -+ struct Params { -+ ElementC const* ptr_C = nullptr; -+ StrideC dC{}; -+ ElementD* ptr_D = nullptr; -+ StrideD dD{}; -+ typename ThreadEpilogueOp::Params thread_params{}; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.epilogue_params}; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ DefaultEpilogue(Params const& params_) : params(params_) { } -+ -+ template< -+ class ProblemShapeMNKL, -+ class BlockShapeMNK, -+ class BlockCoordMNKL, -+ class FrgEngine, class FrgLayout, -+ class TiledMma, -+ class ResidueMNK -+ > -+ CUTLASS_HOST_DEVICE void -+ operator()( -+ ProblemShapeMNKL problem_shape_mnkl, -+ BlockShapeMNK blk_shape_MNK, -+ BlockCoordMNKL blk_coord_mnkl, -+ cute::Tensor const& accumulators, -+ TiledMma tiled_mma, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char* smem_buf) -+ { -+ using namespace cute; -+ using X = Underscore; -+ -+ static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); -+ static_assert(is_static::value, "ThreadBlock tile shape must be static"); -+ static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); -+ static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); -+ -+ (void) smem_buf; -+ ThreadEpilogueOp epilogue_op{params.thread_params}; -+ -+ // Separate out problem shape for convenience -+ auto M = get<0>(problem_shape_mnkl); -+ auto N = get<1>(problem_shape_mnkl); -+ auto L = get<3>(problem_shape_mnkl); -+ -+ // Represent the full output tensor -+ Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) -+ Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) -+ Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) -+ Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) -+ -+ // Slice to get the tile this CTA is responsible for -+ auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; -+ Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) -+ Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) -+ -+ // Partition source and destination tiles to match the accumulator partitioning -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) -+ Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) -+ -+ static_assert(is_static::value, "Accumulator layout must be static"); -+ CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), -+ "Source and destination must have the same number of elements."); -+ CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), -+ "Accumulator count must have the same destination element count."); -+ -+ // Make an identity coordinate tensor for predicating our output MN tile -+ auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); -+ Tensor tCcD = thr_mma.partition_C(cD); -+ -+ // source is needed -+ if (epilogue_op.is_source_needed()) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < size(accumulators); ++i) { -+ if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { -+ tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); -+ } -+ } -+ } -+ // source is not needed, avoid load -+ else { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < size(accumulators); ++i) { -+ if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { -+ tCgD(i) = epilogue_op(accumulators(i)); -+ } -+ } -+ } -+ } -+ -+private: -+ Params params; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace collective -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/collective/default_transposed_epilogue.hpp b/3rdparty/cutlass/include/cutlass/epilogue/collective/default_transposed_epilogue.hpp -new file mode 100644 -index 0000000..7e38acd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/collective/default_transposed_epilogue.hpp -@@ -0,0 +1,203 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing elementwise operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cute/tensor.hpp" -+#include "cute/numeric/int.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace collective { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies an element wise operation to all elements within the fragment -+/// and writes them out to destination storage. -+template < -+ class StrideC_, -+ class StrideD_, -+ class ThreadEpilogueOp_ -+> -+class DefaultTransposedEpilogue { -+ -+public: -+ // -+ // Type Aliases -+ // -+ // derived types of output thread level operator -+ using ThreadEpilogueOp = ThreadEpilogueOp_; -+ using ElementOutput = typename ThreadEpilogueOp::ElementOutput; -+ using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; -+ using ElementCompute = typename ThreadEpilogueOp::ElementCompute; -+ using ElementScalar = ElementCompute; -+ using ElementC = typename ThreadEpilogueOp::ElementC; -+ using StrideC = StrideC_; -+ using ElementD = typename ThreadEpilogueOp::ElementD; -+ using StrideD = StrideD_; -+ -+ static const int kOutputAlignment = ThreadEpilogueOp::kCount; -+ using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; -+ -+ static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ -+ struct SharedStorage { }; -+ -+ // Params of epilogue::collective contain the epilogue::thread params -+ struct Params { -+ ElementC const* ptr_C = nullptr; -+ StrideC dC{}; -+ ElementD* ptr_D = nullptr; -+ StrideD dD{}; -+ typename ThreadEpilogueOp::Params thread_params{}; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.epilogue_params}; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ DefaultTransposedEpilogue(Params const& params_) : params(params_) { } -+ -+ template< -+ class ProblemShapeMNKL, -+ class BlockShapeMNK, -+ class BlockCoordMNKL, -+ class FrgEngine, class FrgLayout, -+ class TiledMma, -+ class ResidueMNK -+ > -+ CUTLASS_HOST_DEVICE void -+ operator()( -+ ProblemShapeMNKL problem_shape_mnkl, -+ BlockShapeMNK blk_shape_MNK, -+ BlockCoordMNKL blk_coord_mnkl, -+ cute::Tensor const& accumulators, -+ TiledMma tiled_mma, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char* smem_buf) -+ { -+ using namespace cute; -+ using X = Underscore; -+ -+ static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); -+ static_assert(is_static::value, "ThreadBlock tile shape must be static"); -+ static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); -+ static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); -+ -+ (void) smem_buf; -+ ThreadEpilogueOp epilogue_op{params.thread_params}; -+ -+ // Separate out problem shape for convenience -+ auto M = get<0>(problem_shape_mnkl); -+ auto N = get<1>(problem_shape_mnkl); -+ auto L = get<3>(problem_shape_mnkl); -+ -+ // Tranpose stride C/D. -+ auto stride_c = make_stride(get<1>(params.dC), get<0>(params.dC), get<2>(params.dC)); -+ auto stride_d = make_stride(get<1>(params.dD), get<0>(params.dD), get<2>(params.dD)); -+ -+ // Represent the full output tensor -+ Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) -+ Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) -+ Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) -+ Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) -+ -+ // Slice to get the tile this CTA is responsible for -+ auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; -+ Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) -+ Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) -+ -+ // Partition source and destination tiles to match the accumulator partitioning -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) -+ Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) -+ -+ static_assert(is_static::value, "Accumulator layout must be static"); -+ CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), -+ "Source and destination must have the same number of elements."); -+ CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), -+ "Accumulator count must have the same destination element count."); -+ -+ auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); -+ Tensor tCcD = thr_mma.partition_C(cD); -+ -+ // source is needed -+ if (epilogue_op.is_source_needed()) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < size(accumulators); ++i) { -+ if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { -+ tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); -+ } -+ } -+ } -+ // source is not needed, avoid load -+ else { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < size(accumulators); ++i) { -+ if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { -+ tCgD(i) = epilogue_op(accumulators(i)); -+ } -+ } -+ } -+ } -+ -+private: -+ Params params; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace collective -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/collective/epilogue.hpp b/3rdparty/cutlass/include/cutlass/epilogue/collective/epilogue.hpp -new file mode 100644 -index 0000000..565e752 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/collective/epilogue.hpp -@@ -0,0 +1,322 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing elementwise operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cute/tensor.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace collective { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies an element wise operation to all elements within the fragment -+/// and writes it out to destination storage. -+/// -+/// Ways to generalize this: -+/// - CTA tile shape -+/// - vectorization requirements (GMEM) -+/// - vectoriz(able) transform() -+/// -+template < -+ class StrideC_, -+ class StrideD_, -+ class ThreadEpilogueOp_, -+ class SmemLayout_, -+ class CopyAtomR2S_, -+ class TiledCopyS2R_, -+ class CopyAtomR2G_ -+> -+class Epilogue { -+public: -+ // -+ // Type Aliases -+ // -+ // derived types of output thread level operator -+ using ThreadEpilogueOp = ThreadEpilogueOp_; -+ using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; -+ using ElementCompute = typename ThreadEpilogueOp::ElementCompute; -+ using ElementScalar = ElementCompute; -+ using ElementOutput = typename ThreadEpilogueOp::ElementOutput; -+ using ElementC = typename ThreadEpilogueOp::ElementC; -+ using StrideC = StrideC_; -+ using ElementD = typename ThreadEpilogueOp::ElementD; -+ using StrideD = StrideD_; -+ -+ using SmemLayout = SmemLayout_; -+ using CopyAtomR2S = CopyAtomR2S_; -+ using TiledCopyS2R = TiledCopyS2R_; -+ using CopyAtomR2G = CopyAtomR2G_; -+ -+ static const int kOutputAlignment = ThreadEpilogueOp::kCount; -+ using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; -+ -+ static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_epilogue; -+ }; -+ -+ // Params of epilogue::collective contain the epilogue::thread params -+ struct Params { -+ ElementC const* ptr_C = nullptr; -+ StrideC dC{}; -+ ElementD* ptr_D = nullptr; -+ StrideD dD{}; -+ typename ThreadEpilogueOp::Params thread_params{}; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.epilogue_params}; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Epilogue(Params const& params_) : params(params_) { }; -+ -+ template< -+ class ProblemShapeMNKL, -+ class BlockShapeMNK, -+ class BlockCoordMNKL, -+ class FrgEngine, class FrgLayout, -+ class TiledMma, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator()( -+ ProblemShapeMNKL problem_shape_mnkl, -+ BlockShapeMNK blk_shape_MNK, -+ BlockCoordMNKL blk_coord_mnkl, -+ cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) -+ TiledMma tiled_mma, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char* smem_buf) -+ { -+ using namespace cute; -+ using X = Underscore; -+ -+ static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); -+ static_assert(is_static::value, "ThreadBlock tile shape must be static"); -+ static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); -+ static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); -+ -+ // synchronizing function for smem reads/writes -+#if CUDA_BARRIER_ENABLED -+ auto synchronize = [] () { NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, 0); }; -+#else -+ auto synchronize = [] () { __syncthreads(); }; -+#endif -+ -+ ThreadEpilogueOp epilogue_op{this->params.thread_params}; -+ -+ // Separate out problem shape for convenience -+ auto M = get<0>(problem_shape_mnkl); -+ auto N = get<1>(problem_shape_mnkl); -+ auto L = get<3>(problem_shape_mnkl); -+ -+ // Represent the full output tensor -+ Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) -+ Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) -+ Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) -+ Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) -+ -+ // Slice to get the tile this CTA is responsible for -+ auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; -+ Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) -+ Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) -+ -+ // Construct a tensor in SMEM that we can partition for rearranging data -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sC = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) -+ -+ // Partition sC to match the accumulator partitioning -+ auto tC = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma).get_thread_slice(thread_idx); -+ Tensor tCaC = tC.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) -+ Tensor tCsC = tC.partition_D(sC); // ((Atom,AtomNum),PIPE_M,PIPE_N) -+ -+ // Tile gD and gC by the shape of SmemLayout first -+ auto tile = make_shape(size<0>(sC), size<1>(sC)); -+ Tensor gCt = local_tile(gC, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) -+ Tensor gDt = local_tile(gD, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) -+ -+ // Partition sC, gC, and gD for the output -+ auto tD = TiledCopyS2R{}.get_thread_slice(thread_idx); -+ Tensor tDsC = tD.partition_S(sC); // ((Atom,AtomNum),ATOM_M,ATOM_N) -+ Tensor tDgC = tD.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) -+ Tensor tDgD = tD.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) -+ -+ // Allocate intermediate registers on the dst tensors -+ Tensor tDrC = make_tensor(take<0,3>(shape(tDgC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) -+ Tensor tDrD = make_tensor(shape(tDrC)); // ((Atom,AtomNum),ATOM_M,ATOM_N) -+ -+ // Repeat the D-partitioning for coordinates and predication -+ Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) -+ Tensor cDt = local_tile(cD, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) -+ Tensor tDcD = tD.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) -+ -+ CUTE_STATIC_ASSERT(size<1>(tCaC) % size<3>(tDgC) == 0); // TILE_M divides MMA_M -+ CUTE_STATIC_ASSERT(size<2>(tCaC) % size<4>(tDgC) == 0); // TILE_N divides MMA_N -+ CUTE_STATIC_ASSERT(typename TiledCopyS2R::TiledNumThr{} == size<0>(typename TiledMma::AtomLayoutC_TV{})); -+ -+#if 0 -+ if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { -+ print("aC : "); print(accumulators.layout()); print("\n"); -+ print("gC : "); print(gC.layout()); print("\n"); -+ print("gD : "); print(gD.layout()); print("\n"); -+ print("sC : "); print(sC.layout()); print("\n"); -+ print("\n"); -+ print("tCsC : "); print(tCsC.layout()); print("\n"); -+ print("tCaC : "); print(tCaC.layout()); print("\n"); -+ print("\n"); -+ print("gDt : "); print(gDt.layout()); print("\n"); -+ print("tDsC : "); print(tDsC.layout()); print("\n"); -+ print("tDrC : "); print(tDrC.layout()); print("\n"); -+ print("\n"); -+ print("tDrD : "); print(tDrD.layout()); print("\n"); -+ print("tDgC : "); print(tDgC.layout()); print("\n"); -+ print("tDgD : "); print(tDgD.layout()); print("\n"); -+ print("\n"); -+ } -+#endif -+ -+ // For each tiling needed for SmemLayout to cover shape(gD) -+ CUTLASS_PRAGMA_UNROLL -+ for (int step_m = 0; step_m < size<2>(cDt); ++step_m) -+ { -+ CUTLASS_PRAGMA_UNROLL -+ for (int step_n = 0; step_n < size<3>(cDt); ++step_n) -+ { -+ // Step 1. Copy to SMEM -+ CUTLASS_PRAGMA_UNROLL -+ for (int pipe_m = 0; pipe_m < size<1>(tCsC); ++pipe_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int pipe_n = 0; pipe_n < size<2>(tCsC); ++pipe_n) { -+ int mma_m = step_m * size<1>(tCsC) + pipe_m; -+ int mma_n = step_n * size<2>(tCsC) + pipe_n; -+ -+ copy(tC, tCaC(_,mma_m,mma_n), tCsC(_,pipe_m,pipe_n)); -+ } -+ } -+ -+ // Step 2. Wait for SMEM writes to complete -+ synchronize(); -+ -+ // Step 3. Copy from SMEM into a fragment -+ copy(tD, tDsC, tDrC); -+ -+ // Step 4. Wait for SMEM reads to complete -+ synchronize(); -+ -+ Tensor tDgDmn = tDgD(_,_,_,step_m,step_n); -+ Tensor tDcDmn = tDcD(_,_,_,step_m,step_n); -+ -+ if (epilogue_op.is_source_needed()) { -+ // source is needed -+ Tensor tDgCmn = tDgC(_,_,_,step_m,step_n); -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < size<1>(tDgDmn); ++m) -+ { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < size<2>(tDgDmn); ++n) -+ { -+ // Predication -+ if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && -+ get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) -+ { -+ // Step 5. Elementwise operation with conversion -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < size<0>(tDrC); ++i) { -+ tDrD(i,m,n) = epilogue_op(tDrC(i,m,n), tDgCmn(i,m,n)); -+ } -+ // Step 6. Copy to GMEM -+ copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); -+ } -+ } -+ } -+ } -+ else { -+ // source is not needed, avoid load and lift compute -+ -+ // Step 5. Elementwise operation with conversion -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < size(tDrC); ++i) { -+ tDrD(i) = epilogue_op(tDrC(i)); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < size<1>(tDgDmn); ++m) -+ { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < size<2>(tDgDmn); ++n) -+ { -+ // Predication -+ if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && -+ get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) -+ { -+ // Step 6. Copy to GMEM -+ copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+private: -+ Params params; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace collective -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/dispatch_policy.hpp b/3rdparty/cutlass/include/cutlass/epilogue/dispatch_policy.hpp -new file mode 100644 -index 0000000..de318d5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/dispatch_policy.hpp -@@ -0,0 +1,39 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Redistribution and use in source and binary forms, with or without modification, are permitted -+ * provided that the following conditions are met: -+ * * Redistributions of source code must retain the above copyright notice, this list of -+ * conditions and the following disclaimer. -+ * * Redistributions in binary form must reproduce the above copyright notice, this list of -+ * conditions and the following disclaimer in the documentation and/or other materials -+ * provided with the distribution. -+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used -+ * to endorse or promote products derived from this software without specific prior written -+ * permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::epilogue { -+ -+////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Collective Epilogue Policies -+// -+ -+////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::epilogue -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/activation.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/activation.h -new file mode 100644 -index 0000000..484f2cc ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/activation.h -@@ -0,0 +1,705 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This extends the contents of cutlass/functional.h with frequently used activation functions. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/constants.h" -+#include "cutlass/complex.h" -+#include "cutlass/array.h" -+#include "cutlass/half.h" -+#include "cutlass/functional.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+struct LinearCombinationGenericParams { -+ T alpha; ///< scales accumulators -+ T beta; ///< scales source tensor -+ T const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ T const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ LinearCombinationGenericParams(): -+ alpha(T(1)), -+ beta(T(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ LinearCombinationGenericParams( -+ T alpha, -+ T beta = T(0) -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ LinearCombinationGenericParams( -+ T const *alpha_ptr, -+ T const *beta_ptr = nullptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Identity operator -+template -+struct Identity { -+ static const bool kIsHeavy=false; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T value) const { -+ return value; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+template -+struct Identity > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ return value; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+/// ReLu operator - propagates NaNs -+/// Always put threshold in the right hand side of max to propagate NaN. -+template -+struct ReLu { -+ static const bool kIsHeavy=false; -+ CUTLASS_HOST_DEVICE -+ T operator()(T const & threshold, T value) const { -+ maximum mx; -+ -+ return mx(value, threshold); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T value) const { -+ maximum mx; -+ -+ return mx(value, T(0)); -+ } -+ -+ /// Host-constructable parameters structure -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+template -+struct ReLu> { -+ static const bool kIsHeavy=false; -+ CUTLASS_HOST_DEVICE -+ Array operator()(T const & threshold, Array const &frag) const { -+ maximum > mx; -+ -+ return mx(frag, threshold); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &frag) const { -+ maximum > mx; -+ return mx(frag, T(0)); -+ } -+ -+ /// Host-constructable parameters structure -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &frag, Params const ¶ms_) const { -+ return this->operator()(frag); -+ } -+}; -+ -+// Leaky Relu operator -+template -+struct LeakyReLU { -+ -+ struct Params: LinearCombinationGenericParams { -+ T leaky_alpha; ///< leaky_alpha -+ -+ // Methods -+ using LinearCombinationGenericParams::LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ LinearCombinationGenericParams(), -+ leaky_alpha(T(1)) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ T alpha, -+ T beta, -+ T leaky_alpha = T(1) -+ ): LinearCombinationGenericParams(alpha, beta), leaky_alpha(leaky_alpha) {} -+ }; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &value, T const & alpha_recip) const { -+ T res = value > T(0) ? value : value * alpha_recip; -+ return res; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &value, Params const ¶ms_) const { -+ this->operator()(value, params_.leaky_alpha); -+ } -+}; -+ -+template -+struct LeakyReLU > { -+ -+ struct Params: LinearCombinationGenericParams { -+ T leaky_alpha; ///< leaky_alpha -+ using LinearCombinationGenericParams::LinearCombinationGenericParams; -+ -+ // Methods -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ LinearCombinationGenericParams(), -+ leaky_alpha(T(1)) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ T alpha, -+ T beta, -+ T leaky_alpha = T(1) -+ ): LinearCombinationGenericParams(alpha, beta), leaky_alpha(leaky_alpha) {} -+ }; -+ -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, T const & alpha_recip) const { -+ Array y; -+ LeakyReLU leaky_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < int(value.size()); ++i) { -+ y[i] = leaky_op(value[i], alpha_recip); -+ } -+ -+ return y; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value, params_.leaky_alpha); -+ } -+}; -+ -+// Tanh operator -+template -+struct Tanh { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar) const { -+ return fast_tanh(scalar); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template -+struct Tanh > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ Array y; -+ Tanh tanh_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = tanh_op(value[i]); -+ } -+ -+ return y; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+template -+struct Tanh> { -+ using T = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const& z) const { -+ fast_tanh_op> tanh; -+ return tanh(z); -+ -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+// Sigmoid operator -+template -+struct Sigmoid { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar) const { -+ return T(1) / (T(1) + fast_exp(-scalar)); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template -+struct Sigmoid > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ Array y; -+ Sigmoid sigmoid_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = sigmoid_op(value[i]); -+ } -+ -+ return y; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+template -+struct Sigmoid> { -+ using T = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const& z) const { -+ plus> add; -+ -+#if defined(CUTLASS_USE_TANH_FOR_SIGMOID) -+ multiplies> mul; -+ fast_tanh_op> tanh; -+ return mul(add(tanh(mul(z, cutlass::constants::half())), cutlass::constants::one()), -+ cutlass::constants::half()); -+#else -+ divides> div; -+ negate> neg; -+ fast_exp_op> fast_exp; -+ return div(cutlass::constants::one(), -+ add(cutlass::constants::one(), -+ fast_exp(neg(z)))); -+#endif -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &z, Params const ¶ms_) const { -+ return this->operator()(z); -+ } -+}; -+ -+// SiLu (swish) operator introduced by Elfwing et al. in the following paper -+// "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning" (2017) -+// https://arxiv.org/pdf/1702.03118.pdf -+// It is used in EfficientNet and YOLOv5, for example. -+// Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html -+template -+struct SiLu { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar) const { -+ Sigmoid sigmoid; -+ return scalar * sigmoid(scalar); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template -+struct SiLu> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ Sigmoid> sigmoid_op; -+ multiplies> mul; -+ return mul(value, sigmoid_op(value)); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+// Hardswish operator introduced by Howard et al. in the following paper -+// "Searching for MobileNetV3" (2019) -+// https://arxiv.org/pdf/1905.02244.pdf -+// It is used in models based on MobilenetNetV3. -+// Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html -+template -+struct HardSwish { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &x) const { -+ minimum mn; -+ maximum mx; -+ T relu6 = mn(mx(x + T(3), T(0)), T(6)); -+ return x * relu6 / T(6); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &x, Params const ¶ms_) const { -+ return this->operator()(x); -+ } -+}; -+ -+template <> -+struct HardSwish { -+ using T = float; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &x) const { -+ minimum mn; -+ maximum mx; -+ T relu6 = mn(mx(x + T(3), T(0)), T(6)); -+ return x * relu6 * 0.16666667f; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &x, Params const ¶ms_) const { -+ return this->operator()(x); -+ } -+}; -+ -+template -+struct HardSwish > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ Array y; -+ HardSwish hardswish_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = hardswish_op(value[i]); -+ } -+ -+ return y; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &x, Params const ¶ms_) const { -+ return this->operator()(x); -+ } -+}; -+ -+template -+struct HardSwish > { -+ using T = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ minimum > mn; -+ maximum > mx; -+ multiplies > mul; -+ plus > add; -+ -+ return mul(mul(mn(mx(add(value, T(3)), T(0)), T(6)), value), T(0.16666667f)); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &x, Params const ¶ms_) const { -+ return this->operator()(x); -+ } -+}; -+ -+// -+// GELU function definitions implemented as described by -+// Hendrycks, D., and Gimpel, K. in -+// "Gaussian Error Linear Units (GELUs)." (2020) -+// https://arxiv.org/pdf/1606.08415.pdf -+// -+// Floating-point constants are Taylor coefficients described in the paper. -+// -+ -+// GELU operator -+template -+struct GELU { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar) const { -+ return T(cutlass::constants::half() * scalar * -+ (cutlass::constants::one() + (T)erff((float)(scalar / cutlass::constants::root_two())))); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template <> -+struct GELU { -+ CUTLASS_HOST_DEVICE -+ float operator()(float const &scalar) const { -+ return cutlass::constants::half() * scalar * -+ (cutlass::constants::one() + erff( scalar / cutlass::constants::root_two() )); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ float operator()(float const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template <> -+struct GELU { -+ CUTLASS_HOST_DEVICE -+ double operator()(double const &scalar) const { -+ return cutlass::constants::half() * scalar * -+ (cutlass::constants::one() + erf( scalar / cutlass::constants::root_two() )); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ double operator()(double const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template -+struct GELU > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ Array y; -+ GELU gelu_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = gelu_op(value[i]); -+ } -+ -+ return y; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+// GELU operator implemented using the Taylor series approximation -+template -+struct GELU_taylor { -+ static const bool kIsHeavy=true; -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &z) const { -+ -+ T k0 = T(0.7978845608028654); -+ T k1 = T(0.044715); -+ -+ return T(cutlass::constants::half() * z * -+ (cutlass::constants::one() + fast_tanh(k0 * z * (cutlass::constants::one() + k1 * z * z)))); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template -+struct GELU_taylor > { -+ static const bool kIsHeavy=true; -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &z) const { -+ -+ using T = half_t; -+ Array y; -+ -+ half_t k0 = half_t(0.7978845608028654); -+ half_t k1 = half_t(0.044715); -+ -+ multiply_add> fma; -+ multiplies> mul; -+ plus> add; -+ -+ fast_tanh_op> tanh; -+ -+ Array u = mul(mul(k0, z), fma(mul(k1, z), z, cutlass::constants::one())); -+ -+ y = mul(mul(z, cutlass::constants::half()), add(cutlass::constants::one(), tanh(u))); -+ -+ return y; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+template -+struct GELU_taylor > { -+ static const bool kIsHeavy=true; -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ Array y; -+ GELU_taylor gelu_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = gelu_op(value[i]); -+ } -+ -+ return y; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+/// Computes backwards pass for GELU operator assuming d_t is the layer gradient and -+/// z is computed from the forward pass. -+template -+struct dGELU { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &d_t, T const &z) const { -+ -+ T k0 = T(0.7978845608028654); -+ T k1 = T(0.044715); -+ T k2 = T(0.1070322243); -+ -+ T tanh_out = fast_tanh(k0 * z * (1 + k1 * z * z)); -+ -+ T ff = constants::half() * z * ((1 - tanh_out * tanh_out) * (k0 + k2 * z * z)) + -+ constants::half() * (1 + tanh_out); -+ -+ return ff * d_t; -+ } -+}; -+ -+template -+struct dGELU > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &d_t, Array const &z) const { -+ Array y; -+ dGELU gelu_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = gelu_op(d_t[i], z[i]); -+ } -+ -+ return y; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/conversion_op.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/conversion_op.h -new file mode 100644 -index 0000000..98e3beb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/conversion_op.h -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing conversion operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Converts the result without other operations -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class Convert { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementAccumulator_; -+ -+ static int const kCount = Count; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = FragmentAccumulator; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ static bool const kIsHeavy = false; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ }; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ Convert(Params const ¶ms = Params()) { -+ -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ -+ } -+ -+ /// Returns true if source is needed based on state of runtime arguments -+ CUTLASS_HOST_DEVICE -+ constexpr bool is_source_needed() const { -+ return false; -+ } -+ -+ /// Constexpr function to enable the compiler to optimize away the source loading if it is -+ /// never needed. -+ CUTLASS_HOST_DEVICE -+ constexpr bool is_source_ever_needed() const { -+ return false; -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source = FragmentOutput(), -+ ElementCompute uniform = ElementCompute(0)) const { -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(accumulator); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination.h -new file mode 100644 -index 0000000..0c4b384 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination.h -@@ -0,0 +1,306 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+#include "cutlass/epilogue/thread/linear_combination_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation. -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, -+ typename ElementSource_ = ElementOutput_ -+> -+class LinearCombination { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ using ElementC = ElementSource_; -+ using ElementD = ElementOutput_; -+ -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ -+ using ParamsBase = LinearCombinationParams; -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params : ParamsBase{ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ ParamsBase( -+ ElementCompute(1), -+ ElementCompute(0) -+ ), -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta -+ ): -+ ParamsBase(alpha, beta), -+ alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha -+ ): -+ ParamsBase(alpha, ElementCompute(0)), -+ alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr -+ ): -+ ParamsBase(*alpha_ptr, *beta_ptr), -+ alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr -+ ): -+ ParamsBase(*alpha_ptr, ElementCompute(0)), -+ alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ParamsBase const& base -+ ): ParamsBase(base), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ #if defined(__CUDA_ARCH__) -+ alpha = reinterpret_cast(base.alpha_data); -+ beta = reinterpret_cast(base.beta_data); -+ #else -+ memcpy( alpha, base.alpha_data, sizeof(ElementCompute) ); -+ memcpy( beta, base.alpha_data, sizeof(ElementCompute) ); -+ #endif -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombination(Params const ¶ms) { -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ ComputeFragment converted_source = source_converter(source); -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ if (Scale == ScaleType::Nothing) -+ return destination_converter(converted_accumulator); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ if (Scale == ScaleType::NoBetaScaling) -+ intermediate = converted_source; -+ else -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ if (Scale == ScaleType::Nothing) -+ return destination_converter(converted_accumulator); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ multiplies mul_accumulator; -+ -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ -+ return destination_converter(intermediate); -+ } -+ -+ // -+ // Specializations for scalar (for use with cute::collective::DefaultEpilogue) -+ // -+ CUTLASS_HOST_DEVICE -+ ElementD operator()(ElementAccumulator const accumulator, ElementC const source) const { -+ // Convert everything to Compute type, do compute, and then store to output type -+ NumericConverter accumulator_converter; -+ [[maybe_unused]] NumericConverter source_converter; -+ NumericConverter destination_converter; -+ -+ // Convert to destination numeric type -+ -+ ElementCompute converted_accumulator = accumulator_converter(accumulator); -+ if constexpr (Scale == ScaleType::Nothing) { -+ return destination_converter(converted_accumulator); -+ } -+ -+ // Perform binary operations -+ ElementCompute intermediate; -+ multiplies multiply; -+ multiply_add madd; -+ -+ if constexpr (Scale == ScaleType::NoBetaScaling) { -+ intermediate = source_converter(source); -+ } -+ else { -+ intermediate = multiply(beta_, source); // X = beta * C + uniform -+ } -+ -+ intermediate = madd(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ return destination_converter(intermediate); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ ElementD operator()(ElementAccumulator const accumulator) const { -+ // Convert everything to Compute type, do compute, and then store to output type -+ NumericConverter accumulator_converter; -+ NumericConverter destination_converter; -+ ElementCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Convert to destination numeric type -+ if constexpr (Scale == ScaleType::Nothing) { -+ return destination_converter(converted_accumulator); -+ } -+ -+ // Perform binary operations -+ ElementCompute intermediate; -+ multiplies multiply; -+ -+ intermediate = multiply(alpha_, accumulator); // D = alpha * Accum -+ return destination_converter(intermediate); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h -new file mode 100644 -index 0000000..6892efb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h -@@ -0,0 +1,260 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Functor performing linear combination operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This base class is meant to define the concept required of the -+/// EpilogueWithBroadcast::OutputOp -+template < -+ typename ElementC_, -+ typename ElementAccumulator_, -+ typename ElementCompute_, -+ typename ElementZ_, -+ typename ElementT_, -+ int ElementsPerAccess, -+ typename ElementwiseOp_ = Identity, -+ typename BinaryOp_ = plus -+> -+class LinearCombinationBiasElementwise { -+public: -+ -+ using ElementOutput = ElementC_; -+ using ElementC = ElementC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ using ElementZ = ElementZ_; -+ using ElementT = ElementT_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kCount = kElementsPerAccess; -+ -+ using ElementwiseOp = ElementwiseOp_; -+ using BinaryOp = BinaryOp_; -+ -+ // Indicates that this epilogue applies only one binary operation -+ static bool const kIsSingleSource = true; -+ -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentC = Array; -+ using FragmentZ = Array; -+ using FragmentT = Array; -+ -+ using FragmentOutput = FragmentZ; -+ -+ static bool const kIsHeavy = ElementwiseOp::kIsHeavy; -+ -+ /// If true, the 'Z' tensor is stored -+ static bool const kStoreZ = true; -+ -+ /// If true, the 'T' tensor is stored -+ static bool const kStoreT = true; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha -+ ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ bool skip_elementwise_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor from Params -+ CUTLASS_HOST_DEVICE -+ LinearCombinationBiasElementwise(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ skip_elementwise_ = false; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ skip_elementwise_ = true; -+ } -+ } -+ -+ /// Applies the operation when is_source_needed() is true -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentZ &frag_Z, -+ FragmentT &frag_T, -+ FragmentAccumulator const &AB, -+ FragmentC const &frag_C, -+ FragmentCompute const &V) const { -+ -+ ElementwiseOp elementwise_op; -+ BinaryOp binary_op; -+ -+ FragmentCompute tmp_Accum = NumericArrayConverter()(AB); -+ FragmentCompute tmp_C = NumericArrayConverter()(frag_C); -+ FragmentCompute result_Z; -+ FragmentCompute result_T; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElementsPerAccess; ++i) { -+ ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]); -+ result_T[i] = z; -+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z); -+ } -+ -+ NumericArrayConverter convert_z; -+ frag_Z = convert_z(result_Z); -+ -+ NumericArrayConverter convert_t; -+ frag_T = convert_t(result_T); -+ } -+ -+ /// Applies the operation when is_source_needed() is false -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentZ &frag_Z, -+ FragmentT &frag_T, -+ FragmentAccumulator const &AB, -+ FragmentCompute const &V) const { -+ -+ ElementwiseOp elementwise_op; -+ BinaryOp binary_op; -+ -+ FragmentCompute tmp_Accum = NumericArrayConverter()(AB); -+ FragmentCompute result_Z; -+ FragmentCompute result_T; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElementsPerAccess; ++i) { -+ ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]); -+ result_T[i] = z; -+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z); -+ } -+ -+ NumericArrayConverter convert_z; -+ frag_Z = convert_z(result_Z); -+ -+ NumericArrayConverter convert_t; -+ frag_T = convert_t(result_T); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h -new file mode 100644 -index 0000000..b095c91 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h -@@ -0,0 +1,450 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct ArrayMaximum { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &lhs, -+ Array const &rhs) const { -+ -+ Array result; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ElementsPerAccess; ++i) { -+ result[i] = fmax(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct ArrayMaximum { -+ -+ CUTLASS_DEVICE -+ Array operator()( -+ Array const &lhs, -+ Array const &rhs) const { -+ -+ Array result; -+ -+ #if __CUDA_ARCH__ >= 800 -+ int const kVectorCount = ElementsPerAccess / 2; -+ -+ -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(lhs.raw_data()); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(rhs.raw_data()); -+ __half2 *res_ptr = reinterpret_cast<__half2 *>(result.raw_data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kVectorCount; ++i) { -+ res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ #else -+ __half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data()); -+ __half const *rhs_ptr = reinterpret_cast<__half const *>(rhs.raw_data()); -+ __half *res_ptr = reinterpret_cast<__half *>(result.raw_data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ElementsPerAccess; ++i) { -+ res_ptr[i] = ((lhs_ptr[i] < rhs_ptr[i]) ? rhs_ptr[i] : lhs_ptr[i]); -+ } -+ -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_DEVICE -+ Array operator()( -+ Array const &lhs, -+ half_t const &rhs) const { -+ -+ Array result; -+ -+ #if __CUDA_ARCH__ >= 800 -+ int const kVectorCount = ElementsPerAccess / 2; -+ -+ -+ __half rhs_raw = reinterpret_cast<__half const &>(rhs); -+ __half2 rhs_pair = __half2half2(rhs_raw); -+ -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(lhs.raw_data()); -+ __half2 *res_ptr = reinterpret_cast<__half2 *>(result.raw_data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kVectorCount; ++i) { -+ res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); -+ } -+ -+ #else -+ -+ __half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data()); -+ __half const rhs_raw = reinterpret_cast<__half const &>(rhs); -+ __half *res_ptr = reinterpret_cast<__half *>(result.raw_data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ElementsPerAccess; ++i) { -+ res_ptr[i] = ((lhs_ptr[i] < rhs_raw) ? rhs_raw : lhs_ptr[i]); -+ } -+ -+ #endif -+ -+ return result; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct ReluConditional { -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ bool conditional[], -+ Array const &fragment, -+ Element threshold) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ElementsPerAccess; ++i) { -+ conditional[i] = !(fragment[i] < threshold); -+ } -+ } -+}; -+ -+template -+struct ReluConditional { -+ -+ CUTLASS_DEVICE -+ void operator()( -+ bool conditional[], -+ Array const &fragment, -+ half_t threshold) const { -+ -+ __half y = reinterpret_cast<__half const &>(threshold); -+ __half const *x = reinterpret_cast<__half const *>(fragment.raw_data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ElementsPerAccess; ++i) { -+ conditional[i] = !__hlt(x[i], y); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This is a partial specialization for fused Bias and ReLU. It supports the option of packing -+/// ReLU conditionals in a bit vector that may be used by backwards passes as an optimization. -+/// -+/// This class can only be used with cutlass::epilogue::threadblock::EpilogueWithBroadcast<>. -+/// -+/// This base class is meant to define the concept required of the -+/// EpilogueWithBroadcast::OutputOp -+template < -+ typename ElementC_, -+ typename ElementAccumulator_, -+ typename ElementCompute_, -+ typename ElementZ_, -+ int ElementsPerAccess, -+ bool StoreT = true -+> -+class LinearCombinationBiasRelu { -+public: -+ -+ using ElementOutput = ElementC_; -+ using ElementC = ElementC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ using ElementZ = ElementZ_; -+ -+ using ElementT = uint1b_t; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kCount = kElementsPerAccess; -+ -+ using ElementwiseOp = ReLu; -+ using BinaryOp = plus; -+ -+ // Indicates that this epilogue applies only one binary operation -+ static bool const kIsSingleSource = true; -+ -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentC = Array; -+ using FragmentZ = Array; -+ using FragmentT = Array; -+ -+ /// If true, the 'Z' tensor is stored -+ static bool const kStoreZ = true; -+ -+ /// If true, the 'T' tensor is stored -+ static bool const kStoreT = StoreT; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ ElementZ threshold; ///< ReLu threshold -+ -+ // -+ // Methods -+ // -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute()), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr), -+ threshold(ElementCompute()) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta, -+ ElementCompute threshold_ = ElementCompute() -+ ): -+ alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ NumericConverter convert_threshold; -+ -+ threshold = convert_threshold(threshold_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha -+ ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr), threshold(ElementZ()) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr, -+ ElementCompute threshold_ = ElementCompute() -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ NumericConverter convert_threshold; -+ -+ threshold = convert_threshold(threshold_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr), threshold(ElementZ()) { -+ } -+ -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ ElementZ threshold_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor from Params -+ CUTLASS_HOST_DEVICE -+ LinearCombinationBiasRelu(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ threshold_ = params.threshold; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // set to NaN to make ReLU no-op for all except last k partitions -+ int64_t allones = -1; -+ threshold_ = reinterpret_cast(allones); -+ } -+ } -+ -+ /// Applies the operation when is_source_needed() is true -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentZ &frag_Z, -+ FragmentT &frag_T, -+ FragmentAccumulator const &AB, -+ FragmentC const &frag_C, -+ FragmentCompute const &V) const { -+ -+ BinaryOp binary_op; -+ -+ FragmentCompute tmp_Accum = NumericArrayConverter()(AB); -+ FragmentCompute tmp_C = NumericArrayConverter()(frag_C); -+ FragmentCompute result_Z; -+ -+ bool conditions[kElementsPerAccess]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElementsPerAccess; ++i) { -+ -+ ElementCompute z = alpha_ * tmp_Accum[i]; -+ z += beta_ * tmp_C[i]; -+ -+ z = binary_op(z, V[i]); -+ result_Z[i] = z; -+ } -+ -+ NumericArrayConverter convert_z; -+ frag_Z = convert_z(result_Z); -+ -+ // -+ // Compute condition -+ // -+ -+ detail::ReluConditional relu_conditional; -+ relu_conditional(conditions, frag_Z, threshold_); -+ -+ detail::ArrayMaximum maximum_op; -+ frag_Z = maximum_op(frag_Z, threshold_); -+ -+ if (kStoreT) { -+ PackPredicates pack_predicates; -+ frag_T = pack_predicates(conditions); -+ } -+ } -+ -+ /// Applies the operation when is_source_needed() is false -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentZ &frag_Z, -+ FragmentT &frag_T, -+ FragmentAccumulator const &AB, -+ FragmentCompute const &V) const { -+ -+ BinaryOp binary_op; -+ -+ FragmentCompute tmp_Accum = NumericArrayConverter()(AB); -+ FragmentCompute result_Z; -+ -+ bool conditions[kElementsPerAccess]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElementsPerAccess; ++i) { -+ ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]); -+ result_Z[i] = z; -+ } -+ -+ NumericArrayConverter convert_z; -+ frag_Z = convert_z(result_Z); -+ -+ // -+ // Compute condition -+ // -+ -+ detail::ReluConditional relu_conditional; -+ relu_conditional(conditions, frag_Z, threshold_); -+ -+ detail::ArrayMaximum maximum_op; -+ frag_Z = maximum_op(frag_Z, threshold_); -+ -+ // -+ // Compute conditions -+ // -+ -+ // -+ // Store -+ // -+ if (kStoreT) { -+ PackPredicates pack_predicates; -+ frag_T = pack_predicates(conditions); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h -new file mode 100644 -index 0000000..fdfe171 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h -@@ -0,0 +1,693 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear scaling operations used by epilogues. Values are clamped before -+ converting to the output element type. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Single source of truth for whether to unroll for `LinearCombinationClamp()` -+constexpr bool LinearCombinationClampIsHeavy() { -+ return false; -+} -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements then clamps the output before -+/// converting to the output element type. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationClamp { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ static bool const kIsHeavy = detail::LinearCombinationClampIsHeavy(); -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha -+ ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationClamp(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source, -+ ElementCompute uniform = ElementCompute(0)) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_source = source_converter(source); -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ minimum min_accumulator; -+ maximum max_accumulator; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ /// Clamping constant value -+ ElementCompute const kClampMax = -+ ElementCompute(platform::numeric_limits::max()); -+ -+ ElementCompute const kClampMin = -+ ElementCompute(platform::numeric_limits::lowest()); -+ -+ intermediate = max_accumulator(intermediate, kClampMin); -+ intermediate = min_accumulator(intermediate, kClampMax); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ -+ ComputeFragment intermediate; -+ -+ multiplies mul_accumulator; -+ -+ minimum min_accumulator; -+ maximum max_accumulator; -+ -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ /// Clamping constant value -+ ElementCompute const kClampMax = -+ ElementCompute(platform::numeric_limits::max()); -+ -+ ElementCompute const kClampMin = -+ ElementCompute(platform::numeric_limits::lowest()); -+ -+ intermediate = max_accumulator(intermediate, kClampMin); -+ intermediate = min_accumulator(intermediate, kClampMax); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conditional guards to enable partial specialization for packed integers -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) -+ -+/// Applies a linear combination operator to an array of elements then clamps the output before -+/// converting to the output element type. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ScaleType::Kind Scale, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round -+> -+class LinearCombinationClamp { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ static_assert( -+ platform::numeric_limits::is_integer, -+ "This elementwise op expects the output to be int."); -+ -+ static int const kCount = Count; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ static bool const kIsHeavy = detail::LinearCombinationClampIsHeavy(); -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha -+ ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationClamp(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source, -+ ElementCompute uniform = ElementCompute(0)) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_source = source_converter(source); -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Compute linear scaling in floating point -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ // Float min-max -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()(FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Compute linear scaling in floating point -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_accumulator; -+ -+ // Float min-max -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } -+}; -+ -+#endif // Conditional guards to enable partial specialization for packed integers -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements then clamps -+/// the output before converting to the output element type. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+/// Note: The below method only when problem_size_K <= 256 for signed int8 gemm -+/// or problem_size_K <= 128 for unsigned int8 gemm. The default approach is -+/// above. -+/// TODO: Add logic to fallback to the default approach -+template < -+ /// Data type used to load and store< tensors -+ typename ElementOutput_, -+ /// Number of elements computed per operation -+ int Count, -+ ///< Control Alpha and Beta scaling -+ ScaleType::Kind Scale = ScaleType::Default, -+ /// Rounding mode -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> -+class FastLinearCombinationClamp { -+ public: -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ static_assert( -+ platform::numeric_limits::is_integer, -+ "This elementwise op expects the output to be int."); -+ -+ static int const kCount = Count; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ static bool const kIsHeavy = false; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ /// scales accumulators -+ ElementCompute alpha; -+ /// scales source tensor -+ ElementCompute beta; -+ /// pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *alpha_ptr; -+ /// pointer to source scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params() -+ : alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute alpha, ElementCompute beta) -+ : alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute alpha) -+ : alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr) -+ : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute const *alpha_ptr) -+ : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+ public: -+ /// Constructs the function object, possibly loading from pointers in host -+ /// memory -+ CUTLASS_HOST_DEVICE -+ FastLinearCombinationClamp(Params const ¶ms) { -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()(FragmentAccumulator const &accumulator, -+ FragmentOutput const &source, -+ ElementCompute uniform = ElementCompute(0)) const { -+ // Convert source to interal compute numeric type -+ FastNumericArrayConverter -+ source_converter; -+ FastNumericArrayConverter -+ accumulator_converter; -+ -+ ComputeFragment converted_source = source_converter(source); -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Compute linear scaling in floating point -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ minimum min_accumulator; -+ maximum max_accumulator; -+ -+ // Float min-max -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = -+ mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, -+ intermediate); // D = alpha * Accum + X -+ } -+ -+ /// Clamping constant value -+ ElementCompute const kClamp = -+ ElementCompute(1 << (sizeof_bits::value - 1)); -+ -+ intermediate = max_accumulator(intermediate, -kClamp); -+ intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1)); -+ -+ // Convert to destination numeric type -+ FastNumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()(FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ FastNumericArrayConverter -+ accumulator_converter; -+ -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Compute linear scaling in floating point -+ ComputeFragment intermediate; -+ -+ multiplies mul_accumulator; -+ -+ minimum min_accumulator; -+ maximum max_accumulator; -+ -+ // Float min-max -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); -+ } -+ -+ /// Clamping constant value -+ ElementCompute const kClamp = -+ ElementCompute(1 << (sizeof_bits::value - 1)); -+ -+ intermediate = max_accumulator(intermediate, -kClamp); -+ intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1)); -+ -+ // Convert to destination numeric type -+ FastNumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h -new file mode 100644 -index 0000000..d026a8c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h -@@ -0,0 +1,250 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Functor performing linear combination followed by dGelu operation -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/constants.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementCompute_, ///< Data type returned by this functor -+ typename ElementAccumulator_, ///< Data type of accumulators -+ typename ElementSource_, ///< Data type of source tensor -+ typename ElementTensor_, ///< Data type of additional tensor -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationDGelu { -+public: -+ -+ using ElementOutput = ElementSource_; -+ using ElementCompute = ElementCompute_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementSource = ElementSource_; -+ using ElementTensor = ElementTensor_; -+ -+ static bool const kIsHeavy = true; -+ -+ static int const kCount = Count; -+ -+ using FragmentCompute = Array; -+ using FragmentAccumulator = Array; -+ using FragmentSource = Array; -+ using FragmentTensor = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute threshold; ///< minimum value that is output -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ threshold(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ ElementCompute threshold_; -+ bool participates_in_reduction_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationDGelu(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ threshold_ = params.threshold; -+ participates_in_reduction_ = true; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Returns true if the threadblock computes the reduction -+ CUTLASS_HOST_DEVICE -+ bool participates_in_reduction() const { -+ return participates_in_reduction_; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // set to NaN to make ReLU no-op for all except last k partitions -+ int64_t allones = -1; -+ threshold_ = reinterpret_cast(allones); -+ // Avoid computing the reduction if this isn't the final Split-K slice -+ participates_in_reduction_ = false; -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentSource const &source, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ -+ dGELU gelu_op; -+ -+ // dGelu -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ intermediate[i] = gelu_op(intermediate[i], ElementCompute(tensor[i])); -+ } -+ -+ return intermediate; -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ -+ dGELU gelu_op; -+ -+ // dGelu with conversion -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ intermediate[i] = gelu_op(intermediate[i], ElementCompute(tensor[i])); -+ } -+ -+ return intermediate; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h -new file mode 100644 -index 0000000..f05da6d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h -@@ -0,0 +1,452 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with a maximum operation used by epilogues. -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementCompute_, ///< Data type returned by this functor -+ typename ElementAccumulator_, ///< Data type of accumulators -+ typename ElementSource_, ///< Data type of source tensor -+ typename ElementTensor_, ///< Data type of additional tensor -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationDRelu { -+public: -+ -+ using ElementOutput = ElementSource_; -+ using ElementCompute = ElementCompute_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementSource = ElementSource_; -+ using ElementTensor = ElementTensor_; -+ -+ static int const kCount = Count; -+ -+ using FragmentCompute = Array; -+ using FragmentAccumulator = Array; -+ using FragmentSource = Array; -+ using FragmentTensor = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute threshold; ///< minimum value that is output -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ threshold(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ ElementTensor threshold_; -+ bool participates_in_reduction_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationDRelu(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ threshold_ = ElementTensor(params.threshold); -+ participates_in_reduction_ = true; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Returns true if the threadblock computes the reduction -+ CUTLASS_HOST_DEVICE -+ bool participates_in_reduction() const { -+ return participates_in_reduction_; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // set to NaN to make ReLU no-op for all except last k partitions -+ int64_t allones = -1; -+ threshold_ = reinterpret_cast(allones); -+ participates_in_reduction_ = false; -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentSource const &source, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ -+ // dReLU = (cond ? dy : 0) -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ ElementTensor cond = tensor[i]; -+ if (cond <= threshold_) { -+ intermediate[i] = ElementCompute(); -+ } -+ } -+ -+ return intermediate; -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ -+ // dReLU = (cond ? dy : 0) -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ ElementTensor cond = tensor[i]; -+ if (cond <= threshold_) { -+ intermediate[i] = ElementCompute(); -+ } -+ } -+ -+ return intermediate; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementCompute_, ///< Data type returned by this functor -+ typename ElementAccumulator_, ///< Data type of accumulators -+ typename ElementSource_, ///< Data type of source tensor -+ int Count, ///< Number of elements computed per operation -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationDReluConditionalBits { -+public: -+ -+ using ElementOutput = ElementSource_; -+ using ElementCompute = ElementCompute_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementSource = ElementSource_; -+ using ElementTensor = uint1b_t; -+ -+ static bool const kIsHeavy = false; -+ -+ static int const kCount = Count; -+ -+ using FragmentCompute = Array; -+ using FragmentAccumulator = Array; -+ using FragmentSource = Array; -+ using FragmentTensor = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ FragmentTensor predicate_mask_; -+ bool participates_in_reduction_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationDReluConditionalBits(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ participates_in_reduction_ = true; -+ predicate_mask_.clear(); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Returns true if the threadblock computes the reduction -+ CUTLASS_HOST_DEVICE -+ bool participates_in_reduction() const { -+ return participates_in_reduction_; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ predicate_mask_.clear(); -+ -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // Avoid computing the reduction if this isn't the final Split-K slice -+ participates_in_reduction_ = false; -+ -+ bit_not not_op; -+ predicate_mask_ = not_op(predicate_mask_); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentSource const &source, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ -+ bit_or or_op; -+ -+ FragmentTensor predicates = or_op(tensor, predicate_mask_); -+ -+ // Obtain from packed bits -+ bool conditions[kCount]; -+ UnpackPredicates unpack_predicates; -+ -+ unpack_predicates(conditions, predicates); -+ -+ // dReLU = (cond ? dy : 0) -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ if (!conditions[i]) { -+ intermediate[i] = ElementCompute(); -+ } -+ } -+ -+ return intermediate; -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ -+ bit_or or_op; -+ -+ FragmentTensor predicates = or_op(tensor, predicate_mask_); -+ -+ // Obtain from packed bits -+ bool conditions[kCount]; -+ UnpackPredicates unpack_predicates; -+ -+ unpack_predicates(conditions, predicates); -+ -+ // dReLU = (cond ? dy : 0) -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ if (!conditions[i]) { -+ intermediate[i] = ElementCompute(); -+ } -+ } -+ -+ return intermediate; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h -new file mode 100644 -index 0000000..0a68c16 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h -@@ -0,0 +1,70 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with GELU operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/linear_combination_generic.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator followed by the GELU activation to an array of elements. -+/// -+/// D = gelu(alpha * accumulator + beta * source + uniform) -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+using LinearCombinationGELU = LinearCombinationGeneric; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h -new file mode 100644 -index 0000000..71ada3f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h -@@ -0,0 +1,207 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator followed by an activation function to an array of elements. -+/// -+/// D = activation(alpha * accumulator + beta * source + uniform) -+/// -+template < -+ template class ActivationFunctor, -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, -+ bool IsHeavy = false -+> -+class LinearCombinationGeneric { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static bool const kIsHeavy = IsHeavy; -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ using Params = typename ActivationFunctor::Params; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Params params_; -+ bool skip_elementwise_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationGeneric(Params const ¶ms) { -+ params_ = params; -+ params_.alpha = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ params_.beta = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ skip_elementwise_ = false; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return params_.beta != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ params_.beta = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ skip_elementwise_ = true; -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ ActivationFunctor activation; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(params_.beta, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_accumulator; -+ ActivationFunctor activation; -+ -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_accumulator(params_.alpha, converted_accumulator); // D = alpha * Accum -+ } -+ -+ intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h -new file mode 100644 -index 0000000..3bd4b89 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h -@@ -0,0 +1,69 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with HardSwish operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/linear_combination_generic.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator followed by the HardSwish activation to an array of elements. -+/// -+/// D = hardswish(alpha * accumulator + beta * source + uniform) -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+using LinearCombinationHardSwish = LinearCombinationGeneric; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h -new file mode 100644 -index 0000000..ebee6b4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h -@@ -0,0 +1,230 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationLeakyRelu { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta_bias; ///< scales bias tensor -+ ElementCompute leaky_alpha; ///< leaky_alpha -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta_bias(ElementCompute(0)), -+ leaky_alpha(ElementCompute(1)) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta_bias, -+ ElementCompute leaky_alpha = ElementCompute(1) -+ ): alpha(alpha), beta_bias(beta_bias), leaky_alpha(leaky_alpha) { -+ -+ } -+ -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_bias_; -+ ElementCompute leaky_alpha_recip_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationLeakyRelu(Params const ¶ms) { -+ alpha_ = (params.alpha); -+ beta_bias_ = (params.beta_bias); -+ leaky_alpha_recip_ = (ElementCompute(params.leaky_alpha)); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_bias_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition) { -+ if (k_partition) { -+ beta_bias_ = ElementCompute(1); -+ } -+ } -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_bias_ = ElementCompute(1); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_source = source_converter(source); -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ LeakyReLU leakyrelu; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(beta_bias_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ // Compute threshold optionally -+ intermediate = leakyrelu(intermediate, leaky_alpha_recip_); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ -+ multiplies mul_accumulator; -+ LeakyReLU leakyrelu; -+ //printf("in doing with bias"); -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ // Compute threshold optionally -+ intermediate = leakyrelu(intermediate, leaky_alpha_recip_); -+ -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h -new file mode 100644 -index 0000000..a3f825e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h -@@ -0,0 +1,75 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct LinearCombinationParams { -+ uint64_t alpha_data[2]; -+ uint64_t beta_data[2]; -+ -+ CUTLASS_HOST_DEVICE -+ LinearCombinationParams() -+ : alpha_data {0lu, 0lu}, beta_data {0lu, 0lu} -+ { } -+ -+ template -+ CUTLASS_HOST_DEVICE -+ LinearCombinationParams(ElementCompute alpha, ElementCompute beta) -+ : alpha_data {0lu, 0lu}, beta_data {0lu, 0lu} -+ { -+ #if defined(__CUDA_ARCH__) -+ reinterpret_cast(alpha_data) = alpha; -+ reinterpret_cast(beta_data) = beta; -+ #else -+ memcpy( alpha_data, &alpha, sizeof(ElementCompute) ); -+ memcpy( beta_data, &beta, sizeof(ElementCompute) ); -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h -new file mode 100644 -index 0000000..005e301 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h -@@ -0,0 +1,237 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination operations on planar-complex arrays -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/complex.h" -+#include "cutlass/array_planar_complex.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to arrays of planar-complex elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+/// Note, as with most CUTLASS components for planar complex, the template arguments describe -+/// the underlying real data type. -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationPlanarComplex { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ -+ using FragmentOutput = ArrayPlanarComplex; -+ using FragmentAccumulator = ArrayPlanarComplex; -+ using ComputeFragment = ArrayPlanarComplex; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ complex alpha; ///< scales accumulators -+ complex beta; ///< scales source tensor -+ complex const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ complex const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ complex alpha, -+ complex beta -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ complex const *alpha_ptr, -+ complex const *beta_ptr -+ ): alpha(complex()), beta(complex()), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ complex alpha_; -+ complex beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationPlanarComplex(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_.real() != ElementCompute(0) || beta_.imag() != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_source( -+ source_converter(source.real), -+ source_converter(source.imag)); -+ -+ ComputeFragment converted_accumulator( -+ accumulator_converter(accumulator.real), -+ accumulator_converter(accumulator.imag)); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ -+ multiplies > mul_op; -+ multiply_add > mul_add_op; -+ -+ // complex multiply: I = beta * C -+ intermediate.real = mul_op(beta_.real(), converted_source.real); -+ intermediate.imag = mul_op(beta_.real(), converted_source.imag); -+ -+ intermediate.real = mul_add_op(-beta_.imag(), converted_source.imag, intermediate.real); -+ intermediate.imag = mul_add_op( beta_.imag(), converted_source.real, intermediate.imag); -+ -+ // complex multiply-add: I = alpha * AB + I -+ intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real, intermediate.real); -+ intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag, intermediate.imag); -+ -+ intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real); -+ intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return FragmentOutput( -+ destination_converter(intermediate.real), -+ destination_converter(intermediate.imag)); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_accumulator( -+ accumulator_converter(accumulator.real), -+ accumulator_converter(accumulator.imag)); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ -+ multiplies > mul_op; -+ multiply_add > mul_add_op; -+ -+ // complex multiply-add: I = alpha * AB + I -+ intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real); -+ intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag); -+ -+ intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real); -+ intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return FragmentOutput( -+ destination_converter(intermediate.real), -+ destination_converter(intermediate.imag)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h -new file mode 100644 -index 0000000..eb1b436 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h -@@ -0,0 +1,570 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with a maximum operation used by epilogues. -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Single source of truth for whether to unroll for `LinearCombinationClamp()` -+constexpr bool LinearCombinationReluIsHeavy() { -+ return false; -+} -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationRelu { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentScaleBias = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ static bool const kIsHeavy = detail::LinearCombinationReluIsHeavy(); -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute threshold; ///< minimum value that is output -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ threshold(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta = ElementCompute(0), -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr = nullptr, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ ElementCompute threshold_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationRelu(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ threshold_ = params.threshold; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // set to NaN to make ReLU no-op for all except last k partitions -+ int64_t allones = -1; -+ threshold_ = reinterpret_cast(allones); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(threshold_, intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(threshold_, intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias -+ /// Scale and Bias are from input Fragment -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentScaleBias const &scale, -+ FragmentScaleBias const &bias) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform per-channel scale and bias -+ FragmentCompute intermediate; -+ -+ multiply_add mul_add_accumulator; -+ -+ if(Scale == ScaleType::OnlyAlphaPerChannelScaling) -+ intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias -+ else -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias -+ -+ ReLu relu; -+ -+ // Compute threshold optionally -+ intermediate = relu(threshold_, intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conditional guards to enable partial specialization for packed integers -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+/// Special handling for int types -+ -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ScaleType::Kind Scale, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round -+> -+class LinearCombinationRelu { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ static bool const kIsHeavy = detail::LinearCombinationReluIsHeavy(); -+ -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentScaleBias = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute threshold; ///< minimum value that is output -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ threshold(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta = ElementCompute(0), -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr = nullptr, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ ElementCompute threshold_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationRelu(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ threshold_ = params.threshold; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // set to NaN to make ReLU no-op for all except last k partitions -+ int64_t allones = -1; -+ threshold_ = reinterpret_cast(allones); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(threshold_, intermediate); -+ -+ if (platform::numeric_limits::is_integer) { -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } else { -+ NumericArrayConverter -+ destination_converter; -+ return destination_converter(intermediate); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(threshold_, intermediate); -+ -+ if (platform::numeric_limits::is_integer) { -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } else { -+ NumericArrayConverter -+ destination_converter; -+ return destination_converter(intermediate); -+ } -+ } -+ -+ /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias -+ /// Scale and Bias are from input Fragment -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentScaleBias const &scale, -+ FragmentScaleBias const &bias) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform per-channel scale and bias -+ FragmentCompute intermediate; -+ -+ multiply_add mul_add_accumulator; -+ -+ if(Scale == ScaleType::OnlyAlphaPerChannelScaling) -+ intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias -+ else -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias -+ -+ ReLu relu; -+ -+ // Compute threshold optionally -+ intermediate = relu(threshold_, intermediate); -+ -+ if (platform::numeric_limits::is_integer) { -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } else { -+ NumericArrayConverter -+ destination_converter; -+ return destination_converter(intermediate); -+ } -+ } -+}; -+ -+#endif // Conditional guards to enable partial specialization for packed integers -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h -new file mode 100644 -index 0000000..3cffd93 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h -@@ -0,0 +1,541 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with a relu operation used by epilogues. -+ This one only supports relu0 and tries to folding relu into other instructions. Thus, -+ serial splitk is not supported by this one. For example, relu can be folded into -+ hfma2/hmul2 for sm80+ -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Single source of truth for whether to unroll for `LinearCombinationClamp()` -+constexpr bool LinearCombinationRelu0IsHeavy() { -+ return false; -+} -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationRelu0 { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentScaleBias = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy(); -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta = ElementCompute(0) -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr = nullptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationRelu0(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// This is used for serial reduction which is not supported by Relu0 -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ assert(k_partition == 0); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add_relu0 mul_add_relu0_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ -+ // Compute threshold optionally -+ intermediate = relu(intermediate); -+ } else { -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias -+ /// Scale and Bias are from input Fragment -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentScaleBias const &scale, -+ FragmentScaleBias const &bias) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform per-channel scale and bias -+ FragmentCompute intermediate; -+ -+ multiply_add mul_add_accumulator; -+ -+ if(Scale == ScaleType::OnlyAlphaPerChannelScaling) -+ intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias -+ else -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias -+ -+ ReLu relu; -+ -+ // Compute threshold optionally -+ intermediate = relu(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conditional guards to enable partial specialization for packed integers -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+/// Special handling for int types -+ -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ScaleType::Kind Scale, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round -+> -+class LinearCombinationRelu0 { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy(); -+ -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentScaleBias = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta = ElementCompute(0) -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr = nullptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationRelu0(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// This is used for serial reduction which is not supported by Relu0 -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ assert(k_partition == 0); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(intermediate); -+ -+ if (platform::numeric_limits::is_integer) { -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } else { -+ NumericArrayConverter -+ destination_converter; -+ return destination_converter(intermediate); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(intermediate); -+ -+ if (platform::numeric_limits::is_integer) { -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } else { -+ NumericArrayConverter -+ destination_converter; -+ return destination_converter(intermediate); -+ } -+ } -+ -+ /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias -+ /// Scale and Bias are from input Fragment -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentScaleBias const &scale, -+ FragmentScaleBias const &bias) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform per-channel scale and bias -+ FragmentCompute intermediate; -+ -+ multiply_add mul_add_accumulator; -+ -+ if(Scale == ScaleType::OnlyAlphaPerChannelScaling) -+ intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias -+ else -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias -+ -+ ReLu relu; -+ -+ // Compute threshold optionally -+ intermediate = relu(intermediate); -+ -+ if (platform::numeric_limits::is_integer) { -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } else { -+ NumericArrayConverter -+ destination_converter; -+ return destination_converter(intermediate); -+ } -+ } -+}; -+ -+#endif // Conditional guards to enable partial specialization for packed integers -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h -new file mode 100644 -index 0000000..7c47c24 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h -@@ -0,0 +1,302 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Epilogue functor specialized for residual blocks in deep neural networks. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+namespace detail { -+ -+/// Dummy class used to designate that the second binary operator in the epilogue is unsued -+template -+class NoOp {}; -+ -+} -+ -+/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2)) -+template class ActivationOp_, -+ template class BinaryOp1_, -+ template class UnaryOp_, -+ template class BinaryOp2_ = detail::NoOp> -+class LinearCombinationResidualBlock { -+public: -+ static bool const kIsSingleSource = false; -+ -+ using ElementOutput = ElementC_; -+ using ElementC = ElementC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kCount = kElementsPerAccess; -+ -+ using UnaryOp = UnaryOp_>; -+ using BinaryOp1 = BinaryOp1_>; -+ using BinaryOp2 = BinaryOp2_>; -+ using ActivationOp = ActivationOp_>; -+ -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentC = Array; -+ using FragmentOutput = Array; -+ -+ using ElementZ = ElementOutput_; -+ using ElementT = ElementZ; -+ using FragmentZ = Array; -+ using FragmentT = Array; -+ -+ static bool const kIsHeavy = true; -+ static bool const kStoreZ = true; -+ static bool const kStoreT = false; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales residual input -+ ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory -+ -+ CUTLASS_HOST_DEVICE -+ Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute alpha, ElementCompute beta) -+ : alpha(alpha), beta(beta) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr) -+ : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {} -+ }; -+ -+private: -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ bool skip_elementwise_; -+ -+public: -+ -+ /// Constructor from Params -+ CUTLASS_HOST_DEVICE -+ LinearCombinationResidualBlock(Params const ¶ms) { -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ skip_elementwise_ = false; -+ } -+ -+ /// The "source" tensor corresponds to the residual input -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { return true; } -+ -+ /// Functionally required for serial reduction in the epilogue -+ /// IMPORTANT: Split-k is supported only when ActivationOp is Identity. -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ skip_elementwise_ = true; -+ } -+ } -+ -+ /// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2)) -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB, -+ FragmentC const &residual1, FragmentC const &residual2, -+ FragmentCompute const &bias) const { -+ UnaryOp unary_op; -+ BinaryOp1 binary_op1; -+ BinaryOp2 binary_op2; -+ ActivationOp activation; -+ -+ FragmentCompute tmp_Accum = -+ NumericArrayConverter()(AB); -+ FragmentCompute tmp_residual1 = -+ NumericArrayConverter()(residual1); -+ FragmentCompute tmp_residual2 = -+ NumericArrayConverter()(residual2); -+ -+ FragmentCompute z = -+ binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2); -+ FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z); -+ -+ NumericArrayConverter convert_z; -+ frag_Z = convert_z(result_Z); -+ } -+ -+ /// Should never be called -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &, -+ FragmentCompute const &) const {} -+}; -+ -+/// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual)) -+template class ActivationOp_, -+ template class BinaryOp1_, -+ template class UnaryOp_> -+class LinearCombinationResidualBlock { -+public: -+ static bool const kIsSingleSource = true; -+ -+ using ElementOutput = ElementC_; -+ using ElementC = ElementC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kCount = kElementsPerAccess; -+ -+ using UnaryOp = UnaryOp_>; -+ using BinaryOp = BinaryOp1_>; -+ using ActivationOp = ActivationOp_>; -+ -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentC = Array; -+ using FragmentOutput = Array; -+ -+ using ElementZ = ElementOutput_; -+ using ElementT = ElementZ; -+ using FragmentZ = Array; -+ using FragmentT = Array; -+ -+ static bool const kIsHeavy = true; -+ static bool const kStoreZ = true; -+ static bool const kStoreT = false; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales residual input -+ ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory -+ -+ CUTLASS_HOST_DEVICE -+ Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute alpha, ElementCompute beta) -+ : alpha(alpha), beta(beta) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr) -+ : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {} -+ }; -+ -+private: -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ bool skip_elementwise_; -+ -+public: -+ -+ /// Constructor from Params -+ CUTLASS_HOST_DEVICE -+ LinearCombinationResidualBlock(Params const ¶ms) { -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ skip_elementwise_ = false; -+ } -+ -+ /// The "source" tensor corresponds to the residual input -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { return true; } -+ -+ /// Functionally required for serial reduction in the epilogue -+ /// IMPORTANT: Split-k is supported only when ActivationOp is Identity. -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ skip_elementwise_ = true; -+ } -+ } -+ -+ /// Applies the operation UnaryOp(BinaryOp(ActivationOp(AB + bias), residual)) -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB, -+ FragmentC const &residual, -+ FragmentCompute const &bias) const { -+ UnaryOp unary_op; -+ BinaryOp binary_op; -+ ActivationOp activation; -+ -+ FragmentCompute tmp_Accum = -+ NumericArrayConverter()(AB); -+ FragmentCompute tmp_residual = -+ NumericArrayConverter()(residual); -+ -+ FragmentCompute z = -+ binary_op(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual); -+ FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z); -+ -+ NumericArrayConverter convert_z; -+ frag_Z = convert_z(result_Z); -+ } -+ -+ /// Should never be called -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &, -+ FragmentCompute const &) const {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h -new file mode 100644 -index 0000000..c449d23 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h -@@ -0,0 +1,70 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with Sigmoid operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/linear_combination_generic.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator followed by the Sigmoid activation, to an array of elements. -+/// -+/// D = sigmoid(alpha * accumulator + beta * source + uniform) -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+using LinearCombinationSigmoid = LinearCombinationGeneric; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h -new file mode 100644 -index 0000000..222f6de ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h -@@ -0,0 +1,69 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with SiLU operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/linear_combination_generic.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator folllowed by the SiLU activation to an array of elements. -+/// -+/// D = silu(alpha * accumulator + beta * source + uniform) -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+using LinearCombinationSilu = LinearCombinationGeneric; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h -new file mode 100644 -index 0000000..aac19b0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h -@@ -0,0 +1,234 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Functor performing linear combination with elementwise -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/constants.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementCompute_, ///< Data type returned by this functor -+ typename ElementAccumulator_, ///< Data type of accumulators -+ typename ElementSource_, ///< Data type of source tensor -+ typename ElementTensor_, ///< Data type of additional tensor -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationWithElementwise { -+public: -+ -+ using ElementOutput = ElementSource_; -+ using ElementCompute = ElementCompute_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementSource = ElementSource_; -+ using ElementTensor = ElementTensor_; -+ -+ static bool const kIsHeavy = true; -+ -+ static int const kCount = Count; -+ -+ using FragmentCompute = Array; -+ using FragmentAccumulator = Array; -+ using FragmentSource = Array; -+ using FragmentTensor = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute threshold; ///< minimum value that is output -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ threshold(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ ElementCompute threshold_; -+ bool participates_in_reduction_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationWithElementwise(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ threshold_ = params.threshold; -+ participates_in_reduction_ = true; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Returns true if the threadblock computes the reduction -+ CUTLASS_HOST_DEVICE -+ bool participates_in_reduction() const { -+ return participates_in_reduction_; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // set to NaN to make ReLU no-op for all except last k partitions -+ int64_t allones = -1; -+ threshold_ = reinterpret_cast(allones); -+ // Avoid computing the reduction if this isn't the final Split-K slice -+ participates_in_reduction_ = false; -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentSource const &source, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ -+ return intermediate; -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ -+ return intermediate; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/reduction_op.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/reduction_op.h -new file mode 100644 -index 0000000..f904856 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/reduction_op.h -@@ -0,0 +1,97 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing reduction operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a reduction sum to an array of elements. -+/// -+/// -+template < -+ typename Element_, ///< Data type used to load and store tensors -+ int Count ///< Number of elements computed per operation -+> -+class ReductionOpPlus { -+public: -+ -+ using Element = Element_; -+ static int const kCount = Count; -+ -+ using Fragment = Array; -+ using Operator = plus; -+ -+ /// Host-constructable parameters structure -+ struct Params { }; -+ -+private: -+ -+ /// reduction operator -+ Operator operator_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ ReductionOpPlus(Params const ¶ms) { -+ -+ } -+ -+ /// Computes Compute => -+ CUTLASS_HOST_DEVICE -+ Fragment operator()( -+ Fragment const &lhs, -+ Fragment const &rhs) const { -+ -+ return operator_(lhs, rhs); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/scale_type.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/scale_type.h -new file mode 100644 -index 0000000..f229927 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/scale_type.h -@@ -0,0 +1,62 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Enum defines the behaviors of the epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specifies internal data type for computation -+struct ScaleType { -+ enum Kind { -+ Default, // alpha x C + beta x D -+ NoBetaScaling, // alpha x C + D -+ OnlyAlphaScaling, // alpha x C -+ OnlyAlphaPerChannelScaling, // alpha_vec x C -+ Nothing // C -+ }; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h -new file mode 100644 -index 0000000..1b25816 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h -@@ -0,0 +1,255 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped complex GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Specialization and defines sensible defaults for epilogues for complex*complex case -+// 4 real-valued mma operations (Complex) -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Epilouge Shape -+ typename Shape_, -+ /// Warp-level mma operator -+ typename WarpMmaTensorOp_, -+ /// Number of k partitions -+ int PartitionsK, -+ /// Epilogue output operator -+ typename OutputOp_, -+ /// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load() -+ int ElementsPerAccess, -+ /// Multiply-add operator -+ /// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_ = arch::OpMultiplyAddComplex -+> -+struct DefaultEpilogueComplexTensorOp { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using Operator = Operator_; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 0>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization and defines sensible defaults for epilogues for complex*complex case -+// 3 real-valued mma operations (Gaussian Complex) -+// A = (ar + j ai), B = (br +j bi), D = AB -+// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) -+// D = dr + j di = (P1 - P3) + j (P1 + P2) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueComplexTensorOp { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using Operator = arch::OpMultiplyAddGaussianComplex; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorGaussianComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 0>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h -new file mode 100644 -index 0000000..966d44c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h -@@ -0,0 +1,264 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped complex GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Specialization and defines sensible defaults for epilogues for complex*complex case -+// 4 real-valued mma operations (Complex) -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Epilouge Shape -+ typename Shape_, -+ /// Warp-level mma operator -+ typename WarpMmaTensorOp_, -+ /// Number of k partitions -+ int PartitionsK, -+ /// Epilogue output operator -+ typename OutputOp_, -+ /// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load() -+ int ElementsPerAccess, -+ /// Multiply-add operator -+ /// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_ = arch::OpMultiplyAddComplex, -+ /// Is for a symmetric kernel -+ BlasMode BlasMode_ = BlasMode::kGemm -+> -+struct DefaultEpilogueComplexTensorOpBlas3 { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using Operator = Operator_; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3< -+ OutputTileThreadMap, -+ ElementOutput -+ , kBlasMode -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 0>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization and defines sensible defaults for epilogues for complex*complex case -+// 3 real-valued mma operations (Gaussian Complex) -+// A = (ar + j ai), B = (br +j bi), D = AB -+// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) -+// D = dr + j di = (P1 - P3) + j (P1 + P2) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess, -+ BlasMode BlasMode_ -+> -+struct DefaultEpilogueComplexTensorOpBlas3 { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using Operator = arch::OpMultiplyAddGaussianComplex; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3< -+ OutputTileThreadMap, -+ ElementOutput, -+ kBlasMode -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorGaussianComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 0>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h -new file mode 100644 -index 0000000..fc93eb0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h -@@ -0,0 +1,74 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Direct store epilogue -+*/ -+ -+#pragma once -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/epilogue/threadblock/epilogue_direct_store.h" -+#include "cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Given a properly constructed epilogue, returns a direct store epilogue -+template -+struct DefaultEpilogueDirectStore { -+ -+ using OutputTileIterator = DirectStoreEpilogueIterator; -+ -+ using Epilogue = EpilogueDirectStore< -+ typename EpilogueTensorOp::Shape, -+ typename EpilogueTensorOp::WarpMmaOperator, -+ EpilogueTensorOp::kPartitionsK, -+ OutputTileIterator, -+ typename EpilogueTensorOp::AccumulatorFragmentIterator, -+ typename EpilogueTensorOp::WarpTileIterator, -+ typename EpilogueTensorOp::SharedLoadIterator, -+ typename EpilogueTensorOp::OutputOp -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h -new file mode 100644 -index 0000000..872e425 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h -@@ -0,0 +1,241 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Constructs a default epilogue for planar complex outputs. -+ -+ This template reuses components for real-valued epilogues and applies them to planar complex -+ output matrices. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/array_planar_complex.h" -+ -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_planar_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues. -+template < -+ typename ThreadblockShape_, -+ typename WarpMma_, -+ typename OpcodeClass_, -+ typename ArchTag_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpiloguePlanarComplex; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues. -+template < -+ typename ThreadblockShape_, -+ typename WarpMmaOperator_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ arch::OpClassTensorOp, -+ arch::Sm70, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess> { -+ -+ using RealEpilogue = DefaultEpilogueVoltaTensorOp< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess -+ >; -+ -+ using Epilogue = EpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ typename RealEpilogue::OutputTileIterator, -+ typename RealEpilogue::AccumulatorFragmentIterator, -+ typename RealEpilogue::WarpTileIterator, -+ typename RealEpilogue::SharedLoadIterator, -+ OutputOp_, -+ typename RealEpilogue::Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues. -+template < -+ typename ThreadblockShape_, -+ typename WarpMmaOperator_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess> { -+ -+ using RealEpilogue = DefaultEpilogueTensorOp< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess -+ >; -+ -+ using Epilogue = EpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ typename RealEpilogue::OutputTileIterator, -+ typename RealEpilogue::AccumulatorFragmentIterator, -+ typename RealEpilogue::WarpTileIterator, -+ typename RealEpilogue::SharedLoadIterator, -+ OutputOp_, -+ typename RealEpilogue::Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues. -+template < -+ typename ThreadblockShape_, -+ typename WarpMmaOperator_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess> { -+ -+ using RealEpilogue = DefaultEpilogueTensorOp< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess -+ >; -+ -+ using Epilogue = EpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ typename RealEpilogue::OutputTileIterator, -+ typename RealEpilogue::AccumulatorFragmentIterator, -+ typename RealEpilogue::WarpTileIterator, -+ typename RealEpilogue::SharedLoadIterator, -+ OutputOp_, -+ typename RealEpilogue::Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues. -+template < -+ typename ThreadblockShape_, -+ typename WarpMmaOperator_, -+ typename ArchTag_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ arch::OpClassSimt, -+ ArchTag_, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess> { -+ -+ using RealEpilogue = DefaultEpilogueSimt< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ OutputOp_, -+ ElementsPerAccess -+ >; -+ -+ using Epilogue = EpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ typename RealEpilogue::OutputTileIterator, -+ typename RealEpilogue::AccumulatorFragmentIterator, -+ typename RealEpilogue::WarpTileIterator, -+ typename RealEpilogue::SharedLoadIterator, -+ OutputOp_, -+ typename RealEpilogue::Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h -new file mode 100644 -index 0000000..3214d19 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h -@@ -0,0 +1,422 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using SIMT. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/arch/mma.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_simt.h" -+#include "cutlass/epilogue/warp/tile_iterator_simt.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_simt.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+ -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h" -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/threadblock/epilogue_depthwise.h" -+ -+#include "cutlass/layout/permute.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for SimtOps. -+template < -+ typename Shape_, -+ typename WarpMmaSimt_, -+ typename OutputOp_, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueSimt { -+ -+ using Shape = Shape_; -+ using WarpMmaSimt = WarpMmaSimt_; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaSimt::LayoutC; -+ using ElementAccumulator = typename WarpMmaSimt::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt< -+ Shape, -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::Policy, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ ElementAccumulator, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaSimt, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for SimtOps. -+template < -+ typename Shape_, -+ typename WarpMmaSimt_, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueSimtStridedDgrad { -+ -+ using Shape = Shape_; -+ using WarpMmaSimt = WarpMmaSimt_; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaSimt::LayoutC; -+ using ElementAccumulator = typename WarpMmaSimt::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt< -+ Shape, -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::Policy, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< -+ OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ ElementAccumulator, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaSimt, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for SimtOps. -+template < -+ int Rank, -+ typename Shape_, -+ typename WarpMmaSimt_, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueSimtAffineRankN { -+ -+ using Shape = Shape_; -+ using WarpMmaSimt = WarpMmaSimt_; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaSimt::LayoutC; -+ using ElementAccumulator = typename WarpMmaSimt::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt< -+ Shape, -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::Policy, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN< -+ OutputTileThreadMap, -+ ElementOutput, -+ Rank -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ ElementAccumulator, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaSimt, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for SimtOps. -+template , -+ typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> > -+struct DefaultDirectConvEpilogueSimt { -+ using Shape = Shape_; -+ using WarpMmaSimt = WarpMmaSimt_; -+ using WarpShape = typename WarpMmaSimt::Shape; -+ using OutputOp = OutputOp_; -+ using ThreadOutputShape = ThreadOutputShape_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ static int const kElementsPerAccess = ElementsPerAccess_; -+ -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaSimt::LayoutC; -+ using ElementAccumulator = typename WarpMmaSimt::ElementC; -+ -+ /// Number of threads total -+ using WarpCount = gemm::GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN -+ >; -+ -+ static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; -+ -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv< -+ OutputTileThreadMap, -+ ElementOutput, -+ ThreadOutputShape, -+ ThreadBlockOutputShape -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimtDirect2dConv< -+ typename WarpMmaSimt::Shape, -+ ThreadOutputShape, -+ ThreadBlockOutputShape, -+ typename WarpMmaSimt::ThreadMma, -+ ElementAccumulator, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorPitchLiner< -+ OutputTileThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::EpilogueDepthwise< -+ Shape, -+ ThreadOutputShape, -+ ThreadBlockOutputShape, -+ WarpMmaSimt, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h -new file mode 100644 -index 0000000..77411f3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h -@@ -0,0 +1,808 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_relu0.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_hardswish.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename ElementOutput, -+ typename ElementAccumulator, -+ int ElementsPerAccess, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp { -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ ElementAccumulator, -+ layout::RowMajor -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ ElementAccumulator -+ >; -+ -+ static int const kFragmentsPerIteration = 1; -+}; -+ -+/// Partial specialization for float <= float x 4 -+template < -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp { -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ float, -+ layout::RowMajor -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ float -+ >; -+ -+ static int const kFragmentsPerIteration = 2; -+}; -+ -+/// Partial specialization for int32_t <= int32_t x 4 -+template < -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp { -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ int32_t, -+ layout::RowMajor -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ int32_t -+ >; -+ -+ static int const kFragmentsPerIteration = 1; -+}; -+ -+/// Partial specialization for float <= int32_t x 4 -+template < -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp { -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ int32_t, -+ layout::RowMajor -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ int32_t -+ >; -+ -+ static int const kFragmentsPerIteration = 1; -+}; -+ -+/// Partial specialization for half <= float x 8 epilogues avoids shared memory bank conflicts. -+template < -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp< -+ half_t, -+ float, -+ 8, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ThreadMap> { -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed< -+ WarpShape, -+ InstructionShape, -+ float, -+ 32, -+ 16, -+ 8, -+ 8 -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< -+ ThreadMap, -+ float, -+ 32, -+ 16, -+ 8, -+ 8 -+ >; -+ -+ static int const kFragmentsPerIteration = 2; -+}; -+ -+/// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts. -+/// Threadblock::kN = 256 still has bank conflicts. -+template < -+ typename ElementOutput, -+ int ElementsPerAccess, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp< -+ ElementOutput, -+ int32_t, -+ ElementsPerAccess, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ThreadMap> { -+ -+ static_assert(platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value, -+ "ElementOutput needs to be 4 or 8 bit (unsigned) int."); -+ -+ static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8), -+ "ElementsPerAccess needs to be 16 or 8."); -+ -+ using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed< -+ WarpShape, -+ InstructionShape, -+ int32_t, -+ 32, -+ cutlass::sizeof_bits::value, -+ ElementsPerAccess, -+ 8 -+ >; -+ -+ using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ int32_t, -+ layout::RowMajor -+ >; -+ -+ using WarpTileIterator = typename platform::conditional< -+ (ThreadblockShape::kN == 256), -+ WarpTileIteratorNotMixed, -+ WarpTileIteratorMixed>::type; -+ -+ using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< -+ ThreadMap, -+ int32_t, -+ 32, -+ cutlass::sizeof_bits::value, -+ ElementsPerAccess, -+ 8 -+ >; -+ -+ using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ int32_t -+ >; -+ -+ using SharedLoadIterator = typename platform::conditional< -+ (ThreadblockShape::kN == 256), -+ SharedLoadIteratorNotMixed, -+ SharedLoadIteratorMixed>::type; -+ -+ static int const kFragmentsPerIteration = 1; -+}; -+ -+/// Partial specialization for float_e4m3_t <= float x 16/8 epilogues avoids shared memory bank conflicts. -+/// Threadblock::kN = 256 still has bank conflicts. -+template < -+ int ElementsPerAccess, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp< -+ cutlass::float_e4m3_t, -+ float, -+ ElementsPerAccess, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ThreadMap> { -+ -+ using ElementOutput = cutlass::float_e4m3_t; -+ -+ static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8), -+ "ElementsPerAccess needs to be 16 or 8."); -+ -+ using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed< -+ WarpShape, -+ InstructionShape, -+ float, -+ 32, -+ cutlass::sizeof_bits::value, -+ ElementsPerAccess, -+ 8 -+ >; -+ -+ using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ float, -+ layout::RowMajor -+ >; -+ -+ using WarpTileIterator = typename platform::conditional< -+ (ThreadblockShape::kN == 256), -+ WarpTileIteratorNotMixed, -+ WarpTileIteratorMixed>::type; -+ -+ using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< -+ ThreadMap, -+ float, -+ 32, -+ cutlass::sizeof_bits::value, -+ ElementsPerAccess, -+ 8 -+ >; -+ -+ using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ float -+ >; -+ -+ using SharedLoadIterator = typename platform::conditional< -+ (ThreadblockShape::kN == 256), -+ SharedLoadIteratorNotMixed, -+ SharedLoadIteratorMixed>::type; -+ -+ static int const kFragmentsPerIteration = 1; -+}; -+ -+/// Partial specialization for float_e5m2_t <= float x 16/8 epilogues avoids shared memory bank conflicts. -+/// Threadblock::kN = 256 still has bank conflicts. -+template < -+ int ElementsPerAccess, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp< -+ cutlass::float_e5m2_t, -+ float, -+ ElementsPerAccess, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ThreadMap> { -+ -+ using ElementOutput = cutlass::float_e5m2_t; -+ -+ static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8), -+ "ElementsPerAccess needs to be 16 or 8."); -+ -+ using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed< -+ WarpShape, -+ InstructionShape, -+ float, -+ 32, -+ cutlass::sizeof_bits::value, -+ ElementsPerAccess, -+ 8 -+ >; -+ -+ using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ float, -+ layout::RowMajor -+ >; -+ -+ using WarpTileIterator = typename platform::conditional< -+ (ThreadblockShape::kN == 256), -+ WarpTileIteratorNotMixed, -+ WarpTileIteratorMixed>::type; -+ -+ using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< -+ ThreadMap, -+ float, -+ 32, -+ cutlass::sizeof_bits::value, -+ ElementsPerAccess, -+ 8 -+ >; -+ -+ using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ float -+ >; -+ -+ using SharedLoadIterator = typename platform::conditional< -+ (ThreadblockShape::kN == 256), -+ SharedLoadIteratorNotMixed, -+ SharedLoadIteratorMixed>::type; -+ -+ static int const kFragmentsPerIteration = 1; -+}; -+ -+} // namespace detail -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueTensorOp { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ static bool const UseCUDAStore = platform::is_same::value; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout, -+ UseCUDAStore -+ >; -+ -+ using AccumulatorFragmentIterator = typename platform::conditional::value, -+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC>, -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC> >::type; -+ -+ /// Support several implementations depending on structure of epilogue -+ using DefaultIterators = detail::DefaultIteratorsTensorOp< -+ ElementOutput, -+ ElementAccumulator, -+ kElementsPerAccess, -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename OutputTileThreadMap::CompactedThreadMap -+ >; -+ -+ using WarpTileIterator = typename DefaultIterators::WarpTileIterator; -+ using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; -+ -+ static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding, -+ kFragmentsPerIteration -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueTensorOpStridedDgrad { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< -+ OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ using AccumulatorFragmentIterator = typename platform::conditional::value, -+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC>, -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC> >::type; -+ -+ /// Support several implementations depending on structure of epilogue -+ using DefaultIterators = detail::DefaultIteratorsTensorOp< -+ ElementOutput, -+ ElementAccumulator, -+ kElementsPerAccess, -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename OutputTileThreadMap::CompactedThreadMap -+ >; -+ -+ using WarpTileIterator = typename DefaultIterators::WarpTileIterator; -+ using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; -+ -+ static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding, -+ kFragmentsPerIteration -+ >; -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ int Rank, -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueTensorOpAffineRankN { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN< -+ OutputTileThreadMap, -+ ElementOutput, -+ Rank -+ >; -+ -+ // Map to the row major iterator since the iterator selection for affineN is the same. -+ using AccumulatorFragmentIterator = typename platform::conditional::value, -+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ layout::RowMajor>, -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ layout::RowMajor> >::type; -+ -+ /// Support several implementations depending on structure of epilogue -+ using DefaultIterators = detail::DefaultIteratorsTensorOp< -+ ElementOutput, -+ ElementAccumulator, -+ kElementsPerAccess, -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename OutputTileThreadMap::CompactedThreadMap -+ >; -+ -+ using WarpTileIterator = typename DefaultIterators::WarpTileIterator; -+ using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; -+ -+ static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding, -+ kFragmentsPerIteration -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Defines sensible defaults for epilogues for TensorOps which uses -+/// intereleaved output layout. For this case, shared memory is not needed. -+template -+struct DefaultInterleavedEpilogueTensorOp { -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedThreadMapTensorOp< -+ Shape, typename WarpMmaTensorOp::Shape, kPartitionsK, ElementOutput, -+ kElementsPerAccess, InterleavedK>::Type; -+ -+ using OutputTileIterator = -+ cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator< -+ OutputTileThreadMap, ElementOutput, InterleavedK>; -+ -+ using AccumulatorFragmentIterator = -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::InterleavedEpilogue< -+ Shape, WarpMmaTensorOp, kPartitionsK, OutputTileIterator, -+ AccumulatorFragmentIterator, OutputOp, InterleavedK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps which uses -+/// intereleaved output layout. For this case, shared memory is not needed. -+template -+struct DefaultInterleavedConvEpilogue { -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedConvThreadMapTensorOp< -+ Shape, typename WarpMmaTensorOp::Shape, kPartitionsK, ElementOutput, -+ kElementsPerAccess, InterleavedK>::Type; -+ -+ using OutputTileIterator = -+ cutlass::epilogue::threadblock::InterleavedConvPredicatedTileIterator< -+ OutputTileThreadMap, ElementOutput, InterleavedK>; -+ -+ using AccumulatorFragmentIterator = -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ // can reuse the gemm version here to do element selection -+ layout::ColumnMajorInterleaved>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::InterleavedEpilogue< -+ Shape, WarpMmaTensorOp, kPartitionsK, OutputTileIterator, -+ AccumulatorFragmentIterator, OutputOp, InterleavedK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h -new file mode 100644 -index 0000000..aef4961 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess, -+ /// Is for a symmetric kernel -+ BlasMode BlasMode_ = BlasMode::kGemm -+> -+struct DefaultEpilogueTensorOpBlas3 { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3< -+ OutputTileThreadMap, -+ ElementOutput, -+ kBlasMode -+ >; -+ -+ using AccumulatorFragmentIterator = typename std::conditional::value, -+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC>, -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC> >::type; -+ -+ /// Support several implementations depending on structure of epilogue -+ using DefaultIterators = detail::DefaultIteratorsTensorOp< -+ ElementOutput, -+ ElementAccumulator, -+ kElementsPerAccess, -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename OutputTileThreadMap::CompactedThreadMap -+ >; -+ -+ using WarpTileIterator = typename DefaultIterators::WarpTileIterator; -+ using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h -new file mode 100644 -index 0000000..9936f96 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h -@@ -0,0 +1,337 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops on Volta. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+ -+#include "cutlass/layout/permute.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueVoltaTensorOp { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ gemm::GemmShape<32, 32, 4>, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ gemm::GemmShape<32, 32, 4>, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ static int const kSharedMemAlignment = sizeof_bits::value * WarpTileIterator::kElementsPerAccess / 8; -+ -+ static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B"); -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator, -+ kSharedMemAlignment -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueVoltaTensorOpStridedDgrad { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< -+ OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ gemm::GemmShape<32, 32, 4>, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ gemm::GemmShape<32, 32, 4>, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ static int const kSharedMemAlignment = sizeof_bits::value * WarpTileIterator::kElementsPerAccess / 8; -+ -+ static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B"); -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator, -+ kSharedMemAlignment -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ int Rank, -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueVoltaTensorOpAffineRankN { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN< -+ OutputTileThreadMap, -+ ElementOutput, -+ Rank -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ gemm::GemmShape<32, 32, 4>, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ gemm::GemmShape<32, 32, 4>, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ static int const kSharedMemAlignment = sizeof_bits::value * WarpTileIterator::kElementsPerAccess / 8; -+ -+ static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B"); -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator, -+ kSharedMemAlignment -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h -new file mode 100644 -index 0000000..381cb30 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h -@@ -0,0 +1,183 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename ElementTensor, -+ typename ElementVector, -+ typename OutputOp, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueWithBroadcastTensorOp { -+ -+ /// Use defaults related to the existing epilogue -+ using Base = DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ ElementsPerAccess -+ >; -+ -+ // -+ // Stores the result z = (y = GEMM(A, B, C), broadcast) -+ // -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ // -+ // Additional tensor tile iterator - stores t = Elementwise(z) -+ // -+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ ElementTensor -+ >; -+ -+ /// Define the epilogue -+ using Epilogue = EpilogueWithBroadcast< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputTileIterator, -+ TensorTileIterator, -+ ElementVector, -+ typename Base::AccumulatorFragmentIterator, -+ typename Base::WarpTileIterator, -+ typename Base::SharedLoadIterator, -+ OutputOp, -+ typename Base::Padding, -+ Base::kFragmentsPerIteration -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for VoltaTensorOps. -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename ElementTensor, -+ typename ElementVector, -+ typename OutputOp, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueWithBroadcastVoltaTensorOp { -+ -+ /// Use defaults related to the existing epilogue -+ using Base = DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ ElementsPerAccess -+ >; -+ -+ // -+ // Stores the result z = (y = GEMM(A, B, C), broadcast) -+ // -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ // -+ // Additional tensor tile iterator - stores t = Elementwise(z) -+ // -+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ ElementTensor -+ >; -+ -+ /// Define the epilogue -+ using Epilogue = EpilogueWithBroadcast< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputTileIterator, -+ TensorTileIterator, -+ ElementVector, -+ typename Base::AccumulatorFragmentIterator, -+ typename Base::WarpTileIterator, -+ typename Base::SharedLoadIterator, -+ OutputOp, -+ typename Base::Padding -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h -new file mode 100644 -index 0000000..3c85551 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h -@@ -0,0 +1,177 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename OutputOp, -+ typename ReductionOp, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueWithReductionTensorOp { -+ -+ /// Use defaults related to the existing epilogue -+ using Base = DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ ElementsPerAccess -+ >; -+ -+ /// Additional tensor tile iterator -+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ typename OutputOp::ElementTensor -+ >; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ /// Define the epilogue -+ using Epilogue = EpilogueWithReduction< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputTileIterator, -+ TensorTileIterator, -+ typename WarpMmaTensorOp::ElementC, -+ typename Base::AccumulatorFragmentIterator, -+ typename Base::WarpTileIterator, -+ typename Base::SharedLoadIterator, -+ typename Base::OutputOp, -+ ReductionOp, -+ typename Base::Padding -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename OutputOp, -+ typename ReductionOp, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueWithReductionVoltaTensorOp { -+ -+ /// Use defaults related to the existing epilogue -+ using Base = DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ ElementsPerAccess -+ >; -+ -+ /// Additional tensor tile iterator -+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ typename OutputOp::ElementTensor -+ >; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ /// Define the epilogue -+ using Epilogue = EpilogueWithReduction< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputTileIterator, -+ TensorTileIterator, -+ typename WarpMmaTensorOp::ElementC, -+ typename Base::AccumulatorFragmentIterator, -+ typename Base::WarpTileIterator, -+ typename Base::SharedLoadIterator, -+ typename Base::OutputOp, -+ ReductionOp, -+ typename Base::Padding -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h -new file mode 100644 -index 0000000..f95e4ea ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h -@@ -0,0 +1,165 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using WMMA. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for WMMA TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueWmmaTensorOp { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapWmmaTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorWmmaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h -new file mode 100644 -index 0000000..363d1e5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "predicated_tile_iterator.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for SIMT accumulator layouts -+template < -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ typename MmaSimtPolicy_, -+ int PartitionsK, -+ typename Element_, -+ int ElementsPerAccess -+> -+struct DefaultThreadMapSimt { -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using MmaSimtPolicy = MmaSimtPolicy_; -+ static int const kPartitionsK = PartitionsK; -+ using Element = Element_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ -+ static int const kWarpSize = 32; -+ -+ static_assert( -+ !(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ ThreadblockShape::kM / WarpShape::kM, -+ ThreadblockShape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Computes number of thread-level matrix multiplies are needed to span a warp -+ static int const kGroupCount = -+ WarpShape::kM / (MmaSimtPolicy::WarpShape::kRow * MmaSimtPolicy::LaneMmaShape::kM); -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Number of iterations -+ static int const kIterations = MmaSimtPolicy::LaneMmaShape::kM * kGroupCount; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap -+ using Type = OutputTileOptimalThreadMap< -+ OutputTileShape< // Shape -+ ThreadblockShape::kN, -+ 1, -+ MmaSimtPolicy::WarpShape::kRow, -+ Detail::WarpCount::kM, -+ 1>, -+ OutputTileShape< // Count -+ 1, -+ MmaSimtPolicy::LaneMmaShape::kM, -+ Detail::kGroupCount, -+ 1, -+ Detail::kIterations>, -+ Detail::kThreads, -+ kElementsPerAccess, -+ sizeof_bits::value -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h -new file mode 100644 -index 0000000..14972d0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h -@@ -0,0 +1,208 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "predicated_tile_iterator.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template < -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ int PartitionsK, -+ typename Element_, -+ int ElementsPerAccess -+> -+struct DefaultThreadMapTensorOp { -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using Element = Element_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ -+ /// Tensor Operations fundamentally perform operations on 8 rows -+ static int const kTensorOpRows = 8; -+ static int const kWarpSize = 32; -+ -+ static_assert( -+ !(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ ThreadblockShape::kM / WarpShape::kM, -+ ThreadblockShape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap -+ using Type = OutputTileOptimalThreadMap < -+ OutputTileShape, -+ OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>, -+ Detail::kThreads, -+ kElementsPerAccess, -+ sizeof_bits::value -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template -+struct DefaultInterleavedThreadMapTensorOp { -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using Element = Element_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kInterleavedK = InterleavedK; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ /// Tensor Operations fundamentally perform operations on 8 rows -+ static int const kTensorOpRows = 8; -+ static int const kWarpSize = 32; -+ -+ static_assert(!(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), -+ "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = -+ gemm::GemmShape; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept -+ /// InterleavedOutputTileThreadMap -+ using Type = InterleavedOutputTileThreadMap< -+ layout::PitchLinearShape, -+ layout::PitchLinearShape, -+ Detail::kThreads, kElementsPerAccess, sizeof_bits::value>; -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template -+struct DefaultInterleavedConvThreadMapTensorOp { -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using Element = Element_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kInterleavedK = InterleavedK; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ /// Tensor Operations fundamentally perform operations on 8 rows -+ static int const kTensorOpRows = 8; -+ static int const kWarpSize = 32; -+ -+ static_assert(!(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), -+ "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = -+ gemm::GemmShape; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::MaskedTileIterator satisfying concept -+ /// InterleavedOutputTileThreadMap -+ using Type = InterleavedConvOutputTileThreadMap< -+ MatrixShape, -+ MatrixShape, -+ Detail::kThreads, kElementsPerAccess, sizeof_bits::value>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h -new file mode 100644 -index 0000000..1c0edb1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h -@@ -0,0 +1,228 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "predicated_tile_iterator.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template < -+ typename ThreadblockShape, -+ typename WarpShape, -+ int PartitionsK, -+ typename ElementOutput, -+ int ElementsPerAccess, -+ typename ElementAccumulator -+> -+struct DefaultThreadMapVoltaTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template < -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ int PartitionsK, -+ typename ElementOutput_, -+ int ElementsPerAccess -+> -+struct DefaultThreadMapVoltaTensorOp< -+ ThreadblockShape_, -+ WarpShape_, -+ PartitionsK, -+ ElementOutput_, -+ ElementsPerAccess, -+ half_t> { -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using ElementOutput = ElementOutput_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using ElementAccumulator = half_t; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ -+ static int const kTensorOpRows = 16; -+ static int const kWarpSize = 32; -+ static int const kInterleavedTilesM = WarpShape::kM / 32; -+ -+ static_assert( -+ !(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ ThreadblockShape::kM / WarpShape::kM, -+ ThreadblockShape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ ThreadblockShape::kN, // column -+ 4, // row -+ 4, // group -+ WarpCount::kM, // cluster -+ 1 // tile -+ >; -+ -+ /// Number of iterations per subspace -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 2, // row -+ kInterleavedTilesM, // group -+ 1, // cluster -+ WarpShape::kM / kTensorOpRows // iterations -+ >; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap -+ using Type = OutputTileOptimalThreadMap < -+ typename Detail::Shape, -+ typename Detail::Count, -+ Detail::kThreads, -+ kElementsPerAccess, -+ sizeof_bits::value -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template < -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ int PartitionsK, -+ typename ElementOutput_, -+ int ElementsPerAccess -+> -+struct DefaultThreadMapVoltaTensorOp< -+ ThreadblockShape_, -+ WarpShape_, -+ PartitionsK, -+ ElementOutput_, -+ ElementsPerAccess, -+ float> { -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using ElementOutput = ElementOutput_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using ElementAccumulator = float; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ -+ static int const kTensorOpRows = 16; -+ static int const kWarpSize = 32; -+ static int const kInterleavedTilesM = WarpShape::kM / 32; -+ -+ static_assert( -+ !(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ ThreadblockShape::kM / WarpShape::kM, -+ ThreadblockShape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ ThreadblockShape::kN, // column -+ 4, // row -+ 4, // group -+ WarpCount::kM, // cluster -+ 1 // tile -+ >; -+ -+ /// Number of iterations per subspace -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 2, // row -+ kInterleavedTilesM, // group -+ 1, // cluster -+ WarpShape::kM / kTensorOpRows // iterations -+ >; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap -+ using Type = OutputTileOptimalThreadMap < -+ typename Detail::Shape, -+ typename Detail::Count, -+ Detail::kThreads, -+ kElementsPerAccess, -+ sizeof_bits::value -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h -new file mode 100644 -index 0000000..929762b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h -@@ -0,0 +1,113 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "predicated_tile_iterator.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for Wmma TensorOp accumulator layouts -+template < -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ typename InstructionShape_, -+ int PartitionsK, -+ typename Element_, -+ int ElementsPerAccess -+> -+struct DefaultThreadMapWmmaTensorOp { -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ static int const kPartitionsK = PartitionsK; -+ using Element = Element_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ -+ /// Wmma Tensor Operations fundamentally perform operations on InstructionShape::kM rows -+ static int const kTensorOpRows = InstructionShape::kM; -+ static int const kWarpSize = 32; -+ -+ static_assert( -+ !(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ ThreadblockShape::kM / WarpShape::kM, -+ ThreadblockShape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap -+ using Type = OutputTileOptimalThreadMap < -+ OutputTileShape, -+ OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>, -+ Detail::kThreads, -+ kElementsPerAccess, -+ sizeof_bits::value -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h -new file mode 100644 -index 0000000..afacca2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h -@@ -0,0 +1,142 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+class DirectStoreEpilogueIterator { -+public: -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = 1; -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorParams { -+ using Base = PredicatedTileIteratorParams; -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) { -+ stride = layout.stride(0) * sizeof(Element); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) : -+ Base(base) { } -+ }; -+ -+public: -+ -+ // -+ // Data members -+ // -+ -+ Element *pointer; // pointer to the output matrix -+ -+ LongIndex stride; // stride in elements between rows -+ -+ TensorCoord extent; // extent of output matrix -+ -+ int thread_idx; // thread index -+ -+ TensorCoord threadblock_offset; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ DirectStoreEpilogueIterator( -+ PredicatedTileIteratorParams const & params, -+ Element *pointer_, -+ TensorCoord extent_, -+ int thread_idx_, -+ TensorCoord threadblock_offset_ = TensorCoord(), -+ int const * indices = nullptr -+ ): -+ pointer(pointer_), -+ stride(params.stride / sizeof(Element)), -+ extent(extent_), -+ thread_idx(thread_idx_), -+ threadblock_offset(threadblock_offset_) -+ { -+ -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue.h -new file mode 100644 -index 0000000..7672a59 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue.h -@@ -0,0 +1,535 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+ The shared memory resource is time-sliced across warps. -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ typename OutputOp_, ///< Output operator -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+ int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity -+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large -+ (!IsEpilogueFunctorHeavy::value) -+> -+class Epilogue : -+ public EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>, -+ public EpilogueBaseStreamK< -+ Shape_, -+ PartitionsK, -+ WarpMmaOperator_, -+ AccumulatorFragmentIterator_> -+{ -+ -+public: -+ -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>; -+ -+ using BaseStreamK = EpilogueBaseStreamK< -+ Shape_, -+ PartitionsK, -+ WarpMmaOperator_, -+ AccumulatorFragmentIterator_>; -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = Padding_; -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Number of warps per block -+ using WarpCount = typename Base::WarpCount; -+ -+ /// Number of threads per block -+ static int const kBlockThreads = 32 * WarpCount::kCount; -+ -+ /// Per-thread accumulator tile type -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Numerical accumulation element type -+ using ElementAccumulator = typename WarpMmaOperator::ElementC; -+ -+ /// Fragment type used by the accumulator tile's fragment iterator -+ using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Vector type used by the global output iterator -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Vector type used by the shared output iterator -+ using AccumulatorAccessType = Array; -+ -+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; -+ -+ static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; -+ -+ -+public: -+ -+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+ static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); -+ -+ -+public: -+ -+ /// Aspect for when epilogue source is not needed -+ struct SourceAspectNotNeeded -+ { -+ /// Constructor -+ CUTLASS_DEVICE -+ SourceAspectNotNeeded() -+ {} -+ -+ /// Invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment) -+ { -+ OutputAccessType *output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const *compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) -+ { -+ // Call the output operator -+ output_frag_ptr[i] = output_op(compute_frag_ptr[i]); -+ } -+ } -+ }; -+ -+ -+ /// Aspect for when epilogue source is needed -+ struct SourceAspectNeeded -+ { -+ OutputTileIterator source_iterator; -+ -+ typename OutputTileIterator::Fragment source_fragment; -+ -+ /// Invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ static void apply_output_operator( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment, -+ typename OutputTileIterator::Fragment const &source_fragment) -+ { -+ OutputAccessType *output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const *compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ OutputAccessType const *source_frag_ptr = -+ reinterpret_cast(&source_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) -+ { -+ // Call the output operator -+ output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); -+ } -+ } -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SourceAspectNeeded(OutputTileIterator source_iterator) : -+ source_iterator(source_iterator) -+ { -+ source_fragment.clear(); -+ } -+ -+ /// Invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment) -+ { -+ // Load addend source fragment from global memory -+ source_iterator.load(source_fragment); -+ ++source_iterator; -+ -+ apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment); -+ } -+ }; -+ -+ -+private: -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ /// Thread index in the threadblock -+ int thread_idx; -+ -+ /// Warp index in the threadblock -+ int warp_idx; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ Epilogue( -+ typename Base::SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx) ///< Id of thread within warp -+ : -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ BaseStreamK(thread_idx), -+ shared_load_iterator_(shared_storage.reference(), thread_idx), -+ thread_idx(thread_idx), -+ warp_idx(warp_idx) -+ {} -+ -+ -+ /// Aggregates the accumulator sets shared by peer blocks in the global workspace, -+ /// performing epilogue computations, writing to output -+ CUTLASS_DEVICE -+ void reduce( -+ int peer_idx_begin, -+ int peer_idx_end, -+ int reduce_fragment_idx, -+ void *element_workspace, -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ { -+ // Redcuce peer accumulator fragments into one fragment -+ AccumulatorFragment accum_fragment; -+ BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace); -+ -+ // Store fragment to shared memory -+ this->warp_tile_iterator_.store(accum_fragment); -+ -+ __syncthreads(); -+ -+ // Initialize/load source-fragment data -+ typename OutputTileIterator::Fragment source_fragment; -+ source_fragment.clear(); -+ -+ if (output_op.is_source_needed()) -+ { -+ source_iterator += reduce_fragment_idx; -+ source_iterator.load(source_fragment); -+ } -+ -+ // Load fragment from shared memory -+ typename SharedLoadIterator::Fragment aligned_accum_fragment; -+ shared_load_iterator_.load(aligned_accum_fragment); -+ -+ // Add fragments shared by other k partitions -+ if (kPartitionsK > 1) -+ { -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ typename SharedLoadIterator::Fragment aligned_addend_fragment; -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_addend_fragment); -+ aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_addend_fragment); -+ } -+ } -+ -+ // Compute the output result -+ typename OutputTileIterator::Fragment output_fragment; -+ -+ // Apply the output operator -+ SourceAspectNeeded::apply_output_operator( -+ output_fragment, -+ output_op, -+ aligned_accum_fragment, -+ source_fragment); -+ -+ // Store the final result -+ destination_iterator += reduce_fragment_idx; -+ destination_iterator.store(output_fragment); -+ } -+ -+ -+ /// Perform the epilogue computations and stream the result to global memory. -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile -+ { -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); -+ } -+ -+ -+ /// Perform the epilogue computations and stream the result to global memory. Implements -+ /// two alternative codepaths, depending on whether the output op requires addend data to be loaded. -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator ) ///< Tile iterator for addend source -+ { -+ if (output_op.is_source_needed()) -+ { -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); -+ } -+ else -+ { -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); -+ } -+ } -+ -+ -+ /// Perform the epilogue computations and stream the result to global memory. Implements a -+ /// single codepath, regardless of whether the output op requires addend data to be loaded -+ CUTLASS_DEVICE -+ void unified( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator ) ///< Tile iterator for addend source -+ { -+ if (!output_op.is_source_needed()) -+ { -+ source_iterator.clear_mask(); -+ __syncthreads(); // Dummy (CUDA 11.0) -+ } -+ -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); -+ } -+ -+ -+ /// Streams the result to global memory -+ template -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ SourceAspect source) -+ { -+ // Iterator over warp-level accumulator fragment -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) -+ { -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) -+ { -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ this->warp_tile_iterator_.store(accum_fragment); -+ -+ if (p < Base::kFragmentsPerIteration - 1) { -+ this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset); -+ } -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); -+ } -+ -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ __syncthreads(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) -+ { -+ typename SharedLoadIterator::Fragment aligned_accum_fragment; -+ shared_load_iterator_.load(aligned_accum_fragment); -+ -+ if (p < Base::kFragmentsPerIteration - 1) -+ { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ } -+ else if (kPartitionsK > 1) -+ { -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ typename SharedLoadIterator::Fragment aligned_accum_fragment_addend; -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_accum_fragment_addend); -+ aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_accum_fragment_addend); -+ } -+ -+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Compute the output result -+ // -+ -+ typename OutputTileIterator::Fragment output_fragment; -+ source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment); -+ -+ // -+ // Store the final result -+ // -+ -+ destination_iterator.store(output_fragment); -+ ++destination_iterator; -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h -new file mode 100644 -index 0000000..cad06bb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h -@@ -0,0 +1,240 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#if !defined(__CUDACC_RTC__) -+#include -+#include -+#endif -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+// -+// This is used for metaprogramming epilogue functors. If they define -+// `static bool const kIsHeavy = true;`, then the epilogue functor itself is -+// not inlined. This results in smaller code and is advantageous if the epilogue -+// functor consists of many instructions. -+// -+// If the epilogue functor does not define `kIsHeavy` or if it is `false`, then -+// the behavior from CUTLASS 2.5 and before is retained. The epilogue is fully -+// unrolled and inlined. -+// -+ -+template -+struct TypeSink { typedef void type; }; -+ -+template using TypeSinkT = typename TypeSink::type; -+ -+template struct IsEpilogueFunctorHeavy { -+ static bool const value = false; -+}; -+ -+template struct IsEpilogueFunctorHeavy > { -+ static bool const value = T::kIsHeavy; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Base class for epilogues defining warp-level -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpShape_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+ int FragmentsPerIteration = 1 -+> -+class EpilogueBase { -+public: -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using Padding = Padding_; -+ -+ /// Output layout is always row-major -+ using Layout = layout::RowMajor; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename AccumulatorTile::Element; -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Use this to control the granularity of one epilogue 'iteration' -+ static int const kFragmentsPerIteration = FragmentsPerIteration; -+ -+public: -+ -+ /// Shared storage allocation needed by the epilogue -+ struct SharedStorage { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element type of shared memory -+ using Element = typename WarpTileIterator::Element; -+ -+ /// Tensor reference to shared memory allocation -+ using TensorRef = typename WarpTileIterator::TensorRef; -+ -+ /// Layout of shared memory allocation -+ using Layout = typename WarpTileIterator::Layout; -+ -+ /// Logical shape of the shared memory tile written to by all warps. -+ using Shape = MatrixShape< -+ WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK, -+ WarpCount::kN * WarpTileIterator::Shape::kColumn -+ >; -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = MatrixShape< -+ (Shape::kRow + Padding::kRow) * kFragmentsPerIteration, -+ Shape::kColumn + Padding::kColumn -+ >; -+ -+ // -+ // Data members -+ // -+ -+ AlignedBuffer storage; -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a pointer to the shared memory buffer -+ CUTLASS_DEVICE -+ Element *data() { -+ return storage.data(); -+ } -+ -+ /// Returns a tensor reference to the shared memory buffer -+ CUTLASS_DEVICE -+ TensorRef reference() { -+ return TensorRef( -+ storage.data(), -+ Layout::packed({StorageShape::kRow, StorageShape::kColumn})); -+ } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ SharedStorage &shared_storage_; -+ -+ /// Stores a warp's fragment of accumulators to SMEM -+ WarpTileIterator warp_tile_iterator_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueBase( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ shared_storage_(shared_storage), -+ warp_tile_iterator_(shared_storage.reference(), lane_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to three coordinates: -+ // -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); -+ int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); -+ int warp_m = warp_mn % WarpCount::kM; -+ int warp_n = warp_mn / WarpCount::kM; -+ -+ MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; -+ -+ warp_tile_iterator_.add_tile_offset(warp_offset); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h -new file mode 100644 -index 0000000..2be1aeb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Basic subset of epilogue functionality for supporting StreamK decompositions -+*/ -+ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/block_striped.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// StreamK epilogue functionality for cross-block accumulator fragment reduction -+template < -+ typename Shape, ///< Shape of threadblock tile (concept: GemmShape) -+ int PartitionsK, -+ typename WarpMmaOperator, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ typename AccumulatorFragmentIterator> ///< Iterator for enumerating fragments within the per-thread tile of raw accumulators -+class EpilogueBaseStreamK -+{ -+ -+protected: -+ -+ /// The per-thread tile of raw accumulators -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ Shape::kM / WarpMmaOperator::Shape::kM, -+ Shape::kN / WarpMmaOperator::Shape::kN, -+ PartitionsK>; -+ -+ /// Number of threads per block -+ static int const kBlockThreads = 32 * WarpCount::kCount; -+ -+ /// Numerical accumulation element type -+ using ElementAccumulator = typename WarpMmaOperator::ElementC; -+ -+ /// Fragment type used by the accumulator tile's fragment iterator -+ using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment; -+ -+public: -+ -+ /// Number of AccumulatorTile fragments per thread -+ static int const kAccumulatorFragments = AccumulatorFragmentIterator::Policy::kIterations; -+ -+protected: -+ -+ /// Number of AccumulatorTile fragments per block output tile -+ static int const kOutputTileFragments = kBlockThreads * kAccumulatorFragments; -+ -+ /// Block-striped transfer utility for sharing AccumulatorFragment -+ using BlockStripedT = BlockStriped; -+ -+ /// AccumulatorFragment stride in the shared workspace between different peer blocks (each thread block can share accumulators for up to two block output tiles) -+ static const int kPeerFragmentStride = kOutputTileFragments * 2; -+ -+public: -+ -+ /// Workspace bytes per thread block -+ static size_t const kWorkspaceBytesPerBlock =sizeof(AccumulatorFragment) * kPeerFragmentStride; -+ -+public: -+ -+ /// Thread index in the threadblock -+ int thread_idx; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueBaseStreamK( -+ int thread_idx) ///< ID of a thread within the threadblock -+ : -+ thread_idx(thread_idx) -+ {} -+ -+ -+ /// Aggregates the accumulator sets shared by peer blocks in the global workspace -+ CUTLASS_DEVICE -+ void reduce( -+ AccumulatorFragment &accum_fragment, ///< [out] sum of all shared accumulator fragments for these peer partials -+ int peer_idx_begin, -+ int peer_idx_end, -+ int reduce_fragment_idx, -+ void *workspace_ptr) -+ { -+ plus add_fragments; -+ -+ AccumulatorFragment *fragment_workspace = reinterpret_cast(workspace_ptr); -+ -+ int fragment_offset = (peer_idx_begin * kPeerFragmentStride) + (reduce_fragment_idx * kBlockThreads); -+ -+ // Load first peer fragment -+ BlockStripedT::load(accum_fragment, fragment_workspace + fragment_offset, this->thread_idx); -+ -+ fragment_offset += kPeerFragmentStride; // Move to next peer -+ fragment_offset += kOutputTileFragments; // Move to the set of fragments for this peer's "non-started" output tile -+ -+ // Reduce fragments from additional peers -+ #pragma unroll 2 -+ for (; fragment_offset < peer_idx_end * kPeerFragmentStride; fragment_offset += kPeerFragmentStride) -+ { -+ // Load peer fragment -+ AccumulatorFragment addend_fragment; -+ BlockStripedT::load(addend_fragment, fragment_workspace + fragment_offset, this->thread_idx); -+ -+ // Add peer fragment -+ accum_fragment = add_fragments(accum_fragment, addend_fragment); -+ } -+ } -+ -+ -+ /// Shares the accumulator set with peers in the global workspace -+ CUTLASS_DEVICE -+ void share( -+ int peer_idx, -+ void *workspace_ptr, -+ AccumulatorTile const &accumulators, -+ bool started_tile) ///< Whether this thread block computed the first work volume for the current output tile -+ { -+ AccumulatorFragment *fragment_workspace = reinterpret_cast(workspace_ptr); -+ -+ int fragment_offset = peer_idx * kPeerFragmentStride; -+ -+ if (!started_tile) { -+ // Move to the set of fragments for the "non-started" output tile -+ fragment_offset += kOutputTileFragments; -+ } -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // Convert raw accumulator tile to fragments and store -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < kAccumulatorFragments; ++iter) -+ { -+ // Acquire reordered accumulator fragment -+ AccumulatorFragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ // Store accumulator fragment -+ BlockStripedT::store(fragment_workspace + fragment_offset, accum_fragment, this->thread_idx); -+ -+ fragment_offset += kBlockThreads; -+ } -+ } -+ -+}; -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h -new file mode 100644 -index 0000000..d5a52ea ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h -@@ -0,0 +1,335 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for Depthwise convoltuion -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template -+class EpilogueDepthwise { -+ public: -+ using Shape = Shape_; -+ using WarpShape = typename WarpMmaOperator_::Shape; -+ using ThreadOutputShape = ThreadOutputShape_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ using OutputTileIterator = OutputTileIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = -+ Array; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = -+ Array; -+ -+ /// Number of warps -+ using WarpCount = -+ gemm::GemmShape; -+ -+ public: -+ static_assert(SharedLoadIterator::Fragment::kElements == -+ OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, -+ "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+ /// Shared storage allocation needed by the epilogue -+ struct SharedStorage { -+ // -+ // Type definitions -+ // -+ -+ /// Element type of shared memory -+ using Element = typename WarpTileIterator::Element; -+ -+ /// Tensor reference to shared memory allocation -+ using TensorRef = typename WarpTileIterator::TensorRef; -+ -+ /// Layout of shared memory allocation -+ using Layout = typename WarpTileIterator::Layout; -+ -+ /// Logical shape of the shared memory tile written to by all warps. -+ using Shape = MatrixShape; -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = MatrixShape; -+ -+ // -+ // Data members -+ // -+ -+ AlignedBuffer storage; -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a pointer to the shared memory buffer -+ CUTLASS_DEVICE -+ Element *data() { return storage.data(); } -+ -+ /// Returns a tensor reference to the shared memory buffer -+ CUTLASS_DEVICE -+ TensorRef reference() { -+ return TensorRef(storage.data(), Layout::packed({StorageShape::kRow, StorageShape::kColumn})); -+ } -+ }; -+ -+ private: -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ /// Stores a warp's fragment of accumulators to SMEM -+ WarpTileIterator warp_tile_iterator_; -+ -+ LongIndex warp_offset; -+ int thread_idx; -+ int warp_idx; -+ int lane_idx; -+ int warp_m, warp_n; // warp coordinates within a cta -+ int tid_m, tid_n; // thread coordinates within a warp -+ -+ public: -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueDepthwise(SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx_, ///< ID of a thread within the threadblock -+ int warp_idx_, ///< ID of warp within threadblock -+ int lane_idx_ ///< Id of thread within warp -+ ) -+ : thread_idx(thread_idx_), -+ warp_idx(warp_idx_), -+ lane_idx(lane_idx_), -+ shared_load_iterator_(shared_storage.reference(), thread_idx_), -+ warp_tile_iterator_(shared_storage.reference(), thread_idx_, lane_idx_) {} -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()(OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in -+ ///< units of threadblock tiles) -+ const int smem_base_offset) { ///< SMEM base offset for epilogue operation -+ // initiate the smem base offset for different output tile. -+ warp_tile_iterator_.set_smem_base_address(smem_base_offset); -+ -+ shared_load_iterator_.set_smem_base_address(smem_base_offset); -+ -+ if (!output_op.is_source_needed()) { -+ compute_source_not_needed_(output_op, destination_iterator, accumulators); -+ } else { -+ compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator); -+ } -+ } -+ -+ private: -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ typename OutputTileIterator::Fragment source_fragment; -+ -+ source_fragment.clear(); -+ -+ source_iterator.load(source_fragment); -+ -+ // store to smem -+ warp_tile_iterator_.store(accumulators); -+ -+ __syncthreads(); -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment; -+ -+ // load from smem -+ shared_load_iterator_.load(aligned_accum_fragment); -+ -+ typename OutputTileIterator::Fragment output_fragment; -+ -+ apply_output_operator_(output_fragment, output_op, aligned_accum_fragment, source_fragment); -+ -+ // Store to GMEM -+ destination_iterator.store(output_fragment); -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_not_needed_( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ // store to smem -+ warp_tile_iterator_.store(accumulators); -+ -+ __syncthreads(); -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment; -+ -+ // load from smem -+ shared_load_iterator_.load(aligned_accum_fragment); -+ -+ typename OutputTileIterator::Fragment output_fragment; -+ -+ apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment); -+ -+ // Store to GMEM -+ destination_iterator.store(output_fragment); -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment, -+ typename OutputTileIterator::Fragment const &source_fragment) { -+ -+ OutputAccessType *output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const *compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ OutputAccessType const *source_frag_ptr = -+ reinterpret_cast(&source_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ // Call the output operator -+ output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_source_not_needed_( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment) { -+ OutputAccessType *output_frag_ptr = reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const *compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ // Call the output operator -+ output_frag_ptr[i] = output_op(compute_frag_ptr[i]); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h -new file mode 100644 -index 0000000..8cd4791 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h -@@ -0,0 +1,347 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs and convolution using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ typename OutputOp_ ///< Output operator -+> -+class EpilogueDirectStore { -+public: -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ using WarpShape = typename WarpMmaOperator_::Shape; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = MatrixShape<0, 0>; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Use this to control the granularity of one epilogue 'iteration' -+ static int const kFragmentsPerIteration = 1; -+ -+ static int constexpr kSmemTiles = 1; -+ static int constexpr kSmemPointerOffset = 0; -+ -+ /// Shared storage allocation needed by the epilogue -+ struct SharedStorage { } ; -+ -+private: -+ -+ // Assume accumulator tile is multipile interleaved 32x32 tile. -+ static int const kElementsPerPartial = 4; -+ using EleShapePerPatial = typename platform::conditional< -+ platform::is_same::value, -+ MatrixShape<2, 2>, -+ MatrixShape<1, 4> >::type; -+ static int const kElementsPerMma = 8; -+ static int const kAccumulatorPatials = 2; -+ using QuadShapePerPatialMma = MatrixShape<4, 4>; -+ -+ static_assert(OutputOp::kCount >= 2, -+ "The direct store epilogue for Tensor Ops requires the output functor have kCount >= 2."); -+ -+private: -+ -+ LongIndex warp_offset; -+ int thread_idx; -+ int warp_idx; -+ int lane_idx; -+ int warp_m, warp_n; // warp coordinates within a cta -+ int tid_m, tid_n; // thread coordinates within a warp -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueDirectStore( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx_, ///< ID of a thread within the threadblock -+ int warp_idx_, ///< ID of warp within threadblock -+ int lane_idx_ ///< Id of thread within warp -+ ): -+ thread_idx(thread_idx_), -+ warp_idx(warp_idx_), -+ lane_idx(lane_idx_) -+ { -+ -+ // warp offsetting calculations -+ warp_offset = warp_idx * WarpShape::kM * WarpShape::kN; -+ int warp_id_mn = warp_idx % (WarpCount::kM * WarpShape::kN); -+ warp_m = warp_id_mn % WarpCount::kM; -+ warp_n = warp_id_mn / WarpCount::kM; -+ MatrixCoord warp_offset_coord(warp_m*WarpShape::kM, warp_n*WarpShape::kN); -+ -+ // thread offsetting calculations -+ int quad = (lane_idx >> 2); -+ int lane_in_quad = (lane_idx & 3); -+ -+ // this seems to be te correct layout -+ tid_m = quad; -+ tid_n = 2 * lane_in_quad; -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ if (!output_op.is_source_needed()) { -+ compute_source_not_needed_(output_op, destination_iterator, accumulators); -+ } -+ else { -+ compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator); -+ } -+ } -+ -+private: -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ const int kAccumBlockN = 2; -+ const int kThreadsM = 8; -+ const int kThreadsN = 4; -+ const int kBlockM = WarpShape::kM / kThreadsM; -+ -+ /// Array type used to output -+ using OutputAccessType = AlignedArray; -+ -+ /// Array type passed to the output operator - unused elements are optimized away -+ using OutputFragmentType = Array; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Array type used by output functor -+ using AccumulatorFragmentType = Array; -+ -+ AccumulatorAccessType const *accumulator_pair = reinterpret_cast(&accumulators); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int accum_m_idx = 0; accum_m_idx < WarpShape::kM / kThreadsM; accum_m_idx++) { -+ -+ int accum_m = kThreadsM * accum_m_idx; -+ int mL = destination_iterator.threadblock_offset.row() + WarpShape::kM * warp_m + tid_m + accum_m; -+ int nL_base = destination_iterator.threadblock_offset.column() + WarpShape::kN * warp_n + tid_n; -+ -+ ElementOutput *output_ptr = destination_iterator.pointer + mL * destination_iterator.stride; -+ ElementOutput *source_ptr = source_iterator.pointer + mL * source_iterator.stride; -+ -+ int const kIterationsN = WarpShape::kN / kThreadsN / kAccumBlockN; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int accum_n_idx = 0; accum_n_idx < kIterationsN; accum_n_idx++) { -+ -+ int accum_idx = accum_m_idx + kBlockM * accum_n_idx; -+ int accum_n = kThreadsM * accum_n_idx; -+ -+ // mL and nL are logical coordinate in 2D mapping of epilogue's 4D output -+ int nL = nL_base + accum_n; -+ -+ bool guard = (mL < destination_iterator.extent.row()) && (nL < destination_iterator.extent.column()); -+ -+ AccumulatorFragmentType accum_fragment; -+ reinterpret_cast(accum_fragment) = accumulator_pair[accum_idx]; -+ -+ OutputFragmentType output_fragment; -+ -+ if(guard) { -+ reinterpret_cast(output_fragment) = -+ *reinterpret_cast(source_ptr + nL); -+ } -+ -+ // Perform output operator -+ output_fragment = output_op(accum_fragment, output_fragment); -+ -+ if(guard) { -+ // Store -+ *reinterpret_cast(output_ptr + nL) = reinterpret_cast(output_fragment); -+ } -+ } -+ } -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_not_needed_( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ const int kAccumBlockN = 2; -+ const int kThreadsM = 8; -+ const int kThreadsN = 4; -+ const int kBlockM = WarpShape::kM / kThreadsM; -+ -+ /// Array type used to output -+ using OutputAccessType = AlignedArray; -+ -+ /// Array type passed to the output operator - unused elements are optimized away -+ using OutputFragmentType = Array; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Array type used by output functor -+ using AccumulatorFragmentType = Array; -+ -+ AccumulatorAccessType const *accumulator_pair = reinterpret_cast(&accumulators); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int accum_m_idx = 0; accum_m_idx < WarpShape::kM / kThreadsM; accum_m_idx++) { -+ -+ int accum_m = kThreadsM * accum_m_idx; -+ int mL = destination_iterator.threadblock_offset.row() + WarpShape::kM * warp_m + tid_m + accum_m; -+ int nL_base = destination_iterator.threadblock_offset.column() + WarpShape::kN * warp_n + tid_n; -+ -+ ElementOutput *output_ptr = destination_iterator.pointer + mL * destination_iterator.stride; -+ -+ int const kIterationsN = WarpShape::kN / kThreadsN / kAccumBlockN; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int accum_n_idx = 0; accum_n_idx < kIterationsN; accum_n_idx++) { -+ -+ int accum_idx = accum_m_idx + kBlockM * accum_n_idx; -+ int accum_n = kThreadsM * accum_n_idx; -+ -+ // mL and nL are logical coordinate in 2D mapping of epilogue's 4D output -+ int nL = nL_base + accum_n; -+ -+ bool guard = (mL < destination_iterator.extent.row()) && (nL < destination_iterator.extent.column()); -+ -+ AccumulatorFragmentType accum_fragment; -+ reinterpret_cast(accum_fragment) = accumulator_pair[accum_idx]; -+ -+ OutputFragmentType output_fragment; -+ -+ // Perform output operator -+ output_fragment = output_op(accum_fragment); -+ -+ if(guard) { -+ -+ // Store -+ *reinterpret_cast(output_ptr + nL) = -+ reinterpret_cast(output_fragment); -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h -new file mode 100644 -index 0000000..927035b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h -@@ -0,0 +1,212 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/numeric_types.h" -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename ElementAccumulator_, -+ typename ElementOutput_, -+ typename ThreadBlockShape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ bool ReduceKForA_ -+> -+class EpilogueGemmKReduction { -+ -+public: -+ -+ using ThreadBlockShape = ThreadBlockShape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ using WarpShape = typename WarpMmaOperator::Shape; -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Accumulator element -+ using ElementAccumulator = ElementAccumulator_; -+ -+ /// Output element -+ using ElementOutput = ElementOutput_; -+ -+ /// Output access size -+ static int const kElementsPerAccess = 1; -+ -+ static bool const kReduceKForA = ReduceKForA_; -+ -+ static int const kThreadBlockSize = kReduceKForA ? ThreadBlockShape::kM : ThreadBlockShape::kN; -+ -+ static int const kWarpSize = kReduceKForA ? WarpShape::kM : WarpShape::kN; -+ -+ static int const kIterations = kWarpSize / 8; -+ -+ using FragmentAccumulator = Array; -+ -+private: -+ -+ int thread_offset_; -+ ElementOutput* pointer_; -+ int col_; -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueGemmKReduction( -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx, ///< Id of thread within warp -+ int threadblock_offset, -+ ElementOutput* pointer -+ ) -+ { -+ col_ = lane_idx % 4; -+ thread_offset_ = threadblock_offset * kThreadBlockSize -+ + warp_idx * kWarpSize -+ + lane_idx / 4 + col_ * 8; -+ -+ pointer_ = pointer + LongIndex(thread_offset_); -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ int size, -+ FragmentAccumulator &gemm_k_with_reduction_accumulation, -+ bool LoadForSerialSplitK -+ ) { -+ bool guard[kIterations / 4]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kIterations / 4; ++i) { -+ guard[i] = ((thread_offset_ + i * 32) < size); -+ } -+ -+ Array source; -+ source.clear(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kIterations / 4; ++i) { -+ ElementOutput *source_ptr = reinterpret_cast(&source); -+ cutlass::arch::global_load( -+ source_ptr[i], -+ (void *)(pointer_ + i * 32), -+ guard[i] && LoadForSerialSplitK); -+ -+ } -+ -+ FragmentAccumulator sum = gemm_k_with_reduction_accumulation; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kIterations; ++i) { -+ sum[i] += __shfl_xor_sync(0xffffffff, sum[i], 1); -+ sum[i] += __shfl_xor_sync(0xffffffff, sum[i], 2); -+ } -+ -+ Array intermediate; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kIterations / 4; ++i) { -+ if (col_ == 0) { -+ intermediate[i] = sum[0 + i * 4]; -+ } -+ -+ if (col_ == 1) { -+ intermediate[i] = sum[1 + i * 4]; -+ } -+ -+ if (col_ == 2) { -+ intermediate[i] = sum[2 + i * 4]; -+ } -+ -+ if (col_ == 3) { -+ intermediate[i] = sum[3 + i * 4]; -+ } -+ } -+ -+ NumericArrayConverter source_converter; -+ Array converted_source = source_converter(source); -+ -+ plus> plus_source; -+ intermediate = plus_source(intermediate, converted_source); -+ -+ NumericArrayConverter converter; -+ Array result = converter(intermediate); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kIterations / 4; ++i) { -+ cutlass::arch::global_store(result[i], -+ (void *)(pointer_ + i * 32), guard[i]); -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h -new file mode 100644 -index 0000000..1c70bed ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h -@@ -0,0 +1,401 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/array_planar_complex.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator for planar-complex output representations. -+/// -+/// Note, as with most CUTLASS components for planar complex, the template arguments describe -+/// the underlying real data type. -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ typename OutputOp_, ///< Output operator -+ typename Padding_ ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+> -+class EpiloguePlanarComplex { -+public: -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = Padding_; -+ -+ /// Output layout is always row-major -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = ArrayPlanarComplex< -+ typename WarpMmaOperator::FragmentC::Element, -+ WarpMmaOperator::FragmentC::kElements -+ >; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Shape of each warp-level operation -+ using WarpShape = typename WarpMmaOperator::Shape; -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Shared memory allocation -+ struct SharedStorage { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element type of shared memory -+ using Element = typename WarpTileIterator::Element; -+ -+ /// Tensor reference to shared memory allocation -+ using TensorRef = typename WarpTileIterator::TensorRef; -+ -+ /// Layout of shared memory allocation -+ using Layout = typename WarpTileIterator::Layout; -+ -+ /// Logical shape of the shared memory tile written to by all warps. -+ using Shape = MatrixShape< -+ WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK, -+ WarpCount::kN * WarpTileIterator::Shape::kColumn -+ >; -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = MatrixShape< -+ Shape::kRow + Padding::kRow, -+ Shape::kColumn + Padding::kColumn -+ >; -+ -+ static int const kImaginaryStride = StorageShape::kCount; -+ -+ // -+ // Data members -+ // -+ -+ AlignedBuffer storage; -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a pointer to the shared memory buffer -+ CUTLASS_DEVICE -+ Element *data() { -+ return storage.data(); -+ } -+ -+ /// Returns a tensor reference to the shared memory buffer -+ CUTLASS_DEVICE -+ TensorRef reference() { -+ return TensorRef( -+ storage.data(), -+ Layout::packed({StorageShape::kRow, StorageShape::kColumn})); -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ SharedStorage &shared_storage_; -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ /// Stores a warp's fragment of accumulators to SMEM -+ WarpTileIterator warp_tile_iterator_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpiloguePlanarComplex( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ shared_storage_(shared_storage), -+ shared_load_iterator_(shared_storage.reference(), thread_idx), -+ warp_tile_iterator_(shared_storage.reference(), lane_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to three coordinates: -+ // -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); -+ int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); -+ int warp_m = warp_mn % WarpCount::kM; -+ int warp_n = warp_mn / WarpCount::kM; -+ -+ MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; -+ -+ warp_tile_iterator_.add_tile_offset(warp_offset); -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator_real, ///< Tile iterator for destination -+ OutputTileIterator destination_iterator_imag, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator_real, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ OutputTileIterator source_iterator_imag) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ typename OutputTileIterator::Fragment source_fragment_real; -+ typename OutputTileIterator::Fragment source_fragment_imag; -+ -+ if (!output_op.is_source_needed()) { -+ source_iterator_real.clear_mask(); -+ source_iterator_imag.clear_mask(); -+ } -+ -+ source_fragment_real.clear(); -+ source_fragment_imag.clear(); -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator_real(accumulators.real); -+ AccumulatorFragmentIterator accum_fragment_iterator_imag(accumulators.imag); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Load the source -+ // -+ -+ source_iterator_real.load(source_fragment_real); -+ source_iterator_imag.load(source_fragment_imag); -+ -+ ++source_iterator_real; -+ ++source_iterator_imag; -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment_real; -+ typename AccumulatorFragmentIterator::Fragment accum_fragment_imag; -+ -+ accum_fragment_iterator_real.load(accum_fragment_real); -+ accum_fragment_iterator_imag.load(accum_fragment_imag); -+ -+ ++accum_fragment_iterator_real; -+ ++accum_fragment_iterator_imag; -+ -+ this->warp_tile_iterator_.store(accum_fragment_real); -+ this->warp_tile_iterator_.store_with_pointer_offset(accum_fragment_imag, SharedStorage::kImaginaryStride); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment_real[kPartitionsK]; -+ typename SharedLoadIterator::Fragment aligned_accum_fragment_imag[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment_real[0]); -+ shared_load_iterator_.load_with_pointer_offset(aligned_accum_fragment_imag[0], SharedStorage::kImaginaryStride); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ static_assert(kPartitionsK == 1, "Sliced-K not supported for planar complex at this time"); -+ -+ // -+ // Compute the output result -+ // -+ -+ typename OutputTileIterator::Fragment output_fragment_real; -+ typename OutputTileIterator::Fragment output_fragment_imag; -+ -+ apply_output_operator_( -+ output_fragment_real, -+ output_fragment_imag, -+ output_op, -+ aligned_accum_fragment_real[0], -+ aligned_accum_fragment_imag[0], -+ source_fragment_real, -+ source_fragment_imag); -+ -+ // -+ // Store the final result -+ // -+ -+ destination_iterator_real.store(output_fragment_real); -+ destination_iterator_imag.store(output_fragment_imag); -+ -+ ++destination_iterator_real; -+ ++destination_iterator_imag; -+ } -+ } -+ -+private: -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ typename OutputTileIterator::Fragment &output_fragment_real, -+ typename OutputTileIterator::Fragment &output_fragment_imag, -+ OutputOp const &output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment_real, -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment_imag, -+ typename OutputTileIterator::Fragment const &source_fragment_real, -+ typename OutputTileIterator::Fragment const &source_fragment_imag) { -+ -+ OutputAccessType *output_frag_real_ptr = -+ reinterpret_cast(&output_fragment_real); -+ -+ OutputAccessType *output_frag_imag_ptr = -+ reinterpret_cast(&output_fragment_imag); -+ -+ AccumulatorAccessType const *compute_frag_real_ptr = -+ reinterpret_cast(&aligned_accum_fragment_real); -+ -+ AccumulatorAccessType const *compute_frag_imag_ptr = -+ reinterpret_cast(&aligned_accum_fragment_imag); -+ -+ OutputAccessType const *source_frag_real_ptr = -+ reinterpret_cast(&source_fragment_real); -+ -+ OutputAccessType const *source_frag_imag_ptr = -+ reinterpret_cast(&source_fragment_imag); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ -+ // Call the output operator -+ auto result_fragment = output_op( -+ make_ArrayPlanarComplex(compute_frag_real_ptr[i], compute_frag_imag_ptr[i]), -+ make_ArrayPlanarComplex(source_frag_real_ptr[i], source_frag_imag_ptr[i]) -+ ); -+ -+ output_frag_real_ptr[i] = result_fragment.real; -+ output_frag_imag_ptr[i] = result_fragment.imag; -+ } -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h -new file mode 100644 -index 0000000..6dabe72 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h -@@ -0,0 +1,230 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMM/CONV to store accumulator in shared memory after -+ applying scale, bias loaded from global memory and element-wise operations. -+ -+ This Epilogue is typically used in fused GEMM/CONV to stage the intermediate accumulator. -+ -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename SmemTileIterator_, ///< Shared memory Tile iterator to output to shared memory -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename ScaleBiasIterator_, ///< Iterator to load scale and bias from global memory -+ typename OutputOp_ ///< Output operator -+> -+class EpilogueSmemAccumulator { -+ -+public: -+ -+ using SmemTileIterator = SmemTileIterator_; -+ -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ -+ using ScaleBiasIterator = ScaleBiasIterator_; -+ -+ using OutputOp = OutputOp_; -+ -+ /// Fragment of accumulator tile -+ using FragmentAccumulator = typename AccumulatorFragmentIterator::Fragment; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentScaleBias = typename ScaleBiasIterator::Fragment; -+ -+ static const bool PerChannelScale = (OutputOp::kScale == -+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueSmemAccumulator() {} -+ -+ /// Streams the result to shared memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ SmemTileIterator smem_iterator, ///< Tile iterator for destination in shared memory -+ AccumulatorTile const &accumulator, ///< Complete warp-level accumulator tile -+ ScaleBiasIterator scale_iterator, ///< iterator for scale vector in global memory -+ ScaleBiasIterator bias_iterator) { ///< iterator for bias vector in global memory -+ -+ -+ // Fragment to load scale bias from global memory -+ FragmentScaleBias tb_frag_scale; -+ FragmentScaleBias tb_frag_bias; -+ -+ /// Fragment Iterator to load slice of accumulator tile -+ AccumulatorFragmentIterator frag_iterator_accum(accumulator); -+ FragmentAccumulator tb_frag_accum; -+ -+ /// Epilogue output fragment -+ typename SmemTileIterator::Fragment tb_frag_smem; -+ -+ /// Load scale and bias from global memory -+ -+ if(PerChannelScale) -+ scale_iterator.load(tb_frag_scale); -+ -+ bias_iterator.load(tb_frag_bias); -+ -+ /// Iterate over the accumulator tile and store to shared memory -+ CUTLASS_PRAGMA_UNROLL -+ for (int rid = 0; rid < AccumulatorFragmentIterator::TileIterations::kRow; ++rid) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cid = 0; cid < AccumulatorFragmentIterator::TileIterations::kColumn; ++cid) { -+ -+ using AccumulatorAccessType = typename OutputOp::FragmentAccumulator; -+ using ScaleBiasAccessType = typename OutputOp::FragmentScaleBias; -+ using FragmentSmemAccessType = typename OutputOp::FragmentOutput; -+ -+ -+ ScaleBiasAccessType const * scale_frag_ptr = -+ reinterpret_cast(&tb_frag_scale); -+ ScaleBiasAccessType const * bias_frag_ptr = -+ reinterpret_cast(&tb_frag_bias); -+ -+ FragmentSmemAccessType * smem_frag_ptr = -+ reinterpret_cast(&tb_frag_smem); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < AccumulatorFragmentIterator::kIterationsPerTile; ++idx) { -+ frag_iterator_accum.load(tb_frag_accum); -+ ++frag_iterator_accum; -+ -+ AccumulatorAccessType const * accumulator_frag_ptr = -+ reinterpret_cast(&tb_frag_accum); -+ const int kOutputIterations = FragmentAccumulator::kElements / OutputOp::kCount; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int it = 0; it < kOutputIterations; it++) { -+ smem_frag_ptr[idx * kOutputIterations + it] = output_op(accumulator_frag_ptr[it], -+ scale_frag_ptr[cid * kOutputIterations + it], bias_frag_ptr[cid * kOutputIterations + it]); -+ } -+ } -+ -+ smem_iterator.store(tb_frag_smem); -+ ++smem_iterator; -+ -+ } -+ } -+ } -+ -+ /// Streams the result to shared memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ SmemTileIterator smem_iterator, ///< Tile iterator for destination in shared memory -+ AccumulatorTile const &accumulator) { ///< Complete warp-level accumulator tile -+ -+ /// Fragment Iterator to load slice of accumulator tile -+ AccumulatorFragmentIterator frag_iterator_accum(accumulator); -+ FragmentAccumulator tb_frag_accum; -+ -+ /// Epilogue output fragment -+ typename SmemTileIterator::Fragment tb_frag_smem; -+ -+ /// Iterate over the accumulator tile and store to shared memory -+ CUTLASS_PRAGMA_UNROLL -+ for (int rid = 0; rid < AccumulatorFragmentIterator::TileIterations::kRow; ++rid) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cid = 0; cid < AccumulatorFragmentIterator::TileIterations::kColumn; ++cid) { -+ -+ using AccumulatorAccessType = typename OutputOp::FragmentAccumulator; -+ using FragmentSmemAccessType = typename OutputOp::FragmentOutput; -+ -+ FragmentSmemAccessType * smem_frag_ptr = -+ reinterpret_cast(&tb_frag_smem); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < AccumulatorFragmentIterator::kIterationsPerTile; ++idx) { -+ frag_iterator_accum.load(tb_frag_accum); -+ ++frag_iterator_accum; -+ -+ AccumulatorAccessType const * accumulator_frag_ptr = -+ reinterpret_cast(&tb_frag_accum); -+ const int kOutputIterations = FragmentAccumulator::kElements / OutputOp::kCount; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int it = 0; it < kOutputIterations; it++) { -+ smem_frag_ptr[idx * kOutputIterations + it] = output_op(accumulator_frag_ptr[it]); -+ } -+ } -+ -+ smem_iterator.store(tb_frag_smem); -+ ++smem_iterator; -+ -+ } -+ } -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h -new file mode 100644 -index 0000000..de70352 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h -@@ -0,0 +1,513 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue visitor for threadblock scoped GEMMs that process softmax computations in epilogue. -+ -+ The epilogue finds max values in each row of the row-major output matrix and stores them. -+ The max values are also used for a further round of threadblock scoped reduction operation, where -+ the partial reduction results are stored in a pre-allocated array and used for further full reduction. -+ -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/fast_math.h" -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+template < -+ typename ThreadblockShape_, -+ int ThreadCount, -+ typename OutputTileIterator_, -+ typename ElementAccumulator_, -+ typename ElementNorm_, -+ typename ElementSum_, -+ typename ElementSoftmaxCompute_, -+ typename ElementwiseFunctor_, -+ bool UseMasking_ = false -+> -+class EpilogueVisitorSoftmax { -+public: -+ -+ using ThreadblockShape = ThreadblockShape_; -+ static int const kThreadCount = ThreadCount; -+ -+ using OutputTileIterator = OutputTileIterator_; -+ using ElementwiseFunctor = ElementwiseFunctor_; -+ -+ static int const kIterations = OutputTileIterator::kIterations; -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ using ElementOutput = typename OutputTileIterator::Element; -+ using LayoutOutput = cutlass::layout::RowMajor; -+ using ElementAccumulator = ElementAccumulator_; -+ -+ using ElementNorm = ElementNorm_; -+ using ElementSum = ElementSum_; -+ using ElementSoftmaxCompute = ElementSoftmaxCompute_; -+ -+ using AccumulatorFragment = Array; -+ using SoftmaxFragment = Array; -+ using OutputVector = Array; -+ using TensorRefD = TensorRef; -+ -+ static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; -+ static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); -+ static bool const kUseMasking = UseMasking_; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ typename ElementwiseFunctor::Params elementwise; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ int64_t batch_stride_Max; -+ int64_t batch_stride_Sum; -+ -+ // -+ // Methods -+ // -+ Arguments(): -+ batch_stride_C(0), -+ batch_stride_D(0), -+ batch_stride_Max(0), -+ batch_stride_Sum(0) -+ { -+ -+ } -+ -+ Arguments( -+ typename ElementwiseFunctor::Params elementwise_ -+ ): -+ elementwise(elementwise_), -+ batch_stride_C(0), -+ batch_stride_D(0), -+ batch_stride_Max(0), -+ batch_stride_Sum(0) -+ { -+ -+ } -+ -+ Arguments( -+ typename ElementwiseFunctor::Params elementwise_, -+ int64_t batch_stride_C_, -+ int64_t batch_stride_D_, -+ int64_t batch_stride_Max_, -+ int64_t batch_stride_Sum_ -+ ): -+ elementwise(elementwise_), -+ batch_stride_C(batch_stride_C_), -+ batch_stride_D(batch_stride_D_), -+ batch_stride_Max(batch_stride_Max_), -+ batch_stride_Sum(batch_stride_Sum_) -+ { -+ -+ } -+ -+ }; -+ -+ struct Params { -+ -+ typename ElementwiseFunctor::Params elementwise; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ int64_t batch_stride_Max; -+ int64_t batch_stride_Sum; -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ elementwise(args.elementwise), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_D(args.batch_stride_D), -+ batch_stride_Max(args.batch_stride_Max), -+ batch_stride_Sum(args.batch_stride_Sum) -+ { -+ -+ } -+ }; -+ -+ /// Shared storage -+ struct SharedStorage { -+ -+ }; -+ -+private: -+ -+ Params const & params_; -+ SharedStorage & shared_storage_; -+ MatrixCoord extent_; -+ MatrixCoord extent_real_; -+ ElementwiseFunctor elementwise_; -+ -+ OutputTileIterator iterator_C_; -+ OutputTileIterator iterator_D_; -+ typename OutputTileIterator::Fragment fragment_C_; -+ typename OutputTileIterator::Fragment fragment_D_; -+ -+ ElementAccumulator alpha_; -+ ElementAccumulator beta_; -+ -+ ElementNorm *ptr_Max_; -+ ElementSum *ptr_Sum_; -+ -+ int column_offset_; -+ -+ ElementSoftmaxCompute accum_max_; -+ ElementSoftmaxCompute accum_sum_; -+ -+ MatrixCoord thread_offset_; -+ -+ float infinity_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ EpilogueVisitorSoftmax( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ cutlass::MatrixCoord const &problem_size, -+ int thread_idx, -+ int warp_idx, -+ int lane_idx, -+ typename OutputTileIterator::Params params_C, -+ typename OutputTileIterator::Params params_D, -+ typename OutputTileIterator::Element *ptr_C, -+ typename OutputTileIterator::Element *ptr_D, -+ ElementNorm *ptr_Max = nullptr, -+ ElementSum *ptr_Sum = nullptr, -+ cutlass::MatrixCoord const &threadblock_offset = cutlass::MatrixCoord(0, 0), -+ int column_offset = 0, -+ cutlass::MatrixCoord const &problem_size_real = cutlass::MatrixCoord(0, 0), -+ float infinity = 10000.0f -+ ): -+ params_(params), -+ shared_storage_(shared_storage), -+ extent_(problem_size), -+ elementwise_(params.elementwise), -+ iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset), -+ iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset), -+ ptr_Max_(ptr_Max), -+ ptr_Sum_(ptr_Sum), -+ column_offset_(column_offset), -+ extent_real_(problem_size_real), -+ infinity_(infinity) -+ { -+ alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); -+ beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); -+ -+ if (beta_ == ElementAccumulator()) { -+ iterator_C_.clear_mask(); -+ } -+ } -+ -+ /// Helper to indicate split-K behavior -+ CUTLASS_DEVICE -+ void set_k_partition( -+ int split_k_index, ///< Index of this threadblock within split-K partitioned scheme -+ int split_k_slices) { ///< Total number of split-K slices -+ -+ } -+ -+ /// Called to set the batch index -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); -+ iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); -+ } -+ -+ /// Called at the start of the epilogue just before iterating over accumulator slices -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ -+ } -+ -+ /// Called at the start of one step before starting accumulator exchange -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ fragment_D_.clear(); -+ fragment_C_.clear(); -+ -+ if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { -+ iterator_C_.load(fragment_C_); -+ ++iterator_C_; -+ } -+ -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ // Clear accumulators for max and sum when starting a whole row -+ clear_accum_(); -+ -+ } -+ -+ /// Called after accumulators have been exchanged for each accumulator vector -+ CUTLASS_DEVICE -+ void visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorFragment const &accum) { -+ -+ using Mul = cutlass::multiplies; -+ using Minus = cutlass::minus; -+ using Exp = cutlass::fast_exp_op; -+ -+ Minus minus; -+ Exp exponential; -+ -+ SoftmaxFragment result; -+ -+ NumericArrayConverter source_converter; -+ OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; -+ -+ if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { -+ result = source_converter(elementwise_(accum)); -+ }else{ -+ result = source_converter(elementwise_(accum, source_vector)); -+ } -+ -+ thread_offset_ = -+ iterator_D_.thread_start() + -+ OutputTileIterator::ThreadMap::iteration_offset(frag_idx); -+ -+ bool column_guard = (thread_offset_.column() < extent_.column()); -+ -+ if (kUseMasking) { -+ int elements_in_boundary = extent_real_.column() - thread_offset_.column(); -+ elements_in_boundary = (elements_in_boundary > kElementsPerAccess) ? kElementsPerAccess : elements_in_boundary; -+ elementwise_padding_(result, elements_in_boundary); -+ } -+ -+ ElementSoftmaxCompute accum_max_prev = accum_max_; -+ -+ // Compute the maximum within one row -+ if (!column_idx) { -+ // This is the first fragment in a new row -+ if (column_guard) { -+ accum_max_ = maximum_accumulator_(result); -+ } -+ } -+ else { -+ // This is an additional fragment in the same row -+ if (column_guard) { -+ accum_max_ = maximum_accumulator_(result, accum_max_); -+ } -+ } -+ -+ // proactively compute max in warps -+ accum_max_ = warp_reduce_max_(accum_max_); -+ -+ ElementSoftmaxCompute updater = fast_exp(accum_max_prev - accum_max_); -+ -+ SoftmaxFragment intermediate = exponential(minus(result, accum_max_)); -+ -+ if (kHasMultiStepsInRow) { -+ if (!column_idx) { -+ accum_sum_ = (column_guard) ? \ -+ sum_accumulator_(intermediate) : ElementSoftmaxCompute(0); -+ } else { -+ // Algorithm in $3.1, https://arxiv.org/pdf/2205.14135v1.pdf -+ // S* = S* x updater + sum_row(P'), where updater = exp(M* - M_row) -+ accum_sum_ = (column_guard) ? \ -+ sum_accumulator_(intermediate, accum_sum_ * updater) : accum_sum_ * updater; -+ } -+ } else { -+ accum_sum_ = (column_guard) ? sum_accumulator_(intermediate, accum_sum_) : ElementSoftmaxCompute(0); -+ } -+ -+ // Convert to the output -+ NumericArrayConverter output_converter; -+ OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; -+ output = output_converter(result); -+ } -+ -+ /// Called at the end of a row -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ -+ using ConvertSumOutput = cutlass::NumericConverter; -+ using ConvertNormOutput = cutlass::NumericConverter; -+ -+ ConvertSumOutput convert_sum_output; -+ ConvertNormOutput convert_norm_output; -+ -+ // Compute accumulate sum only in the last step -+ accum_sum_ = warp_reduce_sum_(accum_sum_); -+ -+ bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0); -+ bool row_guard = thread_offset_.row() < extent_.row(); -+ bool is_write_thread = row_guard && is_first_thread_in_tile; -+ -+ int block_batch = blockIdx.z; -+ -+ ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Max; -+ ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Sum; -+ -+ arch::global_store( -+ convert_norm_output(accum_max_), -+ (void *)curr_ptr_max, -+ is_write_thread); -+ -+ arch::global_store( -+ convert_sum_output(accum_sum_), -+ (void *)curr_ptr_sum, -+ is_write_thread); -+ -+ // Clear accumulators for max and sum when finishing a whole row -+ clear_accum_(); -+ -+ } -+ -+ /// Called after all accumulator elements have been visited -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ -+ iterator_D_.store(fragment_D_); -+ ++iterator_D_; -+ } -+ -+ /// Called after all steps have been completed -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ void elementwise_padding_(SoftmaxFragment &result, int elements_in_boundary) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < SoftmaxFragment::kElements; ++i) { -+ result[i] = (i < elements_in_boundary) ? result[i] : ElementSoftmaxCompute(-infinity_); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ ElementSoftmaxCompute warp_reduce_sum_(ElementSoftmaxCompute sum_) { -+ int half_thread_in_row = (kThreadsPerRow >> 1); -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = half_thread_in_row; i > 0; i >>= 1) { -+ ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, sum_, i); -+ sum_ += tmp; -+ } -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementSoftmaxCompute warp_reduce_max_(ElementSoftmaxCompute max_) { -+ int half_thread_in_row = (kThreadsPerRow >> 1); -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = half_thread_in_row; i > 0; i >>= 1) { -+ ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, max_, i); -+ max_ = fast_max(max_, tmp); -+ } -+ return max_; -+ } -+ -+ CUTLASS_DEVICE -+ void clear_accum_() { -+ -+ uint32_t float_max_bits = 0xff7fffff; // -FLT_MAX -+ float min_float = reinterpret_cast(float_max_bits); -+ accum_max_ = ElementSoftmaxCompute(min_float); -+ accum_sum_ = ElementSoftmaxCompute(0); -+ } -+ -+ CUTLASS_DEVICE -+ ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum) { -+ ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < SoftmaxFragment::kElements; ++i) { -+ sum_ += ElementSoftmaxCompute(accum[i]); -+ } -+ -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute sum_) { -+ // ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < SoftmaxFragment::kElements; ++i) { -+ sum_ += ElementSoftmaxCompute(accum[i]); -+ } -+ -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum) { -+ ElementSoftmaxCompute max_ = accum[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < SoftmaxFragment::kElements; ++i) { -+ max_ = fast_max(max_, ElementSoftmaxCompute(accum[i])); -+ } -+ -+ return max_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute max_) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < SoftmaxFragment::kElements; ++i) { -+ max_ = fast_max(max_, ElementSoftmaxCompute(accum[i])); -+ } -+ -+ return max_; -+ } -+}; -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h -new file mode 100644 -index 0000000..9c9f716 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h -@@ -0,0 +1,1540 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#include -+#else -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+ -+#include "cutlass/numeric_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This base class is meant to define the concept required of the -+/// EpilogueWithBroadcast::OutputOp -+template < -+ typename ElementC_, -+ typename ElementAccumulator_, -+ typename ElementCompute_, -+ typename ElementZ_, -+ typename ElementT_, -+ int ElementsPerAccess, -+ bool StoreZ = true, -+ bool StoreT = true -+> -+struct EpilogueWithBroadcastOpBase { -+ -+ using ElementOutput = ElementC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ using ElementZ = ElementZ_; -+ using ElementT = ElementT_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentC = Array; -+ using FragmentZ = Array; -+ using FragmentT = Array; -+ -+ /// If true, the 'Z' tensor is stored -+ static bool const kStoreZ = StoreZ; -+ -+ /// If true, the 'T' tensor is stored -+ static bool const kStoreT = StoreT; -+ -+ /// Parameters structure - required -+ struct Params { }; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor from Params -+ EpilogueWithBroadcastOpBase(Params const ¶ms_) { } -+ -+ /// Determine if the source is needed. May return false if -+ bool is_source_needed() const { -+ return true; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { } -+ -+ /// Applies the operation when is_source_needed() is true -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentZ &frag_Z, -+ FragmentT &frag_T, -+ FragmentAccumulator const &AB, -+ FragmentC const &frag_C1, -+ FragmentC const &frag_C2, -+ FragmentCompute const &V) const { -+ -+ } -+ -+ /// Applies the operation when is_source_needed() is false -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentZ &frag_Z, -+ FragmentT &frag_T, -+ FragmentAccumulator const &AB, -+ FragmentCompute const &V) const { -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator with bias vector broadcast over columns. -+/// -+/// Computes the following: -+/// -+/// -+/// Z, T = OutputOp(AB, C, Broadcast) -+/// -+/// if (ElementwiseOp::kStoreZ) { -+/// store(converted_u); -+/// } -+/// -+/// if (ElementwiseOp::kStoreT) { -+/// store(v); -+/// } -+/// -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z) -+ typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t) -+ typename ElementVector_, ///< Pointer to broadcast vector -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+ int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity -+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large -+ (!IsEpilogueFunctorHeavy::value), -+ bool IsSingleSource = OutputOp_::kIsSingleSource -+> -+class EpilogueWithBroadcast; -+ -+template < -+ typename Shape_, -+ typename WarpMmaOperator_, -+ int PartitionsK, -+ typename OutputTileIterator_, -+ typename TensorTileIterator_, -+ typename ElementVector_, -+ typename AccumulatorFragmentIterator_, -+ typename WarpTileIterator_, -+ typename SharedLoadIterator_, -+ typename OutputOp_, -+ typename Padding_, -+ int FragmentsPerPartition, -+ int IterationsUnroll -+> -+class EpilogueWithBroadcast< -+ Shape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ OutputTileIterator_, -+ TensorTileIterator_, -+ ElementVector_, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ SharedLoadIterator_, -+ OutputOp_, -+ Padding_, -+ FragmentsPerPartition, -+ IterationsUnroll, -+ false -+> : -+ public EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition> { -+ -+public: -+ -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>; -+ -+ static bool const kIsSingleSource = false; -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using TensorTileIterator = TensorTileIterator_; -+ using ElementVector = ElementVector_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Compute data type produced by the output op -+ using ElementCompute = typename OutputOp::ElementCompute; -+ -+ /// Compute fragment -+ using FragmentCompute = Array; -+ -+ /// Thread map used by output tile iterators -+ using ThreadMap = typename OutputTileIterator::ThreadMap; -+ -+ /// Fragment object used to store the broadcast values -+ using BroadcastFragment = Array< -+ ElementCompute, -+ ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Data type of additional tensor -+ using ElementTensor = typename TensorTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Array type used by output functor -+ using ComputeAccessType = Array; -+ -+ /// Tensor access type -+ using TensorAccessType = Array; -+ -+ /// Number of warps -+ using WarpCount = typename Base::WarpCount; -+ -+ /// Shared memory allocation from epilogue base class -+ using BaseSharedStorage = typename Base::SharedStorage; -+ -+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; -+ static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; -+ -+ /// Used for the broadcast -+ struct BroadcastDetail { -+ -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = kWarpSize * WarpCount::kCount; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ -+ /// I'm not sure what I meant here. -+ static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = MatrixShape< -+ kThreadRows, -+ Shape::kN -+ >; -+ -+ /// Debug printing -+ CUTLASS_DEVICE -+ static void print() { -+#if 0 -+ printf("BroadcastDetail {\n"); -+ printf( -+ " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" -+ "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", -+ kColumnsPerThread, -+ kRowsPerThread, -+ kThreadCount, -+ kThreadsPerRow, -+ kThreadRows, -+ kThreadAccessesPerRow, -+ StorageShape::kRow, -+ StorageShape::kColumn, -+ StorageShape::kCount -+ ); -+ printf("};\n"); -+#endif -+ } -+ }; -+ -+ /// Shared storage structure (shadows base) with additional SMEM buffer for reduction -+ struct SharedStorage { -+ union { -+ BaseSharedStorage base; -+ }; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+public: -+ -+ -+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+private: -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ /// Thread index within the threadblock -+ int thread_idx_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueWithBroadcast( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ Base(shared_storage.base, thread_idx, warp_idx, lane_idx), -+ shared_load_iterator_(shared_storage.base.reference(), thread_idx), -+ thread_idx_(thread_idx) -+ { -+ -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ ElementVector const * broadcast_ptr, ///< Broadcast vector -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix -+ OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix -+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand -+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord(Shape::kM, Shape::kN), -+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space -+ MatrixCoord()) { -+ -+ BroadcastFragment broadcast_fragment; -+ -+ load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); -+ -+ if (!output_op.is_source_needed()) { -+ compute_source_not_needed_( -+ output_op, -+ broadcast_fragment, -+ destination_iterator, -+ accumulators, -+ tensor_iterator); -+ } -+ else { -+ compute_source_needed_( -+ output_op, -+ broadcast_fragment, -+ destination_iterator, -+ accumulators, -+ source_iterator1, -+ source_iterator2, -+ tensor_iterator); -+ } -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ void load_broadcast_fragment_( -+ BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ ElementVector const * broadcast_ptr, ///< Broadcast vector -+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space -+ ) { -+ -+ broadcast_fragment.clear(); -+ -+ // If no pointer is supplied, set with all zeros and avoid memory accesses -+ if (!broadcast_ptr) { -+ return; -+ } -+ -+ int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); -+ -+ int thread_column_idx = threadblock_offset.column() + thread_initial_column; -+ broadcast_ptr += thread_initial_column; -+ -+ NumericArrayConverter converter; -+ using AccessType = AlignedArray; -+ using ComputeFragmentType = Array; -+ -+ ComputeFragmentType *frag_ptr = reinterpret_cast(&broadcast_fragment); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { -+ -+ AccessType loaded; -+ -+ loaded.clear(); -+ -+ if (thread_column_idx < problem_size.column()) { -+ loaded = *reinterpret_cast(broadcast_ptr); -+ } -+ -+ ComputeFragmentType cvt = converter(loaded); -+ frag_ptr[j] = cvt; -+ -+ thread_column_idx += ThreadMap::Delta::kColumn; -+ broadcast_ptr += ThreadMap::Delta::kColumn; -+ } -+ } -+ -+ template -+ struct acc2smem_source_not_needed; -+ -+ template -+ struct acc2smem_source_not_needed> { -+ template -+ CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ warp_tile_iterator.store(accum_fragment); -+ if (p < Base::kFragmentsPerIteration - 1) { -+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); -+ } -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * -+ (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = { -+ (pos == (Seq * Base::kFragmentsPerIteration)) && -+ (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ -+ CUTLASS_UNUSED(dummy[0]); -+ } -+ }; -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_not_needed_( -+ OutputOp const &output_op, ///< Output operator -+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand -+ ) { -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ // CUTLASS_PRAGMA_UNROLL -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) { -+ -+ // -+ // Convert and store fragment -+ // -+ -+ -+ __syncthreads(); -+ -+ acc2smem_source_not_needed< -+ cutlass::make_index_sequence>::push(iter, -+ accum_fragment_iterator, -+ this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { -+ -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ if (p < Base::kFragmentsPerIteration - 1) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ } -+ else if (kPartitionsK > 1) { -+ -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Apply output operation -+ // -+ -+ typename OutputTileIterator::Fragment frag_Z; -+ typename TensorTileIterator::Fragment frag_T; -+ -+ apply_output_operator_source_not_needed_( -+ frag_Z, -+ frag_T, -+ output_op, -+ aligned_accum_fragment[0], -+ broadcast_fragment); -+ -+ // -+ // Conditionally store fragments -+ // -+ -+ if (OutputOp::kStoreZ) { -+ destination_iterator.store(frag_Z); -+ ++destination_iterator; -+ } -+ -+ if (OutputOp::kStoreT) { -+ tensor_iterator.store(frag_T); -+ ++tensor_iterator; -+ } -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ } -+ -+ -+ template -+ struct acc2smem_source_needed; -+ -+ template -+ struct acc2smem_source_needed> { -+ template -+ CUTLASS_DEVICE -+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ warp_tile_iterator.store(accum_fragment); -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ } -+ }; -+ -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const &output_op, ///< Output operator -+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix -+ OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix -+ TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand -+ ) { -+ -+ typename OutputTileIterator::Fragment source_fragment1; -+ source_fragment1.clear(); -+ typename OutputTileIterator::Fragment source_fragment2; -+ source_fragment2.clear(); -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Load the source -+ // -+ -+ source_iterator1.load(source_fragment1); -+ ++source_iterator1; -+ -+ source_iterator2.load(source_fragment2); -+ ++source_iterator2; -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ acc2smem_source_needed>::push( -+ iter, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ if (kPartitionsK > 1) -+ { -+ plus add_fragments; -+ const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); -+ } -+ -+ // -+ // Apply output operation -+ // -+ -+ typename OutputTileIterator::Fragment frag_Z; -+ typename TensorTileIterator::Fragment frag_T; -+ -+ apply_output_operator_( -+ frag_Z, -+ frag_T, -+ output_op, -+ aligned_accum_fragment[0], -+ source_fragment1, -+ source_fragment2, -+ broadcast_fragment); -+ -+ // -+ // Conditionally store fragments -+ // -+ -+ if (OutputOp::kStoreZ) { -+ destination_iterator.store(frag_Z); -+ ++destination_iterator; -+ } -+ -+ if (OutputOp::kStoreT) { -+ tensor_iterator.store(frag_T); -+ ++tensor_iterator; -+ } -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ typename OutputTileIterator::Fragment &frag_Z, -+ typename TensorTileIterator::Fragment &frag_T, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &frag_AB, -+ typename OutputTileIterator::Fragment const &frag_C1, -+ typename OutputTileIterator::Fragment const &frag_C2, -+ BroadcastFragment const &frag_Broadcast) { -+ -+ using AccessTypeZ = Array; -+ using AccessTypeT = Array; -+ using AccessTypeBroadcast = Array; -+ -+ AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); -+ AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); -+ -+ AccumulatorAccessType const *frag_AB_ptr = -+ reinterpret_cast(&frag_AB); -+ -+ OutputAccessType const *frag_C1_ptr = -+ reinterpret_cast(&frag_C1); -+ -+ OutputAccessType const *frag_C2_ptr = -+ reinterpret_cast(&frag_C2); -+ -+ AccessTypeBroadcast const *frag_Broadcast_ptr = -+ reinterpret_cast(&frag_Broadcast); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ output_op( -+ frag_Z_ptr[i], -+ frag_T_ptr[i], -+ frag_AB_ptr[i], -+ frag_C1_ptr[i], -+ frag_C2_ptr[i], -+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_source_not_needed_( -+ typename OutputTileIterator::Fragment &frag_Z, -+ typename TensorTileIterator::Fragment &frag_T, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &frag_AB, -+ BroadcastFragment const &frag_Broadcast) { -+ -+ using AccessTypeZ = Array; -+ using AccessTypeT = Array; -+ using AccessTypeBroadcast = Array; -+ -+ AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); -+ AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); -+ -+ AccumulatorAccessType const *frag_AB_ptr = -+ reinterpret_cast(&frag_AB); -+ -+ AccessTypeBroadcast const *frag_Broadcast_ptr = -+ reinterpret_cast(&frag_Broadcast); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ -+ output_op( -+ frag_Z_ptr[i], -+ frag_T_ptr[i], -+ frag_AB_ptr[i], -+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); -+ } -+ } -+}; -+ -+ -+template < -+ typename Shape_, -+ typename WarpMmaOperator_, -+ int PartitionsK, -+ typename OutputTileIterator_, -+ typename TensorTileIterator_, -+ typename ElementVector_, -+ typename AccumulatorFragmentIterator_, -+ typename WarpTileIterator_, -+ typename SharedLoadIterator_, -+ typename OutputOp_, -+ typename Padding_, -+ int FragmentsPerPartition, -+ int IterationsUnroll -+> -+class EpilogueWithBroadcast< -+ Shape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ OutputTileIterator_, -+ TensorTileIterator_, -+ ElementVector_, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ SharedLoadIterator_, -+ OutputOp_, -+ Padding_, -+ FragmentsPerPartition, -+ IterationsUnroll, -+ true -+> : -+ public EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition> { -+ -+public: -+ -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>; -+ -+ static bool const kIsSingleSource = true; -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using TensorTileIterator = TensorTileIterator_; -+ using ElementVector = ElementVector_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Compute data type produced by the output op -+ using ElementCompute = typename OutputOp::ElementCompute; -+ -+ /// Compute fragment -+ using FragmentCompute = Array; -+ -+ /// Thread map used by output tile iterators -+ using ThreadMap = typename OutputTileIterator::ThreadMap; -+ -+ /// Fragment object used to store the broadcast values -+ using BroadcastFragment = Array< -+ ElementCompute, -+ ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Data type of additional tensor -+ using ElementTensor = typename TensorTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Array type used by output functor -+ using ComputeAccessType = Array; -+ -+ /// Tensor access type -+ using TensorAccessType = Array; -+ -+ /// Number of warps -+ using WarpCount = typename Base::WarpCount; -+ -+ /// Shared memory allocation from epilogue base class -+ using BaseSharedStorage = typename Base::SharedStorage; -+ -+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; -+ static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; -+ -+ /// Used for the broadcast -+ struct BroadcastDetail { -+ -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = kWarpSize * WarpCount::kCount; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ -+ /// I'm not sure what I meant here. -+ static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = MatrixShape< -+ kThreadRows, -+ Shape::kN -+ >; -+ -+ /// Debug printing -+ CUTLASS_DEVICE -+ static void print() { -+#if 0 -+ printf("BroadcastDetail {\n"); -+ printf( -+ " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" -+ "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", -+ kColumnsPerThread, -+ kRowsPerThread, -+ kThreadCount, -+ kThreadsPerRow, -+ kThreadRows, -+ kThreadAccessesPerRow, -+ StorageShape::kRow, -+ StorageShape::kColumn, -+ StorageShape::kCount -+ ); -+ printf("};\n"); -+#endif -+ } -+ }; -+ -+ /// Shared storage structure (shadows base) with additional SMEM buffer for reduction -+ struct SharedStorage { -+ union { -+ BaseSharedStorage base; -+ }; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+public: -+ -+ -+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+private: -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ /// Thread index within the threadblock -+ int thread_idx_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueWithBroadcast( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ Base(shared_storage.base, thread_idx, warp_idx, lane_idx), -+ shared_load_iterator_(shared_storage.base.reference(), thread_idx), -+ thread_idx_(thread_idx) -+ { -+ -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ ElementVector const * broadcast_ptr, ///< Broadcast vector -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix -+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand -+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord(Shape::kM, Shape::kN), -+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space -+ MatrixCoord()) { -+ -+ BroadcastFragment broadcast_fragment; -+ -+ load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); -+ -+ if (!output_op.is_source_needed()) { -+ compute_source_not_needed_( -+ output_op, -+ broadcast_fragment, -+ destination_iterator, -+ accumulators, -+ tensor_iterator); -+ } -+ else { -+ compute_source_needed_( -+ output_op, -+ broadcast_fragment, -+ destination_iterator, -+ accumulators, -+ source_iterator, -+ tensor_iterator); -+ } -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ void load_broadcast_fragment_( -+ BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ ElementVector const * broadcast_ptr, ///< Broadcast vector -+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space -+ ) { -+ -+ broadcast_fragment.clear(); -+ -+ // If no pointer is supplied, set with all zeros and avoid memory accesses -+ if (!broadcast_ptr) { -+ return; -+ } -+ -+ int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); -+ -+ int thread_column_idx = threadblock_offset.column() + thread_initial_column; -+ broadcast_ptr += thread_initial_column; -+ -+ NumericArrayConverter converter; -+ using AccessType = AlignedArray; -+ using ComputeFragmentType = Array; -+ -+ ComputeFragmentType *frag_ptr = reinterpret_cast(&broadcast_fragment); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { -+ -+ AccessType loaded; -+ -+ loaded.clear(); -+ -+ if (thread_column_idx < problem_size.column()) { -+ loaded = *reinterpret_cast(broadcast_ptr); -+ } -+ -+ ComputeFragmentType cvt = converter(loaded); -+ frag_ptr[j] = cvt; -+ -+ thread_column_idx += ThreadMap::Delta::kColumn; -+ broadcast_ptr += ThreadMap::Delta::kColumn; -+ } -+ } -+ -+ template -+ struct acc2smem_source_not_needed; -+ -+ template -+ struct acc2smem_source_not_needed> { -+ template -+ CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ warp_tile_iterator.store(accum_fragment); -+ if (p < Base::kFragmentsPerIteration - 1) { -+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); -+ } -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * -+ (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = { -+ (pos == (Seq * Base::kFragmentsPerIteration)) && -+ (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ -+ CUTLASS_UNUSED(dummy[0]); -+ } -+ }; -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_not_needed_( -+ OutputOp const &output_op, ///< Output operator -+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand -+ ) { -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ // CUTLASS_PRAGMA_UNROLL -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) { -+ -+ // -+ // Convert and store fragment -+ // -+ -+ -+ __syncthreads(); -+ -+ acc2smem_source_not_needed< -+ cutlass::make_index_sequence>::push(iter, -+ accum_fragment_iterator, -+ this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { -+ -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ if (p < Base::kFragmentsPerIteration - 1) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ } -+ else if (kPartitionsK > 1) { -+ -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Apply output operation -+ // -+ -+ typename OutputTileIterator::Fragment frag_Z; -+ typename TensorTileIterator::Fragment frag_T; -+ -+ apply_output_operator_source_not_needed_( -+ frag_Z, -+ frag_T, -+ output_op, -+ aligned_accum_fragment[0], -+ broadcast_fragment); -+ -+ // -+ // Conditionally store fragments -+ // -+ -+ if (OutputOp::kStoreZ) { -+ destination_iterator.store(frag_Z); -+ ++destination_iterator; -+ } -+ -+ if (OutputOp::kStoreT) { -+ tensor_iterator.store(frag_T); -+ ++tensor_iterator; -+ } -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ } -+ -+ -+ template -+ struct acc2smem_source_needed; -+ -+ template -+ struct acc2smem_source_needed> { -+ template -+ CUTLASS_DEVICE -+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ warp_tile_iterator.store(accum_fragment); -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ } -+ }; -+ -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const &output_op, ///< Output operator -+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix -+ TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand -+ ) { -+ -+ typename OutputTileIterator::Fragment source_fragment; -+ source_fragment.clear(); -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Load the source -+ // -+ -+ source_iterator.load(source_fragment); -+ ++source_iterator; -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ acc2smem_source_needed>::push( -+ iter, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ if (kPartitionsK > 1) -+ { -+ plus add_fragments; -+ const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); -+ } -+ -+ // -+ // Apply output operation -+ // -+ -+ typename OutputTileIterator::Fragment frag_Z; -+ typename TensorTileIterator::Fragment frag_T; -+ -+ apply_output_operator_( -+ frag_Z, -+ frag_T, -+ output_op, -+ aligned_accum_fragment[0], -+ source_fragment, -+ broadcast_fragment); -+ -+ // -+ // Conditionally store fragments -+ // -+ -+ if (OutputOp::kStoreZ) { -+ destination_iterator.store(frag_Z); -+ ++destination_iterator; -+ } -+ -+ if (OutputOp::kStoreT) { -+ tensor_iterator.store(frag_T); -+ ++tensor_iterator; -+ } -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ typename OutputTileIterator::Fragment &frag_Z, -+ typename TensorTileIterator::Fragment &frag_T, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &frag_AB, -+ typename OutputTileIterator::Fragment const &frag_C, -+ BroadcastFragment const &frag_Broadcast) { -+ -+ using AccessTypeZ = Array; -+ using AccessTypeT = Array; -+ using AccessTypeBroadcast = Array; -+ -+ AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); -+ AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); -+ -+ AccumulatorAccessType const *frag_AB_ptr = -+ reinterpret_cast(&frag_AB); -+ -+ OutputAccessType const *frag_C_ptr = -+ reinterpret_cast(&frag_C); -+ -+ AccessTypeBroadcast const *frag_Broadcast_ptr = -+ reinterpret_cast(&frag_Broadcast); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ output_op( -+ frag_Z_ptr[i], -+ frag_T_ptr[i], -+ frag_AB_ptr[i], -+ frag_C_ptr[i], -+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_source_not_needed_( -+ typename OutputTileIterator::Fragment &frag_Z, -+ typename TensorTileIterator::Fragment &frag_T, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &frag_AB, -+ BroadcastFragment const &frag_Broadcast) { -+ -+ using AccessTypeZ = Array; -+ using AccessTypeT = Array; -+ using AccessTypeBroadcast = Array; -+ -+ AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); -+ AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); -+ -+ AccumulatorAccessType const *frag_AB_ptr = -+ reinterpret_cast(&frag_AB); -+ -+ AccessTypeBroadcast const *frag_Broadcast_ptr = -+ reinterpret_cast(&frag_Broadcast); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ -+ output_op( -+ frag_Z_ptr[i], -+ frag_T_ptr[i], -+ frag_AB_ptr[i], -+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h -new file mode 100644 -index 0000000..6e76f7e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h -@@ -0,0 +1,823 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator with reduction over each column -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors -+ typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands -+ typename ElementVector_, ///< Pointer to reduction vector -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ typename OutputOp_, ///< Output operator -+ typename ReductionOp_, ///< Reduction operator -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large -+ (!IsEpilogueFunctorHeavy::value) -+> -+class EpilogueWithReduction : -+ public EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_> { -+ -+public: -+ -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_>; -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using TensorTileIterator = TensorTileIterator_; -+ using ElementVector = ElementVector_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using ReductionOp = ReductionOp_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ static bool const kIsSingleSource = true; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Compute data type produced by the output op -+ using ElementCompute = typename OutputOp::ElementCompute; -+ -+ /// Compute fragment -+ using FragmentCompute = Array; -+ -+ /// Thread map used by output tile iterators -+ using ThreadMap = typename OutputTileIterator::ThreadMap; -+ -+ /// Fragment object used in reduction -+ using ReductionFragment = Array< -+ ElementAccumulator, -+ ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Data type of additional tensor -+ using ElementTensor = typename TensorTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Array type used by output functor -+ using ComputeAccessType = Array; -+ -+ /// Tensor access type -+ using TensorAccessType = Array; -+ -+ /// Number of warps -+ using WarpCount = typename Base::WarpCount; -+ -+ /// Shared memory allocation from epilogue base class -+ using BaseSharedStorage = typename Base::SharedStorage; -+ -+ /// Used for the reduction -+ struct ReductionDetail { -+ -+ /// If true, accumulator coordinates are computed and out-of-bounds checks are enabled when -+ /// performing the reduction. -+ static bool const kOobCheck = false; -+ -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = kWarpSize * WarpCount::kCount; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ -+ /// I'm not sure what I meant here. -+ static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = MatrixShape< -+ kThreadRows, -+ Shape::kN -+ >; -+ -+ /// Debug printing -+ CUTLASS_DEVICE -+ static void print() { -+#if 0 -+ printf("ReductionDetail {\n"); -+ printf( -+ " kElementsPerAccess:%d\nkColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" -+ "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", -+ kElementsPerAccess, -+ kColumnsPerThread, -+ kRowsPerThread, -+ kThreadCount, -+ kThreadsPerRow, -+ kThreadRows, -+ kThreadAccessesPerRow, -+ StorageShape::kRow, -+ StorageShape::kColumn, -+ StorageShape::kCount -+ ); -+ printf("};\n"); -+#endif -+ } -+ }; -+ -+ /// Shared storage structure (shadows base) with additional SMEM buffer for reduction -+ struct SharedStorage { -+ union { -+ BaseSharedStorage base; -+ AlignedArray reduction; ///< Shared storage for reduction -+ }; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+public: -+ -+ -+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+private: -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ /// Shared memory pointer fo rreduction -+ ElementAccumulator *reduction_ptr_; -+ -+ /// Thread index within the threadblock -+ int thread_idx_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueWithReduction( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ Base(shared_storage.base, thread_idx, warp_idx, lane_idx), -+ shared_load_iterator_(shared_storage.base.reference(), thread_idx), -+ reduction_ptr_(shared_storage.reduction.data()), -+ thread_idx_(thread_idx) -+ { -+ -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ ElementVector * reduction_output_ptr, ///< Reduction output vector -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix -+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand -+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord(Shape::kM, Shape::kN), -+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space -+ MatrixCoord()) { -+ -+ ReductionFragment reduction_fragment; -+ reduction_fragment.clear(); -+ -+ if (!output_op.is_source_needed()) { -+ compute_source_not_needed_( -+ output_op, -+ reduction_fragment, -+ destination_iterator, -+ accumulators, -+ tensor_iterator, -+ problem_size, -+ threadblock_offset); -+ } -+ else { -+ compute_source_needed_( -+ output_op, -+ reduction_fragment, -+ destination_iterator, -+ accumulators, -+ source_iterator, -+ tensor_iterator, -+ problem_size, -+ threadblock_offset); -+ } -+ -+ if (output_op.participates_in_reduction()) { -+ reduction_(problem_size, threadblock_offset, reduction_output_ptr, reduction_fragment); -+ } -+ } -+ -+private: -+ -+ /// Perform the reduction -+ CUTLASS_DEVICE -+ void reduction_( -+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord const &threadblock_offset, ///< Problem size needed to guard against out-of-bounds accesses -+ ElementVector * reduction_output_ptr, ///< Reduction output vector -+ ReductionFragment const & reduction_fragment) { -+ -+ // -+ // Store the partially reduced value to SMEM -+ // -+ -+ // Guard against uses of the existing SMEM tile -+ __syncthreads(); -+ -+ using AccessType = AlignedArray; -+ -+ // -+ // Determine a compacted thread arrangement to store to SMEM. -+ // -+ int const kThreadsPerRow = Shape::kN / (ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess); -+ -+ MatrixCoord thread_offset( -+ thread_idx_ / kThreadsPerRow, -+ (thread_idx_ % kThreadsPerRow) * ThreadMap::kElementsPerAccess); -+ -+ // -+ // Each thread store its fragment to a SMEM -+ // -+ -+ AccessType *aligned_reduction_ptr = reinterpret_cast( -+ &reduction_ptr_[thread_offset.row() * Shape::kN + thread_offset.column()]); -+ -+ AccessType const *frag_ptr = reinterpret_cast(&reduction_fragment); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess; -+ -+ aligned_reduction_ptr[col_idx] = frag_ptr[column]; -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Now, threads are assigned several columns of the output. They fetch over all rows from -+ // the compacted SMEM tile and perform a reduction. -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) { -+ int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount; -+ -+ ReductionOp reduction_op; -+ ElementAccumulator reduction_element = ElementAccumulator(); -+ -+ int output_column_idx = threadblock_offset.column() + column_idx; -+ -+ if (column_idx < Shape::kN && output_column_idx < problem_size.column()) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ReductionDetail::kThreadRows; ++row) { -+ if (row) { -+ auto frag = reduction_ptr_[row * Shape::kN + column_idx]; -+ -+ reduction_element = reduction_op(reduction_element, frag); -+ } -+ else { -+ -+ reduction_element = reduction_ptr_[column_idx]; -+ } -+ } -+ -+ // Store -+ reduction_output_ptr[column_idx] = ElementVector(reduction_element); -+ } -+ } -+ } -+ -+ template -+ struct acc2smem; -+ -+ template -+ struct acc2smem> { -+ template -+ CUTLASS_DEVICE -+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ warp_tile_iterator.store(accum_fragment); -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ } -+ }; -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_not_needed_( -+ OutputOp const &output_op, ///< Output operator -+ ReductionFragment &reduction_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additioanl tensor operand -+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space -+ ) { -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ typename TensorTileIterator::Fragment tensor_fragment; -+ tensor_fragment.clear(); -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Convert and store fragment -+ // -+ -+ tensor_iterator.load(tensor_fragment); -+ ++tensor_iterator; -+ -+ __syncthreads(); -+ -+ acc2smem>::push( -+ iter, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ // -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ // -+ if (kPartitionsK > 1) -+ { -+ plus add_fragments; -+ const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); -+ } -+ -+ // -+ // Compute the output result -+ // -+ -+ FragmentCompute compute_fragment; -+ -+ apply_output_operator_source_not_needed_( -+ reduction_fragment, -+ compute_fragment, -+ output_op, -+ aligned_accum_fragment[0], -+ tensor_fragment, -+ destination_iterator); -+ -+ // -+ // Store the final result -+ // -+ -+ NumericArrayConverter converter; -+ -+ typename OutputTileIterator::Fragment output_fragment = converter(compute_fragment); -+ -+ destination_iterator.store(output_fragment); -+ ++destination_iterator; -+ } -+ } -+ -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const &output_op, ///< Output operator -+ ReductionFragment &reduction_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additioanl tensor operand -+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space -+ ) { -+ -+ typename OutputTileIterator::Fragment source_fragment; -+ source_fragment.clear(); -+ -+ typename TensorTileIterator::Fragment tensor_fragment; -+ tensor_fragment.clear(); -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Load the source -+ // -+ -+ source_fragment.clear(); -+ source_iterator.load(source_fragment); -+ ++source_iterator; -+ -+ tensor_iterator.load(tensor_fragment); -+ ++tensor_iterator; -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ acc2smem>::push( -+ iter, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ if (kPartitionsK > 1) -+ { -+ plus add_fragments; -+ const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); -+ } -+ -+ // -+ // Compute the output result -+ // -+ -+ FragmentCompute compute_fragment; -+ -+ apply_output_operator_( -+ reduction_fragment, -+ compute_fragment, -+ output_op, -+ aligned_accum_fragment[0], -+ source_fragment, -+ tensor_fragment, -+ destination_iterator); -+ -+ // -+ // Convert and store the final result -+ // -+ -+ NumericArrayConverter converter; -+ -+ typename OutputTileIterator::Fragment output_fragment = converter(compute_fragment); -+ -+ destination_iterator.store(output_fragment); -+ ++destination_iterator; -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ ReductionFragment &reduction_fragment, -+ FragmentCompute &compute_fragment, -+ OutputOp const &output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment, -+ typename OutputTileIterator::Fragment const &source_fragment, -+ typename TensorTileIterator::Fragment const &tensor_fragment, -+ OutputTileIterator const & destination_iterator) { -+ -+ ComputeAccessType *compute_frag_ptr = -+ reinterpret_cast(&compute_fragment); -+ -+ AccumulatorAccessType const *accum_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ OutputAccessType const *source_frag_ptr = -+ reinterpret_cast(&source_fragment); -+ -+ TensorAccessType const *tensor_frag_ptr = -+ reinterpret_cast(&tensor_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ -+ // Call the output operator -+ compute_frag_ptr[i] = output_op(accum_frag_ptr[i], source_frag_ptr[i], tensor_frag_ptr[i]); -+ } -+ -+ // -+ // Partial reduction over each column -+ // -+ -+ ReductionOp reduction_op; -+ -+ typename OutputTileIterator::Mask mask; -+ destination_iterator.get_mask(mask); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ReductionDetail::kColumnsPerThread; ++column) { -+ -+ int column_vector_idx = column / ThreadMap::kElementsPerAccess; -+ bool column_guard = mask.predicates[column_vector_idx]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ReductionDetail::kRowsPerThread; ++row) { -+ -+ bool fetch; -+ if (ReductionDetail::kOobCheck) { -+ int row_idx = (row % ThreadMap::Iterations::kRow); -+ int residual = (row / ThreadMap::Iterations::kRow); -+ -+ int group_idx = (residual % ThreadMap::Iterations::kGroup); -+ residual = (residual / ThreadMap::Iterations::kGroup); -+ -+ int cluster_idx = (residual % ThreadMap::Iterations::kCluster); -+ -+ int row_offset = row_idx * ThreadMap::Delta::kRow -+ + group_idx * ThreadMap::Delta::kGroup -+ + cluster_idx * ThreadMap::Delta::kCluster; -+ -+ int output_row = destination_iterator.thread_start_row() + row_offset; -+ -+ fetch = (output_row < destination_iterator.extent_row() && column_guard); -+ } -+ else { -+ fetch = true; -+ } -+ -+ ElementCompute value = ElementCompute(); -+ if (fetch) { -+ value = compute_fragment[row * ReductionDetail::kColumnsPerThread + column]; -+ } -+ -+ reduction_fragment[column] = reduction_op( -+ reduction_fragment[column], -+ value); -+ } -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_source_not_needed_( -+ ReductionFragment &reduction_fragment, -+ FragmentCompute &compute_fragment, -+ OutputOp const &output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment, -+ typename TensorTileIterator::Fragment const &tensor_fragment, -+ OutputTileIterator const & destination_iterator -+ ) { -+ -+ ComputeAccessType *compute_frag_ptr = -+ reinterpret_cast(&compute_fragment); -+ -+ AccumulatorAccessType const *accum_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ TensorAccessType const *tensor_frag_ptr = -+ reinterpret_cast(&tensor_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ -+ // Call the output operator -+ compute_frag_ptr[i] = output_op(accum_frag_ptr[i], tensor_frag_ptr[i]); -+ } -+ -+ // -+ // Partial reduction over each column -+ // -+ -+ ReductionOp reduction_op; -+ -+ typename OutputTileIterator::Mask mask; -+ destination_iterator.get_mask(mask); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ReductionDetail::kColumnsPerThread; ++column) { -+ -+ int column_vector_idx = column / ThreadMap::kElementsPerAccess; -+ bool column_guard = mask.predicates[column_vector_idx]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ReductionDetail::kRowsPerThread; ++row) { -+ -+ bool fetch; -+ if (ReductionDetail::kOobCheck) { -+ int row_idx = (row % ThreadMap::Iterations::kRow); -+ int residual = (row / ThreadMap::Iterations::kRow); -+ -+ int group_idx = (residual % ThreadMap::Iterations::kGroup); -+ residual = (residual / ThreadMap::Iterations::kGroup); -+ -+ int cluster_idx = (residual % ThreadMap::Iterations::kCluster); -+ -+ int row_offset = row_idx * ThreadMap::Delta::kRow -+ + group_idx * ThreadMap::Delta::kGroup -+ + cluster_idx * ThreadMap::Delta::kCluster; -+ -+ int output_row = destination_iterator.thread_start_row() + row_offset; -+ -+ fetch = (output_row < destination_iterator.extent_row() && column_guard); -+ } -+ else { -+ fetch = true; -+ } -+ -+ ElementCompute value = ElementCompute(); -+ if (fetch) { -+ value = compute_fragment[row * ReductionDetail::kColumnsPerThread + column]; -+ } -+ -+ reduction_fragment[column] = reduction_op( -+ reduction_fragment[column], -+ value); -+ } -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor.h -new file mode 100644 -index 0000000..6c54353 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor.h -@@ -0,0 +1,409 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Generic epilogue for implementing certain kinds of fused epilogue behavior. -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+class EpilogueFusedVisitorConcept { -+public: -+ -+ static int const kIterations = 1; -+ static int const kElementsPerAccess = 4; -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using AccumulatorFragment = Array; -+ -+ /// Arguments structure -+ struct Arguments { }; -+ -+ /// Params structure -+ struct Params { -+ -+ Params() { } -+ Params(Arguments const &args) { } -+ }; -+ -+ /// Shared storage -+ struct SharedStorage { }; -+ -+public: -+ -+ CUTLASS_DEVICE -+ EpilogueFusedVisitorConcept( -+ Params const ¶ms, ///< Parameters routed to the epilogue -+ SharedStorage &shared_storage, ///< Shared storage needed by the functors here -+ MatrixCoord const &problem_size, ///< Problem size of the output -+ int thread_idx, ///< Thread index within the threadblock -+ int warp_idx, ///< Warp index within the threadblock -+ int lane_idx, ///< Lane index within the warp -+ MatrixCoord const &threadblock_offset = MatrixCoord(0, 0)) { ///< Coordinate -+ -+ } -+ -+ /// Helper to indicate split-K behavior -+ CUTLASS_DEVICE -+ void set_k_partition( -+ int split_k_index, ///< Index of this threadblock within split-K partitioned scheme -+ int split_k_slices) { ///< Total number of split-K slices -+ -+ } -+ -+ /// Called to set the batch index -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ -+ } -+ -+ /// Called at the start of the epilogue just before iterating over accumulator slices -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ -+ } -+ -+ /// Called at the start of one step before starting accumulator exchange -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ -+ } -+ -+ /// Called after accumulators have been exchanged for each accumulator vector -+ CUTLASS_DEVICE -+ void visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorFragment const &accum) { -+ -+ } -+ -+ /// Called at the end of a row -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ -+ } -+ -+ /// Called after all accumulator elements have been visited -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ -+ } -+ -+ /// Called after all steps have been completed -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename Visitor_, ///< Functor containing fused operations (satisfies EpilogueFusedVisitorConcept) -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+ int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity -+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large -+ (true || !IsEpilogueFunctorHeavy::value) -+> -+class EpilogueWithVisitor : -+ public EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition> { -+ -+public: -+ -+ using Visitor = Visitor_; -+ -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>; -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = Visitor::kElementsPerAccess; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array< -+ typename WarpTileIterator::Element, kElementsPerAccess>; -+ -+ /// Number of warps -+ using WarpCount = typename Base::WarpCount; -+ -+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; -+ static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; -+ -+ using SharedStorage = typename Base::SharedStorage; -+ -+private: -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueWithVisitor( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ shared_load_iterator_(shared_storage.reference(), thread_idx) -+ { -+ -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ Visitor & visitor, -+ AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ visitor.begin_epilogue(); -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? Visitor::kIterations : 1) -+ for (int iter_idx = 0; iter_idx < Visitor::kIterations; ++iter_idx) { -+ -+ // -+ // Load the source -+ // -+ -+ visitor.begin_step(iter_idx); -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ acc2smem_source_needed>::push( -+ iter_idx, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ if (kPartitionsK > 1) { -+ -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Iterate over output fragments -+ // -+ -+ AccumulatorAccessType const *accum_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment[0]); -+ -+ int const kAccumulatorFragmentCount = AccumulatorTile::kElements / (Visitor::kIterations * AccumulatorAccessType::kElements); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < kAccumulatorFragmentCount; ++idx) { -+ -+ int row_idx = idx / SharedLoadIterator::ThreadMap::Iterations::kColumn; -+ int col_idx = idx % SharedLoadIterator::ThreadMap::Iterations::kColumn; -+ -+ // Start a new row of the output fragment -+ if (!col_idx) { -+ visitor.begin_row(row_idx); -+ } -+ -+ visitor.visit( -+ iter_idx, -+ row_idx, -+ col_idx, -+ idx, -+ accum_frag_ptr[idx] -+ ); -+ -+ // End the row of the output fragment -+ if (col_idx + 1 == SharedLoadIterator::ThreadMap::Iterations::kColumn) { -+ visitor.end_row(row_idx); -+ } -+ } -+ -+ // -+ // Conclude the step -+ // -+ -+ visitor.end_step(iter_idx); -+ } -+ -+ visitor.end_epilogue(); -+ } -+ -+private: -+ -+ -+ template -+ struct acc2smem_source_needed; -+ -+ template -+ struct acc2smem_source_needed> { -+ template -+ CUTLASS_DEVICE -+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ warp_tile_iterator.store(accum_fragment); -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ } -+ }; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to create an EpilogueWithVisitor from an existing epilogue -+template -+struct EpilogueWithVisitorFromExistingEpilogue { -+ -+ using Epilogue = EpilogueWithVisitor< -+ Visitor_, -+ typename Existing_::Shape, -+ typename Existing_::WarpMmaOperator, -+ Existing_::kPartitionsK, -+ typename Existing_::AccumulatorFragmentIterator, -+ typename Existing_::WarpTileIterator, -+ typename Existing_::SharedLoadIterator, -+ typename Existing_::Padding, -+ Existing_::kFragmentsPerIteration, -+ IterationsUnroll -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_workspace.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_workspace.h -new file mode 100644 -index 0000000..5034af3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_workspace.h -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs. -+ -+ This does not attempt to target any particular output layout. Instead, each threadblock -+ streams out its accumulator elements using 128b store operations. This assumes all threadblocks -+ have unique output tiles. -+ -+ The target data layout is: -+ - threadblock indices mapped to linear offsets as (m, n, k), where m is fastest-changing -+ - threadblock output space partitioned into warps; each warp's region is contiguous -+ - per-thread accumulators partitioned into 128b accesses -+ - output memory striped across the threads of a warp -+ -+ This enables very fast streaming of data, completely limited by the memory system. No predication -+ or data exchange is performed, and each threadblock is assumed to have a full region of memory -+ to write to. -+ -+ This epilogue establishes an upper bound for epilogue performance and is suitable for -+ reductions across the GEMM K dimension which require a separate workspace. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, ///< shape of accumulator tile (concept: MatrixShape) -+ int WarpCount, ///< number of warps -+ typename FragmentC_ ///< warp-level GEMM operator (concept: gemm::warp::Mma) -+> -+class EpilogueWorkspace { -+public: -+ -+ using Shape = Shape_; -+ using FragmentC = FragmentC_; -+ using ElementC = typename FragmentC::value_type; -+ -+ static int const kWarpCount = WarpCount; -+ -+ /// Optimize for 128b accesses -+ static int const kAccessSizeInBits = 128; -+ -+ /// Warp size from the perspective of memory operations -+ static int const kWarpSize = 32; -+ -+ /// Vector length of accesses -+ static int const kElementsPerAccess = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ /// Number of stores per thread -+ static int const kIterations = FragmentC::kElements / kElementsPerAccess; -+ -+ static_assert( -+ !(FragmentC::kElements % kElementsPerAccess), -+ "The number of accumulators must be divisible by the access size."); -+ -+ /// Total number of vectorized accesses in warp (in units of vector) -+ static int const kWarpAccesses = kIterations * kWarpSize; -+ -+ /// Total number of vectorized accesses in threadblock tile (in units of vector) -+ static int const kThreadblockAccesses = kWarpAccesses * kWarpCount; -+ -+ /// Parameters structure -+ struct Params { -+ -+ /// Pointer to C matrix -+ ElementC *ptr_C; -+ -+ /// Stride between tiles along the GEMM N dimension (in units of vectors) -+ int stride_n; -+ -+ /// Stride between tiles along the GEMM K dimension (in units of vectors) -+ int stride_k; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementC *ptr_C, ///< Pointer to C matrix -+ int stride_n_, ///< Stride between tiles along the GEMM N dimension (in units of ElementC) -+ int stride_k_ ///< Stride between tiles along the GEMM K dimension (in units of ElementC) -+ ): -+ ptr_C(ptr_C), stride_n(stride_n_ / kElementsPerAccess), stride_k(stride_k_ / kElementsPerAccess) { -+ -+ } -+ }; -+ -+ /// Shared storage allocation needed by the epilogue -+ struct SharedStorage { -+ // Intentionally empty -+ }; -+ -+private: -+ -+ struct alignas((kAccessSizeInBits / 8)) AccessType { -+ Array storage; -+ }; -+ -+ /// Constant reference to parameters object -+ AccessType *pointer_; -+ -+ /// Stride between tiles along the n dimension (in vectors) -+ int stride_n_; -+ -+ /// Stride between tiles along the k dimension (in vectors) -+ int stride_k_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueWorkspace( -+ Params const ¶ms, ///< Host-constructable params object -+ SharedStorage &, ///< Shared storage object -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ -+ ): -+ pointer_(reinterpret_cast(params.ptr_C)), -+ stride_n_(params.stride_n), -+ stride_k_(params.stride_k) { -+ -+ // Add per-thread offset -+ pointer_ += lane_idx + warp_idx * kWarpAccesses; -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ cutlass::gemm::GemmCoord problem_size, ///< Problem size of GEMM (units of ElementC) -+ cutlass::gemm::GemmCoord tb_tile_coord, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ FragmentC const &accum) { ///< Accumulator tile -+ -+ // Compute offset for entire threadblock (note, per-thread offset has been folded in already) -+ AccessType *pointer = pointer_ + -+ tb_tile_coord.m() * kThreadblockAccesses + -+ tb_tile_coord.n() * stride_n_ + -+ tb_tile_coord.k() * stride_k_; -+ -+ // Cast to vectorized view of accumulator fragments -+ AccessType const * src_pointer = reinterpret_cast(&accum); -+ -+ // Write out accumulators at full speed -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kIterations; ++i) { -+ pointer[i * kWarpSize] = src_pointer[i]; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h -new file mode 100644 -index 0000000..b4d1bbe ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h -@@ -0,0 +1,407 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator without splitk -+template < -+ /// Shape of threadblock tile (concept: GemmShape) -+ typename Shape_, -+ /// Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ typename WarpMmaOperator_, -+ /// Number of partitions of the K dimension -+ int PartitionsK, -+ /// Tile iterator reading and writing output tensors -+ typename OutputTileIterator_, -+ /// Fragment iterator selecting accumulators -+ typename AccumulatorFragmentIterator_, -+ /// Output operator -+ typename OutputOp_, -+ /// Number of interleaved k -+ int InterleavedK> -+class InterleavedEpilogue : -+ public EpilogueBaseStreamK< -+ Shape_, -+ PartitionsK, -+ WarpMmaOperator_, -+ AccumulatorFragmentIterator_> -+{ -+public: -+ -+ using BaseStreamK = EpilogueBaseStreamK< -+ Shape_, -+ PartitionsK, -+ WarpMmaOperator_, -+ AccumulatorFragmentIterator_>; -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using OutputTileIterator = OutputTileIterator_; -+ using OutputOp = OutputOp_; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Fragment type used by the accumulator tile's fragment iterator -+ using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename AccumulatorTile::Element; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = -+ typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = -+ Array; -+ -+ /// Number of warps -+ using WarpCount = -+ gemm::GemmShape; -+ -+public: -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, -+ "This must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % -+ OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+public: -+ -+ /// Aspect for when epilogue source is not needed -+ struct SourceAspectNotNeeded -+ { -+ /// Constructor -+ CUTLASS_DEVICE -+ SourceAspectNotNeeded() -+ {} -+ -+ /// Invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, -+ typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment) -+ { -+ OutputAccessType *output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const *compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) -+ { -+ // Call the output operator -+ output_frag_ptr[i] = output_op(compute_frag_ptr[i]); -+ } -+ } -+ }; -+ -+ -+ /// Aspect for when epilogue source is needed -+ struct SourceAspectNeeded -+ { -+ OutputTileIterator source_iterator; -+ -+ typename OutputTileIterator::Fragment source_fragment; -+ -+ /// Invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ static void apply_output_operator( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, -+ typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment, -+ typename OutputTileIterator::Fragment const &source_fragment) -+ { -+ OutputAccessType *output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const *compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ OutputAccessType const *source_frag_ptr = -+ reinterpret_cast(&source_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) -+ { -+ // Call the output operator -+ output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); -+ } -+ } -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SourceAspectNeeded(OutputTileIterator source_iterator) : -+ source_iterator(source_iterator) -+ { -+ source_fragment.clear(); -+ } -+ -+ /// Invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, -+ typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment) -+ { -+ // Load addend source fragment from global memory -+ source_iterator.load(source_fragment); -+ ++source_iterator; -+ -+ apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment); -+ } -+ }; -+ -+ -+ /// Shared storage allocation needed by the epilogue -+ struct SharedStorage {}; -+ -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ InterleavedEpilogue( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx) ///< Id of thread within warp -+ : -+ BaseStreamK(thread_idx) -+ {} -+ -+ -+ /// Aggregates the accumulator sets shared by peer blocks in the global workspace, -+ /// performing epilogue computations, writing to output -+ CUTLASS_DEVICE -+ void reduce( -+ int peer_idx_begin, -+ int peer_idx_end, -+ int reduce_fragment_idx, -+ void *element_workspace, -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ { -+ // Redcuce peer accumulator fragments into one fragment -+ AccumulatorFragment accum_fragment; -+ BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace); -+ -+ // Source-fragment data (zero-initialized for scenarios where the -+ // output operator allows us to skip loading it from global input) -+ typename OutputTileIterator::Fragment source_fragment; -+ source_fragment.clear(); -+ -+ if (output_op.is_source_needed()) -+ { -+ source_iterator += reduce_fragment_idx; -+ source_iterator.load(source_fragment); -+ } -+ -+ // Compute the output result -+ typename OutputTileIterator::Fragment output_fragment; -+ -+ // Apply the output operator -+ SourceAspectNeeded::apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment); -+ -+ // Store the final result -+ destination_iterator += reduce_fragment_idx; -+ destination_iterator.store(output_fragment); -+ } -+ -+ -+ /// Perform the epilogue computations and stream the result to global memory. -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile -+ { -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); -+ } -+ -+ -+ /// Perform the epilogue computations and stream the result to global memory. Implements -+ /// two alternative codepaths, depending on whether the output op requires addend data to be loaded. -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator ) ///< Tile iterator for addend source -+ { -+ if (output_op.is_source_needed()) -+ { -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); -+ } -+ else -+ { -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); -+ } -+ } -+ -+ -+ /// Perform the epilogue computations and stream the result to global memory. Implements a -+ /// single codepath, regardless of whether the output op requires addend data to be loaded -+ CUTLASS_DEVICE -+ void unified( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator ) ///< Tile iterator for addend source -+ { -+ if (!output_op.is_source_needed()) -+ { -+ source_iterator.clear_mask(); -+ __syncthreads(); // Dummy (CUDA 11.0) -+ } -+ -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); -+ } -+ -+ -+ /// Streams the result to global memory -+ template -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ SourceAspect source) -+ { -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Convert fragment -+ // -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ // -+ // Compute the output result -+ // -+ -+ typename OutputTileIterator::Fragment output_fragment; -+ source.apply_output_operator(output_fragment, output_op, accum_fragment); -+ -+ // -+ // Store the final result -+ // -+ -+ destination_iterator.set_iteration_index(iter); -+ destination_iterator.store(output_fragment); -+ ++destination_iterator; -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h -new file mode 100644 -index 0000000..8cfba76 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h -@@ -0,0 +1,92 @@ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/tensor_ref.h" -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+template< -+ typename TensorLayout_, ///! The original output tensor layout -+ typename OutputIteratorLayout_, ///! Layout used by epilogue output iterator -+ typename TensorRef_, ///! Input tensor to epilogue output iterator -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem -+> -+struct ConvOutputIteratorParameter { -+ -+ using TensorLayout = TensorLayout_; -+ using OutputIteratorLayout = OutputIteratorLayout_; -+ using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; -+ using TensorRef = TensorRef_; -+ static conv::Operator const kConvolutionalOperator = ConvOperator; -+ using ConvProblemSize = ConvProblemSize_; -+ -+ /// Wgrad stride idx for implicit gemm algorithm -+ // Conv2d row-major matrix (KxRSC) -+ // Conv3d row-major matrix (KxTRSC) -+ static int const kWgradStrideIdx = -+ platform::is_same::value ? 2 : 3; -+ -+ /// This chooses the appropriate stride element of the C tensor. -+ static int const kTensorStrideIdx = -+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradStrideIdx : 0); -+ -+ -+ CUTLASS_HOST_DEVICE -+ static OutputIteratorLayout layout(const TensorRef & ref) { -+ return ref.stride(kTensorStrideIdx); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static OutputTensorCoord extent(ConvProblemSize problem_size) { -+ return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); -+ } -+ -+}; -+ -+ -+ -+template < -+ int InterleavedK, -+ typename TensorRef_, -+ conv::Operator ConvOperator, -+ typename ConvProblemSize_ -+> -+struct ConvOutputIteratorParameter< -+ layout::TensorNCxHWx, -+ layout::TensorNCxHWx, -+ TensorRef_, -+ ConvOperator, -+ ConvProblemSize_> -+{ -+ -+ using TensorLayout = typename layout::TensorNCxHWx; -+ using OutputIteratorLayout = typename layout::TensorNCxHWx; -+ using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; -+ using TensorRef = TensorRef_; -+ static conv::Operator const kConvolutionalOperator = ConvOperator; -+ using ConvProblemSize = ConvProblemSize_; -+ -+ CUTLASS_HOST_DEVICE -+ static OutputIteratorLayout layout(const TensorRef & ref) { -+ return ref.stride(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static OutputTensorCoord extent(ConvProblemSize problem_size) { -+ return problem_size.output_extent(); -+ } -+ -+}; -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h -new file mode 100644 -index 0000000..828b7a7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h -@@ -0,0 +1,626 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/fast_math.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tuple defining point in output tile -+template < -+ int Column, -+ int Row, -+ int Group, -+ int Cluster, -+ int Tile -+> -+struct OutputTileShape { -+ static int const kColumn = Column; -+ static int const kRow = Row; -+ static int const kGroup = Group; -+ static int const kCluster = Cluster; -+ static int const kTile = Tile; -+ -+ static int const kCount = kColumn * kRow * kGroup * kCluster * kTile; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct OutputTileThreadMapHelpers { -+ -+ /// Determines the iteration index of a vector access according to the thread map -+ CUTLASS_HOST_DEVICE -+ static void iteration_index( -+ int &column_idx, -+ int &row_idx, -+ int &group_idx, -+ int &cluster_idx, -+ int &tile_idx, -+ int iter_idx) { -+ -+ column_idx = iter_idx % Iterations::kColumn; -+ int residual = iter_idx / Iterations::kColumn; -+ -+ row_idx = residual % Iterations::kRow; -+ residual = residual / Iterations::kRow; -+ -+ group_idx = residual % Iterations::kGroup; -+ residual = residual / Iterations::kGroup; -+ -+ cluster_idx = residual % Iterations::kCluster; -+ tile_idx = residual / Iterations::kCluster; -+ } -+ -+ /// Computes the offset of a given vector access -+ CUTLASS_HOST_DEVICE -+ static MatrixCoord iteration_offset(int iter_idx) { -+ -+ int column_idx; -+ int row_idx; -+ int group_idx; -+ int cluster_idx; -+ int tile_idx; -+ -+ iteration_index(column_idx, row_idx, group_idx, cluster_idx, tile_idx, iter_idx); -+ -+ return -+ MatrixCoord( -+ row_idx * Delta::kRow + -+ group_idx * Delta::kGroup + -+ cluster_idx * Delta::kCluster + -+ tile_idx * Delta::kTile, -+ -+ column_idx * Delta::kColumn); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+template < -+ typename ThreadMap_, -+ typename Shape_, -+ typename Iterations_, -+ typename Delta_, -+ typename Count_ -+> -+struct OutputTileThreadMap : public OutputTileThreadMapHelpers { -+ -+ /// Conventional thread map (concept: ThreadMap) -+ using ThreadMap = ThreadMap_; -+ -+ /// Number of threads participating in the operation -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Number of scalar elements per access -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ /// Shape of the tile -+ using Shape = Shape_; -+ -+ /// Iterations performed by each thread -+ using Iterations = Iterations_; -+ -+ /// Delta between accesses -+ using Delta = Delta_; -+ -+ /// Number of iterator iterations -+ using Count = Count_; -+ -+ /// Initial offset function -+ CUTLASS_HOST_DEVICE -+ static MatrixCoord initial_offset(int thread_idx) { -+ -+ using Index = typename layout::PitchLinearCoord::Index; -+ -+ layout::PitchLinearCoord coord = ThreadMap::initial_offset(thread_idx); -+ -+ Index cluster = coord.strided() / (Shape::kGroup * Shape::kRow); -+ Index cluster_residual = coord.strided() % (Shape::kGroup * Shape::kRow); -+ -+ Index group = cluster_residual / (Shape::kRow); -+ Index row = cluster_residual % (Shape::kRow); -+ -+ return MatrixCoord{ -+ row + group * Shape::kRow * Count::kRow -+ + cluster * Shape::kGroup * Count::kGroup * Shape::kRow * Count::kRow, -+ coord.contiguous() -+ }; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// RowArrangement determines how one or more warps cover a region of consecutive rows. -+template < -+ typename Shape, -+ int WarpsRemaining, -+ int ElementsPerAccess, -+ int ElementSize, -+ bool Is2dTile -+> -+struct RowArrangement; -+ -+/// RowArrangement in which each warp's access is a 1D tiled arrangement. -+template < -+ typename Shape, -+ int WarpsRemaining, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct RowArrangement { -+ static int const kWarpSize = 32; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ static int const kIterationsRow = 1; -+ static int const kDeltaRow = 1; -+ static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize; -+ static int const kDeltaColumn = kWarpSize * kElementsPerAccess; -+ -+ static int const kAccessWidth = kWarpSize; -+ static int const kAccessRows = 1; -+ static int const kWarpPartitionsRow = 1; -+ static int const kWarpPartitionsColumn = WarpsRemaining; -+}; -+ -+/// RowArrangement in which each warp's access is a 2D tiled arrangement. -+template < -+ typename Shape, -+ int WarpsRemaining, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct RowArrangement { -+ -+ static int const kMemoryAccessSize = 256; // Preferred access size -+ static int const kWarpSize = 32; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ struct Detail { -+ static int const kShapeRow = Shape::kRow / WarpsRemaining; -+ static int const kShapeWidth = Shape::kColumn / kElementsPerAccess; -+ -+ static int const kTargetMemoryAccessWidth = -+ kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8); -+ -+ static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth; -+ }; -+ -+ static int const kAccessWidth = -+ (Detail::kTargetAccessRows > Detail::kShapeRow ? -+ kWarpSize / Detail::kShapeRow -+ : const_min( -+ Detail::kShapeWidth, -+ const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8)) -+ )); -+ -+ static int const kAccessRows = -+ (Detail::kTargetAccessRows > Detail::kShapeRow ? -+ Detail::kShapeRow -+ : const_min(Shape::kRow, kWarpSize / kAccessWidth)); -+ -+ static int const kIterationsRow = Detail::kShapeRow / kAccessRows; -+ static int const kDeltaRow = kAccessRows; -+ -+ static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth; -+ static int const kDeltaColumn = kAccessWidth * kElementsPerAccess; -+ -+ static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access"); -+ static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" ); -+ static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" ); -+ -+ static int const kWarpPartitionsRow = 1; -+ static int const kWarpPartitionsColumn = 1; -+}; -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template metaprogram for partitioning a 4D space across warps to achieve several performance -+/// objectives: -+/// -+/// - coalesced memory accesses in units of 128 Byte lines -+/// - minimal address arithmetic -+/// - minimal predicate calculations -+/// -+template < -+ typename Shape_, -+ typename Count_, -+ int Threads, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct OutputTileOptimalThreadMap { -+ -+ using Shape = Shape_; -+ using Count = Count_; -+ -+ static int const kWarpSize = 32; -+ static int const kThreads = Threads; -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ // -+ // Metaprogram computation -+ // -+ -+ struct Detail { -+ -+ // Clusters -+ static int const kIterationsCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ Shape::kCluster / kWarpCount -+ : 1); -+ -+ static int const kDeltaCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster -+ : 1); -+ -+ static int const kCompactedDeltaCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster -+ : 1); -+ -+ static int const kWarpPartitionsCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ kWarpCount -+ : kWarpCount / Shape::kCluster); -+ -+ static int const kWarpsRemainingForGroups = -+ ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster); -+ -+ // Groups -+ static int const kIterationsGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ Shape::kGroup / kWarpsRemainingForGroups -+ : 1); -+ -+ static int const kDeltaGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup -+ : 1); -+ -+ static int const kCompactedDeltaGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ Shape::kRow * Shape::kGroup / kIterationsGroup -+ : 1); -+ -+ static int const kWarpPartitionsGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ 1 -+ : kWarpsRemainingForGroups / Shape::kGroup); -+ -+ static int const kWarpsRemainingForRows = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ 1 -+ : kWarpsRemainingForGroups / Shape::kGroup); -+ -+ // Rows -+ using RowArrangement = detail::RowArrangement< -+ Shape, -+ kWarpsRemainingForRows, -+ kElementsPerAccess, -+ kElementSize, -+ (Shape::kRow > kWarpsRemainingForRows) -+ >; -+ -+ // Warp partitions -+ using WarpPartitions = OutputTileShape< -+ RowArrangement::kWarpPartitionsColumn, -+ RowArrangement::kWarpPartitionsRow, -+ kWarpPartitionsGroup, -+ kWarpPartitionsCluster, -+ 1>; -+ -+ static int const kAccessWidth = RowArrangement::kAccessWidth; -+ static int const kAccessRows = RowArrangement::kAccessRows; -+ }; -+ -+ // -+ // Output -+ // -+ -+ using Iterations = OutputTileShape< -+ Detail::RowArrangement::kIterationsColumn, -+ Detail::RowArrangement::kIterationsRow, -+ Detail::kIterationsGroup, -+ Detail::kIterationsCluster, -+ 1>; -+ -+ using Delta = OutputTileShape< -+ Detail::RowArrangement::kDeltaColumn, -+ Detail::RowArrangement::kDeltaRow, -+ Detail::kDeltaGroup, -+ Detail::kDeltaCluster, -+ 1>; -+ -+ /// Initial offset function -+ CUTLASS_DEVICE -+ static MatrixCoord initial_offset(int thread_idx) { -+ -+ int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); -+ int lane_idx = thread_idx % kWarpSize; -+ -+ // Compute warp location -+ int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster; -+ int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster; -+ -+ int group_idx = residual_cluster / Detail::WarpPartitions::kGroup; -+ int residual_group = residual_cluster % Detail::WarpPartitions::kGroup; -+ -+ int row_idx = residual_group / Detail::WarpPartitions::kRow; -+ int col_idx = residual_group % Detail::WarpPartitions::kRow; -+ -+ // Compute per-lane offset -+ int lane_row_offset = lane_idx / Detail::kAccessWidth; -+ int lane_col_offset = lane_idx % Detail::kAccessWidth; -+ -+ // Compute coordinate in output space -+ int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup; -+ int group_offset = group_idx * Shape::kRow * Count::kRow; -+ int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows; -+ int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess; -+ -+ return MatrixCoord( -+ cluster_offset + group_offset + row_offset + lane_row_offset, -+ column_offset + lane_col_offset * kElementsPerAccess -+ ); -+ } -+ -+ /// Computes the offset of a given vector access -+ CUTLASS_HOST_DEVICE -+ static MatrixCoord iteration_offset(int iter_idx) { -+ return OutputTileThreadMapHelpers::iteration_offset(iter_idx); -+ } -+ -+ /// Compacted thread map in which the 4D region is contiguous -+ struct CompactedThreadMap { -+ -+ -+ using Shape = Shape_; -+ -+ using TileShape = MatrixShape< -+ Shape::kTile * Shape::kCluster * Shape::kGroup * Shape::kRow, -+ Shape::kColumn -+ >; -+ -+ using Iterations = OutputTileShape< -+ Detail::RowArrangement::kIterationsColumn, -+ Detail::RowArrangement::kIterationsRow, -+ Detail::kIterationsGroup, -+ Detail::kIterationsCluster, -+ 1>; -+ -+ using Delta = OutputTileShape< -+ Detail::RowArrangement::kDeltaColumn, -+ Detail::RowArrangement::kDeltaRow, -+ Detail::kCompactedDeltaGroup, -+ Detail::kCompactedDeltaCluster, -+ 1>; -+ -+ /// Number of elements within each vector access -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ /// Number of threads -+ static int const kThreads = Threads; -+ -+ /// Function to compute each thread's initial offset -+ CUTLASS_DEVICE -+ static MatrixCoord initial_offset(int thread_idx) { -+ -+ int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); -+ int lane_idx = thread_idx % kWarpSize; -+ -+ // Compute warp location -+ int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster; -+ int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster; -+ -+ int group_idx = residual_cluster / Detail::WarpPartitions::kGroup; -+ int residual_group = residual_cluster % Detail::WarpPartitions::kGroup; -+ -+ int row_idx = residual_group / Detail::WarpPartitions::kRow; -+ int col_idx = residual_group % Detail::WarpPartitions::kRow; -+ -+ // Compute per-lane offset -+ int lane_row_offset = lane_idx / Detail::kAccessWidth; -+ int lane_col_offset = lane_idx % Detail::kAccessWidth; -+ -+ // Compute coordinate in output space -+ int cluster_offset = cluster_idx * Shape::kRow * Shape::kGroup; -+ int group_offset = group_idx * Shape::kRow; -+ int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows; -+ int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess; -+ -+ MatrixCoord coord( -+ cluster_offset + group_offset + row_offset + lane_row_offset, -+ column_offset + lane_col_offset * kElementsPerAccess -+ ); -+ -+ return coord; -+ } -+ }; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template metaprogram for partitioning a 3D interleaved layout across warps -+/// to achieve several performance objectives: -+/// -+/// - coalesced memory accesses in units of 64 Byte lines -+/// - minimal address arithmetic -+/// - minimal predicate calculations -+/// -+template -+struct InterleavedOutputTileThreadMap { -+ using WarpCount = WarpCount_; -+ -+ static int const kWarpSize = 32; -+ static int const kThreads = Threads; -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ // -+ // Metaprogram computation -+ // -+ -+ struct Detail {}; -+ -+ // -+ // Output -+ // -+ -+ using Iterations = Iterations_; -+ -+ using Delta = layout::PitchLinearShape; -+ -+ /// Initial offset function -+ CUTLASS_HOST_DEVICE -+ static layout::PitchLinearCoord initial_offset(int thread_idx) { -+ int warp_idx = thread_idx / kWarpSize; -+ int lane_idx = thread_idx % kWarpSize; -+ -+ // Compute warp location -+ layout::PitchLinearCoord warp_footprint{ -+ Delta::kContiguous * Iterations::kContiguous, -+ Delta::kStrided * Iterations::kStrided}; -+ -+ layout::PitchLinearCoord warp_offset{warp_idx % WarpCount::kContiguous, -+ warp_idx / WarpCount::kContiguous}; -+ -+ // Compute per-lane offset -+ layout::PitchLinearCoord thread_offset_in_warp{ -+ lane_idx * kElementsPerAccess, 0}; -+ -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = -+ warp_footprint * warp_offset + thread_offset_in_warp; -+ -+ return thread_offset_in_threadblock_tile; -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template metaprogram for partitioning a 4D interleaved layout across warps -+/// to achieve several performance objectives: -+/// -+/// - coalesced memory accesses in units of 64 Byte lines -+/// - minimal address arithmetic -+/// - minimal predicate calculations -+/// -+template -+struct InterleavedConvOutputTileThreadMap { -+ using WarpCount = WarpCount_; -+ -+ static int const kWarpSize = 32; -+ static int const kThreads = Threads; -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ // -+ // Metaprogram computation -+ // -+ -+ struct Detail {}; -+ -+ // -+ // Output -+ // -+ -+ using Iterations = Iterations_; -+ -+ using Delta = MatrixShape; -+ -+ /// Initial offset function -+ CUTLASS_HOST_DEVICE -+ static MatrixCoord initial_offset(int thread_idx) { -+ int warp_idx = thread_idx / kWarpSize; -+ int lane_idx = thread_idx % kWarpSize; -+ -+ // Compute warp location -+ MatrixCoord warp_footprint{ -+ Delta::kRow * Iterations::kRow, -+ Delta::kColumn * Iterations::kColumn, -+ }; -+ -+ MatrixCoord warp_offset{warp_idx % WarpCount::kRow, -+ warp_idx / WarpCount::kRow}; -+ -+ // Compute per-lane offset -+ MatrixCoord thread_offset_in_warp{lane_idx / 4, -+ (lane_idx % 4) * kElementsPerAccess}; -+ -+ MatrixCoord thread_offset_in_threadblock_tile = -+ warp_footprint * warp_offset + thread_offset_in_warp; -+ -+ return thread_offset_in_threadblock_tile; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h -new file mode 100644 -index 0000000..685b6bb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h -@@ -0,0 +1,1351 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/permute.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load and store output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ bool ScatterD = false, ///< Scatter D operand or not -+ typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not -+ bool UseCUDAStore = false -+> -+class PredicatedTileIterator { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Count::kTile; -+ -+ static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); -+ static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); -+ static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); -+ static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorParams { -+ using Base = PredicatedTileIteratorParams; -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): -+ PredicatedTileIteratorParams( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ make_OutputTileThreadMapDesc() -+ ) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) : -+ Base(base) { } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ -+ static int const kCount = ThreadMap::Iterations::kColumn; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ PredicatedTileIteratorParams params_; -+ -+ /// Byte-level pointer. This pointer is usually for both load() and store(), unless PermuteD is performed. When having PermuteD, byte_pointer_ is only for load(). -+ uint8_t *byte_pointer_; -+ -+ /// Byte-level pointer for store(). Due to PermuteD Op, store_byte_pointer_ may be with different address computation compared to byte_pointer_. -+ uint8_t *store_byte_pointer_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_column_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have been computed) -+ Index thread_start_row_; -+ -+ /// A thread's starting column -+ Index thread_start_column_; -+ -+ /// Internal state counter -+ int state_[3]; -+ -+ /// Scatter indices -+ int const *indices_; -+ -+ /// Whether to perform Permute Op -+ bool PermuteD; -+ /// PermuteDLayout -+ mutable PermuteDLayout permute_layout_; -+ -+ // -+ // Static asserts about internal strides -+ // -+ -+ static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIterator( -+ PredicatedTileIteratorParams const & params, -+ Element *pointer, -+ TensorCoord extent, -+ int thread_idx, -+ TensorCoord threadblock_offset = TensorCoord(), -+ int const *indices = nullptr -+ ): -+ params_(params), indices_(indices) -+ { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ extent_row_ = extent.row(); -+ extent_column_ = extent.column(); -+ -+ thread_start_row_ = thread_offset.row(); -+ thread_start_column_ = thread_offset.column(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { -+ -+ mask_.predicates[c] = ((thread_offset.column() -+ + ThreadMap::Delta::kColumn * c) < extent.column()); -+ } -+ -+ // Null pointer performs no accesses -+ if (!pointer) { -+ mask_.clear(); -+ } -+ -+ if (ScatterD && !indices) { -+ mask_.clear(); -+ } -+ -+ // Initialize byte_pointer_ -+ byte_pointer_ = reinterpret_cast(pointer) + -+ LongIndex(thread_offset.row()) * LongIndex(params_.stride) + -+ LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; -+ -+ if (ScatterD) { -+ byte_pointer_ = reinterpret_cast(pointer) + -+ LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; -+ } -+ -+ // store_byte_pointer_ is set to be the same with byte_pointer_ unless PermuteD is used. -+ store_byte_pointer_ = byte_pointer_; -+ -+ // Initialize PermuteD. If PermuteD is true, store_byte_pointer_ is initialized accordingly. -+ if (platform::is_same::value) { -+ PermuteD = false; -+ }else{ -+ PermuteD = true; -+ store_byte_pointer_ = reinterpret_cast(pointer); -+ permute_layout_ = PermuteDLayout(extent, -+ params_.stride * kElementsPerAccess / sizeof(AccessType)); -+ } -+ -+ // Initialize internal state counter -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ store_byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ if (ScatterD && row_guard) { -+ assert(indices_); -+ -+ memory_pointer = reinterpret_cast(byte_pointer + byte_offset + -+ LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / -+ kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ if (!ScatterD) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { -+ uint8_t *byte_pointer = store_byte_pointer_; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ if (ScatterD && row_guard) { -+ assert(indices_); -+ -+ memory_pointer = reinterpret_cast(byte_pointer + byte_offset + -+ LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ int col_offset = column * ThreadMap::Delta::kColumn; -+ -+ if (PermuteD) { -+ int col = col_offset + thread_start_column_; -+ int row = row_offset + thread_start_row_; -+ -+ TensorCoord init_coord(row, col); -+ -+ // Locate memory_pointer -+ memory_pointer = reinterpret_cast(byte_pointer + byte_offset -+ + permute_layout_(init_coord) * sizeof(AccessType) / kElementsPerAccess); -+ } -+ -+ if (UseCUDAStore) { -+ if (guard) { -+ memory_pointer[0] = -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; -+ } -+ } else { -+ cutlass::arch::global_store( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void *)&memory_pointer[0], -+ guard); -+ } -+ -+ if (!PermuteD) { -+ memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess); -+ } -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ if (!ScatterD && !PermuteD) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) const { -+ -+ store_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void downsample_load_with_byte_offset(Fragment &frag, int64_t byte_offset, int convolution_P, int convolution_Q, int add_P, int add_Q, int problem_N) const { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ int output_row = row_offset + thread_start_row_; -+ int output_N = output_row / (convolution_P * convolution_Q); -+ int output_PQ = output_row % (convolution_P * convolution_Q); -+ int output_P = output_PQ / convolution_Q; -+ int output_Q = output_PQ % convolution_Q; -+ -+ int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + -+ (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; -+ -+ int64_t byte_offset = (input_row-output_row)*problem_N*sizeof(float); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / -+ kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void upsample_load_with_byte_offset(Fragment &frag, int64_t byte_offset, int convolution_P, int convolution_Q, int add_P, int add_Q, int problem_N) const { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ int output_row = row_offset + thread_start_row_; -+ int output_N = output_row / (convolution_P * convolution_Q); -+ int output_PQ = output_row % (convolution_P * convolution_Q); -+ int output_P = output_PQ / convolution_Q; -+ int output_Q = output_PQ % convolution_Q; -+ int row_add_P = add_P; -+ int row_add_Q = add_Q; -+ if (output_P > convolution_P - 2) row_add_P = 0; -+ if (output_Q > convolution_Q - 2) row_add_Q = 0; -+ -+ int input_row = output_N * (convolution_P/2) * (convolution_Q/2) + -+ ((output_P + row_add_P)/2) * (convolution_Q/2) + (output_Q + row_add_Q)/2; -+ -+ int64_t byte_offset = (input_row-output_row)*problem_N*sizeof(float); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / -+ kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ MatrixCoord thread_start() const { -+ return MatrixCoord(thread_start_row_, thread_start_column_); -+ } -+ -+ /// Need to get the thread start row from the tile iterator -+ CUTLASS_DEVICE -+ int32_t thread_start_row() const { -+ return thread_start_row_; -+ } -+ -+ /// Need to get the thread start row from the tile iterator -+ CUTLASS_DEVICE -+ int32_t thread_start_column() const { -+ return thread_start_column_; -+ } -+ -+ /// Extent of the matrix in rows -+ CUTLASS_DEVICE -+ Index extent_row() const { -+ return extent_row_; -+ } -+ -+ /// Extent of the matrix in columns -+ CUTLASS_DEVICE -+ Index extent_column() const { -+ return extent_column_; -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ -+ ++state_[0]; -+ -+ if (!ScatterD && !PermuteD) { -+ store_byte_pointer_ += params_.advance_row; -+ } -+ -+ if (!ScatterD) { -+ byte_pointer_ += params_.advance_row; -+ } -+ -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ -+ if (state_[0] == ThreadMap::Count::kRow) { -+ -+ state_[0] = 0; -+ ++state_[1]; -+ byte_pointer_ += params_.advance_group; -+ store_byte_pointer_ += params_.advance_group; -+ -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ -+ state_[1] = 0; -+ ++state_[2]; -+ byte_pointer_ += params_.advance_cluster; -+ store_byte_pointer_ += params_.advance_cluster; -+ -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ byte_pointer_ += params_.advance_tile; -+ store_byte_pointer_ += params_.advance_tile; -+ -+ thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow -+ * ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile; -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances a number of positions to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator+=(int increment) -+ { -+ // Row -+ state_[0] += increment; -+ int increment_row = state_[0] / ThreadMap::Count::kRow; -+ state_[0] = state_[0] % ThreadMap::Count::kRow; -+ -+ byte_pointer_ += (params_.advance_row * increment); -+ store_byte_pointer_ += (params_.advance_row * increment); -+ thread_start_row_ += (ThreadMap::Shape::kRow * increment); -+ -+ // Group -+ state_[1] += increment_row; -+ int increment_group = state_[1] / ThreadMap::Count::kGroup; -+ state_[1] = state_[1] % ThreadMap::Count::kGroup; -+ -+ byte_pointer_ += (params_.advance_group * increment_row); -+ store_byte_pointer_ += (params_.advance_group * increment_row); -+ thread_start_row_ += -+ (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * -+ ThreadMap::Count::kRow * -+ increment_row; -+ -+ -+ // Cluster -+ state_[2] += increment_group; -+ int increment_cluster = state_[2] / ThreadMap::Count::kCluster; -+ state_[2] = state_[2] % ThreadMap::Count::kCluster; -+ -+ byte_pointer_ += (params_.advance_cluster * increment_group); -+ store_byte_pointer_ += (params_.advance_cluster * increment_group); -+ thread_start_row_ += -+ ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * -+ ThreadMap::Count::kRow * -+ ThreadMap::Shape::kRow * -+ increment_group; -+ -+ // Tile -+ byte_pointer_ += (params_.advance_tile * increment_cluster); -+ store_byte_pointer_ += (params_.advance_tile * increment_cluster); -+ thread_start_row_ += -+ ThreadMap::Shape::kGroup * -+ ThreadMap::Shape::kRow * -+ ThreadMap::Shape::kCluster * -+ ThreadMap::Shape::kTile * -+ increment_cluster; -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) const { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | InterleavedPredicatedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ int InterleavedN ///< Number of Interleaved N -+> -+class InterleavedPredicatedTileIterator { -+public: -+ using ThreadMap = ThreadMap_; -+ -+ using Element = Element_; -+ -+ using Layout = layout::ColumnMajorInterleaved; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Iterations::kCount; -+ -+ /// Fragment object -+ using Fragment = Array; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ /// Uses a non-template class -+ struct Params : InterleavedPredicatedTileIteratorParams { -+ using Base = InterleavedPredicatedTileIteratorParams; -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): -+ Base( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ make_InterleavedPredicatedTileIteratorDesc() -+ ) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) : -+ Base(base) { } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ static int const kCount = (ThreadMap::Iterations::kContiguous < 8) -+ ? 8 -+ : ThreadMap::Iterations::kContiguous; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ Params params_; -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in columns -+ Index extent_col_; -+ -+ /// A thread's starting column position (assuming steady-state predicates have -+ /// been computed) -+ Index thread_start_col_; -+ -+ /// Internal iteration counter -+ int iteration_contiguous_; -+ -+ int iteration_strided_; -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ InterleavedPredicatedTileIterator( -+ Params const & params, -+ Element *pointer, -+ TensorCoord extent, -+ int thread_idx, -+ TensorCoord threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ params_(params) { -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + -+ TensorCoord(threadblock_offset.contiguous() * InterleavedN, -+ threadblock_offset.strided() / InterleavedN); -+ -+ extent_col_ = extent.strided() / InterleavedN; -+ thread_start_col_ = thread_offset.strided(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ mask_.predicates[c] = -+ ((thread_offset.contiguous() + ThreadMap::Delta::kContiguous * c) < -+ (extent.contiguous() * InterleavedN)); -+ } -+ -+ // Initialize pointer -+ byte_pointer_ = reinterpret_cast(pointer) + -+ LongIndex(thread_offset.strided()) * LongIndex(params_.stride) + -+ LongIndex(thread_offset.contiguous()) * sizeof(AccessType) / kElementsPerAccess; -+ -+ // Initialize internal state counter -+ iteration_contiguous_ = iteration_strided_ = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer); -+ -+ int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided; -+ -+ bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); -+ -+ bool guard = col_guard && mask_.predicates[iteration_contiguous_]; -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ *frag_ptr, -+ (void *)memory_pointer, -+ guard); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer); -+ -+ int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided; -+ -+ bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); -+ -+ bool guard = col_guard && mask_.predicates[iteration_contiguous_]; -+ -+ cutlass::arch::global_store( -+ *frag_ptr, (void *)memory_pointer, guard); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int iteration) { -+ iteration_contiguous_ = iteration % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = iteration / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIterator &operator++() { -+ -+ ++iteration_contiguous_; -+ byte_pointer_ += params_.advance_row; -+ -+ if (iteration_contiguous_ == ThreadMap::Iterations::kContiguous) { -+ -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ byte_pointer_ += params_.advance_column; -+ -+ if (iteration_strided_ == ThreadMap::Iterations::kStrided) { -+ iteration_strided_ = 0; -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances a number of positions to load or store -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIterator &operator+=(int increment) -+ { -+ // Contiguous -+ iteration_contiguous_ += increment; -+ int increment_strided = iteration_contiguous_ / ThreadMap::Iterations::kContiguous; -+ iteration_contiguous_ = iteration_contiguous_ % ThreadMap::Iterations::kContiguous; -+ byte_pointer_ += (params_.advance_row * increment); -+ -+ // Strided -+ iteration_strided_ += increment_strided; -+ byte_pointer_ += (params_.advance_column * increment_strided); -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | InterleavedMaskedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ int InterleavedN ///< Number of Interleaved N -+> -+class InterleavedConvPredicatedTileIterator { -+public: -+ using ThreadMap = ThreadMap_; -+ -+ using Element = Element_; -+ -+ using Layout = layout::TensorNCxHWx; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = Tensor4DCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Iterations::kCount; -+ -+ /// Fragment object -+ using Fragment = Array; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ // -+ // Parameters struct -+ // -+ -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ LongIndex stride_col; ///< stride in bytes between columns -+ LongIndex stride_row; ///< stride in bytes between rows -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(typename Layout::Stride stride_) { -+ stride_col = stride_[1]; -+ stride_row = stride_[2]; -+ -+ return Status::kSuccess; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params() { -+ initialize(cutlass::make_Coord(0, 0, 0)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) { -+ -+ initialize(layout.stride()); -+ } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ static int const kCount = -+ (ThreadMap::Iterations::kRow < 8) ? 8 : ThreadMap::Iterations::kRow; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ Params params_; -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in columns -+ Index extent_col_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// Extent of the matrix tile in pq -+ Index extent_pq_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have -+ /// been computed) -+ Index thread_start_row_; -+ -+ /// A thread's starting column position (assuming steady-state predicates have -+ /// been computed) -+ Index thread_start_col_; -+ -+ /// Internal iteration counter -+ LongIndex iteration_row_; -+ LongIndex iteration_col_; -+ -+ uint32_t pq_mul_; -+ -+ uint32_t pq_shr_; -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ InterleavedConvPredicatedTileIterator( -+ Params const & params, -+ Element *pointer, -+ TensorCoord extent, -+ int thread_idx, -+ MatrixCoord threadblock_offset -+ ): -+ params_(params) { -+ MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ extent_col_ = extent.c(); -+ extent_pq_ = extent.h() * extent.w(); -+ extent_row_ = extent.n() * extent_pq_; -+ -+ find_divisor(pq_mul_, pq_shr_, extent_pq_); -+ -+ thread_start_row_ = thread_offset.row(); -+ thread_start_col_ = thread_offset.column(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int r = 0; r < ThreadMap::Iterations::kRow; ++r) { -+ mask_.predicates[r] = -+ ((thread_offset.row() + ThreadMap::Delta::kRow * r) < extent_row_); -+ } -+ -+ // Initialize pointer -+ byte_pointer_ = reinterpret_cast(pointer) + -+ ((thread_start_col_ / InterleavedN) * params_.stride_col + -+ (thread_start_col_ % InterleavedN)) * -+ sizeof_bits::value / 8; -+ -+ // Initialize internal state counter -+ iteration_row_ = iteration_col_ = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ int col_offset = iteration_col_ * ThreadMap::Delta::kColumn; -+ bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); -+ bool guard = col_guard && mask_.predicates[iteration_row_]; -+ -+ int n, pq_rem; -+ -+ fast_divmod(n, pq_rem, -+ thread_start_row_ + iteration_row_ * ThreadMap::Delta::kRow, -+ extent_pq_, pq_mul_, pq_shr_); -+ -+ uint8_t *byte_pointer = -+ byte_pointer_ + (n * params_.stride_row + pq_rem * InterleavedN) * -+ sizeof_bits::value / 8; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ AccessType const *memory_pointer = -+ reinterpret_cast(byte_pointer); -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ *frag_ptr, -+ (void *)memory_pointer, -+ guard); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ -+ int col_offset = iteration_col_ * ThreadMap::Delta::kColumn; -+ bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); -+ bool guard = col_guard && mask_.predicates[iteration_row_]; -+ -+ int n, pq_rem; -+ -+ fast_divmod(n, pq_rem, -+ thread_start_row_ + iteration_row_ * ThreadMap::Delta::kRow, -+ extent_pq_, pq_mul_, pq_shr_); -+ -+ uint8_t *byte_pointer = -+ byte_pointer_ + (n * params_.stride_row + pq_rem * InterleavedN) * -+ sizeof_bits::value / 8; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer); -+ -+ cutlass::arch::global_store( -+ *frag_ptr, (void *)memory_pointer, guard); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int iteration) { -+ iteration_row_ = iteration % ThreadMap::Iterations::kRow; -+ iteration_col_ = iteration / ThreadMap::Iterations::kRow; -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ InterleavedConvPredicatedTileIterator &operator++() { -+ -+ ++iteration_row_; -+ -+ if (iteration_row_ == ThreadMap::Iterations::kRow) { -+ -+ iteration_row_ = 0; -+ ++iteration_col_; -+ byte_pointer_ += params_.stride_col; -+ -+ if (iteration_col_ == ThreadMap::Iterations::kColumn) { -+ iteration_col_ = 0; -+ } -+ } -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h -new file mode 100644 -index 0000000..505f529 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h -@@ -0,0 +1,615 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load and store output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -+/// -+/// It provides a fast path for the case Rank = 2 which does not need div/rem to -+/// calculate modes. -+ -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ int Rank -+> -+class PredicatedTileIteratorAffineRankN { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::AffineRankN; -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Count::kTile; -+ -+ static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); -+ static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); -+ static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); -+ static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); -+ static_assert( !(Layout::kRank % 2), -+ "Layout rank must be even. This assumes the first half of the modes correspond to the 'row' " -+ "and the second half of the modes correspond to the 'column'"); -+ -+ static bool const kBigEndian = false; -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ Layout layout; -+ -+ /// Stride in units of bytes along M modes -+ Coord stride_m; -+ -+ /// Stride in units of bytes along N modes -+ Coord stride_n; -+ -+ /// Fast divmod objects divided by tensor extents -+ FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; -+ -+ /// Fast divmod objects divided by tensor extents -+ FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; -+ -+ int64_t rank2_inc_col; -+ int64_t rank2_inc_row; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(TensorCoord const &extent, Layout const &layout_): layout(layout_) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2; ++i) { -+ stride_m[i] = OffsetBytes(layout_.stride()[i]); -+ stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2]); -+ } -+ -+ if (kBigEndian) { -+ // "Big Endian" scheme -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { -+ divmod_m[i] = FastDivmod(extent[i + 1]); -+ divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]); -+ } -+ } -+ else { -+ // "Little Endian" scheme -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { -+ divmod_m[i] = FastDivmod(extent[i]); -+ divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]); -+ } -+ } -+ -+ #if 0 -+ // -+ // Debug print statements to verify extents and strides are passed correctly. -+ // -+ printf("PredicatedTileIteratorAffine::Params() entered\n"); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank; ++i) { -+ printf(" extent[%d]: %d\n", i, extent[i]); -+ } -+ for (int i = 0; i < Layout::kRank; ++i) { -+ printf(" stride[%d]: %ld\n", i, layout_.stride()[i]); -+ } -+ printf("PredicatedTileIteratorAffine::Params() returning\n"); -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout_): layout(layout_) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2; ++i) { -+ stride_m[i] = OffsetBytes(layout_.stride()[i]); -+ stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2]); -+ } -+ -+ rank2_inc_col = ThreadMap::Delta::kColumn * stride_n[0]; -+ rank2_inc_row = ThreadMap::Delta::kRow * stride_m[0]; -+ } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ -+ static int const kCount = ThreadMap::Iterations::kColumn; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ Params params_; -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// Extent of the matrix tile in columns -+ Index extent_col_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have been computed) -+ Index thread_start_row_; -+ -+ /// A thread's starting column position (assuming steady-state predicates have been computed) -+ Index thread_start_column_; -+ -+ /// Internal state counter -+ int state_[3]; -+ -+ /// Offsets in columns, cached for performance -+ int64_t offset_modes_n_[ThreadMap::Iterations::kColumn]; -+ -+ // -+ // Static asserts about internal strides -+ // -+ -+ static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIteratorAffineRankN( -+ Params const & params, -+ Element *pointer, -+ MatrixCoord extent, -+ int thread_idx, -+ MatrixCoord threadblock_offset = MatrixCoord(), -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ params_(params) -+ { -+ -+ MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ extent_row_ = extent.row(); -+ extent_col_ = extent.column(); -+ -+ thread_start_row_ = thread_offset.row(); -+ thread_start_column_ = thread_offset.column(); -+ -+ if (Layout::kRank > 2) { -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { -+ -+ // -+ // Compute coordinate and decompose into N modes -+ // -+ -+ int coord_n = thread_start_column_ + c * ThreadMap::Delta::kColumn; -+ -+ mask_.predicates[c] = coord_n < extent.column(); -+ -+ Coord modes_n; -+ -+ int64_t offset_modes_n = 0; -+ -+ if (kBigEndian) { -+ modes_n = CoordinateDecomposition(coord_n, params_.divmod_n); -+ -+ offset_modes_n = dot(modes_n, params_.stride_n); -+ } -+ else { -+ modes_n = CoordinateDecompositionLittleEndian(coord_n, params_.divmod_n); -+ -+ offset_modes_n = dot(modes_n, params_.stride_n); -+ } -+ -+ offset_modes_n_[c] = offset_modes_n; -+ -+ } -+ -+ if (!pointer) { -+ mask_.clear(); -+ } -+ } -+ -+ // Initialize pointer -+ byte_pointer_ = reinterpret_cast(pointer); -+ -+ // Initialize internal state counter -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { -+ uint8_t const *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ int row_begin = thread_start_row_ + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; -+ int64_t offset_modes_m = row_begin * params_.stride_m[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ // -+ // Compute coordinate and decompose into M modes -+ // -+ -+ int coord_m = row * ThreadMap::Delta::kRow + row_begin; -+ -+ Coord modes_m; -+ -+ if (Layout::kRank > 2) { -+ if (kBigEndian) { -+ modes_m = CoordinateDecomposition(coord_m, params_.divmod_m); -+ } else { -+ modes_m = CoordinateDecompositionLittleEndian(coord_m, params_.divmod_m); -+ } -+ -+ offset_modes_m = dot(modes_m, params_.stride_m); -+ } -+ -+ // -+ // Compute the offset due to modes M -+ // -+ -+ bool row_guard = (coord_m < extent_row_); -+ int64_t offset_modes_n = thread_start_column_ * params_.stride_n[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ // -+ // Compute coordinate and decompose into N modes -+ // -+ -+ if (Layout::kRank > 2) { -+ offset_modes_n = offset_modes_n_[column]; -+ } -+ -+ // -+ // Compute the pointer and access -+ // -+ bool guard; -+ -+ if (Layout::kRank > 2) { -+ guard = row_guard && mask_.predicates[column]; -+ } else { -+ guard = (coord_m < extent_row_) && -+ ((thread_start_column_ + ThreadMap::Delta::kColumn * column) < extent_col_); -+ } -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void *)(byte_pointer + offset_modes_m + offset_modes_n + byte_offset), -+ guard -+ ); -+ -+ if (Layout::kRank == 2) { -+ offset_modes_n += params_.rank2_inc_col; -+ } -+ } -+ -+ if (Layout::kRank == 2) { -+ offset_modes_m += params_.rank2_inc_row; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ int row_begin = thread_start_row_ + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; -+ int64_t offset_modes_m = row_begin * params_.stride_m[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ // -+ // Compute coordinate and decompose into M modes -+ // -+ -+ int coord_m = row * ThreadMap::Delta::kRow + row_begin; -+ -+ Coord modes_m; -+ -+ if (Layout::kRank > 2) { -+ if (kBigEndian) { -+ modes_m = CoordinateDecomposition(coord_m, params_.divmod_m); -+ } else { -+ modes_m = CoordinateDecompositionLittleEndian(coord_m, params_.divmod_m); -+ } -+ -+ offset_modes_m = dot(modes_m, params_.stride_m); -+ } -+ -+ // -+ // Compute the offset due to modes M -+ // -+ -+ bool row_guard = (coord_m < extent_row_); -+ int64_t offset_modes_n = thread_start_column_ * params_.stride_n[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ // -+ // Compute coordinate and decompose into N modes -+ // -+ -+ if (Layout::kRank > 2) { -+ offset_modes_n = offset_modes_n_[column]; -+ } -+ -+ // -+ // Compute the pointer and access -+ // -+ bool guard; -+ if (Layout::kRank > 2) { -+ guard = row_guard && mask_.predicates[column]; -+ } else { -+ guard = (coord_m < extent_row_) && ((thread_start_column_ + ThreadMap::Delta::kColumn * column) < extent_col_); -+ } -+ -+ cutlass::arch::global_store( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void *)(byte_pointer + offset_modes_m + offset_modes_n + byte_offset), -+ guard); -+ -+ if (Layout::kRank == 2) { -+ offset_modes_n += params_.rank2_inc_col; -+ } -+ } -+ -+ if (Layout::kRank == 2) { -+ offset_modes_m += params_.rank2_inc_row; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ -+ store_with_byte_offset(frag, 0); -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorAffineRankN &operator++() { -+ -+ ++state_[0]; -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ -+ if (state_[0] == ThreadMap::Count::kRow) { -+ -+ state_[0] = 0; -+ ++state_[1]; -+ -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ -+ state_[1] = 0; -+ ++state_[2]; -+ -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h -new file mode 100644 -index 0000000..7832fde ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h -@@ -0,0 +1,156 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/fast_math.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int Rank -+> -+struct PredicatedTileIteratorAffineLayoutRankNParams { -+ using Layout = layout::AffineRankN; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ static bool const kBigEndian = false; -+ -+ // -+ // Data members -+ // -+ -+ Layout layout; -+ -+ /// Stride in units of bytes along M modes -+ Coord stride_m; -+ -+ /// Stride in units of bytes along N modes -+ Coord stride_n; -+ -+ /// Fast divmod objects divided by tensor extents -+ FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; -+ -+ /// Fast divmod objects divided by tensor extents -+ FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; -+ -+ int64_t rank2_inc_col; -+ int64_t rank2_inc_row; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorAffineLayoutRankNParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorAffineLayoutRankNParams(TensorCoord const &extent, -+ Layout const &layout_, -+ int64_t element_sizeof_bits) -+ : layout(layout_) -+ { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2; ++i) { -+ stride_m[i] = OffsetBytes(layout_.stride()[i], element_sizeof_bits); -+ stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2], element_sizeof_bits); -+ } -+ -+ if (kBigEndian) { -+ // "Big Endian" scheme -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { -+ divmod_m[i] = FastDivmod(extent[i + 1]); -+ divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]); -+ } -+ } -+ else { -+ // "Little Endian" scheme -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { -+ divmod_m[i] = FastDivmod(extent[i]); -+ divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]); -+ } -+ } -+ -+ #if 0 -+ // -+ // Debug print statements to verify extents and strides are passed correctly. -+ // -+ printf("PredicatedTileIteratorAffine::Params() entered\n"); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank; ++i) { -+ printf(" extent[%d]: %d\n", i, extent[i]); -+ } -+ for (int i = 0; i < Layout::kRank; ++i) { -+ printf(" stride[%d]: %ld\n", i, layout_.stride()[i]); -+ } -+ printf("PredicatedTileIteratorAffine::Params() returning\n"); -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorAffineLayoutRankNParams(Layout const &layout_, -+ int32_t threadmap_delta_kColumn, -+ int32_t threadmap_delta_kRow, -+ int64_t element_sizeof_bits) -+ : layout(layout_) -+ { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2; ++i) { -+ stride_m[i] = OffsetBytes(layout_.stride()[i], element_sizeof_bits); -+ stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2], element_sizeof_bits); -+ } -+ -+ rank2_inc_col = threadmap_delta_kColumn * stride_n[0]; -+ rank2_inc_row = threadmap_delta_kRow * stride_m[0]; -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h -new file mode 100644 -index 0000000..9aab017 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h -@@ -0,0 +1,633 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load and store output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ BlasMode BlasMode_ = BlasMode::kGemm ///< Tile Iterator for a Symmetric or Hermitian Kernel -+> -+class PredicatedTileIteratorBlas3 { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Count::kTile; -+ -+ static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); -+ static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); -+ static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); -+ static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ static_assert( AccessType::kElements == 1, "BLAS3 Epilogue must use AccessType::kElements as 1"); -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorParams { -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): -+ PredicatedTileIteratorParams( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ make_OutputTileThreadMapDesc() -+ ) -+ { -+ -+ } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ -+ static int const kCount = ThreadMap::Iterations::kColumn; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ PredicatedTileIteratorParams params_; -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Fill Mode for a tile on diagonal of a symmetric kernel -+ cutlass::FillMode fill_mode; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have been computed) -+ Index thread_start_row_; -+ -+ /// Internal state counter -+ int state_[3]; -+ -+ /// Starting address of the matrix -+ size_t matrix_start_addr; -+ -+ static_assert((kBlasMode == BlasMode::kSymmetric || kBlasMode == BlasMode::kHermitian), -+ "Unsupported blas3 mode."); -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIteratorBlas3( -+ PredicatedTileIteratorParams const & params, -+ Element *pointer, -+ TensorCoord extent, -+ int thread_idx, -+ TensorCoord threadblock_offset -+ , cutlass::FillMode fill_mode -+ ): -+ params_(params), fill_mode(fill_mode) -+ { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ extent_row_ = extent.row(); -+ thread_start_row_ = thread_offset.row(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { -+ -+ mask_.predicates[c] = ((thread_offset.column() -+ + ThreadMap::Delta::kColumn * c) < extent.column()); -+ } -+ -+ // Check Symmetric kernel modes (Lower and Upper - for diagonal CTAs, None for rest CTAs) -+ if ((kBlasMode == BlasMode::kSymmetric || kBlasMode == BlasMode::kHermitian) && -+ fill_mode == cutlass::FillMode::kInvalid) { -+ arch::device_breakpoint(); -+ } -+ -+ // Starting address of the matrix -+ matrix_start_addr = reinterpret_cast(pointer); -+ -+ // Initialize pointer -+ byte_pointer_ = reinterpret_cast(pointer) + -+ LongIndex(thread_offset.row()) * LongIndex(params_.stride) + -+ LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; -+ -+ // Initialize internal state counter -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / -+ kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment on the diagonal of a symmetric kernel to memory -+ CUTLASS_DEVICE -+ void load_symmetric_with_byte_offset(Fragment &frag, int64_t byte_offset) { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ bool isLowerMode = (fill_mode == cutlass::FillMode::kLower) ? true : false; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ // Offset of row from beginning of the matrix per thread -+ size_t row_start_offset = (size_t)memory_pointer - matrix_start_addr; -+ -+ // Absolute row index -+ int row_index = int(row_start_offset/params_.stride); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ // Offset of column from beginning of row per thread -+ size_t col_start_offset = row_start_offset + -+ (column * ThreadMap::Delta::kColumn / kElementsPerAccess) * sizeof(AccessType); -+ -+ // Absolute column index -+ size_t col_index = (col_start_offset%params_.stride)/sizeof(AccessType); -+ guard = guard && ( (isLowerMode && row_index >= col_index) || -+ (!isLowerMode && row_index <= col_index) ); -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / -+ kElementsPerAccess], -+ guard); -+ -+ // The imaginary parts of the diagonal elements of a complex element are assumed and set to zero -+ if (guard && kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { -+ Element *scalar_ptr = reinterpret_cast(frag_ptr); -+ -+ if (row_index == col_index) { -+ scalar_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = -+ real(scalar_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]); -+ } -+ } -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ if (fill_mode == cutlass::FillMode::kNone) { -+ load_with_byte_offset(frag, 0); -+ } -+ else { -+ load_symmetric_with_byte_offset(frag, 0); -+ } -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_store( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Stores a fragment on the diagonal of a symmetric kernel to memory -+ CUTLASS_DEVICE -+ void store_symmetric_with_byte_offset(Fragment const &frag, int64_t byte_offset) { -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ bool isLowerMode = (fill_mode == cutlass::FillMode::kLower) ? true : false; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ // Offset of row from beginning of the matrix per thread -+ size_t row_start_offset = (size_t)memory_pointer - matrix_start_addr; -+ -+ // Absolute row index -+ int row_index = int(row_start_offset/params_.stride); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ // Offset of column from beginning of row per thread -+ size_t col_start_offset = row_start_offset + -+ (column * ThreadMap::Delta::kColumn / kElementsPerAccess) * sizeof(AccessType); -+ -+ // Absolute column index -+ size_t col_index = (col_start_offset%params_.stride)/sizeof(AccessType); -+ -+ guard = guard && ( (isLowerMode && row_index >= col_index) || -+ (!isLowerMode && row_index <= col_index) ); -+ -+ // The imaginary parts of the diagonal elements of a complex element are assumed and set to zero -+ if (guard && kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { -+ -+ AccessType *frag_ptr_modify = const_cast(frag_ptr); -+ Element *scalar_ptr = reinterpret_cast(frag_ptr_modify); -+ -+ if (row_index == col_index) { -+ scalar_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = -+ real(scalar_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]); -+ } -+ } -+ -+ cutlass::arch::global_store( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / -+ kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ -+ if (fill_mode == cutlass::FillMode::kNone) { -+ store_with_byte_offset(frag, 0); -+ } -+ else { -+ store_symmetric_with_byte_offset(frag, 0); -+ } -+ -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorBlas3 &operator++() { -+ -+ ++state_[0]; -+ byte_pointer_ += params_.advance_row; -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ -+ if (state_[0] == ThreadMap::Count::kRow) { -+ -+ state_[0] = 0; -+ ++state_[1]; -+ byte_pointer_ += params_.advance_group; -+ -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ -+ state_[1] = 0; -+ ++state_[2]; -+ byte_pointer_ += params_.advance_cluster; -+ -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ byte_pointer_ += params_.advance_tile; -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h -new file mode 100644 -index 0000000..a641f60 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h -@@ -0,0 +1,445 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/permute.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load and store output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: PitchLinearThreadMap) -+ typename Element_, ///< Element data type -+ typename ThreadOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>, -+ typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> -+> -+class PredicatedTileIteratorDirectConv { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ using ThreadOutputShape = ThreadOutputShape_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ -+ using ConvProblemSize = typename cutlass::conv::Conv2dProblemSize; -+ -+ /// Fragment object -+ using Fragment = Array; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ static int const kLoadsPerAccess = AccessType::kElements / AccessType::kElements; -+ -+ using ThreadTileCount = MatrixShape< -+ ThreadBlockOutputShape::kH / ThreadOutputShape::kH, -+ ThreadBlockOutputShape::kW / ThreadOutputShape::kW -+ >; -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorDirect2dConvParams { -+ using Base = PredicatedTileIteratorDirect2dConvParams; -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout, cutlass::conv::Conv2dProblemSize const &problem_size): -+ PredicatedTileIteratorDirect2dConvParams( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ problem_size, -+ {ThreadBlockOutputShape::kH, ThreadBlockOutputShape::kW} -+ ) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) : -+ Base(base) { } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ -+ static int const kCount = ThreadMap::Iterations::kContiguous; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ PredicatedTileIteratorDirect2dConvParams params_; -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// -+ Element *pointer_; -+ -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_column_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have been computed) -+ Index thread_start_row_; -+ -+ /// A thread's starting column -+ Index thread_start_column_; -+ -+ /// Initial thread ouput location -+ int thread_start_n_, thread_start_p_, thread_start_q_; -+ -+ /// Current threadblock tile index -+ int tile_index_; -+ -+ // -+ // Static asserts about internal strides -+ // -+ -+ static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(PredicatedTileIteratorDirect2dConvParams::stride) == 8, "Expected 64b strides"); -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+ -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIteratorDirectConv( -+ PredicatedTileIteratorDirect2dConvParams const & params, -+ Element *pointer, -+ TensorCoord extent, -+ int thread_idx, -+ TensorCoord threadblock_offset = TensorCoord() -+ ): -+ params_(params), pointer_(pointer) -+ { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); -+ -+ extent_row_ = extent.row(); -+ extent_column_ = extent.column(); -+ -+ // stride dim (PQ) -+ thread_start_row_ = thread_offset.column(); -+ // contiguous dim (Channels) -+ thread_start_column_ = threadblock_offset.column() + thread_offset.row(); -+ -+ tile_index_ = threadblock_offset.row(); -+ -+ set_tile_index(0); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void set_tile_index(const int index) { -+ -+ int residual; -+ params_.pq_divmod(thread_start_n_, residual, tile_index_ + index); -+ params_.q_divmod(thread_start_p_, thread_start_q_, residual); -+ -+ // Compute the base output coord of ThreadBlock -+ thread_start_p_ *= ThreadBlockOutputShape::kH; -+ thread_start_q_ *= ThreadBlockOutputShape::kW; -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ mask_.predicates[c] = ((thread_start_column_ -+ + c * ThreadMap::Delta::kContiguous) < extent_column_); -+ } -+ -+ // Null pointer performs no accesses -+ if (!pointer_) { -+ mask_.clear(); -+ } -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c; -+ -+ int current_row = thread_start_row_ + s * ThreadMap::Delta::kStrided; -+ int p = current_row / ThreadBlockOutputShape::kW; -+ int q = current_row % ThreadBlockOutputShape::kW; -+ -+ int current_p = thread_start_p_ + p; -+ int current_q = thread_start_q_ + q; -+ -+ bool row_guard = (current_p) < params_.P && (current_q) < params_.Q && -+ (thread_start_n_ < params_.N) && current_row < ThreadMap::Shape::kStrided; -+ -+ int output_row_offset = -+ thread_start_n_ * params_.stride_n + current_p * params_.stride_p + current_q; -+ -+ uint8_t *byte_pointer = -+ reinterpret_cast(pointer_) + -+ LongIndex(output_row_offset) * LongIndex(params_.stride) + -+ LongIndex(thread_start_column_ + c * ThreadMap::Delta::kContiguous) * -+ sizeof(AccessType) / kElementsPerAccess; -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ bool guard = row_guard && mask_.predicates[c]; -+ -+ cutlass::arch::global_load( -+ frag_ptr[frag_base_idx], (void *)&memory_pointer[0], guard); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) const { -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c; -+ -+ int current_row = thread_start_row_ + s * ThreadMap::Delta::kStrided; -+ int p = current_row / ThreadBlockOutputShape::kW; -+ int q = current_row % ThreadBlockOutputShape::kW; -+ -+ int current_p = thread_start_p_ + p; -+ int current_q = thread_start_q_ + q; -+ -+ bool row_guard = (current_p) < params_.P && (current_q) < params_.Q && -+ (thread_start_n_ < params_.N) && current_row < ThreadMap::Shape::kStrided; -+ -+ int output_row_offset = -+ thread_start_n_ * params_.stride_n + current_p * params_.stride_p + current_q; -+ -+ uint8_t *byte_pointer = -+ reinterpret_cast(pointer_) + -+ LongIndex(output_row_offset) * LongIndex(params_.stride) + -+ LongIndex(thread_start_column_ + c * ThreadMap::Delta::kContiguous) * -+ sizeof(AccessType) / kElementsPerAccess; -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ bool guard = row_guard && mask_.predicates[c]; -+ -+ cutlass::arch::global_store( -+ frag_ptr[frag_base_idx], (void *)&memory_pointer[0], guard); -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) const { -+ -+ store_with_byte_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ MatrixCoord thread_start() const { -+ return MatrixCoord(thread_start_row_, thread_start_column_); -+ } -+ -+ /// Need to get the thread start row from the tile iterator -+ CUTLASS_DEVICE -+ int32_t thread_start_row() const { -+ return thread_start_row_; -+ } -+ -+ /// Need to get the thread start row from the tile iterator -+ CUTLASS_DEVICE -+ int32_t thread_start_column() const { -+ return thread_start_column_; -+ } -+ -+ /// Extent of the matrix in rows -+ CUTLASS_DEVICE -+ Index extent_row() const { -+ return extent_row_; -+ } -+ -+ /// Extent of the matrix in columns -+ CUTLASS_DEVICE -+ Index extent_column() const { -+ return extent_column_; -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorDirectConv &operator++() { -+ // do nothing -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) const { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h -new file mode 100644 -index 0000000..937409a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h -@@ -0,0 +1,475 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct OutputTileShapeDesc { -+ -+ int column; -+ int row; -+ int group; -+ int cluster; -+ int tile; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ OutputTileShapeDesc(): column(0), row(0), group(0), cluster(0), tile(0) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ OutputTileShapeDesc( -+ int column_, -+ int row_, -+ int group_, -+ int cluster_, -+ int tile_ -+ ): -+ column(column_), -+ row(row_), -+ group(group_), -+ cluster(cluster_), -+ tile(tile_) { } -+ -+ /// Total number of points in the 5D space -+ CUTLASS_HOST_DEVICE -+ int count() const { -+ return column * row * group * cluster * tile; -+ } -+ -+ #if 0 -+ CUTLASS_HOST_DEVICE -+ void print() const { -+ printf("{%d, %d, %d, %d, %d}", column, row, group, cluster, tile); -+ } -+ #endif -+}; -+ -+/// Helper template to construct an OutputTileShapeDesc from a OutputTileShape template. -+template -+CUTLASS_HOST_DEVICE -+OutputTileShapeDesc make_OutputTileShapeDesc() { -+ return OutputTileShapeDesc( -+ Shape::kColumn, -+ Shape::kRow, -+ Shape::kGroup, -+ Shape::kCluster, -+ Shape::kTile -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Thread map description -+struct OutputTileThreadMapDesc { -+ -+ int threads; -+ int elements_per_access; -+ OutputTileShapeDesc shape; -+ OutputTileShapeDesc iterations; -+ OutputTileShapeDesc delta; -+ OutputTileShapeDesc count; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ OutputTileThreadMapDesc() { } -+ -+ CUTLASS_HOST_DEVICE -+ OutputTileThreadMapDesc( -+ int threads_, -+ int elements_per_access_, -+ OutputTileShapeDesc shape_, -+ OutputTileShapeDesc iterations_, -+ OutputTileShapeDesc delta_, -+ OutputTileShapeDesc count_ -+ ): -+ threads(threads_), -+ elements_per_access(elements_per_access_), -+ shape(shape_), -+ iterations(iterations_), -+ delta(delta_), -+ count(count_) -+ { -+ -+ } -+}; -+ -+/// Helper template to construct an OutputTileShapeDesc from a OutputTileThreadMap template. -+template -+CUTLASS_HOST_DEVICE -+OutputTileThreadMapDesc make_OutputTileThreadMapDesc() { -+ return OutputTileThreadMapDesc( -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ make_OutputTileShapeDesc(), -+ make_OutputTileShapeDesc(), -+ make_OutputTileShapeDesc(), -+ make_OutputTileShapeDesc() -+ ); -+} -+/////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Parameters struct for PredicatedTileIterator -+// -+ -+struct PredicatedTileIteratorParams { -+ -+ using Index = int32_t; -+ using LongIndex = int64_t; -+ -+ // -+ // Data members -+ // -+ -+ LongIndex stride; ///< stride in bytes between rows -+ -+ LongIndex increment_row; ///< increment quantity (in bytes) to advance when moving between rows -+ LongIndex increment_group; ///< increment quantity (in bytes) to advance when moving to the next group -+ LongIndex increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster -+ -+ LongIndex advance_row; ///< amount to add to move to the next 'row' position -+ LongIndex advance_group; ///< amount to add to move to the next 'group' position -+ LongIndex advance_cluster; ///< amount to add to move to the next 'cluster' position -+ LongIndex advance_tile; ///< amount to add to move to the next 'tile' -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(LongIndex stride_, OutputTileThreadMapDesc thread_map) { -+ -+ stride = stride_; -+ -+ increment_row = stride * thread_map.delta.row; -+ -+ increment_group = stride * thread_map.delta.group -+ - stride * thread_map.delta.row * (thread_map.iterations.row - 1); -+ -+ increment_cluster = stride * thread_map.delta.cluster -+ - stride * thread_map.delta.group * (thread_map.iterations.group - 1) -+ - stride * thread_map.delta.row * (thread_map.iterations.row - 1); -+ -+ advance_row = stride * thread_map.shape.row; -+ -+ advance_group = -+ stride * -+ (thread_map.shape.group - 1) * thread_map.shape.row * thread_map.count.row; -+ -+ advance_cluster = -+ stride * -+ thread_map.count.group * -+ thread_map.shape.group * -+ thread_map.count.row * -+ thread_map.shape.row; -+ -+ advance_tile = -+ stride * -+ thread_map.shape.group * -+ thread_map.shape.row * -+ thread_map.shape.cluster * -+ thread_map.shape.tile; -+ -+ return Status::kSuccess; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(Index stride_, OutputTileThreadMapDesc thread_map) { -+ return initialize(LongIndex(stride_), thread_map); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorParams() { -+ initialize(LongIndex(0), OutputTileThreadMapDesc()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorParams(Index stride, OutputTileThreadMapDesc thread_map) { -+ initialize(stride, thread_map); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorParams(LongIndex stride, OutputTileThreadMapDesc thread_map) { -+ initialize(stride, thread_map); -+ } -+}; -+ -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Parameters struct for PredicatedTileIteratorDirect2dConv -+// -+ -+struct PredicatedTileIteratorDirect2dConvParams{ -+ using Index = int32_t; -+ using LongIndex = int64_t; -+ -+ // -+ // Data members -+ // -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ LongIndex stride; -+ LongIndex stride_n; -+ LongIndex stride_p; -+ -+ int N; -+ int P; -+ int Q; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(LongIndex stride_, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ MatrixCoord threadblock_output_shape) { -+ stride = stride_; // The stride per row of output tensor (bytes) -+ stride_n = problem_size.P * problem_size.Q; -+ stride_p = problem_size.Q ; -+ -+ N = problem_size.N; -+ P = problem_size.P; -+ Q = problem_size.Q; -+ -+ // Fastdivmod for output O, P, Q -+ if(threadblock_output_shape.row() != 0 && threadblock_output_shape.column() !=0 ){ -+ int tiles_p = -+ (problem_size.P + (threadblock_output_shape.row() - 1)) / (threadblock_output_shape.row()); -+ int tiles_q = (problem_size.Q + (threadblock_output_shape.column() - 1)) / -+ (threadblock_output_shape.column()); -+ -+ pq_divmod = FastDivmod(tiles_p * tiles_q); -+ q_divmod = FastDivmod(tiles_q); -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize( -+ Index stride_, -+ cutlass::conv::Conv2dProblemSize const &problem_size = cutlass::conv::Conv2dProblemSize(), -+ MatrixCoord threadblock_output_shape = MatrixCoord()) { -+ return initialize(LongIndex(stride_), problem_size, threadblock_output_shape); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorDirect2dConvParams() { initialize(LongIndex(0)); } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorDirect2dConvParams(Index stride, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ MatrixCoord threadblock_output_shape) { -+ initialize(stride, problem_size, threadblock_output_shape); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorDirect2dConvParams(LongIndex stride, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ MatrixCoord threadblock_output_shape) { -+ initialize(stride, problem_size, threadblock_output_shape); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+// InterleavedPredicatedTileIterator -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Predicated tile access iterator descriptor object containing template dependent state -+struct InterleavedPredicatedTileIteratorDesc { -+ -+ int element_size_bits; -+ int elements_per_access; -+ int threadmap_warp_size; -+ layout::PitchLinearCoord threadmap_iterations; -+ layout::PitchLinearCoord threadmap_delta; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIteratorDesc() { } -+ -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIteratorDesc( -+ int element_size_bits_, -+ int elements_per_access_, -+ int threadmap_warp_size_, -+ layout::PitchLinearCoord threadmap_iterations_, -+ layout::PitchLinearCoord threadmap_delta_ -+ ): -+ element_size_bits(element_size_bits_), -+ elements_per_access(elements_per_access_), -+ threadmap_warp_size(threadmap_warp_size_), -+ threadmap_iterations(threadmap_iterations_), -+ threadmap_delta(threadmap_delta_) { } -+}; -+ -+// -+// Parameters struct InterleavedPredicatedTileIterator -+// -+ -+struct InterleavedPredicatedTileIteratorParams { -+ -+ using Index = int32_t; -+ using LongIndex = int64_t; -+ -+ // -+ // Data members -+ // -+ -+ LongIndex stride; ///< stride in bytes between rows -+ LongIndex advance_row; ///< amount to add to move to the next 'row' position -+ LongIndex advance_column; ///< amount to add to move to the next 'column' position -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(LongIndex stride_, InterleavedPredicatedTileIteratorDesc desc) { -+ -+ stride = stride_; -+ -+ advance_row = desc.threadmap_delta.contiguous() * desc.element_size_bits / 8; -+ -+ advance_column = stride_ - desc.threadmap_iterations.contiguous() * -+ desc.elements_per_access * -+ desc.element_size_bits * -+ desc.threadmap_warp_size / 8; -+ -+ return Status::kSuccess; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIteratorParams() { -+ initialize(LongIndex(0), InterleavedPredicatedTileIteratorDesc()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIteratorParams(Index stride, InterleavedPredicatedTileIteratorDesc desc) { -+ initialize(stride, desc); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIteratorParams(LongIndex stride, InterleavedPredicatedTileIteratorDesc desc) { -+ initialize(stride, desc); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Helper template to construct an OutputTileShapeDesc from a OutputTileThreadMap template. -+template -+CUTLASS_HOST_DEVICE -+InterleavedPredicatedTileIteratorDesc make_InterleavedPredicatedTileIteratorDesc() { -+ return InterleavedPredicatedTileIteratorDesc( -+ sizeof_bits::value, -+ ThreadMap::kElementsPerAccess, -+ ThreadMap::kWarpSize, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Helper template to construct an MakePredicatedTileIteratorDesc from a template -+// dependent state -+template -+ struct MakePredicatedTileIteratorDesc; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for layout::RowMajor output data. -+template -+struct MakePredicatedTileIteratorDesc < -+ Element, layout::RowMajor, ThreadMap> { -+ -+ CUTLASS_HOST_DEVICE -+ OutputTileThreadMapDesc operator()() { -+ -+ return make_OutputTileThreadMapDesc(); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for layout::ColumnMajorInterleaved output data. -+template -+struct MakePredicatedTileIteratorDesc < -+ Element, layout::ColumnMajorInterleaved, ThreadMap> { -+ -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIteratorDesc operator()() { -+ -+ return make_InterleavedPredicatedTileIteratorDesc(); -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h -new file mode 100644 -index 0000000..36202be ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h -@@ -0,0 +1,309 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief PredicatedTileIteratorPredicates. -+ -+ PredicatedTileIteratorPredicates enables both upper and lower bounds for predicates. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator predicates used to bound computations in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_ ///< Element data type -+> -+class PredicatedTileIteratorPredicates { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Count::kTile; -+ -+ static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); -+ static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); -+ static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); -+ static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorParams { -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): -+ PredicatedTileIteratorParams( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ make_OutputTileThreadMapDesc() -+ ) -+ { -+ -+ } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ -+ static int const kCount = ThreadMap::Iterations::kColumn; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ PredicatedTileIteratorParams params_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index lower_extent_row_; -+ Index upper_extent_row_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have been computed) -+ Index thread_start_row_; -+ -+ /// Internal state counter -+ int state_[3]; -+ -+ // -+ // Static asserts about internal strides -+ // -+ -+ static_assert(sizeof(lower_extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(upper_extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIteratorPredicates( -+ PredicatedTileIteratorParams const & params, -+ TensorCoord lower_extent, -+ TensorCoord upper_extent, -+ int thread_idx, -+ TensorCoord threadblock_offset = TensorCoord() -+ ): -+ params_(params) -+ { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ lower_extent_row_ = lower_extent.row(); -+ upper_extent_row_ = upper_extent.row(); -+ thread_start_row_ = thread_offset.row(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { -+ -+ mask_.predicates[c] = ((thread_offset.column() -+ + ThreadMap::Delta::kColumn * c) < upper_extent.column()) && -+ ((thread_offset.column() + ThreadMap::Delta::kColumn * c) >= lower_extent.column()); -+ } -+ -+ // Initialize internal state counter -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorPredicates &operator++() { -+ -+ ++state_[0]; -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ -+ if (state_[0] == ThreadMap::Count::kRow) { -+ -+ state_[0] = 0; -+ ++state_[1]; -+ -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ -+ state_[1] = 0; -+ ++state_[2]; -+ -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Gets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+ -+ ///< Gets lower_extent_row_ -+ CUTLASS_DEVICE Index get_lower_extent_row() { -+ return lower_extent_row_; -+ } -+ -+ ///< Gets upper_extent_row_ -+ CUTLASS_DEVICE Index get_upper_extent_row() { -+ return upper_extent_row_; -+ } -+ -+ ///< Gets thread_start_row_ -+ CUTLASS_DEVICE Index get_thread_start_row() { -+ return thread_start_row_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h -new file mode 100644 -index 0000000..1e8c71e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h -@@ -0,0 +1,479 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load and store output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_ ///< Element data type -+> -+class PredicatedTileIteratorStridedDgrad { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Count::kTile; -+ -+ static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); -+ static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); -+ static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); -+ static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorParams { -+ -+ /// Convolution problem size -+ cutlass::conv::Conv2dProblemSize problem_size; -+ int tiled_rows_per_filter; -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout, cutlass::conv::Conv2dProblemSize problem_size_, int threadblock_row): -+ problem_size(problem_size_), -+ PredicatedTileIteratorParams( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ make_OutputTileThreadMapDesc() -+ ) -+ { -+ -+ int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_row); -+ -+ tiled_rows_per_filter = tile_m_per_filter * threadblock_row; -+ } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ -+ static int const kCount = ThreadMap::Iterations::kColumn; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ Params params_; -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// Starting Dx h and w dimenstion for strided dgrad mapping -+ int start_h_, start_w_; -+ -+ /// Effective Dy P and Q dimenstions for strided dgrad mapping -+ int p_, q_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have been computed) -+ Index thread_start_row_; -+ -+ /// A thread's starting column position (assuming steady-state predicates have been computed) -+ Index thread_start_column_; -+ -+ /// Internal state counter -+ int state_[3]; -+ -+ // -+ // Static asserts about internal strides -+ // -+ -+ static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIteratorStridedDgrad( -+ Params const & params, -+ Element *pointer, -+ TensorCoord extent, -+ int thread_idx, -+ FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, -+ int start_r, int start_s, -+ TensorCoord threadblock_offset = TensorCoord() -+ ): -+ params_(params) -+ { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ int r = start_r; -+ int s = start_s; -+ -+ if (params_.problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ r = (params_.problem_size.R - 1 - r); -+ s = (params_.problem_size.S - 1 - s); -+ } -+ -+ // compute starting coordinates in Dx start_h_ and start_w_ -+ strided_dgrad_starting_coords( -+ params_.problem_size, -+ stride_h_divmod, stride_w_divmod, -+ r, s, -+ start_h_, start_w_); -+ -+ p_ = (params_.problem_size.H - start_h_ + params_.problem_size.stride_h - 1) / params_.problem_size.stride_h; -+ q_ = (params_.problem_size.W - start_w_ + params_.problem_size.stride_w - 1) / params_.problem_size.stride_w; -+ -+ extent_row_ = extent.row(); -+ thread_start_row_ = thread_offset.row(); -+ thread_start_column_ = thread_offset.column(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { -+ -+ mask_.predicates[c] = ((thread_offset.column() -+ + ThreadMap::Delta::kColumn * c) < extent.column()); -+ } -+ -+ // Null pointer performs no accesses -+ if (!pointer) { -+ mask_.clear(); -+ } -+ -+ // Initialize pointer -+ byte_pointer_ = reinterpret_cast(pointer); -+ -+ // Initialize internal state counter -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ // remapping rows to find the mapped_row_offset -+ int npq_offset = (row_offset + thread_start_row_) % params_.tiled_rows_per_filter; -+ -+ // (STEP 4.a) [order NHW rows to be loaded and stored in output Dx NHWxC layout] -+ int n = npq_offset / (p_ * q_); -+ int residual = npq_offset % (p_ * q_); -+ int p = residual / q_; -+ int q = residual % q_; -+ -+ int mapped_row_offset = n * (params_.problem_size.H * params_.problem_size.W) + -+ (start_h_ + p * params_.problem_size.stride_h) * params_.problem_size.W + -+ (start_w_ + q * params_.problem_size.stride_w); -+ bool row_guard = mapped_row_offset < extent_row_; -+ -+ int64_t row_byte_offset = mapped_row_offset * params_.stride; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ int64_t column_byte_offset = (thread_start_column_ + column * ThreadMap::Delta::kColumn) * (sizeof_bits::value / 8); -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)(byte_pointer + row_byte_offset + column_byte_offset + byte_offset), -+ guard); -+ } -+ } -+ } -+ } -+ } -+ -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ // remapping rows to find the mapped_row_offset -+ int npq_offset = (row_offset + thread_start_row_) % params_.tiled_rows_per_filter; -+ -+ // (STEP 4.a) [order NHW rows to be loaded and stored in output Dx NHWxC layout] -+ int n = npq_offset / (p_ * q_); -+ int residual = npq_offset % (p_ * q_); -+ int p = residual / q_; -+ int q = residual % q_; -+ -+ int mapped_row_offset = n * (params_.problem_size.H * params_.problem_size.W) + -+ (start_h_ + p * params_.problem_size.stride_h) * params_.problem_size.W + -+ (start_w_ + q * params_.problem_size.stride_w); -+ bool row_guard = mapped_row_offset < extent_row_; -+ -+ int64_t row_byte_offset = mapped_row_offset * params_.stride; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ int64_t column_byte_offset = (thread_start_column_ + column * ThreadMap::Delta::kColumn) * (sizeof_bits::value / 8); -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_store( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void *)(byte_pointer + row_byte_offset + column_byte_offset + byte_offset), -+ guard); -+ } -+ } -+ } -+ } -+ } -+ -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ -+ store_with_byte_offset(frag, 0); -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorStridedDgrad &operator++() { -+ -+ ++state_[0]; -+ -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ -+ if (state_[0] == ThreadMap::Count::kRow) { -+ -+ state_[0] = 0; -+ ++state_[1]; -+ -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ -+ state_[1] = 0; -+ ++state_[2]; -+ -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h -new file mode 100644 -index 0000000..197a4df ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load output tile from shared memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ int MaxAlignment = ThreadMap_::kElementsPerAccess * sizeof_bits::value / 8 -+> -+class SharedLoadIterator { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::TileShape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ static int const kMinAlignment = ThreadMap_::kElementsPerAccess * sizeof_bits::value / 8; -+ -+ static int const kAlignment = (MaxAlignment < kMinAlignment ? MaxAlignment : kMinAlignment); -+ -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * -+ ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray< -+ Element, -+ ThreadMap::kElementsPerAccess, -+ kAlignment>; -+ -+ /// Vector type used for SMEM loads -+ using LoadType = AlignedArray< -+ Element, -+ const_min(128 / sizeof_bits::value, ThreadMap::kElementsPerAccess), -+ const_min(16, kAlignment) -+ >; -+ -+ static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Stride along adjacent rows -+ int stride_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SharedLoadIterator( -+ TensorRef ref, -+ int thread_idx -+ ): -+ byte_pointer_(reinterpret_cast(ref.data())), -+ stride_((ref.stride(0) * sizeof_bits::value) / 8) { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); -+ -+ // Initialize pointer -+ byte_pointer_ += -+ thread_offset.row() * stride_ + -+ thread_offset.column() * sizeof(AccessType) / kElementsPerAccess; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &offset) { -+ byte_pointer_ += -+ offset.row() * Shape::kRow * stride_ + -+ offset.column() * Shape::kColumn * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ uint8_t const *byte_pointer = byte_pointer_ + -+ row * ThreadMap::Delta::kRow * stride_ + -+ group * ThreadMap::Delta::kGroup* stride_ + -+ cluster * ThreadMap::Delta::kCluster * stride_ + -+ pointer_offset * sizeof_bits::value / 8; -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ LoadType *frag_ptr = reinterpret_cast(&frag); -+ LoadType const *memory_pointer = reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kLoadsPerAccess; ++v) { -+ frag_ptr[frag_idx * kLoadsPerAccess + v] = -+ memory_pointer[(column * ThreadMap::Delta::kColumn / kElementsPerAccess) * kLoadsPerAccess + v]; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h -new file mode 100644 -index 0000000..a471137 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h -@@ -0,0 +1,585 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops optimized for mixed-precision. -+ -+ This assumes the shared memory tile is in a permuted layout which avoids bank conflicts on loading. -+ -+ When the fragment is loaded into registers, it matches the row-major thread map assumed by -+ the predicated tile iterator writing to global memory. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load output tile from shared memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Accumulator data type -+ int ElementSizeBits_, ///< Size of accumulator in bits -+ int OutputSizeBits_, ///< Size of output element in bits -+ int ElementsPerAccess, ///< Vector length of output vector -+ int ContiguousLanes ///< Number of lanes in the warp writing to contiguous elements -+ /// in the global memory tensor -+> -+class SharedLoadIteratorMixed; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load output tile from shared memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_ ///< Accumulator data type -+> -+class SharedLoadIteratorMixed { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; -+ -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * -+ ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray< -+ Element, -+ ThreadMap::kElementsPerAccess, -+ kAlignment>; -+ -+ /// Vector type used for SMEM loads -+ using LoadType = AlignedArray< -+ Element, -+ const_min(128 / sizeof_bits::value, ThreadMap::kElementsPerAccess), -+ const_min(16, kAlignment) -+ >; -+ -+ static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Byte-level pointer -+ LoadType const *pointers_[kLoadsPerAccess]; -+ -+ /// Stride along adjacent rows in units of LoadType -+ int stride_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SharedLoadIteratorMixed( -+ TensorRef ref, -+ int thread_idx -+ ): -+ stride_((ref.stride(0) / LoadType::kElements)) { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); -+ -+ // Initialize pointers -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] = reinterpret_cast(ref.data()); -+ -+ int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; -+ int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; -+ -+ col_idx += (bank_offset + i) % kLoadsPerAccess; -+ -+ pointers_[i] += thread_offset.row() * stride_ + col_idx; -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] += pointer_offset / LoadType::kElements; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &offset) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] += -+ offset.row() * Shape::kRow * stride_ + -+ offset.column() * Shape::kColumn / LoadType::kElements; -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int row_ptr_offset = -+ row * ThreadMap::Delta::kRow * stride_ + -+ group * ThreadMap::Delta::kGroup* stride_ + -+ cluster * ThreadMap::Delta::kCluster * stride_ + -+ pointer_offset / LoadType::kElements; -+ -+ int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ LoadType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kLoadsPerAccess; ++v) { -+ -+ int vector_idx = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); -+ -+ LoadType const *memory_pointer = pointers_[v] + row_ptr_offset; -+ -+ frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /// Set base smem address -+ CUTLASS_DEVICE -+ void set_smem_base_address(Index address) {} -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for int32_t x 16 => int8_t/int4b_t x 16 -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ int OutputSizeBits_ ///< Size of output element in bits -+> -+class SharedLoadIteratorMixed { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = int32_t; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ static int const kAlignment = 16; -+ -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * -+ ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray< -+ Element, -+ 16, -+ kAlignment>; -+ -+ /// Vector type used for SMEM loads -+ using LoadType = AlignedArray< -+ Element, -+ 4, -+ 16 -+ >; -+ -+ static int const kLoadsPerAccess = 4; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Byte-level pointer -+ LoadType const *pointers_[kLoadsPerAccess]; -+ -+ /// Stride along adjacent rows in units of LoadType -+ int stride_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SharedLoadIteratorMixed( -+ TensorRef ref, -+ int thread_idx -+ ): -+ stride_((ref.stride(0) / LoadType::kElements)) { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); -+ -+ // Initialize pointers -+ LoadType const *base_ptr = reinterpret_cast(ref.data()) + thread_offset.row() * stride_; -+ -+ int lane_col_idx = thread_offset.column() / 16; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ int lane_offset = (lane_col_idx % 2) * 4 | ((lane_col_idx / 2) * 8) | ((lane_col_idx / 2) ^ i); -+ -+ pointers_[i] = base_ptr + lane_offset; -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] += pointer_offset / LoadType::kElements; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &offset) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] += -+ offset.row() * Shape::kRow * stride_ + -+ offset.column() * Shape::kColumn / LoadType::kElements; -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int row_ptr_offset = -+ row * ThreadMap::Delta::kRow * stride_ + -+ group * ThreadMap::Delta::kGroup* stride_ + -+ cluster * ThreadMap::Delta::kCluster * stride_ + -+ pointer_offset / LoadType::kElements; -+ -+ int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ LoadType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kLoadsPerAccess; ++v) { -+ -+ LoadType const *memory_pointer = pointers_[v]; -+ -+ frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[row_ptr_offset]; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /// Set base smem address -+ CUTLASS_DEVICE -+ void set_smem_base_address(Index address) {} -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for int32_t x 8 => int8_t/int4b_t x 8 -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ int OutputSizeBits_ -+> -+class SharedLoadIteratorMixed { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = int32_t; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ static int const kAlignment = 8; -+ -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * -+ ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray< -+ Element, -+ 8, -+ kAlignment>; -+ -+ /// Vector type used for SMEM loads -+ using LoadType = AlignedArray< -+ Element, -+ 4, -+ 16 -+ >; -+ -+ static int const kLoadsPerAccess = 2; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Byte-level pointer -+ LoadType const *pointers_[kLoadsPerAccess]; -+ -+ /// Stride along adjacent rows in units of LoadType -+ int stride_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SharedLoadIteratorMixed( -+ TensorRef ref, -+ int thread_idx -+ ): -+ stride_((ref.stride(0) / LoadType::kElements)) { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); -+ -+ // Initialize pointers -+ LoadType const *base_ptr = reinterpret_cast(ref.data()) + thread_offset.row() * stride_; -+ -+ int lane_col_idx = thread_offset.column() / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ int lane_offset = (lane_col_idx % 8) * 2 | ((lane_col_idx / 4) ^ i); -+ -+ pointers_[i] = base_ptr + lane_offset; -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] += pointer_offset / LoadType::kElements; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &offset) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] += -+ offset.row() * Shape::kRow * stride_ + -+ offset.column() * Shape::kColumn / LoadType::kElements; -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int row_ptr_offset = -+ row * ThreadMap::Delta::kRow * stride_ + -+ group * ThreadMap::Delta::kGroup* stride_ + -+ cluster * ThreadMap::Delta::kCluster * stride_ + -+ pointer_offset / LoadType::kElements; -+ -+ int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ LoadType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kLoadsPerAccess; ++v) { -+ -+ LoadType const *memory_pointer = pointers_[v]; -+ -+ frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[row_ptr_offset]; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /// Set base smem address -+ CUTLASS_DEVICE -+ void set_smem_base_address(Index address) {} -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h -new file mode 100644 -index 0000000..df8676e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h -@@ -0,0 +1,194 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ This assumes the shared memory tile is in a permuted layout which avoids bank conflicts on loading. -+ -+ When the fragment is loaded into registers, it matches the row-major thread map assumed by -+ the predicated tile iterator writing to global memory. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load output tile from shared memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator -+/// -+template ::value / 8> -+class SharedLoadIteratorPitchLiner { -+ public: -+ using ThreadMap = ThreadMap_; -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ static int const kMinAlignment = -+ ThreadMap_::kElementsPerAccess * sizeof_bits::value / 8; -+ -+ static int const kAlignment = (MaxAlignment < kMinAlignment ? MaxAlignment : kMinAlignment); -+ -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Fragment object -+ using Fragment = Array; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ /// Vector type used for SMEM loads -+ using LoadType = -+ AlignedArray::value, ThreadMap::kElementsPerAccess), -+ const_min(16, kAlignment)>; -+ -+ static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Stride along adjacent rows -+ int stride_; -+ -+ /// Base address offset -+ Index base_smem_address_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SharedLoadIteratorPitchLiner(TensorRef ref, int thread_idx) -+ : byte_pointer_(reinterpret_cast(ref.data())), -+ stride_((ref.stride(0) * sizeof_bits::value) / 8), -+ base_smem_address_(0) { -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); -+ -+ // Initialize pointer -+ // thread_offset.row() is contiguous dim -+ // thread_offset.column() is stride dim -+ byte_pointer_ += thread_offset.row() * sizeof(AccessType) / kElementsPerAccess+ -+ thread_offset.column() * stride_ ; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &offset) { -+ byte_pointer_ += -+ offset.row() * ThreadMap::StorageShape::kContiguous * sizeof(AccessType) / kElementsPerAccess + -+ offset.column() * ThreadMap::StorageShape::kStrided * stride_; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ uint8_t const *byte_pointer = -+ byte_pointer_ + s * ThreadMap::Delta::kStrided * stride_ + -+ c * ThreadMap::Delta::kContiguous * ThreadMap::kElementsPerAccess * -+ sizeof_bits::value / 8 + -+ pointer_offset * sizeof_bits::value / 8 + base_smem_address_; -+ -+ int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c; -+ -+ LoadType *frag_ptr = reinterpret_cast(&frag); -+ -+ LoadType const *memory_pointer = reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kLoadsPerAccess; ++v) { -+ frag_ptr[frag_base_idx * kLoadsPerAccess + v] = memory_pointer[v]; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void set_smem_base_address(Index address) { base_smem_address_ = address; } -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h -new file mode 100644 -index 0000000..6dd04ed ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h -@@ -0,0 +1,187 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) -+ typename Layout ///< target shared memory layout -+> -+class FragmentIteratorComplexTensorOp; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_, ///< shape of the warp-level GEMM tile -+ typename OperatorShape_, ///< underlying real-valued matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC_, ///< underlying real-valued matrix multiply operation data type -+ typename OperatorFragmentC_ ///< underlying real-valued matrix multiply operation fragment (concept: Array) -+> -+class FragmentIteratorComplexTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorElementC = OperatorElementC_; -+ using OperatorFragmentC = OperatorFragmentC_; -+ using Layout = layout::RowMajor; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ complex, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ static int const kRealIndex = 0; -+ -+ /// Offset into the accumulator fragment -+ static int const kImaginaryIndex = -+ OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array; -+ -+ /// This is the complete warp-level accumulator tile. -+ using OutputAccumulatorTile = Array, kImaginaryIndex>; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+ using FragmentAccessType = Array, Policy::kElementsPerAccess>; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorComplexTensorOp(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorComplexTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorComplexTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ int index = index_ + index_offset; -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int accumulator_access_offset = -+ index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; -+ -+ auto const & real_accum_array = accumulators_[accumulator_access_offset + kRealIndex]; -+ auto const & imag_accum_array = accumulators_[accumulator_access_offset + kImaginaryIndex / Policy::kElementsPerAccess]; -+ -+ // Pack real and imaginary parts into a structure. This is likely to result in MOVs -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Policy::kElementsPerAccess; ++i) { -+ -+ frag_ptr[n][i].real() = real_accum_array[i]; -+ frag_ptr[n][i].imag() = imag_accum_array[i]; -+ } -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h -new file mode 100644 -index 0000000..f55c4bd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h -@@ -0,0 +1,194 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) -+ typename Layout ///< target shared memory layout -+> -+class FragmentIteratorGaussianComplexTensorOp; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_, ///< shape of the warp-level GEMM tile -+ typename OperatorShape_, ///< underlying real-valued matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC_, ///< underlying real-valued matrix multiply operation data type -+ typename OperatorFragmentC_ ///< underlying real-valued matrix multiply operation fragment (concept: Array) -+> -+class FragmentIteratorGaussianComplexTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorElementC = OperatorElementC_; -+ using OperatorFragmentC = OperatorFragmentC_; -+ using Layout = layout::RowMajor; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ complex, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// Size of one part of accumulator of 3-part accumulator in units of number of OperatorElementC -+ static int const kElementsAccumulatorPerPart = -+ OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn; -+ -+ /// Offset into the accumulator fragment part 1 -+ static int const kPart1Index = kElementsAccumulatorPerPart * 0; -+ -+ /// Offset into the accumulator fragment part 2 -+ static int const kPart2Index = kElementsAccumulatorPerPart * 1; -+ -+ /// Offset into the accumulator fragment part 3 -+ static int const kPart3Index = kElementsAccumulatorPerPart * 2; -+ -+ /// This is the complete warp-level accumulator tile holding part1, part2, and part3 -+ using AccumulatorTile = Array; -+ -+ /// This is the complete warp-level accumulator tile holding final output of complex type -+ using OutputAccumulatorTile = Array, kElementsAccumulatorPerPart>; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+ using FragmentAccessType = Array, Policy::kElementsPerAccess>; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorGaussianComplexTensorOp(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorGaussianComplexTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorGaussianComplexTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ int index = index_ + index_offset; -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int accumulator_access_offset = -+ index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; -+ -+ auto const & part1_accum_array = accumulators_[accumulator_access_offset + kPart1Index]; -+ auto const & part2_accum_array = accumulators_[accumulator_access_offset + kPart2Index / Policy::kElementsPerAccess]; -+ auto const & part3_accum_array = accumulators_[accumulator_access_offset + kPart3Index / Policy::kElementsPerAccess]; -+ -+ // Pack parts 1, 2, and 3 into a structure. This is likely to result in MOVs -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Policy::kElementsPerAccess; ++i) { -+ -+ frag_ptr[n][i].real() = part1_accum_array[i] - part3_accum_array[i]; -+ frag_ptr[n][i].imag() = part1_accum_array[i] + part2_accum_array[i]; -+ } -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h -new file mode 100644 -index 0000000..b181c81 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h -@@ -0,0 +1,164 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/warp/simt_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fragment iterator for SIMT accumulator arrangements -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename Operator, ///< matrix multiply operation (concept: arch::Mma) -+ typename Layout, ///< target shared memory layout -+ typename MmaSimtPolicy ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+class FragmentIteratorSimt; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_, ///< shape of the warp-level GEMM tile -+ typename Operator_ , ///< matrix multiply operator (concept: arch::Mma) -+ typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+class FragmentIteratorSimt { -+public: -+ -+ using WarpShape = WarpShape_; -+ using Operator = Operator_; -+ using Layout = layout::RowMajor; -+ -+ /// Policy for warp-level epilogue components -+ using Policy = SimtPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ typename Operator::ElementC, -+ Policy::kElementsPerIteration>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ typename Operator::ElementC, -+ Policy::kAccumulatorElementCount>; -+ -+ using OutputAccumulatorTile = AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorSimt(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorSimt &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorSimt &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ -+ int accumulator_access_offset = index_ * Policy::kAccessesPerIteration + n; -+ -+ frag_ptr[n] = accumulators_[accumulator_access_offset]; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h -new file mode 100644 -index 0000000..f9b20a6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h -@@ -0,0 +1,277 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) -+ typename Layout ///< target shared memory layout -+> -+class FragmentIteratorTensorOp; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_, ///< shape of the warp-level GEMM tile -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array) -+> -+class FragmentIteratorTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorElementC = OperatorElementC_; -+ using OperatorFragmentC = OperatorFragmentC_; -+ using Layout = layout::RowMajor; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ OperatorElementC, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ OperatorElementC, -+ OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>; -+ -+ using OutputAccumulatorTile = AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ using TileIterations = typename Policy::TileIterations; -+ static int const kIterationsPerTile = kIterations / TileIterations::kCount; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorTensorOp(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ int index = index_ + index_offset; -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int accumulator_access_offset = -+ index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; -+ -+ frag_ptr[n] = accumulators_[accumulator_access_offset]; -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Dedicated to interleaved layout -+template < -+ /// shape of the warp-level GEMM tile -+ typename WarpShape_, -+ /// matrix multiply operator shape (concept: gemm::GemmShape) -+ typename OperatorShape_, -+ /// matrix multiply operator data type (concept: data type) -+ typename OperatorElementC_, -+ /// matrix multiply operator fragment (concept: Array) -+ typename OperatorFragmentC_, -+ /// number of interleaved k -+ int InterleavedK> -+class FragmentIteratorTensorOp> { -+ public: -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorElementC = OperatorElementC_; -+ using OperatorFragmentC = OperatorFragmentC_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = -+ Array; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = -+ Array; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ using TileIterations = typename Policy::TileIterations; -+ static int const kIterationsPerTile = kIterations / TileIterations::kCount; -+ -+ private: -+ /// Internal access type -+ using AccessType = -+ Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+ public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorTensorOp(AccumulatorTile const &accum) -+ : accumulators_(reinterpret_cast(&accum)), -+ index_(0) {} -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ int index = index_ + index_offset; -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < (InterleavedK / OperatorShape::kN); ++n) { -+ int index_m = index % (Policy::OperatorCount::kRow * -+ Policy::kIterationsPerInstruction); -+ int index_n = index / (Policy::OperatorCount::kRow * -+ Policy::kIterationsPerInstruction); -+ int accumulator_access_offset = -+ (index_m / Policy::kIterationsPerInstruction) * -+ (Policy::OperatorCount::kColumn * -+ Policy::kIterationsPerInstruction) + -+ (index_m % Policy::kIterationsPerInstruction) + -+ index_n * (InterleavedK / OperatorShape::kN) * -+ Policy::kIterationsPerInstruction + -+ n * Policy::kIterationsPerInstruction; -+ -+ frag_ptr[n] = accumulators_[accumulator_access_offset]; -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h -new file mode 100644 -index 0000000..d37e82e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h -@@ -0,0 +1,269 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/warp/volta_tensor_op_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename InterleavedTileShape, ///< shape of indivisible instruction-level arrangement (concept: GemmShape) -+ typename ElementC, ///< Accumulator layout -+ typename Layout ///< target shared memory layout -+> -+class FragmentIteratorVoltaTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) -+> -+class FragmentIteratorVoltaTensorOp, half_t, layout::RowMajor> { -+public: -+ -+ using WarpShape = WarpShape_; -+ using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; -+ using ElementC = half_t; -+ using Layout = layout::RowMajor; -+ -+ /// Policy operator -+ using Policy = VoltaTensorOpPolicy; -+ -+ /// Array type for aligned memory accesses -+ using AccessType = typename Policy::AccessType; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = typename Policy::Fragment; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = typename Policy::AccumulatorTile; -+ -+ using OutputAccumulatorTile = AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+private: -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorVoltaTensorOp(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorVoltaTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorVoltaTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ static int const kAccessesPerMma = Policy::kElementsPerMma / Policy::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { -+ -+ int tile_access_idx = -+ (tile_n * Policy::TileIterations::kRow + (index_ & 2) / 2) * Policy::MmaIterations::kCount * kAccessesPerMma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn * kAccessesPerMma; ++mma_n) { -+ -+ int mma_access_idx = ((mma_n & 1) * 2 + (index_ & 1)) * kAccessesPerMma + (mma_n & 2) / 2; -+ -+ frag_ptr[tile_n * Policy::MmaIterations::kColumn * kAccessesPerMma + -+ mma_n] = accumulators_[tile_access_idx + mma_access_idx]; -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) -+> -+class FragmentIteratorVoltaTensorOp, float, layout::RowMajor> { -+public: -+ -+ using WarpShape = WarpShape_; -+ using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; -+ using ElementC = float; -+ using Layout = layout::RowMajor; -+ -+ /// Policy operator -+ using Policy = VoltaTensorOpPolicy; -+ -+ /// Array type for aligned memory accesses -+ using AccessType = typename Policy::AccessType; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = typename Policy::Fragment; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = typename Policy::AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+private: -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorVoltaTensorOp(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorVoltaTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorVoltaTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ int const kRegsPerMmaRow = 2; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int reg_row = 0; reg_row < Policy::kRowsPerMmaTile; ++reg_row) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn * 2; ++mma_n) { -+ -+ int mma_idx = (index_ & 1) + (index_ & 2) * Policy::MmaIterations::kCount / 2 + -+ (tile_n * Policy::TileIterations::kRow) * Policy::MmaIterations::kCount + (mma_n & 1) * 2; -+ -+ int reg_offset = reg_row * kRegsPerMmaRow + (mma_n & 2) * 2; -+ int reg_idx = mma_idx * Policy::kElementsPerMma + reg_offset; -+ -+ *frag_ptr = accumulators_[reg_idx / Policy::kElementsPerAccess]; -+ ++frag_ptr; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h -new file mode 100644 -index 0000000..225e0f0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h -@@ -0,0 +1,164 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#if !(defined(__clang__) && defined(__CUDA__)) -+ -+#include "cutlass/wmma_array.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/warp/wmma_tensor_op_policy.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: nvcuda::cuda::fragment) -+ typename Layout ///< target shared memory layout -+> -+class FragmentIteratorWmmaTensorOp; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_, ///< shape of the warp-level GEMM tile -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: nvcuda::cuda::fragment) -+> -+class FragmentIteratorWmmaTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorElementC = OperatorElementC_; -+ using OperatorFragmentC = OperatorFragmentC_; -+ using Layout = layout::RowMajor; -+ -+ using Policy = WmmaTensorOpPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = WmmaFragmentArray; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = WmmaFragmentArray; -+ -+ using OutputAccumulatorTile = AccumulatorTile; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = WmmaFragmentArray; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorWmmaTensorOp(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorWmmaTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorWmmaTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int n=0; n < Policy::OperatorCount::kColumn; n++) { -+ -+ int accumulator_access_offset = index_ * Policy::OperatorCount::kColumn + n; -+ -+ frag_ptr[n] = accumulators_[accumulator_access_offset]; -+ } -+ } -+}; -+ -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#else -+#error (defined(__clang__) && defined(__CUDA__)) -+#endif // !defined(__clang__) -+ -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/simt_policy.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/simt_policy.h -new file mode 100644 -index 0000000..21bca80 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/simt_policy.h -@@ -0,0 +1,107 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic structures needed for implementing the warp-scoped phase of the epilogue. -+ These quantities assume a 'column-major' arrangement of SimtOp instructions, of which -+ a row-oriented slice is visible per iteration. -+*/ -+ -+#pragma once -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename Operator, ///< matrix multiply operation (concept: arch::Mma) -+ typename Layout, ///< destination layout in shared memory -+ typename MmaSimtPolicy ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+struct SimtPolicy; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename Operator_, ///< matrix multiply operation (concept: arch::Mma) -+ typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+struct SimtPolicy { -+ -+ using WarpShape = WarpShape_; -+ using Operator = Operator_; -+ using MmaSimtPolicy = MmaSimtPolicy_; -+ -+ static_assert(!(WarpShape::kM % MmaSimtPolicy::WarpShape::kRow), "Divisibility"); -+ static_assert(!(WarpShape::kN % MmaSimtPolicy::WarpShape::kColumn), "Divisibility"); -+ -+ /// Number of iterations -+ static int const kIterations = WarpShape::kM / MmaSimtPolicy::WarpShape::kRow; -+ -+ /// Number of accumulators written per iteration -+ static int const kElementsPerIteration = -+ (WarpShape::kN / MmaSimtPolicy::WarpShape::kColumn); -+ -+ /// Total number of accumulators -+ static int const kAccumulatorElementCount = kElementsPerIteration * kIterations; -+ -+ /// Number of consecutive elements -+ static int const kElementsPerAccess = MmaSimtPolicy::LaneMmaShape::kN; -+ -+ /// Number of rows per epilogue iteration -+ static int const kRowsPerIteration = MmaSimtPolicy::WarpShape::kRow; -+ -+ /// Number of accesses made in one iteration -+ static int const kAccessesPerIteration = kElementsPerIteration / kElementsPerAccess; -+ -+ /// Number of elements in between accumulator chunks of (LaneMmaShape::kM x LaneMmaShape::kN) -+ using Delta = MatrixShape< -+ MmaSimtPolicy::WarpShape::kRow * MmaSimtPolicy::LaneMmaShape::kM, -+ MmaSimtPolicy::WarpShape::kColumn * MmaSimtPolicy::LaneMmaShape::kN -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h -new file mode 100644 -index 0000000..e0d1f6f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h -@@ -0,0 +1,148 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic structures needed for implementing the warp-scoped phase of the epilogue. -+ These quantities assume a 'column-major' arrangement of TensorOp instructions, of which -+ a row-oriented slice is visible per iteration. -+*/ -+ -+#pragma once -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy details related to the epilogue -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm:GemmShape) -+ typename Layout ///< target shared memory layout -+> -+struct TensorOpPolicy; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape ///< matrix multiply operation shape (concept: gemm::GemmShape) -+> -+struct TensorOpPolicy { -+ -+ /// Number of operations -+ using OperatorCount = MatrixShape< -+ (WarpShape::kM + OperatorShape::kM - 1) / OperatorShape::kM, -+ (WarpShape::kN + OperatorShape::kN - 1) / OperatorShape::kN -+ >; -+ -+ // -+ // Hard-coded constants regarding Tensor Operations -+ // -+ -+ static int const kElementsPerAccess = 2; -+ static int const kRowsPerIteration = 8; -+ static bool const kDivisible = -+ !(WarpShape::kM % OperatorShape::kM) && !(WarpShape::kN % OperatorShape::kN); -+ -+ // -+ // Derived quantities -+ // -+ -+ // Number of 'externally visible' iterations per actual instruction -+ static int const kIterationsPerInstruction = OperatorShape::kM / kRowsPerIteration; -+ -+ // Number of externally visible iterations -+ static int const kIterations = OperatorCount::kRow * kIterationsPerInstruction; -+ -+ using TileIterations = MatrixShape; -+ -+ static int const kAccumulatorRowStride = kElementsPerAccess; -+ static int const kAccumulatorColumnStride = kElementsPerAccess * OperatorCount::kRow * kIterationsPerInstruction; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for column-major-interleaved -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation (concept: arch::Mma) -+ int InterleavedK ///< number of interleaved k -+ > -+struct TensorOpPolicy > { -+ /// Number of operations -+ using OperatorCount = MatrixShape; -+ -+ // -+ // Hard-coded constants regarding Tensor Operations -+ // -+ -+ static int const kElementsPerAccess = 2; -+ static int const kRowsPerIteration = 8; -+ -+ // -+ // Derived quantities -+ // -+ -+ // Number of 'externally visible' iterations per actual instruction -+ static int const kIterationsPerInstruction = -+ OperatorShape::kM / kRowsPerIteration; -+ -+ // Number of externally visible iterations -+ static int const kIterations = WarpShape::kN / InterleavedK * -+ OperatorCount::kRow * -+ kIterationsPerInstruction; -+ -+ static int const kElementsPerIteration = InterleavedK / OperatorShape::kN * kElementsPerAccess; -+ -+ static int const kAccessPerIteration = kElementsPerIteration / kElementsPerAccess; -+ -+ // Number of externally visible iterations -+ //static int const kTileIterations = OperatorCount::kRow * kIterationsPerInstruction; -+ using TileIterations = MatrixShape<1, WarpShape::kN / InterleavedK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h -new file mode 100644 -index 0000000..5ef4b2e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h -@@ -0,0 +1,785 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/simt_policy.h" -+ -+#define CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES 1 -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename Operator, ///< matrix multiply operation (concept: arch::Mma) -+ typename Element, ///< data type of element to be written -+ typename Layout, ///< target shared memory layout -+ typename MmaSimtPolicy ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+class TileIteratorSimt; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename Operator_, ///< matrix multiply operation (concept: arch::Mma) -+ typename Element_, ///< data type of element to be written -+ typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+class TileIteratorSimt { -+public: -+ -+ using WarpShape = WarpShape_; -+ using Operator = Operator_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = SimtPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ typename Operator::ElementC, -+ Policy::kElementsPerIteration>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ typename Operator::ElementC, -+ Policy::kAccumulatorElementCount>; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ 4 * Policy::kElementsPerAccess -+#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES -+ + 1 -+#endif -+ >; -+ -+private: -+ -+#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray< -+ Element, -+ 1 -+ >; -+ -+#else -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray< -+ Element, -+ Policy::kElementsPerAccess -+ >; -+#endif -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimt(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimt( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / AccessType::kElements) { -+ -+ auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id); -+ -+ pointer_ += layout_({ -+ lane_offset.row(), -+ lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) -+ }); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimt & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / AccessType::kElements; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimt & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ pointer_ += layout_({ -+ tile_offset.row() * Shape::kRow, -+ (tile_offset.column() * Shape::kColumn / int(AccessType::kElements)) -+ }); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimt & operator+=(TensorCoord const &tile_offset) { -+ -+ add_tile_offset(tile_offset); -+ -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES -+ // de-vectorized stores -+ using ScalarAccessType = AlignedArray; -+ ScalarAccessType const *scalarFragPtr = reinterpret_cast(&frag); -+ ScalarAccessType *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::kElementsPerAccess; s++) { -+ scalarPointer[n * Policy::MmaSimtPolicy::WarpShape::kColumn * Policy::kElementsPerAccess + s] = scalarFragPtr[n * Policy::kElementsPerAccess + s]; -+ } -+ } -+#else -+ // original vector stores -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n]; -+ } -+#endif -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)]; -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template -+class TileIteratorSimtDirectConv { -+ public: -+ -+ using WarpShape = WarpShape_; -+ using Operator = Operator_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = SimtPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ /// Padding quantity -+ using Padding = MatrixShape<0, -+ 0 -+ >; -+ -+private: -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray< -+ Element, -+ Policy::kElementsPerAccess -+ >; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+ /// Base smem offset; -+ Index base_smem_address_; -+ -+ public: -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirectConv() : pointer_(nullptr) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirectConv( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / AccessType::kElements) { -+ -+ auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id); -+ -+ pointer_ += layout_({ -+ lane_offset.row(), -+ lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) -+ }); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirectConv & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / AccessType::kElements; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirectConv & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ pointer_ += layout_({ -+ tile_offset.row() * Shape::kRow, -+ (tile_offset.column() * Shape::kColumn / int(AccessType::kElements)) -+ }); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirectConv & operator+=(TensorCoord const &tile_offset) { -+ -+ add_tile_offset(tile_offset); -+ -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ // original vector stores -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ AccessType * load_pointer_ = reinterpret_cast(reinterpret_cast(pointer_) + base_smem_address_); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ load_pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n]; -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)]; -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address){ -+ base_smem_address_ = address; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Template for reading and writing tiles of accumulators to shared memory -+template -+class TileIteratorSimtDirect2dConv { -+ public: -+ using WarpShape = WarpShape_; -+ using ThreadOutputShape = ThreadOutputShape_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ using Operator = Operator_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ using MmaSimtPolicy = MmaSimtPolicy_; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ // Thread-level shape of a fragment -+ using ThreadShape = MatrixShape; -+ -+ static_assert(!(ThreadShape::kColumn % MmaSimtPolicy::LaneMmaShape::kN), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ using ThreadTileCount = MatrixShape; -+ -+ using Iterations = -+ MatrixShape; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = AccumulatorTile; -+ -+ /// Padding quantity -+ using Padding = MatrixShape<0, 0>; -+ -+ private: -+ // Storage type for accessing memory -+ using AccessType = AlignedArray; -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+ /// Base smem offset; -+ Index base_smem_address_; -+ -+ public: -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirect2dConv() : pointer_(nullptr) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirect2dConv(TensorRef const &ref, unsigned thread_id, unsigned lane_id) -+ : pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / AccessType::kElements) { -+ -+ auto lane_layout = MmaSimtPolicy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id); -+ -+ // Get base HW offset of current threads -+ const int threadgroup = thread_id / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); -+ const int base_p = (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH; -+ const int base_q = (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW; -+ -+ const int row_offset = base_p * ThreadBlockOutputShape::kW + base_q; -+ -+ pointer_ += layout_( -+ {row_offset, -+ lane_offset.column() * MmaSimtPolicy::LaneMmaShape::kN / int(AccessType::kElements)}); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirect2dConv &add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / AccessType::kElements; -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ AccessType *storer_pointer_ = -+ reinterpret_cast(reinterpret_cast(pointer_) + base_smem_address_); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int h = 0; h < ThreadOutputShape::kH; ++h) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int w = 0; w < ThreadOutputShape::kW; ++w) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < Iterations::kColumn; ++col) { -+ int offset = (w + h * ThreadBlockOutputShape::kW) * -+ (ThreadBlockOutputShape::kC / AccessType::kElements) + -+ col; -+ storer_pointer_[offset + pointer_offset / int(AccessType::kElements)] = -+ frag_ptr[w + h * ThreadOutputShape::kW + col]; -+ } -+ } -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { base_smem_address_ = address; } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename Operator_, ///< matrix multiply operation (concept: arch::Mma) -+ typename Element_, ///< data type of element to be written -+ typename Layout_, ///< target shared memory layout -+ typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+class TileIteratorSimtCanonical { -+public: -+ -+ using WarpShape = WarpShape_; -+ using Operator = Operator_; -+ using Element = Element_; -+ using Layout = Layout_; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = SimtPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ typename Operator::ElementC, -+ Policy::kElementsPerIteration>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ typename Operator::ElementC, -+ Policy::kAccumulatorElementCount>; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ 4 * Policy::kElementsPerAccess + 1 -+ >; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray< -+ Element, -+ 1 -+ >; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+ /// Guard to indicate whether the shape is divisible -+ bool divisible_; -+ -+ /// Extent of the output tensor -+ MatrixCoord extent_; -+ -+ /// Thread offset -+ MatrixCoord thread_offset_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / AccessType::kElements), -+ divisible_(true), -+ extent_(WarpShape::kM, WarpShape::kN) { -+ -+ auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id); -+ -+ thread_offset_ = { -+ lane_offset.row() * Shape::kRow, -+ lane_offset.column() * Policy::kElementsPerAccess -+ }; -+ -+ pointer_ += layout_({ -+ lane_offset.row() * Shape::kRow, -+ lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) -+ }); -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical( -+ TensorRef const &ref, -+ TensorCoord const &extent, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / AccessType::kElements), -+ divisible_(false), -+ extent_(extent) { -+ -+ auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id); -+ -+ thread_offset_ = { -+ lane_offset.row() * Shape::kRow, -+ lane_offset.column() * Policy::kElementsPerAccess -+ }; -+ -+ pointer_ += layout_({ -+ lane_offset.row() * Shape::kRow, -+ lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) -+ }); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / AccessType::kElements; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ MatrixCoord coord_offset( -+ tile_offset.row(), -+ tile_offset.column() * Shape::kColumn -+ ); -+ -+ thread_offset_ += coord_offset; -+ -+ pointer_ += layout_({ -+ coord_offset.row(), -+ coord_offset.column() -+ }); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical & operator+=(TensorCoord const &tile_offset) { -+ -+ add_tile_offset(tile_offset); -+ -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ // de-vectorized stores -+ using ScalarAccessType = AlignedArray; -+ ScalarAccessType const *scalarFragPtr = reinterpret_cast(&frag); -+ ScalarAccessType *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::kElementsPerAccess; s++) { -+ -+ int ptr_idx = n * Policy::MmaSimtPolicy::WarpShape::kColumn * Policy::kElementsPerAccess + s; -+ int frag_idx = n * Policy::kElementsPerAccess + s; -+ -+ int col = thread_offset_.column() + ptr_idx; -+ -+ if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { -+ scalarPointer[ptr_idx] = scalarFragPtr[frag_idx]; -+ } -+ } -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ // de-vectorized loads -+ using ScalarAccessType = AlignedArray; -+ ScalarAccessType *scalarFragPtr = reinterpret_cast(&frag); -+ ScalarAccessType const *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::kElementsPerAccess; s++) { -+ -+ int ptr_idx = n * Policy::MmaSimtPolicy::WarpShape::kColumn * Policy::kElementsPerAccess + s; -+ int frag_idx = n * Policy::kElementsPerAccess + s; -+ -+ int col = thread_offset_.column() + ptr_idx; -+ -+ if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { -+ scalarFragPtr[frag_idx] = scalarPointer[ptr_idx]; -+ } -+ } -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical & operator++() { -+ return add_tile_offset({1, 0}); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h -new file mode 100644 -index 0000000..a1eb5c9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h -@@ -0,0 +1,671 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename Element, ///< data type of element to be written -+ typename Layout ///< target shared memory layout -+> -+class TileIteratorTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename Element_ ///< data type of element to be written -+> -+class TileIteratorTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using TensorLayout = Layout; -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ Element, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ /// Number of times this iterator can be incremented -+ using TileIterations = typename Policy::TileIterations; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ Detail::kLanesInQuad * Policy::kElementsPerAccess>; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+ /// Thread offset -+ MatrixCoord thread_offset_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / Policy::kElementsPerAccess) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ thread_offset_ = { -+ quad_id, lane_in_quad * Policy::kElementsPerAccess -+ }; -+ -+ pointer_ += layout_({thread_offset_.row(), thread_offset_.column() / Policy::kElementsPerAccess}); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / Policy::kElementsPerAccess; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ MatrixCoord coord_offset( -+ tile_offset.row() * Shape::kRow, -+ tile_offset.column() * Shape::kColumn -+ ); -+ -+ thread_offset_ += coord_offset; -+ -+ pointer_ += layout_({ -+ coord_offset.row(), -+ coord_offset.column() / Policy::kElementsPerAccess -+ }); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ pointer_[n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess] = frag_ptr[n]; -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ frag_ptr[n] = pointer_[n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess]; -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & operator++() { -+ return add_tile_offset({1, 0}); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename Element_, ///< data type of element to be written -+ int InterleavedK ///< number of interleaved k -+> -+class TileIteratorTensorOp > { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorInterleaved; -+ using TensorLayout = Layout; ///< shared memory tensor ref layout -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+// Policy::kRowsPerIteration, -+ WarpShape::kM, -+ InterleavedK -+ >; -+ -+ /// This is the fragment size produced by one tile -+ using Fragment = Array< -+ Element, -+ Policy::OperatorCount::kRow * Policy::kIterationsPerInstruction -+ * Policy::kElementsPerIteration>; -+ -+ /// This is the fragment size produced by one iteration -+// using Fragment = Array< -+// Element, Policy::kElementsPerIteration >; -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// Number of times this iterator can be incremented -+ using TileIterations = typename Policy::TileIterations; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ Detail::kLanesInQuad * Policy::kElementsPerIteration>; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ TensorLayout layout_; -+ -+ /// Thread offset -+ MatrixCoord thread_offset_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0]) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ thread_offset_ = { -+ quad_id, lane_in_quad * Policy::kElementsPerIteration -+ }; -+ -+ pointer_ += (layout_({thread_offset_.row(), thread_offset_.column()}) / Policy::kElementsPerAccess); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / Policy::kElementsPerAccess; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ MatrixCoord coord_offset( -+ tile_offset.row() * Shape::kRow, -+ tile_offset.column() * Shape::kColumn -+ ); -+ -+ thread_offset_ += coord_offset; -+ -+ pointer_ += (layout_({ -+ coord_offset.row(), -+ coord_offset.column() -+ }) / Policy::kElementsPerAccess); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kRow * Policy::kIterationsPerInstruction; n++ ) { -+ -+ AccessType *ptr = pointer_ + layout_({n * Policy::kRowsPerIteration, 0}) / Policy::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int a = 0; a < Policy::kAccessPerIteration; ++a) { -+ ptr[a + pointer_offset / Policy::kElementsPerAccess] = frag_ptr[n * Policy::kAccessPerIteration + a]; -+ -+// printf("store thread %d, address %p, bank %ld\n", threadIdx.x, pointer_+a+n*Detail::kLanesInQuad, -+// ((long long)(pointer_+a+n*Detail::kLanesInQuad)>>2)&0x1f); -+ } -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kRow * Policy::kIterationsPerInstruction; n++ ) { -+ -+ AccessType *ptr = pointer_ + layout_({n * Policy::kRowsPerIteration, 0}) / Policy::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int a = 0; a < Policy::kAccessPerIteration; ++a) { -+ frag_ptr[n * Policy::kAccessPerIteration + a] = ptr[a + pointer_offset / Policy::kElementsPerAccess]; -+ } -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & operator++() { -+ return add_tile_offset({0, 1}); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename Element_, ///< data type of element to be written -+ typename Layout_ -+> -+class TileIteratorTensorOpCanonical { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = TensorOpPolicy; -+ -+ static int const kAccessSize = 1; -+ static int const kAccessCount = Policy::kElementsPerAccess / kAccessSize; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ Element, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ Detail::kLanesInQuad * Policy::kElementsPerAccess>; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+ /// Guard to indicate whether the shape is divisible -+ bool divisible_; -+ -+ /// Extent of the output tensor -+ MatrixCoord extent_; -+ -+ /// Thread offset -+ MatrixCoord thread_offset_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0]), -+ divisible_(true), -+ extent_(WarpShape::kM, WarpShape::kN) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ thread_offset_ = { -+ quad_id, lane_in_quad * Policy::kElementsPerAccess -+ }; -+ -+ pointer_ += layout_({thread_offset_.row(), thread_offset_.column()}); -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical( -+ TensorRef const &ref, -+ TensorCoord const &extent, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0]), -+ divisible_(false), -+ extent_(extent) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ thread_offset_ = { -+ quad_id, lane_in_quad * Policy::kElementsPerAccess -+ }; -+ -+ pointer_ += layout_({thread_offset_.row(), thread_offset_.column()}); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ MatrixCoord coord_offset( -+ tile_offset.row() * Shape::kRow, -+ tile_offset.column() * Shape::kColumn -+ ); -+ -+ thread_offset_ += coord_offset; -+ -+ pointer_ += layout_({ -+ coord_offset.row(), -+ coord_offset.column() -+ }); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int a = 0; a < kAccessCount; ++a) { -+ -+ int ptr_idx = n * Detail::kLanesInQuad * kAccessCount + pointer_offset + a; -+ int frag_idx = n * kAccessCount + a; -+ -+ int col = thread_offset_.column() + n * Detail::kLanesInQuad * Policy::kElementsPerAccess + a; -+ -+ if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { -+ pointer_[ptr_idx] = frag_ptr[frag_idx]; -+ } -+ } -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int a = 0; a < kAccessCount; ++a) { -+ -+ int ptr_idx = n * Detail::kLanesInQuad * kAccessCount + pointer_offset + a; -+ int frag_idx = n * kAccessCount + a; -+ -+ int col = thread_offset_.column() + n * Detail::kLanesInQuad * Policy::kElementsPerAccess + a; -+ -+ if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { -+ frag_ptr[frag_idx] = pointer_[ptr_idx]; -+ } -+ } -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical & operator++() { -+ return add_tile_offset({1, 0}); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h -new file mode 100644 -index 0000000..3bbc942 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h -@@ -0,0 +1,727 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// This is an optimization available on CUDA 11.2 and beyond that eliminates branches in the epilogue. -+#define CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED ((__CUDACC_VER_MAJOR__ * 10 + __CUDACC_VER_MINOR__) >= 112) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory. This is optimized -+/// for mixed-precision epilogues in which the accumulators are 32b in width, but the output -+/// data type is smaller. -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename Element_, ///< data type of accumulator element -+ int ElementSizeBits, ///< Size of accumulator element in bits -+ int OutputSizeBits, ///< Size of output element in bits -+ int OutputElementCount, ///< number of elements in output vector -+ int ContiguousLanes ///< Number of consecutive lanes writing to contiguous memory -+> -+class TileIteratorTensorOpMixed { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kOutputElementCount = OutputElementCount; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ Element, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ -+ /// Number of pointers needed to write accumulators -+ static int const kPointerCount = -+ (OutputElementCount * sizeof_bits::value) / (const_min(128, OutputElementCount * sizeof_bits::value)); -+ -+ static_assert(kPointerCount <= 4, "Can only accommodate four pointers at present."); -+ static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ Detail::kLanesInQuad * Policy::kElementsPerAccess>; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointers_[Detail::kPointerCount]; -+ -+ /// Stride in units of AccessType -+ int stride_; -+ -+ /// Logical column in which warp tile is aligned -+ int warp_column_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] = nullptr; -+ } -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ stride_(ref.stride()[0] / Policy::kElementsPerAccess), -+ warp_column_(0) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; -+ int column_idx = (lane_in_quad % 2) + (((lane_in_quad / 2) + i) % Detail::kPointerCount) * 2; -+ -+ ptr += column_idx; -+ -+ if (i == 0) { -+ pointers_[0 % Detail::kPointerCount] = ptr; -+ } -+ else if (i == 1) { -+ pointers_[1 % Detail::kPointerCount] = ptr; -+ } -+ else if (i == 2) { -+ pointers_[2 % Detail::kPointerCount] = ptr; -+ } -+ else if (i == 3) { -+ pointers_[3 % Detail::kPointerCount] = ptr; -+ } -+ } -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] += pointer_offset / Policy::kElementsPerAccess; -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] += tile_offset.row() * Shape::kRow * stride_ + -+ tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess; -+ } -+ -+ warp_column_ += tile_offset.column() * Shape::kColumn; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { -+ return add_tile_offset(tile_offset); -+ } -+ -+ /// Store -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ AccessType *ptr = pointers_[0]; -+ -+#if CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED -+ -+ // When the optimization is enabled, small tiles require separate logic. -+ bool kN32_optimization = (WarpShape::kN * Detail::kLanesInQuad * Policy::kElementsPerAccess * sizeof_bits::value) % 1024 == 0; -+ if (kN32_optimization) { -+ int ptr_idx = ((warp_column_ * sizeof_bits::value) / 1024) % Detail::kPointerCount; -+ if (ptr_idx == 0) { -+ ptr = pointers_[0]; -+ } else if (ptr_idx == 1) { -+ ptr = pointers_[1]; -+ } else if (ptr_idx == 2) { -+ ptr = pointers_[2]; -+ } else if (ptr_idx == 3) { -+ ptr = pointers_[3]; -+ } -+ } -+ -+#endif -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+#if CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED -+ -+ // -+ // When the optimization is enabled, this expression suffices to obtain the SMEM pointer. -+ // -+ if (WarpShape::kN == 64) { -+ ptr = pointers_[n / 4]; -+ } -+ else if (!kN32_optimization) -+#endif -+ { -+ // This is the reference implementation -+ int column_idx = warp_column_ + n * Detail::kLanesInQuad * Policy::kElementsPerAccess; -+ int ptr_idx = ((column_idx * sizeof_bits::value) / 1024) % Detail::kPointerCount; -+ -+ if (ptr_idx == 0) { -+ ptr = pointers_[0 % Detail::kPointerCount]; -+ } -+ else if (ptr_idx == 1) { -+ ptr = pointers_[1 % Detail::kPointerCount]; -+ } -+ else if (ptr_idx == 2) { -+ ptr = pointers_[2 % Detail::kPointerCount]; -+ } -+ else if (ptr_idx == 3) { -+ ptr = pointers_[3 % Detail::kPointerCount]; -+ } -+ } -+ -+ int offset = n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess; -+ ptr[offset] = frag_ptr[n]; -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int column_idx = warp_column_ + n * Detail::kLanesInQuad * Policy::kElementsPerAccess; -+ int ptr_idx = ((column_idx * sizeof_bits::value) / 1024) % Detail::kPointerCount; -+ -+ AccessType const *smem_ptr = pointers_[ptr_idx]; -+ frag_ptr[n] = smem_ptr[n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess]; -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for int32_t x 16 => int8_t/int4b_t x 16 -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape), -+ int OutputSizeBits ///< Size of output element in bits -+> -+class TileIteratorTensorOpMixed { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using Element = int32_t; -+ using Layout = layout::RowMajor; -+ static int const kOutputElementCount = 16; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ Element, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ -+ /// Number of pointers needed to write accumulators -+ static int const kPointerCount = 2; -+ -+ /// Offsets added -+ static int const kOffsetCount = 4; -+ -+ static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape<0, Detail::kLanesInQuad * 2>; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointers_[Detail::kPointerCount]; -+ -+ /// Stride in units of AccessType -+ int stride_; -+ -+ /// Uniform offset in bytes added to warp tile iterator -+ int uniform_offset_[Detail::kOffsetCount]; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] = nullptr; -+ } -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ stride_(ref.stride()[0] / AccessType::kElements) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; -+ int column_idx = lane_in_quad ^ (i * 2); -+ -+ ptr += column_idx; -+ -+ if (i == 0) { -+ pointers_[0] = ptr; -+ } -+ else if (i == 1) { -+ pointers_[1] = ptr; -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kOffsetCount; ++i) { -+ uniform_offset_[i] = (i ^ 0) * 4 * sizeof(AccessType); -+ } -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] += pointer_offset / AccessType::kElements; -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int ptr_offset = tile_offset.row() * Shape::kRow * stride_ + -+ tile_offset.column() * Shape::kColumn / AccessType::kElements; -+ -+ pointers_[0] += ptr_offset; -+ pointers_[1] += ptr_offset; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kOffsetCount; ++i) { -+ uniform_offset_[i] = (i ^ tile_offset.column()) * 4 * sizeof(AccessType); -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { -+ return add_tile_offset(tile_offset); -+ } -+ -+ /// Store -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int ptr_idx = (n / 4); -+ int offset_idx = (n % 4); -+ -+ AccessType *ptr; -+ if (ptr_idx == 0) { -+ ptr = pointers_[0]; -+ } -+ else if (ptr_idx == 1) { -+ ptr = pointers_[1]; -+ } -+ -+ int offset = (n / 4) * 16 + pointer_offset / AccessType::kElements; -+ -+#if 0 -+ // -+ // Using inline PTX to avoid generic memory -+ // -+ AccessType *smem_ptr = pointers_[ptr_idx]; -+ smem_ptr[offset] = frag_ptr[n]; -+#else -+ uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr); -+ uint32_t const *data = reinterpret_cast(frag_ptr + n); -+ uint32_t offset_in_bytes = offset * sizeof(AccessType) + uniform_offset_[offset_idx]; -+ -+ asm volatile( -+ "{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n" -+ : : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1]) -+ ); -+#endif -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for int32_t x 8 => int8_t/int4b_t x 8 -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ int OutputSizeBits ///< Size of output element in bits -+> -+class TileIteratorTensorOpMixed { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using Element = int32_t; -+ using Layout = layout::RowMajor; -+ static int const kOutputElementCount = 8; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ Element, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ -+ /// Number of pointers needed to write accumulators -+ static int const kPointerCount = 2; -+ -+ static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape<0, Detail::kLanesInQuad * 2>; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointers_[Detail::kPointerCount]; -+ -+ /// Stride in units of AccessType -+ int stride_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] = nullptr; -+ } -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ stride_(ref.stride()[0] / AccessType::kElements) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; -+ int column_idx = lane_in_quad ^ (i * 2); -+ -+ ptr += column_idx; -+ -+ if (i == 0) { -+ pointers_[0] = ptr; -+ } -+ else if (i == 1) { -+ pointers_[1] = ptr; -+ } -+ } -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] += pointer_offset / AccessType::kElements; -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int ptr_offset = tile_offset.row() * Shape::kRow * stride_ + -+ tile_offset.column() * Shape::kColumn / AccessType::kElements; -+ -+ pointers_[0] += ptr_offset; -+ pointers_[1] += ptr_offset; -+ -+ if (tile_offset.column() % 2) { -+ auto tmp = pointers_[0]; -+ pointers_[0] = pointers_[1]; -+ pointers_[1] = tmp; -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { -+ return add_tile_offset(tile_offset); -+ } -+ -+ /// Store -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int ptr_idx = (n / 4); -+ -+ AccessType *ptr; -+ if (ptr_idx == 0) { -+ ptr = pointers_[0]; -+ } -+ else if (ptr_idx == 1) { -+ ptr = pointers_[1]; -+ } -+ -+ int offset = (n / 4) * 16 + pointer_offset / AccessType::kElements + (n % 4) * 4; -+ -+#if 0 -+ // -+ // Using inline PTX to avoid generic memory -+ // -+ AccessType *smem_ptr = pointers_[ptr_idx]; -+ smem_ptr[offset] = frag_ptr[n]; -+#else -+ uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr); -+ uint32_t const *data = reinterpret_cast(frag_ptr + n); -+ uint32_t offset_in_bytes = offset * sizeof(AccessType); -+ -+ asm volatile( -+ "{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n" -+ : : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1]) -+ ); -+#endif -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#undef CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h -new file mode 100644 -index 0000000..a4cabd7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h -@@ -0,0 +1,440 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+#include "cutlass/epilogue/warp/volta_tensor_op_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename InterleavedTileShape, ///< shape of indivisible instruction-level arrangement (concept: GemmShape) -+ typename ElementC, ///< Accumulator layout -+ typename Layout ///< target shared memory layout -+> -+struct TileIteratorVoltaTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) -+> -+struct TileIteratorVoltaTensorOp, half_t, layout::RowMajor> { -+public: -+ -+ using WarpShape = WarpShape_; -+ using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; -+ using Element = half_t; -+ using Layout = layout::RowMajor; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = VoltaTensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// Array type for aligned memory accesses -+ using AccessType = typename Policy::AccessType; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = typename Policy::Fragment; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = typename Policy::AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ /// Number of elements per access -+ static int const kElementsPerAccess = Policy::kElementsPerAccess; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ static int const kRowsPerQuad = 4; -+ static int const kColumnsPerQuad = 8; -+ static int const kAccessesPerQuad = kColumnsPerQuad / Policy::kElementsPerAccess; -+ static int const kAccessQuadDelta = 16; -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ Policy::kElementsPerAccess>; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ TileIteratorVoltaTensorOp( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / Policy::kElementsPerAccess) { -+ -+ int quad_id = lane_id / Detail::kLanesInQuad; -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ int quad_row_idx = ((quad_id & 4) >> 1) + (quad_id & 1); -+ int quad_col_idx = ((quad_id & 2) >> 1); -+ -+ int row = quad_row_idx * Detail::kRowsPerQuad + lane_in_quad; -+ int column = quad_col_idx * Detail::kColumnsPerQuad; -+ -+ pointer_ += layout_({row, column / kElementsPerAccess}); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / Policy::kElementsPerAccess; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ pointer_ += layout_({ -+ tile_offset.row() * Shape::kRow, -+ tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess}); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_idx = 0; tile_idx < Policy::TileIterations::kColumn; ++tile_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < Policy::kAccessesPerInterleavedTile; ++access_idx) { -+ -+ int access_quad = access_idx / 2; -+ int access = access_idx % 2; -+ -+ int ptr_offset = tile_idx * InterleavedTileShape::kN / Policy::kElementsPerAccess + -+ access_quad * Detail::kAccessQuadDelta / Policy::kElementsPerAccess + -+ access + pointer_offset / Policy::kElementsPerAccess; -+ -+ int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx; -+ -+ AccessType access_vector = frag_ptr[frag_idx]; -+ -+ pointer_[ptr_offset] = access_vector; -+ } -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_idx = 0; tile_idx < Policy::TileIterations::kColumn; ++tile_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < Policy::kAccessesPerInterleavedTile; ++access_idx) { -+ -+ int access_quad = access_idx / 2; -+ int access = access_idx % 2; -+ -+ int ptr_offset = tile_idx * Detail::kTileDelta + access_quad * Detail::kAccessQuadDelta + -+ access + pointer_offset / Policy::kElementsPerAccess; -+ -+ int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx; -+ -+ frag_ptr[frag_idx] = pointer_[ptr_offset]; -+ } -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment const &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) -+> -+struct TileIteratorVoltaTensorOp, float, layout::RowMajor> { -+public: -+ -+ using WarpShape = WarpShape_; -+ using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; -+ using Element = float; -+ using Layout = layout::RowMajor; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = VoltaTensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// Array type for aligned memory accesses -+ using AccessType = typename Policy::AccessType; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = typename Policy::Fragment; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = typename Policy::AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ /// Number of elements per access -+ static int const kElementsPerAccess = Policy::kElementsPerAccess; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ static int const kRowsPerQuad = 4; -+ static int const kColumnsPerQuad = 8; -+ static int const kAccessesPerQuad = kColumnsPerQuad / Policy::kElementsPerAccess; -+ static int const kAccessQuadDelta = 16; -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ Policy::kElementsPerAccess>; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ TileIteratorVoltaTensorOp( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / Policy::kElementsPerAccess) { -+ -+ int quad_id = lane_id / Detail::kLanesInQuad; -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ int const kQuadRowDelta = 4; -+ int const kQuadColumnDelta = 2 * Policy::MmaIterations::kColumn; -+ -+ int quad_row_offset = ((quad_id & 4) / 2 + (quad_id & 1)) * kQuadRowDelta; -+ int quad_column_offset = (quad_id & 2) / 2 * kQuadColumnDelta; -+ -+ int thread_row_offset = (lane_in_quad & 1); -+ int thread_column_offset = (lane_in_quad & 2) / 2; -+ -+ int row = quad_row_offset + thread_row_offset; -+ int column = quad_column_offset + thread_column_offset; -+ -+ pointer_ += layout_({row, column}); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / Policy::kElementsPerAccess; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ pointer_ += layout_({ -+ tile_offset.row() * Shape::kRow, -+ tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess}); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ int const kAccessesPerRow = Policy::TileIterations::kColumn * Policy::MmaIterations::kColumn * 2; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row_idx = 0; row_idx < Policy::kRowsPerMmaTile; ++row_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < kAccessesPerRow; ++access_idx) { -+ -+ int frag_idx = row_idx * kAccessesPerRow + access_idx; -+ -+ int ptr_column_offset = (access_idx & 1) * 2 + -+ (access_idx & 2) * Policy::MmaIterations::kColumn * 2 + -+ (access_idx & 4) * Policy::MmaIterations::kColumn * 2; -+ -+ int ptr_row_offset = row_idx * 2; -+ -+ int ptr_offset = layout_({ptr_row_offset, ptr_column_offset}) + pointer_offset / Policy::kElementsPerAccess; -+ -+ pointer_[ptr_offset] = frag_ptr[frag_idx]; -+ } -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ assert(0); // TODO -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment const &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h -new file mode 100644 -index 0000000..6856b3e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#if !(defined(__clang__) && defined(__CUDA__)) -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/wmma_array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/epilogue/warp/wmma_tensor_op_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorFragment, ///< wmma fragment to be written (concept: nvcuda::wmma::fragment) -+ typename Layout ///< target shared memory layout -+> -+class TileIteratorWmmaTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorFragment_ ///< wmma fragment to be written (concept: nvcuda::wmma::fragment) -+> -+class TileIteratorWmmaTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorFragment = OperatorFragment_; -+ using Layout = layout::RowMajor; -+ -+ // -+ // Derived types -+ // -+ using WmmaDataType = typename OperatorFragment::element_type; -+ using Element = typename cutlass::arch::WmmaToCutlassDataType::Type; ///< Data Type of element stored in nvcuda::wmma::frament -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = WmmaTensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = WmmaFragmentArray; -+ -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ -+ /// Padding quantity -+ // (Epilogue shared memory padding for WMMA Gemm kernel is set to run optimaly on Turing) -+ using Padding = MatrixShape< -+ 0, -+ 4 * Policy::kElementsPerAccess -+ >; -+ -+private: -+ -+ /// Storage type for accessing memory -+ //using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to shared memory -+ TensorRef ref_; -+ -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorWmmaTensorOp(): ref_(nullptr) { -+ -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorWmmaTensorOp( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): ref_(ref) { -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorWmmaTensorOp & add_pointer_offset(Index pointer_offset) { -+ ref_.add_pointer_offset(pointer_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorWmmaTensorOp & add_tile_offset(TensorCoord const &tile_offset) { -+ ref_.add_coord_offset({tile_offset.row() * OperatorShape::kM, tile_offset.column() * WarpShape::kN}); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorWmmaTensorOp & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ for(int n=0; n < Policy::OperatorCount::kColumn; n++) { -+ -+ WmmaDataType* ptr = reinterpret_cast (ref_.data() + ref_.offset({0, n * OperatorShape::kN}) + pointer_offset); -+ -+ nvcuda::wmma::store_matrix_sync( -+ ptr, -+ frag[n], -+ ref_.stride()[0], -+ nvcuda::wmma::layout_t::mem_row_major -+ ); -+ -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ for(int n=0; n < Policy::OperatorCount::kColumn; n++) { -+ -+ WmmaDataType* ptr = reinterpret_cast (ref_.data() + ref_.offset({0, n * OperatorShape::kN}) + pointer_offset); -+ -+ nvcuda::wmma::load_matrix_sync( -+ frag[n], -+ ptr, -+ ref_.stride()[0], -+ nvcuda::wmma::layout_t::mem_row_major -+ ); -+ -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // !defined(__clang__) -+ -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h -new file mode 100644 -index 0000000..dede3fd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h -@@ -0,0 +1,195 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic structures needed for implementing the warp-scoped phase of the epilogue. -+ These quantities assume a 'column-major' arrangement of TensorOp instructions, of which -+ a row-oriented slice is visible per iteration. -+*/ -+ -+#pragma once -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy details related to the epilogue -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename InterleavedTileShape, ///< shape of indivisible instruction-level arrangement (concept: GemmShape) -+ typename ElementC, ///< Accumulator layout -+ typename Layout ///< target shared memory layout -+> -+struct VoltaTensorOpPolicy; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major -+template < -+ typename WarpShape_ ///< shape of warp-level GEMM (concept: GemmShape) -+> -+struct VoltaTensorOpPolicy, half_t, layout::RowMajor> { -+ -+ using WarpShape = WarpShape_; -+ using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; -+ using ElementC = half_t; -+ using Layout = layout::RowMajor; -+ -+ /// Shape of one warp-levelinstruction -+ using InstructionShape = gemm::GemmShape<16, 16, 4>; -+ -+ /// Number of mma operations performed for one 32x32x4 interleaved tile -+ using MmaIterations = MatrixShape< -+ InterleavedTileShape::kM / InstructionShape::kM, -+ InterleavedTileShape::kN / InstructionShape::kN -+ >; -+ -+ /// Number of 32x32x4 interleaved tiles performed to cover the warp-level GEMM shape -+ using TileIterations = MatrixShape< -+ WarpShape::kM / InterleavedTileShape::kM, -+ WarpShape::kN / InterleavedTileShape::kN -+ >; -+ -+ /// Number of accumulator elements owned by each thread per Mma -+ static int const kElementsPerMma = 8; -+ static int const kRowsPerIteration = 16; -+ -+ // -+ // Hard-coded constants regarding Tensor Operations -+ // -+ -+ /// Number of accumulator elements stored per memory instruction to shared memory -+ static int const kElementsPerAccess = 4; -+ -+ /// Number of accesses performed per interleaved tile -+ static int const kAccessesPerInterleavedTile = 4; -+ -+ /// Total number of iterations needed to cover the entire tile -+ static int const kIterations = TileIterations::kRow * 2; -+ -+ // -+ // Derived types -+ // -+ -+ /// Array type for aligned memory accesses -+ using AccessType = AlignedArray; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ ElementC, -+ kElementsPerAccess * kAccessesPerInterleavedTile * TileIterations::kColumn>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ ElementC, -+ TileIterations::kCount * MmaIterations::kCount * kElementsPerMma>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major -+template < -+ typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) -+> -+struct VoltaTensorOpPolicy, float, layout::RowMajor> { -+ -+ using WarpShape = WarpShape_; -+ using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; -+ using ElementC = float; -+ using Layout = layout::RowMajor; -+ -+ /// Shape of one warp-levelinstruction -+ using InstructionShape = gemm::GemmShape<16, 16, 4>; -+ -+ /// Number of mma operations performed for one 32x32x4 interleaved tile -+ using MmaIterations = MatrixShape< -+ InterleavedTileShape::kM / InstructionShape::kM, -+ InterleavedTileShape::kN / InstructionShape::kN -+ >; -+ -+ /// Number of 32x32x4 interleaved tiles performed to cover the warp-level GEMM shape -+ using TileIterations = MatrixShape< -+ WarpShape::kM / InterleavedTileShape::kM, -+ WarpShape::kN / InterleavedTileShape::kN -+ >; -+ -+ /// Number of accumulator elements owned by each thread per Mma -+ static int const kElementsPerMma = 8; -+ static int const kRowsPerIteration = 16; -+ -+ // -+ // Hard-coded constants regarding Tensor Operations -+ // -+ -+ /// Number of accumulator elements stored per memory instruction to shared memory -+ static int const kElementsPerAccess = 2; -+ -+ /// Number of accesses performed per interleaved tile -+ static int const kAccessesPerInterleavedTile = 8; -+ -+ /// Number of rows per interleaved tile -+ static int const kRowsPerMmaTile = 2; -+ -+ /// Total number of iterations needed to cover the entire tile -+ static int const kIterations = TileIterations::kRow * MmaIterations::kRow; -+ -+ // -+ // Derived types -+ // -+ -+ /// Array type for aligned memory accesses -+ using AccessType = AlignedArray; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ ElementC, -+ kElementsPerAccess * kAccessesPerInterleavedTile * TileIterations::kColumn>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ ElementC, -+ TileIterations::kCount * MmaIterations::kCount * kElementsPerMma>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h -new file mode 100644 -index 0000000..bbce5cb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h -@@ -0,0 +1,101 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic structures needed for implementing the warp-scoped phase of the epilogue. -+ These quantities assume a 'column-major' arrangement of TensorOp instructions, of which -+ a row-oriented slice is visible per iteration. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/wmma.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy details related to the epilogue -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm:GemmShape) -+ typename Layout ///< target shared memory layout -+> -+struct WmmaTensorOpPolicy; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape ///< matrix multiply operation shape (concept: gemm::GemmShape) -+> -+struct WmmaTensorOpPolicy { -+ -+ /// Number of operations -+ using OperatorCount = MatrixShape< -+ WarpShape::kM / OperatorShape::kM, -+ WarpShape::kN / OperatorShape::kN -+ >; -+ -+ // -+ // Hard-coded constants regarding Tensor Operations -+ // -+ static int const kElementsPerAccess = 2; -+ static int const kRowsPerIteration = OperatorShape::kM; -+ static int const kWmmaFragmentsPerAccess = 1; -+ -+ // -+ // Derived quantities -+ // -+ -+ // Number of externally visible iterations -+ static int const kIterations = OperatorCount::kRow; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -+ -diff --git a/3rdparty/cutlass/include/cutlass/fast_math.h b/3rdparty/cutlass/include/cutlass/fast_math.h -new file mode 100644 -index 0000000..c449def ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/fast_math.h -@@ -0,0 +1,975 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/uint128.h" -+#include "cutlass/coord.h" -+#include "cutlass/numeric_types.h" -+ -+/** -+ * \file -+ * \brief Math utilities -+ */ -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+CUTLASS_HOST_DEVICE void swap(T &lhs, T &rhs) { -+ T tmp = lhs; -+ lhs = rhs; -+ rhs = tmp; -+} -+ -+/****************************************************************************** -+ * Static math utilities -+ ******************************************************************************/ -+ -+/// Mixed precision dot product -+template -+CUTLASS_HOST_DEVICE LongIndex dot( -+ Coord const &coord, -+ Coord const &stride, -+ LongIndex acc = LongIndex()) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < N; ++n) { -+ acc += LongIndex(coord[n]) * stride[n]; -+ } -+ return acc; -+} -+ -+/** -+ * Statically determine if N is a power-of-two -+ */ -+template -+struct is_pow2 { -+ static bool const value = ((N & (N - 1)) == 0); -+}; -+ -+/** -+ * Statically determine log2(N), rounded down -+ */ -+template -+struct log2_down { -+ /// Static logarithm value -+ enum { value = log2_down> 1), Count + 1>::value }; -+}; -+ -+// Base case -+template -+struct log2_down { -+ enum { value = Count }; -+}; -+ -+/** -+ * Statically determine log2(N), rounded up -+ */ -+template -+struct log2_up { -+ /// Static logarithm value -+ enum { value = log2_up> 1), Count + 1>::value }; -+}; -+ -+// Base case -+template -+struct log2_up { -+ enum { value = ((1 << Count) < N) ? Count + 1 : Count }; -+}; -+ -+/** -+ * Statically estimate sqrt(N) to the nearest power-of-two -+ */ -+template -+struct sqrt_est { -+ enum { value = 1 << (log2_up::value / 2) }; -+}; -+ -+/** -+ * For performing a constant-division with a compile-time assertion that the -+ * Divisor evenly-divides the Dividend. -+ */ -+template -+struct divide_assert { -+ enum { value = Dividend / Divisor }; -+ -+ static_assert((Dividend % Divisor == 0), "Not an even multiple"); -+}; -+ -+/****************************************************************************** -+ * Rounding -+ ******************************************************************************/ -+ -+/** -+ * Round dividend up to the nearest multiple of divisor -+ */ -+template -+CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor) { -+ return ((dividend + divisor - 1) / divisor) * divisor; -+} -+ -+/** -+ * Greatest common divisor -+ */ -+template -+CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) { -+ for (;;) { -+ if (a == 0) return b; -+ b %= a; -+ if (b == 0) return a; -+ a %= b; -+ } -+} -+ -+/** -+ * Least common multiple -+ */ -+template -+CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) { -+ value_t temp = gcd(a, b); -+ -+ return temp ? (a / temp * b) : 0; -+} -+ -+/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b -+CUTLASS_HOST_DEVICE -+constexpr int round_up(int a, int b) { -+ return ((a + b - 1) / b) * b; -+} -+ -+/// Returns the ceiling of (a / b) -+CUTLASS_HOST_DEVICE -+constexpr int ceil_div(int a, int b) { -+ return (a + b - 1) / b; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/** -+ * log2 computation, what's the -+ * difference between the below codes and -+ * log2_up/down codes? -+ */ -+template -+CUTLASS_HOST_DEVICE value_t clz(value_t x) { -+ for (int i = 31; i >= 0; --i) { -+ if ((1 << i) & x) return 31 - i; -+ } -+ return 32; -+} -+ -+template -+CUTLASS_HOST_DEVICE value_t find_log2(value_t x) { -+ int a = int(31 - clz(x)); -+ a += (x & (x - 1)) != 0; // Round up, add 1 if not a power of 2. -+ return a; -+} -+ -+ -+/** -+ * Find divisor, using find_log2 -+ */ -+CUTLASS_HOST_DEVICE -+void find_divisor(unsigned int& mul, unsigned int& shr, unsigned int denom) { -+ if (denom == 1) { -+ mul = 0; -+ shr = 0; -+ } else { -+ unsigned int p = 31 + find_log2(denom); -+ unsigned m = unsigned(((1ull << p) + unsigned(denom) - 1) / unsigned(denom)); -+ -+ mul = m; -+ shr = p - 32; -+ } -+} -+ -+/** -+ * Find quotient and remainder using device-side intrinsics -+ */ -+CUTLASS_HOST_DEVICE -+void fast_divmod(int& quo, int& rem, int src, int div, unsigned int mul, unsigned int shr) { -+ -+ #if defined(__CUDA_ARCH__) -+ // Use IMUL.HI if div != 1, else simply copy the source. -+ quo = (div != 1) ? __umulhi(src, mul) >> shr : src; -+ #else -+ quo = int((div != 1) ? int(((int64_t)src * mul) >> 32) >> shr : src); -+ #endif -+ -+ // The remainder. -+ rem = src - (quo * div); -+} -+ -+// For long int input -+CUTLASS_HOST_DEVICE -+void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul, unsigned int shr) { -+ -+ #if defined(__CUDA_ARCH__) -+ // Use IMUL.HI if div != 1, else simply copy the source. -+ quo = (div != 1) ? __umulhi(src, mul) >> shr : src; -+ #else -+ quo = int((div != 1) ? ((src * mul) >> 32) >> shr : src); -+ #endif -+ // The remainder. -+ rem = src - (quo * div); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Object to encapsulate the fast division+modulus operation. -+/// -+/// This object precomputes two values used to accelerate the computation and is best used -+/// when the divisor is a grid-invariant. In this case, it may be computed in host code and -+/// marshalled along other kernel arguments using the 'Params' pattern. -+/// -+/// Example: -+/// -+/// -+/// int quotient, remainder, dividend, divisor; -+/// -+/// FastDivmod divmod(divisor); -+/// -+/// divmod(quotient, remainder, dividend); -+/// -+/// // quotient = (dividend / divisor) -+/// // remainder = (dividend % divisor) -+/// -+struct FastDivmod { -+ -+ int divisor; -+ unsigned int multiplier; -+ unsigned int shift_right; -+ -+ /// Find quotient and remainder using device-side intrinsics -+ CUTLASS_HOST_DEVICE -+ void fast_divmod(int& quotient, int& remainder, int dividend) const { -+ -+#if defined(__CUDA_ARCH__) -+ // Use IMUL.HI if divisor != 1, else simply copy the source. -+ quotient = (divisor != 1) ? __umulhi(dividend, multiplier) >> shift_right : dividend; -+#else -+ quotient = int((divisor != 1) ? int(((int64_t)dividend * multiplier) >> 32) >> shift_right : dividend); -+#endif -+ -+ // The remainder. -+ remainder = dividend - (quotient * divisor); -+ } -+ -+ /// For long int input -+ CUTLASS_HOST_DEVICE -+ void fast_divmod(int& quotient, int64_t& remainder, int64_t dividend) const { -+ -+#if defined(__CUDA_ARCH__) -+ // Use IMUL.HI if divisor != 1, else simply copy the source. -+ quotient = (divisor != 1) ? __umulhi(dividend, multiplier) >> shift_right : dividend; -+#else -+ quotient = int((divisor != 1) ? ((dividend * multiplier) >> 32) >> shift_right : dividend); -+#endif -+ // The remainder. -+ remainder = dividend - (quotient * divisor); -+ } -+ -+ -+ /// Construct the FastDivmod object, in host code ideally. -+ /// -+ /// This precomputes some values based on the divisor and is computationally expensive. -+ -+ CUTLASS_HOST_DEVICE -+ FastDivmod(): divisor(0), multiplier(0), shift_right(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ FastDivmod(int divisor): divisor(divisor) { -+ -+ if (divisor != 1) { -+ unsigned int p = 31 + find_log2(divisor); -+ unsigned m = unsigned(((1ull << p) + unsigned(divisor) - 1) / unsigned(divisor)); -+ -+ multiplier = m; -+ shift_right = p - 32; -+ } else { -+ multiplier = 0; -+ shift_right = 0; -+ } -+ } -+ -+ /// Computes integer division and modulus using precomputed values. This is computationally -+ /// inexpensive. -+ CUTLASS_HOST_DEVICE -+ void operator()(int "ient, int &remainder, int dividend) const { -+ fast_divmod(quotient, remainder, dividend); -+ } -+ -+ /// Computes integer division using precomputed values. This is computationally -+ /// inexpensive. -+ CUTLASS_HOST_DEVICE -+ int div(int dividend) const { -+ int quotient, remainder; -+ fast_divmod(quotient, remainder, dividend); -+ return quotient; -+ } -+ -+ /// Computes integer division and modulus using precomputed values. This is computationally -+ /// inexpensive. -+ /// -+ /// Simply returns the quotient -+ CUTLASS_HOST_DEVICE -+ int divmod(int &remainder, int dividend) const { -+ int quotient; -+ fast_divmod(quotient, remainder, dividend); -+ return quotient; -+ } -+ -+ /// Computes integer division and modulus using precomputed values. This is computationally -+ /// inexpensive. -+ CUTLASS_HOST_DEVICE -+ void operator()(int "ient, int64_t &remainder, int64_t dividend) const { -+ fast_divmod(quotient, remainder, dividend); -+ } -+ -+ /// Computes integer division and modulus using precomputed values. This is computationally -+ /// inexpensive. -+ CUTLASS_HOST_DEVICE -+ int divmod(int64_t &remainder, int64_t dividend) const { -+ int quotient; -+ fast_divmod(quotient, remainder, dividend); -+ return quotient; -+ } -+ -+ /// Returns the divisor when cast to integer -+ CUTLASS_HOST_DEVICE -+ operator int() const { return divisor; } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Object to encapsulate the fast division+modulus operation for 64b integer division. -+/// -+/// This object precomputes two values used to accelerate the computation and is best used -+/// when the divisor is a grid-invariant. In this case, it may be computed in host code and -+/// marshalled along other kernel arguments using the 'Params' pattern. -+/// -+/// Example: -+/// -+/// -+/// uint64_t quotient, remainder, dividend, divisor; -+/// -+/// FastDivmodU64 divmod(divisor); -+/// -+/// divmod(quotient, remainder, dividend); -+/// -+/// // quotient = (dividend / divisor) -+/// // remainder = (dividend % divisor) -+/// -+struct FastDivmodU64 { -+ -+ uint64_t divisor; -+ uint64_t multiplier; -+ unsigned int shift_right; -+ unsigned int round_up; -+ -+ // -+ // Static methods -+ // -+ -+ /// Computes b, where 2^b is the greatest power of two that is less than or equal to x -+ CUTLASS_HOST_DEVICE -+ static uint32_t integer_log2(uint64_t x) { -+ uint32_t n = 0; -+ while (x >>= 1) { -+ ++n; -+ } -+ return n; -+ } -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ FastDivmodU64(): divisor(0), multiplier(0), shift_right(0), round_up(0) { } -+ -+ /// Construct the FastDivmod object, in host code ideally. -+ /// -+ /// This precomputes some values based on the divisor and is computationally expensive. -+ CUTLASS_HOST_DEVICE -+ FastDivmodU64(uint64_t divisor_): divisor(divisor_), multiplier(1), shift_right(0), round_up(0) { -+ -+ if (divisor) { -+ shift_right = integer_log2(divisor); -+ -+ if ((divisor & (divisor - 1)) == 0) { -+ multiplier = 0; -+ } -+ else { -+ uint64_t power_of_two = (uint64_t(1) << shift_right); -+ uint64_t multiplier_lo = uint128_t(0, power_of_two) / divisor; -+ multiplier = uint128_t(power_of_two, power_of_two) / divisor; -+ round_up = (multiplier_lo == multiplier ? 1 : 0); -+ } -+ } -+ } -+ -+ /// Returns the quotient of floor(dividend / divisor) -+ CUTLASS_HOST_DEVICE -+ uint64_t divide(uint64_t dividend) const { -+ uint64_t quotient = 0; -+ -+ #ifdef __CUDA_ARCH__ -+ uint64_t x = dividend; -+ if (multiplier) { -+ x = __umul64hi(dividend + round_up, multiplier); -+ } -+ quotient = (x >> shift_right); -+ #else -+ // TODO - use proper 'fast' division here also. No reason why x86-code shouldn't be optimized. -+ quotient = dividend / divisor; -+ #endif -+ -+ return quotient; -+ } -+ -+ /// Computes the remainder given a computed quotient and dividend -+ CUTLASS_HOST_DEVICE -+ uint64_t modulus(uint64_t quotient, uint64_t dividend) const { -+ return uint32_t(dividend - quotient * divisor); -+ } -+ -+ /// Returns the quotient of floor(dividend / divisor) and computes the remainder -+ CUTLASS_HOST_DEVICE -+ uint64_t divmod(uint64_t &remainder, uint64_t dividend) const { -+ uint64_t quotient = divide(dividend); -+ remainder = modulus(quotient, dividend); -+ return quotient; -+ } -+ -+ /// Computes integer division and modulus using precomputed values. This is computationally -+ /// inexpensive. -+ CUTLASS_HOST_DEVICE -+ void operator()(uint64_t "ient, uint64_t &remainder, uint64_t dividend) const { -+ quotient = divmod(remainder, dividend); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes the coordinate decomposition from a linear index (64-bit linear index => coord) -+/// -+/// This decomposition is accelerated by the FastDivmodU64 object. It is assumed that -+/// a coordinate of indices can be decomposed by div/mod operations. -+/// Note, is assumed that element divmod[0] divides by extent[1]. -+/// -+/// For example, assume 4-D coordinate (n, p, q, c) is mapped to a linear index `npqc`. This -+/// can be decomposed via three divide and modulus operations: -+/// -+/// c = npqc % C; | divmod[2] = FastDivmodU64(C) -+/// npq = npqc / C; | coord[3] = c -+/// -+/// q = npq % Q; | divmod[1] = FastDivmodU64(Q) -+/// np = npq / Q; | coord[2] = q -+/// -+/// p = np % P; | divmod[0] = FastDivmodU64(P) -+/// n = np / P; | coord[1] = p -+/// -+/// | coord[0] = n -+/// -+template -+CUTLASS_HOST_DEVICE Coord CoordinateDecomposition( -+ uint64_t linear_idx, ///< Linear index to decompose -+ FastDivmodU64 const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects -+ -+ static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); -+ -+ Coord coord; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = Rank; i > 1; --i) { -+ uint64_t remainder; -+ linear_idx = divmod[i - 2].divmod(remainder, linear_idx); -+ coord[i - 1] = int(remainder); -+ } -+ -+ coord[0] = int(linear_idx); -+ -+ return coord; -+} -+ -+/// Computes the coordinate decomposition from a linear index (32-bit linear index => coord) -+template -+CUTLASS_HOST_DEVICE Coord CoordinateDecomposition( -+ int linear_idx, ///< Linear index to decompose -+ FastDivmod const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects -+ -+ static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); -+ -+ Coord coord; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = Rank; i > 1; --i) { -+ int remainder; -+ linear_idx = divmod[i - 2].divmod(remainder, linear_idx); -+ coord[i - 1] = int(remainder); -+ } -+ -+ coord[0] = int(linear_idx); -+ -+ return coord; -+} -+ -+template -+CUTLASS_HOST_DEVICE Coord CoordinateDecompositionLittleEndian( -+ uint64_t linear_idx, ///< Linear index to decompose -+ FastDivmodU64 const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects -+ -+ static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); -+ -+ Coord coord; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank - 1; ++i) { -+ uint64_t remainder; -+ linear_idx = divmod[i].divmod(remainder, linear_idx); -+ coord[i] = int(remainder); -+ } -+ -+ coord[Rank - 1] = int(linear_idx); -+ -+ return coord; -+} -+ -+/// Computes the coordinate decomposition from a linear index (32-bit linear index => coord) -+template -+CUTLASS_HOST_DEVICE Coord CoordinateDecompositionLittleEndian( -+ int linear_idx, ///< Linear index to decompose -+ FastDivmod const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects -+ -+ static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); -+ -+ Coord coord; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank - 1; ++i) { -+ int remainder; -+ linear_idx = divmod[i].divmod(remainder, linear_idx); -+ coord[i] = int(remainder); -+ } -+ -+ coord[Rank - 1] = int(linear_idx); -+ -+ return coord; -+} -+ -+/// Safely computes the offset of a linear index in bytes for all types -+template -+CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index) { -+ -+ static_assert( -+ (sizeof_bits::value >= 8 && !(sizeof_bits::value % 8)) || -+ (sizeof_bits::value < 8 && !(8 % sizeof_bits::value)), -+ "Size of numeric type in bits must either be divisible by 8 bits, or 8 bits must be divisible by the size."); -+ -+ if (sizeof_bits::value >= 8) { -+ return index * (sizeof_bits::value / 8); -+ } -+ else { -+ int const kElementsPerByte = ((8 / sizeof_bits::value) + ((sizeof_bits::value >= 8) ? 1 : 0)); -+ return index / kElementsPerByte; -+ } -+} -+ -+CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index, int64_t element_sizeof_bits) { -+ if (element_sizeof_bits >= 8) { -+ return index * (element_sizeof_bits / 8); -+ } -+ else { -+ int64_t const kElementsPerByte = ((8 / element_sizeof_bits) + ((element_sizeof_bits >= 8) ? 1 : 0)); -+ return index / kElementsPerByte; -+ } -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Min/Max -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct Min { -+ static int const kValue = (A < B) ? A : B; -+}; -+ -+template -+struct Max { -+ static int const kValue = (A > B) ? A : B; -+}; -+ -+CUTLASS_HOST_DEVICE -+constexpr int const_min(int a, int b) { -+ return (b < a ? b : a); -+} -+ -+CUTLASS_HOST_DEVICE -+constexpr int const_max(int a, int b) { -+ return (b > a ? b : a); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+T fast_min(T a, T b) { -+ return (b < a ? b : a); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+float fast_min(float a, float b) { -+ return fminf(a, b); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+T fast_max(T a, T b) { -+ return (a < b ? b : a); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+float fast_max(float a, float b) { -+ return fmaxf(a, b); -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_cos(float theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::cosf(theta); -+ #else -+ return std::cos(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_cos(double theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::cos(theta); -+ #else -+ return std::cos(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_sin(float theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::sinf(theta); -+ #else -+ return std::sin(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_sin(double theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::sin(theta); -+ #else -+ return std::sin(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_acos(float theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::acosf(theta); -+ #else -+ return std::acos(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_acos(double theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::acos(theta); -+ #else -+ return std::acos(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_asin(float theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::asinf(theta); -+ #else -+ return std::asin(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_asin(double theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::asin(theta); -+ #else -+ return std::asin(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_sqrt(float theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::sqrtf(theta); -+ #else -+ return std::sqrt(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_sqrt(double theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::sqrt(theta); -+ #else -+ return std::sqrt(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_exp(float x) { -+ #if defined(__CUDA_ARCH__) -+ return ::expf(x); -+ #else -+ return std::exp(x); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_exp(double x) { -+ #if defined(__CUDA_ARCH__) -+ return ::exp(x); -+ #else -+ return std::exp(x); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t fast_exp(half_t x) { -+ #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750) -+ return (half_t)(::hexp(x.to_half())); -+ #else -+ return (half_t)(fast_exp(float(x))); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_log(float x) { -+ #if defined(__CUDA_ARCH__) -+ return ::logf(x); -+ #else -+ return std::log(x); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_log(double x) { -+ #if defined(__CUDA_ARCH__) -+ return ::log(x); -+ #else -+ return std::log(x); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_tanh(float x) { -+ #if defined(__CUDA_ARCH__) -+ #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) -+ float y; -+ asm volatile ( "tanh.approx.f32 %0, %1; " : "=f"(y) : "f"(x)); -+ return y; -+ #else -+ return ::tanhf(x); -+ #endif -+ #else -+ return std::tanh(x); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_tanh(double x) { -+ #if defined(__CUDA_ARCH__) -+ return ::tanh(x); -+ #else -+ return std::tanh(x); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t fast_tanh(half_t x) { -+ #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) -+ -+ asm volatile ( "tanh.approx.f16 %0, %1;" : "=h"(x.raw()) : "h"(x.raw())); -+ return x; -+ -+ #else -+ return half_t(fast_tanh(float(x))); -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct fast_exp_op { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &rhs) const { -+ return fast_exp(rhs); -+ } -+}; -+ -+#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750) -+template -+struct fast_exp_op> { -+ CUTLASS_DEVICE -+ Array operator()(Array const &rhs) const { -+ -+ Array result; -+ -+ // use x2 specialization -+ __half2 const *in = reinterpret_cast<__half2 const *>(&rhs); -+ __half2 *out = reinterpret_cast<__half2 *>(&result); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ out[i] = ::h2exp(in[i]); -+ } -+ -+ // residual -+ if (N % 2) { -+ half_t last = rhs[N - 1]; -+ result[N - 1] = half_t(::hexp(last.to_half())); -+ } -+ -+ return result; -+ } -+}; -+#endif // #if defined(__CUDA_ARCH__) -+ -+template -+struct fast_exp_op> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &rhs) const { -+ -+ fast_exp_op fast_op; -+ Array y; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = fast_op(rhs[i]); -+ } -+ -+ return y; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct fast_tanh_op { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &rhs) const { -+ return fast_tanh(rhs); -+ } -+}; -+ -+#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) -+template -+struct fast_tanh_op> { -+ CUTLASS_DEVICE -+ Array operator()(Array const &rhs) const { -+ -+ Array result; -+ -+ // use x2 specialization -+ uint32_t const *in = reinterpret_cast(&rhs); -+ uint32_t *out = reinterpret_cast(&result); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ asm volatile ("tanh.approx.f16x2 %0, %1;" : "=r"(out[i]) : "r"(in[i])); -+ } -+ -+ // residual -+ if (N % 2) { -+ uint16_t const *in = reinterpret_cast(&rhs); -+ uint16_t *out = reinterpret_cast(&result); -+ asm volatile ("tanh.approx.f16 %0, %1;" : "=h"(out[N - 1]) : "h"(in[N - 1])); -+ } -+ -+ return result; -+ } -+}; -+#endif // #if defined(__CUDA_ARCH__) -+ -+template -+struct fast_tanh_op> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &rhs) const { -+ -+ fast_tanh_op fast_op; -+ Array y; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = fast_op(rhs[i]); -+ } -+ -+ return y; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Absolute value function -+template -+CUTLASS_HOST_DEVICE -+T absolute_value(T x) { -+ if (x < T()) { -+ return -x; -+ } -+ return x; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/float8.h b/3rdparty/cutlass/include/cutlass/float8.h -new file mode 100644 -index 0000000..93e3209 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/float8.h -@@ -0,0 +1,1215 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines a class for using IEEE half-precision floating-point types in host or -+ device code. -+*/ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#if defined(__CUDACC_RTC__) -+ -+#include "cutlass/floating_point_nvrtc.h" -+ -+#else -+// -+// Standard Library headers belong here to avoid conflicts with NVRTC. -+// -+#include -+#include -+#include -+#include -+#endif -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)) -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -+#ifndef CUDA_PTX_FP8_CVT_ENABLED -+#define CUDA_PTX_FP8_CVT_ENABLED 1 -+#endif -+#endif -+#endif -+ -+#ifdef __GNUC__ -+// Ignore checks on reinterpret-casts that are being used for bitcasts. -+#pragma GCC diagnostic ignored "-Wstrict-aliasing" -+#endif -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// FP8 Has 2 encodings possible : E4M3 and E5M2 -+// -+// E4M3 : 7 | 6 5 4 3 | 2 1 0 -+// E5M2 : 7 | 6 5 4 3 2 | 1 0 -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+enum class FloatEncoding { -+ E4M3, -+ E5M2 -+}; -+ -+template -+struct alignas(1) float8_base { -+ -+ static constexpr bool IS_E4M3 = (T == FloatEncoding::E4M3); -+ static constexpr bool IS_E5M2 = (T == FloatEncoding::E5M2); -+ -+ // Number of Bits representing mantissa and exponents -+ static constexpr int FP32_NUM_BITS = 32; -+ static constexpr int FP32_NUM_EXPONENT_BITS = 8; -+ static constexpr int FP32_NUM_MANTISSA_BITS = 23; -+ static constexpr uint32_t FP32_NAN = 0x7fffffff; -+ static constexpr uint32_t FP32_INFINITY_MASK = 0x7f800000; -+ static constexpr int FP32_MAX_EXPONENT = 127; -+ static constexpr int FP32_MIN_EXPONENT = -126; -+ static constexpr int FP32_EXPONENT_BIAS = 127; -+ -+ static constexpr int FP16_NUM_BITS = 16; -+ static constexpr int FP16_NUM_EXPONENT_BITS = 5; -+ static constexpr int FP16_NUM_MANTISSA_BITS = 10; -+ static constexpr uint16_t FP16_NAN = 0x7fff; -+ static constexpr uint16_t FP16_INFINITY_MASK = 0x7c00; -+ static constexpr int FP16_MAX_EXPONENT = 15; -+ static constexpr int FP16_MIN_EXPONENT = -14; -+ static constexpr int FP16_EXPONENT_BIAS = 15; -+ -+ static constexpr int FP8_NUM_BITS = 8; -+ static constexpr int FP8_NUM_EXPONENT_BITS = IS_E4M3 ? 4 : 5; -+ static constexpr int FP8_NUM_MANTISSA_BITS = IS_E4M3 ? 3 : 2; -+ static constexpr uint8_t FP8_NAN = 0x7f; // Also F8_INF -+ static constexpr uint8_t FP8_INFINITY_MASK = IS_E4M3 ? 0x78 : 0x7c; -+ static constexpr int FP8_MAX_EXPONENT = IS_E4M3 ? 7 : 15; -+ static constexpr int FP8_MIN_EXPONENT = IS_E4M3 ? -6 : -14; -+ static constexpr int FP8_EXPONENT_BIAS = IS_E4M3 ? 7 : 15; -+ -+ static constexpr uint8_t FP8_EXPONENT_MASK = (1 << FP8_NUM_EXPONENT_BITS) - 1; -+ static constexpr uint8_t FP8_MANTISSA_MASK = (1 << FP8_NUM_MANTISSA_BITS) - 1; -+ -+ static constexpr uint8_t FP8_MAX_FLT = (IS_E4M3 ? 0x7e : 0x7b); -+ -+ // 256 in float -+ static constexpr uint32_t FP8_SAT_VAL_FP32 = 0x43800000; -+ -+ // -+ // Data members -+ // -+ -+ /// Data container -+ uint8_t storage; -+ -+ /// Ctors. -+ CUTLASS_HOST_DEVICE -+ float8_base() : storage(0) { } -+ -+ /// Is finite implementation -+ CUTLASS_HOST_DEVICE -+ static bool isfinite(float flt) { -+ uint32_t s; -+ -+ #if defined(__CUDA_ARCH__) -+ s = reinterpret_cast(flt); -+ #else -+ std::memcpy(&s, &flt, sizeof(s)); -+ #endif -+ -+ return (s & 0x7f800000) < 0x7f800000; -+ } -+ -+ /// Is NaN implementation -+ CUTLASS_HOST_DEVICE -+ static bool isnan(float flt) { -+ uint32_t s; -+ -+ #if defined(__CUDA_ARCH__) -+ s = reinterpret_cast(flt); -+ #else -+ std::memcpy(&s, &flt, sizeof(s)); -+ #endif -+ -+ return (s & 0x7fffffff) > 0x7f800000; -+ } -+ -+ /// Is infinite implementation -+ CUTLASS_HOST_DEVICE -+ static bool isinf(float flt) { -+ uint32_t s; -+ -+ #if defined(__CUDA_ARCH__) -+ s = reinterpret_cast(flt); -+ #else -+ std::memcpy(&s, &flt, sizeof(s)); -+ #endif -+ -+ // Sign = 0 for +inf, 1 for -inf -+ // Exponent = all ones -+ // Mantissa = all zeros -+ return (s == 0x7f800000) || (s == 0xff800000); -+ } -+ -+ /// FP32 -> FP8 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static uint8_t convert_float_to_fp8(float const& flt) { -+ -+ // software implementation rounds toward nearest even -+ uint32_t s; -+ -+ #if defined(__CUDA_ARCH__) -+ s = reinterpret_cast(flt); -+ #else -+ std::memcpy(&s, &flt, sizeof(s)); -+ #endif -+ -+ // Extract the bits in the FP32 type -+ uint8_t sign = uint8_t((s >> 24 & 0x80)); -+ int8_t exp = uint8_t(((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS); -+ int mantissa = s & 0x7fffff; -+ uint8_t u = 0; -+ -+ uint8_t const kF8_NaN = 0x7f; -+ -+ // NaN => NaN -+ if (isnan(flt)) { -+ return kF8_NaN; -+ } -+ -+ // Inf => MAX_FLT (satfinite) -+ if (isinf(flt)) { -+ return sign | FP8_MAX_FLT; -+ } -+ -+ // Special handling -+ if ( exp == -128 ) { -+ // int8 range is from -128 to 127 -+ // So 255(inf) - 127(bias) = 128 - will show up as -128 -+ -+ // satfinite -+ return (sign | FP8_MAX_FLT); -+ } -+ -+ int sticky_bit = 0; -+ -+ bool skip_sign = false; -+ bool may_be_nan = false; -+ -+ if ( (exp >= FP8_MIN_EXPONENT) && (exp <= FP8_MAX_EXPONENT) ) { -+ // normal fp32 to normal fp8 -+ exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS)); -+ u = uint8_t(((exp & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS)); -+ u = uint8_t(u | (mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS))); -+ } else if(exp < FP8_MIN_EXPONENT) { -+ // normal single-precision to subnormal float8-precision representation -+ int rshift = (FP8_MIN_EXPONENT - exp); -+ if (rshift < FP32_NUM_BITS) { -+ mantissa |= (1 << FP32_NUM_MANTISSA_BITS); -+ -+ sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0); -+ -+ mantissa = (mantissa >> rshift); -+ u = (uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS- FP8_NUM_MANTISSA_BITS)) & FP8_MANTISSA_MASK); -+ } else { -+ mantissa = 0; -+ u = 0; -+ } -+ // Exponent > FP8_MAX_EXPONENT - this is a special case done to match HW -+ // 0x4380_0000 to 0x43e0_0000 - maps from 256 to 448, and does not saturate / inf. -+ } else { -+ if( exp == (FP8_MAX_EXPONENT + 1) ) { -+ uint8_t mantissa_tmp = uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)); -+ if( mantissa_tmp < FP8_MANTISSA_MASK) { -+ exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS)); -+ u = uint8_t(exp << FP8_NUM_MANTISSA_BITS) | mantissa_tmp; -+ may_be_nan = (mantissa_tmp == (FP8_MANTISSA_MASK-1)); -+ } else { -+ // satfinite -+ return (sign | FP8_MAX_FLT); -+ } -+ } else{ -+ // satfinite -+ return (sign | FP8_MAX_FLT); -+ } -+ } -+ -+ // round to nearest even -+ int NUM_BITS_SHIFT = FP32_NUM_MANTISSA_BITS - (FP8_NUM_MANTISSA_BITS + 1); -+ int round_bit = ((mantissa >> NUM_BITS_SHIFT) & 1); -+ sticky_bit |= ((mantissa & ((1 << NUM_BITS_SHIFT) - 1)) != 0); -+ -+ if ((round_bit && sticky_bit) || (round_bit && (u & 1))) { -+ u = uint8_t(u + 1); -+ if( may_be_nan ) { -+ skip_sign = true; -+ } -+ } -+ -+ if (u > FP8_MAX_FLT) { -+ // satfinite -+ u = (sign | FP8_MAX_FLT); -+ } -+ -+ if( ! skip_sign ) { -+ u |= sign; -+ } -+ -+ return u; -+ } -+ -+ -+ /// Converts a fp8 value stored as a uint8_t to a float -+ CUTLASS_HOST_DEVICE -+ static float convert_fp8_to_float(uint8_t const& x) { -+ -+ uint32_t constexpr kF32_NaN = 0x7fffffff; -+ -+ uint8_t const &f8 = x; -+ int sign = (f8 >> (FP8_NUM_BITS - 1)) & 1; -+ int exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK; -+ int mantissa = f8 & FP8_MANTISSA_MASK; -+ unsigned f = (sign << (FP32_NUM_BITS-1)); -+ -+ if (IS_E4M3 && exp == 15 && mantissa == 0x7) { -+ f = kF32_NaN; -+ } -+ else if (exp > 0 && (IS_E4M3 || exp < (FP8_MAX_EXPONENT + FP8_EXPONENT_BIAS + 1))) { -+ // normal -+ exp += (FP32_EXPONENT_BIAS - FP8_EXPONENT_BIAS); -+ f = f | -+ (exp << FP32_NUM_MANTISSA_BITS) | -+ (mantissa << (FP32_NUM_MANTISSA_BITS-FP8_NUM_MANTISSA_BITS)); -+ } else if (exp == 0) { -+ if (mantissa) { -+ // subnormal -+ exp += (FP32_EXPONENT_BIAS - FP8_EXPONENT_BIAS) + 1; -+ while ((mantissa & (1 << FP8_NUM_MANTISSA_BITS)) == 0) { -+ mantissa <<= 1; -+ exp--; -+ } -+ mantissa &= FP8_MANTISSA_MASK; -+ f = f | -+ (exp << FP32_NUM_MANTISSA_BITS) | -+ (mantissa << (FP32_NUM_MANTISSA_BITS-FP8_NUM_MANTISSA_BITS)); -+ } else { -+ // sign-preserving zero -+ } -+ } else { -+ if(mantissa == 0){ -+ // Sign-preserving infinity -+ f = (f | 0x7f800000); -+ } else { -+ // Canonical NaN -+ f = kF32_NaN; -+ } -+ } -+ -+ #if defined(__CUDA_ARCH__) -+ return reinterpret_cast(f); -+ #else -+ float flt; -+ std::memcpy(&flt, &f, sizeof(flt)); -+ return flt; -+ #endif -+ } -+}; -+ -+ -+// Forward declaration of float_e5m2_t to define float_e4m3_t <=> float_e5m2_t -+// conversions in class float_e4m3_t -+struct float_e5m2_t; -+ -+ -+/////////////////////////////////////////////////////////////// -+/// -+/// floating-point 8 type : E4M3 -+/// -+/////////////////////////////////////////////////////////////// -+struct alignas(1) float_e4m3_t : float8_base { -+ -+ using Base = float8_base; -+ -+ static constexpr int MAX_EXPONENT = Base::FP8_MAX_EXPONENT; -+ -+ // -+ // Static conversion operators -+ // -+ -+ /// Constructs from an uint8_t -+ CUTLASS_HOST_DEVICE -+ static float_e4m3_t bitcast(uint8_t x) { -+ float_e4m3_t f; -+ f.storage = x; -+ return f; -+ } -+ -+ /// FP32 -> FP8 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static float_e4m3_t from_float(float const& flt) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t tmp; -+ float y = float(); -+ asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt)); -+ -+ return *reinterpret_cast(&tmp); -+ #else -+ return bitcast(Base::convert_float_to_fp8(flt)); -+ #endif -+ } -+ -+ /// FP16 -> E5M2 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static float_e4m3_t from_half(half const& flt) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t tmp = 0; -+ uint32_t bits = reinterpret_cast(flt); -+ asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;" : "=h"(tmp) : "r"(bits)); -+ -+ return *reinterpret_cast(&tmp); -+ #else -+ return bitcast(Base::convert_float_to_fp8(float(flt))); -+ #endif -+ } -+ -+ // E4M3 -> half -+ CUTLASS_HOST_DEVICE -+ static half to_half(float_e4m3_t const& x) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t bits = x.storage; -+ uint32_t packed; -+ asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); -+ -+ return reinterpret_cast(packed).x; -+ #else -+ return half(Base::convert_fp8_to_float(x.storage)); -+ #endif -+ } -+ -+ // E4M3 -> Float -+ CUTLASS_HOST_DEVICE -+ static float to_float(float_e4m3_t const& x) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t bits = x.storage; -+ uint32_t packed; -+ asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); -+ -+ return float(reinterpret_cast(packed).x); -+ #else -+ return Base::convert_fp8_to_float(x.storage); -+ #endif -+ } -+ -+ // -+ // Methods -+ // -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ float_e4m3_t() : Base() { } -+ -+ /// Reinterpret cast from CUDA's FP8 type -+ CUTLASS_HOST_DEVICE -+ float_e4m3_t(float_e4m3_t const& x) { -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(x); -+ #else -+ uint8_t raw = x.storage; -+ std::memcpy(&storage, &raw, sizeof(storage)); -+ #endif -+ } -+ -+ /// Floating point conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e4m3_t(float x) { -+ storage = from_float(x).storage; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ explicit float_e4m3_t(half x) { -+ storage = from_half(x).storage; -+ } -+ -+ /// Floating point conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e4m3_t(double x): float_e4m3_t(float(x)) { -+ } -+ -+ /// Integer conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e4m3_t(int x): float_e4m3_t(float(x)) { -+ } -+ -+ /// E5M2 conversion. Defined after float_e5m2_t is defined. -+ CUTLASS_HOST_DEVICE -+ explicit float_e4m3_t(float_e5m2_t x); -+ -+ /// Assignment -+ CUTLASS_HOST_DEVICE -+ float_e4m3_t & operator=(float_e4m3_t const &x) { -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(x); -+ #else -+ uint8_t raw = x.storage; -+ std::memcpy(&storage, &raw, sizeof(storage)); -+ #endif -+ return *this; -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ operator float() const { -+ return to_float(*this); -+ } -+ -+ /// Converts to half -+ CUTLASS_HOST_DEVICE -+ operator half() const { -+ return to_half(*this); -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(to_float(*this)); -+ } -+ -+ /// Converts to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ #if defined(__CUDA_ARCH__) -+ return __half2int_rn(to_half(*this)); -+ #else -+ return int(to_float(*this)); -+ #endif -+ } -+ -+ /// Casts to bool -+ CUTLASS_HOST_DEVICE -+ explicit operator bool() const { -+ #if defined(__CUDA_ARCH__) -+ return bool(__half2int_rn(to_half(*this))); -+ #else -+ return bool(int(to_float(*this))); -+ #endif -+ } -+ -+ /// Accesses raw internal state -+ CUTLASS_HOST_DEVICE -+ uint8_t& raw() { -+ return storage; -+ } -+ -+ /// Accesses raw internal state -+ CUTLASS_HOST_DEVICE -+ uint8_t raw() const { -+ return storage; -+ } -+ -+ /// Returns the sign bit -+ CUTLASS_HOST_DEVICE -+ bool signbit() const { -+ return ((storage & (1 << (Base::FP8_NUM_BITS - 1))) != 0); -+ } -+ -+ /// Returns the biased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent_biased() const { -+ return int((storage >> FP8_NUM_MANTISSA_BITS) & Base::FP8_EXPONENT_MASK); -+ } -+ -+ /// Returns the unbiased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent() const { -+ return exponent_biased() - 15; -+ } -+ -+ /// Returns the mantissa -+ CUTLASS_HOST_DEVICE -+ int mantissa() const { -+ return int(storage & Base::FP8_MANTISSA_MASK); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////// -+/// -+/// floating-point 8 type : E5M2 -+/// -+/////////////////////////////////////////////////////////////// -+struct alignas(1) float_e5m2_t : float8_base { -+ -+ using Base = float8_base; -+ -+ static constexpr int MAX_EXPONENT = Base::FP8_MAX_EXPONENT; -+ -+ // -+ // Static conversion operators -+ // -+ -+ /// Constructs from an uint8_t -+ CUTLASS_HOST_DEVICE -+ static float_e5m2_t bitcast(uint8_t x) { -+ float_e5m2_t f; -+ f.storage = x; -+ return f; -+ } -+ -+ /// FP32 -> FP8 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static float_e5m2_t from_float(float const& flt) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t tmp; -+ float y = float(); -+ asm volatile("cvt.rn.satfinite.e5m2x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt)); -+ -+ return *reinterpret_cast(&tmp); -+ #else -+ return bitcast(Base::convert_float_to_fp8(flt)); -+ #endif -+ } -+ -+ /// FP16 -> E5M2 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static float_e5m2_t from_half(half const& flt) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t tmp = 0; -+ uint32_t bits = reinterpret_cast(flt); -+ asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;" : "=h"(tmp) : "r"(bits)); -+ -+ return *reinterpret_cast(&tmp); -+ #else -+ return bitcast(Base::convert_float_to_fp8(float(flt))); -+ #endif -+ } -+ -+ // E5M2 -> half -+ CUTLASS_HOST_DEVICE -+ static half to_half(float_e5m2_t const& x) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t bits = x.storage; -+ uint32_t packed; -+ asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); -+ -+ return reinterpret_cast(packed).x; -+ #else -+ return half(Base::convert_fp8_to_float(x.storage)); -+ #endif -+ } -+ -+ // E5M2 -> Float -+ CUTLASS_HOST_DEVICE -+ static float to_float(float_e5m2_t const& x) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t bits = x.storage; -+ uint32_t packed; -+ asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); -+ -+ return float(reinterpret_cast(packed).x); -+ #else -+ return Base::convert_fp8_to_float(x.storage); -+ #endif -+ } -+ -+ // -+ // Methods -+ // -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ float_e5m2_t() : Base() { } -+ -+ /// Reinterpret cast from CUDA's FP8 type -+ CUTLASS_HOST_DEVICE -+ float_e5m2_t(float_e5m2_t const& x) { -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(x); -+ #else -+ uint8_t raw = x.storage; -+ std::memcpy(&storage, &raw, sizeof(storage)); -+ #endif -+ } -+ -+ /// Floating point conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e5m2_t(float x) { -+ storage = from_float(x).storage; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ explicit float_e5m2_t(half x) { -+ storage = from_half(x).storage; -+ } -+ -+ /// Floating point conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e5m2_t(double x): float_e5m2_t(float(x)) { -+ } -+ -+ /// Integer conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e5m2_t(int x): float_e5m2_t(float(x)) { -+ } -+ -+ /// E4M3 conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e5m2_t(float_e4m3_t x); -+ -+ /// Assignment -+ CUTLASS_HOST_DEVICE -+ float_e5m2_t & operator=(float_e5m2_t const &x) { -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(x); -+ #else -+ uint8_t raw = x.storage; -+ std::memcpy(&storage, &raw, sizeof(storage)); -+ #endif -+ return *this; -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ operator float() const { -+ return to_float(*this); -+ } -+ -+ /// Converts to half -+ CUTLASS_HOST_DEVICE -+ operator half() const { -+ return to_half(*this); -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(to_float(*this)); -+ } -+ -+ /// Converts to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ #if defined(__CUDA_ARCH__) -+ return __half2int_rn(to_half(*this)); -+ #else -+ return int(to_float(*this)); -+ #endif -+ } -+ -+ /// Casts to bool -+ CUTLASS_HOST_DEVICE -+ explicit operator bool() const { -+ #if defined(__CUDA_ARCH__) -+ return bool(__half2int_rn(to_half(*this))); -+ #else -+ return bool(int(to_float(*this))); -+ #endif -+ } -+ -+ /// Accesses raw internal state -+ CUTLASS_HOST_DEVICE -+ uint8_t& raw() { -+ return storage; -+ } -+ -+ /// Accesses raw internal state -+ CUTLASS_HOST_DEVICE -+ uint8_t raw() const { -+ return storage; -+ } -+ -+ /// Returns the sign bit -+ CUTLASS_HOST_DEVICE -+ bool signbit() const { -+ return ((storage & (1 << (Base::FP8_NUM_BITS - 1))) != 0); -+ } -+ -+ /// Returns the biased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent_biased() const { -+ return int((storage >> FP8_NUM_MANTISSA_BITS) & Base::FP8_EXPONENT_MASK); -+ } -+ -+ /// Returns the unbiased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent() const { -+ return exponent_biased() - 15; -+ } -+ -+ /// Returns the mantissa -+ CUTLASS_HOST_DEVICE -+ int mantissa() const { -+ return int(storage & Base::FP8_MANTISSA_MASK); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Arithmetic operators -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool operator==(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float(lhs) == float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator!=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float(lhs) != float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float(lhs) < float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float(lhs) <= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float(lhs) > float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float(lhs) >= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator+(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float_e4m3_t(float(lhs) + float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator-(float_e4m3_t const& lhs) { -+ return float_e4m3_t(-float(lhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator-(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float_e4m3_t(float(lhs) - float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator*(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float_e4m3_t(float(lhs) * float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator/(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float_e4m3_t(float(lhs) / float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t& operator+=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { -+ lhs = float_e4m3_t(float(lhs) + float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t& operator-=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { -+ lhs = float_e4m3_t(float(lhs) - float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t& operator*=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { -+ lhs = float_e4m3_t(float(lhs) * float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t& operator/=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { -+ lhs = float_e4m3_t(float(lhs) / float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t& operator++(float_e4m3_t & lhs) { -+ float tmp(lhs); -+ ++tmp; -+ lhs = float_e4m3_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t& operator--(float_e4m3_t & lhs) { -+ float tmp(lhs); -+ --tmp; -+ lhs = float_e4m3_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator++(float_e4m3_t & lhs, int) { -+ float_e4m3_t ret(lhs); -+ float tmp(lhs); -+ tmp++; -+ lhs = float_e4m3_t(tmp); -+ return ret; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator--(float_e4m3_t & lhs, int) { -+ float_e4m3_t ret(lhs); -+ float tmp(lhs); -+ tmp--; -+ lhs = float_e4m3_t(tmp); -+ return ret; -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator==(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float(lhs) == float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator!=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float(lhs) != float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float(lhs) < float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float(lhs) <= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float(lhs) > float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float(lhs) >= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator+(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float_e5m2_t(float(lhs) + float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator-(float_e5m2_t const& lhs) { -+ return float_e5m2_t(-float(lhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator-(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float_e5m2_t(float(lhs) - float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator*(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float_e5m2_t(float(lhs) * float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator/(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float_e5m2_t(float(lhs) / float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t& operator+=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { -+ lhs = float_e5m2_t(float(lhs) + float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t& operator-=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { -+ lhs = float_e5m2_t(float(lhs) - float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t& operator*=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { -+ lhs = float_e5m2_t(float(lhs) * float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t& operator/=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { -+ lhs = float_e5m2_t(float(lhs) / float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t& operator++(float_e5m2_t & lhs) { -+ float tmp(lhs); -+ ++tmp; -+ lhs = float_e5m2_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t& operator--(float_e5m2_t & lhs) { -+ float tmp(lhs); -+ --tmp; -+ lhs = float_e5m2_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator++(float_e5m2_t & lhs, int) { -+ float_e5m2_t ret(lhs); -+ float tmp(lhs); -+ tmp++; -+ lhs = float_e5m2_t(tmp); -+ return ret; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator--(float_e5m2_t & lhs, int) { -+ float_e5m2_t ret(lhs); -+ float tmp(lhs); -+ tmp--; -+ lhs = float_e5m2_t(tmp); -+ return ret; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// float_e4m3_t <=> float_e5m2_t conversions -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// float_e4m3_t <= float_e5m2_t -+CUTLASS_HOST_DEVICE -+float_e4m3_t::float_e4m3_t(float_e5m2_t x) { -+ storage = from_float(float_e5m2_t::to_float(x)).storage; -+} -+ -+/// float_e5m2_t <= float_e4m3_t -+CUTLASS_HOST_DEVICE -+float_e5m2_t::float_e5m2_t(float_e4m3_t x) { -+ storage = from_float(float_e4m3_t::to_float(x)).storage; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Standard Library operations and definitions -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if !defined(__CUDACC_RTC__) -+namespace std { -+ -+/// Numeric limits common to all float8 types -+template -+struct float8_base_numeric_limits { -+private: -+ using F8Type = T; -+public: -+ static bool const is_specialized = true; -+ static bool const is_signed = true; -+ static bool const is_integer = false; -+ static bool const is_exact = false; -+ static bool const has_quiet_NaN = true; -+ static bool const has_signaling_NaN = false; -+ static std::float_denorm_style const has_denorm = std::denorm_present; -+ static bool const has_denorm_loss = true; -+ static std::float_round_style const round_style = std::round_to_nearest; -+ static bool const is_iec559 = false; -+ static bool const is_bounded = true; -+ static bool const is_modulo = false; -+ static int const digits = F8Type::FP8_NUM_MANTISSA_BITS; -+ -+ /// Least positive value -+ static F8Type min() { return F8Type::bitcast(0x01); } -+ -+ /// Maximum finite value -+ static F8Type max() { return F8Type::bitcast(F8Type::FP8_MAX_FLT); } -+ -+ /// Returns maximum rounding error -+ static F8Type round_error() { return F8Type(0.5f); } -+ -+ /// Returns positive infinity value -+ static F8Type infinity() { return F8Type::bitcast(F8Type::FP8_INFINITY_MASK); } -+ -+ /// Returns quiet NaN value -+ static F8Type quiet_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } -+ -+ /// Returns signaling NaN value -+ static F8Type signaling_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } -+ -+ /// Returns smallest positive subnormal value -+ static F8Type denorm_min() { return F8Type::bitcast(0x01); } -+}; -+ -+/// Numeric limits for float_e4m3_t -+template <> -+struct numeric_limits : -+ public float8_base_numeric_limits { -+ static bool const has_infinity = false; -+ -+ /// Minimum finite value -+ static cutlass::float_e4m3_t lowest() { return cutlass::float_e4m3_t::bitcast(0xfe); } -+ -+ /// Returns smallest finite value -+ static cutlass::float_e4m3_t epsilon() { return cutlass::float_e4m3_t::bitcast(0x20); } -+}; -+ -+/// Numeric limits for float_e5m2_t -+template <> -+struct numeric_limits : -+ public float8_base_numeric_limits { -+ static bool const has_infinity = true; -+ -+ /// Minimum finite value -+ static cutlass::float_e5m2_t lowest() { return cutlass::float_e5m2_t::bitcast(0xfb); } -+ -+ /// Returns smallest finite value -+ static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); } -+}; -+ -+} // namespace std -+#endif -+ -+namespace platform { -+ -+/// Numeric limits common to all float8 types -+template -+struct float8_base_numeric_limits { -+private: -+ using F8Type = T; -+public: -+ static bool const is_specialized = true; -+ static bool const is_signed = true; -+ static bool const is_integer = false; -+ static bool const is_exact = false; -+ static bool const has_quiet_NaN = true; -+ static bool const has_signaling_NaN = false; -+#if !defined(__CUDACC_RTC__) -+ static std::float_denorm_style const has_denorm = std::denorm_present; -+#endif -+ static bool const has_denorm_loss = true; -+#if !defined(__CUDACC_RTC__) -+ static std::float_round_style const round_style = std::round_to_nearest; -+#endif -+ static bool const is_iec559 = false; -+ static bool const is_bounded = true; -+ static bool const is_modulo = false; -+ static int const digits = F8Type::FP8_NUM_MANTISSA_BITS; -+ -+ /// Least positive value -+ static F8Type min() { return F8Type::bitcast(0x01); } -+ -+ /// Maximum finite value -+ static F8Type max() { return F8Type::bitcast(F8Type::FP8_MAX_FLT); } -+ -+ /// Returns maximum rounding error -+ static F8Type round_error() { return F8Type(0.5f); } -+ -+ /// Returns positive infinity value -+ static F8Type infinity() { return F8Type::bitcast(F8Type::FP8_INFINITY_MASK); } -+ -+ /// Returns quiet NaN value -+ static F8Type quiet_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } -+ -+ /// Returns signaling NaN value -+ static F8Type signaling_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } -+ -+ /// Returns smallest positive subnormal value -+ static F8Type denorm_min() { return F8Type::bitcast(0x01); } -+}; -+ -+/// std::numeric_limits -+template -+struct numeric_limits; -+ -+/// Numeric limits for float_e4m3_t -+template <> -+struct numeric_limits : -+ public float8_base_numeric_limits { -+ static bool const has_infinity = false; -+ -+ /// Minimum finite value -+ static cutlass::float_e4m3_t lowest() { return cutlass::float_e4m3_t::bitcast(0xfe); } -+ -+ /// Returns smallest finite value -+ static cutlass::float_e4m3_t epsilon() { return cutlass::float_e4m3_t::bitcast(0x20); } -+}; -+ -+/// Numeric limits for float_e5m2_t -+template <> -+struct numeric_limits : -+ public float8_base_numeric_limits { -+ static bool const has_infinity = true; -+ -+ /// Minimum finite value -+ static cutlass::float_e5m2_t lowest() { return cutlass::float_e5m2_t::bitcast(0xfb); } -+ -+ /// Returns smallest finite value -+ static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); } -+}; -+ -+} // namespace platform -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// User-defined literals -+// -+ -+CUTLASS_HOST_DEVICE -+cutlass::float_e4m3_t operator "" _fe4m3(long double x) { -+ return cutlass::float_e4m3_t(float(x)); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::float_e4m3_t operator "" _fe4m3(unsigned long long int x) { -+ return cutlass::float_e4m3_t(int(x)); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::float_e5m2_t operator "" _fe5m2(long double x) { -+ return cutlass::float_e5m2_t(float(x)); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::float_e5m2_t operator "" _fe5m2(unsigned long long int x) { -+ return cutlass::float_e5m2_t(int(x)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/floating_point_nvrtc.h b/3rdparty/cutlass/include/cutlass/floating_point_nvrtc.h -new file mode 100644 -index 0000000..99deff5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/floating_point_nvrtc.h -@@ -0,0 +1,65 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines categories for floating point numbers for use in NVRTC-compiled code -+*/ -+ -+#pragma once -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// All floating-point numbers can be put in one of these categories. -+enum { -+ FP_NAN = -+# define FP_NAN 0 -+ FP_NAN, -+ FP_INFINITE = -+# define FP_INFINITE 1 -+ FP_INFINITE, -+ FP_ZERO = -+# define FP_ZERO 2 -+ FP_ZERO, -+ FP_SUBNORMAL = -+# define FP_SUBNORMAL 3 -+ FP_SUBNORMAL, -+ FP_NORMAL = -+# define FP_NORMAL 4 -+ FP_NORMAL -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/functional.h b/3rdparty/cutlass/include/cutlass/functional.h -new file mode 100644 -index 0000000..277bad5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/functional.h -@@ -0,0 +1,490 @@ -+ /*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Define basic numeric operators -+ -+ This is inspired by the Standard Library's header. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/half.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include -+#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct absolute_value_op { -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs) const { -+ return abs(lhs); -+ } -+}; -+ -+template <> -+struct absolute_value_op { -+ CUTLASS_HOST_DEVICE -+ float operator()(float lhs) const { return fabs(lhs); } -+}; -+ -+template -+struct plus { -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs, T const &rhs) const { -+ lhs += rhs; -+ return lhs; -+ } -+}; -+ -+template -+struct minus { -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs, T const &rhs) const { -+ lhs -= rhs; -+ return lhs; -+ } -+}; -+ -+template -+struct multiplies { -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs, T const &rhs) const { -+ lhs *= rhs; -+ return lhs; -+ } -+}; -+ -+// Maximum with nan propogation -+// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN -+template -+struct maximum_with_nan_propogation { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &lhs, T const &rhs) const { -+ return lhs > rhs or std::isnan(lhs) ? lhs : rhs; -+ } -+}; -+ -+template <> -+struct maximum_with_nan_propogation { -+ CUTLASS_HOST_DEVICE -+ float operator()(float const lhs, float const rhs) const { -+ float res; -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); -+#else -+ res = lhs > rhs or std::isnan(lhs) ? lhs : rhs; -+#endif -+ return res; -+ } -+}; -+ -+/// Squares with optional conversion -+template -+struct square { -+ CUTLASS_HOST_DEVICE -+ Output operator()(T lhs) const { -+ multiplies mul_op; -+ -+ Output y = Output(lhs); -+ return mul_op(y, y); -+ } -+}; -+ -+/// Returns the magnitude squared of an element. -+template -+struct magnitude_squared { -+ CUTLASS_HOST_DEVICE -+ Output operator()(T lhs) const { -+ multiplies mul_op; -+ -+ Output y = Output(lhs); -+ return mul_op(y, y); -+ } -+}; -+ -+/// Computes the square of a difference with optional conversion -+template -+struct square_difference { -+ CUTLASS_HOST_DEVICE -+ Output operator()(T lhs, T rhs) const { -+ multiplies mul_op; -+ -+ Output y = Output(lhs) - Output(rhs); -+ return mul_op(y, y); -+ } -+}; -+ -+/// Computes the square of a difference with optional conversion -+template -+struct magnitude_squared_difference { -+ CUTLASS_HOST_DEVICE -+ Output operator()(T lhs, T rhs) const { -+ multiplies mul_op; -+ -+ Output y = Output(lhs) - Output(rhs); -+ return mul_op(y, y); -+ } -+}; -+ -+/// Divides -+template -+struct divides { -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs, T const &rhs) const { -+ lhs /= rhs; -+ return lhs; -+ } -+}; -+ -+/// Negate -+template -+struct negate { -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs) const { -+ return -lhs; -+ } -+}; -+ -+/// Greater equal -+template -+struct greater_equal { -+ CUTLASS_HOST_DEVICE -+ bool operator()(T const &lhs, T const &rhs) const { -+ return (lhs >= rhs); -+ } -+}; -+ -+/// Greater -+template -+struct greater { -+ CUTLASS_HOST_DEVICE -+ bool operator()(T const &lhs, T const &rhs) const { -+ return (lhs > rhs); -+ } -+}; -+ -+/// Less equal -+template -+struct less_equal { -+ CUTLASS_HOST_DEVICE -+ bool operator()(T const &lhs, T const &rhs) const { -+ return (lhs <= rhs); -+ } -+}; -+ -+/// Less -+template -+struct less { -+ CUTLASS_HOST_DEVICE -+ bool operator()(T const &lhs, T const &rhs) const { -+ return (lhs < rhs); -+ } -+}; -+ -+template -+struct maximum { -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &lhs, T const &rhs) const { -+ return (lhs < rhs ? rhs : lhs); -+ } -+}; -+ -+template <> -+struct maximum { -+ CUTLASS_HOST_DEVICE -+ float operator()(float const &lhs, float const &rhs) const { -+ return fmaxf(lhs, rhs); -+ } -+}; -+ -+template -+struct minimum { -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &lhs, T const &rhs) const { -+ return (rhs < lhs ? rhs : lhs); -+ } -+}; -+ -+template <> -+struct minimum { -+ CUTLASS_HOST_DEVICE -+ float operator()(float const &lhs, float const &rhs) const { -+ return fminf(lhs, rhs); -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add { -+ CUTLASS_HOST_DEVICE -+ C operator()(A const &a, B const &b, C const &c) const { -+ return C(a) * C(b) + c; -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add_relu0 { -+ CUTLASS_HOST_DEVICE -+ C operator()(A const &a, B const &b, C const &c) const { -+ maximum mx; -+ return mx(C(a) * C(b) + c, C(0)); -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct and_add { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b, T const &c) const { -+ return ((a & b) + c); -+ } -+}; -+ -+ -+/// Fused multiply-add -+template -+struct xor_add { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b, T const &c) const { -+ return ((a ^ b) + c); -+ } -+}; -+ -+template -+struct conjugate { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a) const { -+ return a; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct logical_and { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b) const { -+ return ((a && b) ? T(1) : T()); -+ } -+}; -+ -+template -+struct logical_or { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b) const { -+ return ((a || b) ? T(1) : T()); -+ } -+}; -+ -+template -+struct logical_not { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a) const { -+ return T(!(a)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct bit_and { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b) const { -+ return a & b; -+ } -+}; -+ -+template -+struct bit_or { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b) const { -+ return a | b; -+ } -+}; -+ -+template -+struct bit_not { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a) const { -+ return ~a; -+ } -+}; -+ -+template -+struct bit_xor { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b) const { -+ return a ^ b; -+ } -+}; -+ -+ -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Reduces value into the data pointed to by ptr -+template -+struct red -+{ -+ CUTLASS_DEVICE -+ void operator()(T *ptr, const T &data) -+ { -+ atomicAdd(ptr, data); -+ } -+}; -+ -+ -+/// Reduces value into the data pointed to by ptr (double specialization) -+template<> -+struct red -+{ -+ CUTLASS_DEVICE -+ void operator()(double *ptr, const double &data) -+ { -+#if !defined(__CUDA_ARCH__) -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_UNUSED(data); -+#elif (__CUDA_ARCH__ >= 600) -+ -+ atomicAdd(ptr, data); -+ -+#else -+ -+ // Use CAS loop -+ unsigned long long int* ptr_int = reinterpret_cast(ptr); -+ unsigned long long int old_int = *ptr_int; -+ unsigned long long int assumed_int; -+ -+ do { -+ double update = data + __longlong_as_double(old_int); -+ assumed_int = old_int; -+ old_int = atomicCAS(ptr_int, assumed_int, __double_as_longlong(update)); -+ } while (assumed_int != old_int); -+ -+#endif // (__CUDA_ARCH__ >= 600) -+ } -+}; -+ -+ -+/// Reduces value into the data pointed to by ptr (half2 specialization) -+template<> -+struct red -+{ -+ CUTLASS_DEVICE -+ void operator()(half2 *ptr, const half2 &data) -+ { -+#if !defined(__CUDA_ARCH__) -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_UNUSED(data); -+#elif (__CUDA_ARCH__ >= 600) -+ -+ // Vector-2 atomic reduction requires .target sm_60 or higher -+ uint32_t word = reinterpret_cast(data); -+ asm volatile ("red.gpu.global.add.noftz.f16x2 [%0], %1;\n" : : "l"(ptr), "r"(word)); -+ -+#else -+ -+ // Use CAS loop -+ uint32_t *ptr_int = reinterpret_cast(ptr); -+ uint32_t old_int = *ptr_int; -+ uint32_t assumed_int; -+ -+ do -+ { -+ half2 old = reinterpret_cast(old_int); -+ -+ half hi = __hadd(__high2half(old), __high2half(data)); -+ half lo = __hadd(__low2half(old), __low2half(data)); -+ half2 update = __halves2half2(hi, lo); -+ uint32_t update_int = reinterpret_cast(update); -+ -+ assumed_int = old_int; -+ old_int = atomicCAS(ptr_int, assumed_int, update_int); -+ -+ } while (assumed_int != old_int); -+ -+#endif // (__CUDA_ARCH__ >= 600) -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for nvcuda::wmma::fragment -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+template -+struct plus> -+{ -+ using Fragment = nvcuda::wmma::fragment; -+ using ElementType = typename Fragment::element_type; -+ -+ CUTLASS_HOST_DEVICE -+ Fragment operator()(Fragment const &lhs, Fragment const &rhs) const -+ { -+ Fragment result; -+ plus scalar_op; -+ -+ ElementType *result_elts = reinterpret_cast(&result); -+ const ElementType *lhs_elts = reinterpret_cast(&lhs); -+ const ElementType *rhs_elts = reinterpret_cast(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Fragment::num_elements; i++) { -+ result_elts[i] = scalar_op(lhs_elts[i], rhs_elts[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/collective_builder.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/collective_builder.hpp -new file mode 100644 -index 0000000..3cd68a4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/collective_builder.hpp -@@ -0,0 +1,78 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#include "collective_mma.hpp" -+ -+namespace cutlass::gemm::collective { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Used to specify stage counts or dispatch to automatic computation of stage count -+template -+struct StageCount { static constexpr int value = num_stages; }; -+struct StageCountAuto {}; -+ -+// Used to automatically let the builder pick the kernel schedule. -+// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp -+struct KernelScheduleAuto {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class ArchTag, -+ class OpClass, -+ class ElementA, -+ class GmemLayoutA, -+ int AlignmentA, -+ class ElementB, -+ class GmemLayoutB, -+ int AlignmentB, -+ class ElementAccumulator, -+ class TileShape_MNK, -+ class ClusterShape_MNK, -+ class StageCountType, -+ class KernelScheduleType, -+ class Enable = void -+> -+struct CollectiveBuilder { -+ static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "builders/sm90_gmma_builder.inl" -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/collective_mma.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/collective_mma.hpp -new file mode 100644 -index 0000000..a2a9067 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/collective_mma.hpp -@@ -0,0 +1,71 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::collective { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class DispatchPolicy, -+ class TileShape, -+ class ElementA, -+ class StrideA, -+ class ElementB, -+ class StrideB, -+ class TiledMma, -+ class GmemTiledCopyA, -+ class SmemLayoutAtomA, -+ class SmemCopyAtomA, -+ class TransformA, -+ class GmemTiledCopyB, -+ class SmemLayoutAtomB, -+ class SmemCopyAtomB, -+ class TransformB -+> -+struct CollectiveMma { -+ static_assert(sizeof(ElementA) == 0, "Could not find a mainloop specialization."); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "sm70_mma_twostage.hpp" -+#include "sm80_mma_multistage.hpp" -+#include "sm90_mma_multistage_gmma_ss.hpp" -+#include "sm90_mma_tma_gmma_ss.hpp" -+#include "sm90_mma_tma_gmma_ss_warpspecialized.hpp" -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp -new file mode 100644 -index 0000000..11e5515 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp -@@ -0,0 +1,588 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+#include "cute/algorithm/functional.hpp" -+#include "cute/atom/mma_atom.hpp" -+#include "cute/algorithm/gemm.hpp" -+#include "cute/atom/mma_atom.hpp" -+#include "cute/tensor_predicate.hpp" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::collective { -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm70TwoStageUnpredicated, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm70TwoStageUnpredicated; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; -+ using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})))); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})))); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_a; -+ cute::array_aligned> smem_b; -+ }; -+ -+ struct Params { -+ ElementA const* ptr_A; -+ StrideA dA; -+ ElementB const* ptr_B; -+ StrideB dB; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CollectiveMma() = default; -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.ptr_A, args.dA, args.ptr_B, args.dB}; -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ template < -+ class FrgTensorD, -+ class TensorA, -+ class TensorB, -+ class FrgTensorC, -+ class KTileIterator, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ FrgTensorD &accum, -+ TensorA gA, -+ TensorB gB, -+ FrgTensorC const &src_accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char *smem_buf) -+ { -+ using namespace cute; -+ -+ (void)residue_mnk; -+ -+ static_assert(is_rmem::value, "D tensor must be rmem resident."); -+ static_assert(is_gmem::value, "A tensor must be gmem resident."); -+ static_assert(is_gmem::value, "B tensor must be gmem resident."); -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutA{}) == 2, -+ "MainloopTwoStage must not have a smem shape with a pipeline mode."); -+ static_assert(rank(SmemLayoutB{}) == 2, -+ "MainloopTwoStage must not have a smem shape with a pipeline mode."); -+ -+ // Construct shared memory tiles -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // Partition the copying of A and B tiles across the threads -+ GmemTiledCopyA gmem_tiled_copy_a; -+ GmemTiledCopyB gmem_tiled_copy_b; -+ auto copy_a_thr = gmem_tiled_copy_a.get_slice(thread_idx); -+ auto copy_b_thr = gmem_tiled_copy_b.get_slice(thread_idx); -+ -+ Tensor tAgA = copy_a_thr.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) -+ Tensor tAsA = copy_a_thr.partition_D(sA); // (ACPY,ACPY_M,ACPY_K) -+ Tensor tBgB = copy_b_thr.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) -+ Tensor tBsB = copy_b_thr.partition_D(sB); // (BCPY,BCPY_N,BCPY_K) -+ -+ // Allocate the register tiles for double buffering -- same shape as partitioned data -+ Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) -+ Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_N,BCPY_K) -+ -+ // Tile MMA compute thread partitions and allocate accumulators -+ TiledMma tiled_mma; -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) -+ Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_M,MMA_K) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K -+ -+ // -+ // Copy Atom retiling -+ // -+ -+ auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx); -+ Tensor tCsA = thr_copy_A.partition_S(sA); -+ Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M -+ -+ auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx); -+ Tensor tCsB = thr_copy_B.partition_S(sB); -+ Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N -+ -+ // -+ // Prologue -+ // -+ -+ // Copy gmem to rmem for the first k_tile -+ copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tArA); -+ copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBrB); -+ if (--k_tile_count > 0) ++k_tile_iter; -+ // Copy rmem to smem -+ copy(tArA, tAsA); -+ copy(tBrB, tBsB); -+ // Clear accumulators -+ __syncthreads(); -+ -+ // Load A, B smem->rmem for k=0 -+ copy(tCsA(_,_,0), tCrA_copy_view(_,_,0)); -+ copy(tCsB(_,_,0), tCrB_copy_view(_,_,0)); -+ // -+ // Mainloop -+ // -+ -+ // Size of the k-tiles's outer product mode (k) -+ auto K_BLOCK_MAX = size<2>(tCrA); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ while (k_tile_count > -1) -+ { -+ // Pipeline the outer products with a static for loop -+ for_each(make_int_sequence{}, [&] (auto k_block) -+ { -+ if (k_block == K_BLOCK_MAX - 1) -+ { -+ __syncthreads(); -+ -+ // Copy rmem to smem -+ copy(tArA, tAsA); -+ copy(tBrB, tBsB); -+ __syncthreads(); -+ } -+ -+ // Load A, B smem->rmem for k+1 -+ int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static -+ copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); -+ copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); -+ if (k_block == 0) -+ { -+ // Copy gmem to rmem -+ copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tArA); -+ copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBrB); -+ if (--k_tile_count > 0) ++k_tile_iter; -+ } -+ -+ // transform before compute -+ cute::transform(tCrA(_,_,k_block), TransformA{}); -+ cute::transform(tCrB(_,_,k_block), TransformB{}); -+ -+ // Thread-level register gemm for k -+ // disambiguate gemm (shared with the namespace name) -+ cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); -+ }); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm70TwoStage, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm70TwoStage; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; -+ using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})))); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})))); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_a; -+ cute::array_aligned> smem_b; -+ }; -+ -+ struct Params { -+ ElementA const* ptr_A; -+ StrideA dA; -+ ElementB const* ptr_B; -+ StrideB dB; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CollectiveMma() = default; -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.ptr_A, args.dA, args.ptr_B, args.dB}; -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ template < -+ class FrgTensorD, -+ class TensorA, -+ class TensorB, -+ class FrgTensorC, -+ class KTileIterator, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ FrgTensorD &accum, -+ TensorA gA, -+ TensorB gB, -+ FrgTensorC const &src_accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char *smem_buf) -+ { -+ using namespace cute; -+ -+ static_assert(is_rmem::value, "D tensor must be rmem resident."); -+ static_assert(is_gmem::value, "A tensor must be gmem resident."); -+ static_assert(is_gmem::value, "B tensor must be gmem resident."); -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutA{}) == 2, -+ "MainloopTwoStage must not have a smem shape with a pipeline mode."); -+ static_assert(rank(SmemLayoutB{}) == 2, -+ "MainloopTwoStage must not have a smem shape with a pipeline mode."); -+ -+ // Construct shared memory tiles -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) -+ // This aligns the tensor with BLK_K for all but the 0th k_tile -+ gA.data() = &gA(0, get<2>(residue_mnk), 0); -+ gB.data() = &gB(0, get<2>(residue_mnk), 0); -+ -+ // Partition the copying of A and B tiles across the threads -+ GmemTiledCopyA gmem_tiled_copy_a; -+ GmemTiledCopyB gmem_tiled_copy_b; -+ auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); -+ auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); -+ -+ Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) -+ Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) -+ Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) -+ Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) -+ -+ // Allocate the register tiles for double buffering -- same shape as partitioned data -+ Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) -+ Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_N,BCPY_K) -+ -+ // -+ // PREDICATES -+ // -+ -+ // Allocate predicate tensors for m and n -+ Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); -+ Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); -+ -+ // Construct identity layout for sA and sB -+ Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) -+ Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) -+ -+ // Repeat the partitioning with identity layouts -+ Tensor tAcA = gmem_thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) -+ Tensor tBcB = gmem_thr_copy_b.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) -+ -+ // Set predicates for m bounds -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < size<0>(tApA); ++m) { -+ tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m -+ } -+ // Set predicates for n bounds -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < size<0>(tBpB); ++n) { -+ tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n -+ } -+ -+ // -+ // PREFETCH -+ // -+ -+ // Clear the rmem tiles to account for predicated off loads -+ clear(tArA); -+ clear(tBrB); -+ -+ // Start async loads for 0th k-tile, where we take care of the k residue -+ { -+ Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < size<2>(tArA); ++k) { -+ if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) -+ copy_if(gmem_tiled_copy_a, tApA(_,k), tAgAk(_,_,k), tArA(_,_,k)); -+ } -+ } -+ Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < size<2>(tBrB); ++k) { -+ if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) -+ copy_if(gmem_tiled_copy_b, tBpB(_,k), tBgBk(_,_,k), tBrB(_,_,k)); -+ } -+ } -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // Tile MMA compute thread partitions and allocate accumulators -+ TiledMma tiled_mma; -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ Tensor tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA)); // (MMA,MMA_M,MMA_K) -+ Tensor tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB)); // (MMA,MMA_M,MMA_K) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K -+ -+ // -+ // Copy Atom retiling -+ // -+ -+ auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx); -+ Tensor tCsA = thr_copy_A.partition_S(sA); -+ Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M -+ -+ auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx); -+ Tensor tCsB = thr_copy_B.partition_S(sB); -+ Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N -+ -+ // -+ // Prologue -+ // -+ -+ // Copy rmem to smem -+ copy(tArA, tAsA); -+ copy(tBrB, tBsB); -+ // Clear accumulators -+ __syncthreads(); -+ -+ // Load A, B smem->rmem for k=0 -+ copy(tCsA(_,_,0), tCrA_copy_view(_,_,0)); -+ copy(tCsB(_,_,0), tCrB_copy_view(_,_,0)); -+ // -+ // Mainloop -+ // -+ -+ // Size of the k-tiles's outer product mode (k) -+ auto K_BLOCK_MAX = size<2>(tCrA); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ while (k_tile_count > -1) -+ { -+ // Pipeline the outer products with a static for loop -+ for_each(make_int_sequence{}, [&] (auto k_block) -+ { -+ if (k_block == K_BLOCK_MAX - 1) -+ { -+ __syncthreads(); -+ -+ // Copy rmem to smem -+ copy(tArA, tAsA); -+ copy(tBrB, tBsB); -+ __syncthreads(); -+ } -+ -+ // Load A, B smem->rmem for k+1 -+ int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static -+ copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); -+ copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); -+ if (k_block == 0) -+ { -+ if (k_tile_count <= 0) { -+ clear(tApA); -+ clear(tBpB); -+ } -+ copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tArA); -+ copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBrB); -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // transform before compute -+ cute::transform(tCrA(_,_,k_block), TransformA{}); -+ cute::transform(tCrB(_,_,k_block), TransformB{}); -+ -+ // Thread-level register gemm for k -+ // disambiguate gemm (shared with the namespace name) -+ cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); -+ }); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp -new file mode 100644 -index 0000000..6ba6ccc ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp -@@ -0,0 +1,680 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+#include "cute/algorithm/functional.hpp" -+#include "cute/atom/mma_atom.hpp" -+#include "cute/algorithm/gemm.hpp" -+#include "cute/tensor_predicate.hpp" -+#include "cute/numeric/arithmetic_tuple.hpp" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::collective { -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int Stages, -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm80CpAsyncUnpredicated, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm80CpAsyncUnpredicated; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; -+ using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ -+ static_assert(DispatchPolicy::Stages >= 2, "CpAsync mainloop must have at least 2 stages in the pipeline."); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_a; -+ cute::array_aligned> smem_b; -+ }; -+ -+ struct Params { -+ ElementA const* ptr_A; -+ StrideA dA; -+ ElementB const* ptr_B; -+ StrideB dB; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CollectiveMma() = default; -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.ptr_A, args.dA, args.ptr_B, args.dB}; -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ template < -+ class FrgTensorD, -+ class TensorA, -+ class TensorB, -+ class FrgTensorC, -+ class KTileIterator, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ FrgTensorD &accum, -+ TensorA gA, -+ TensorB gB, -+ FrgTensorC const &src_accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char *smem_buf) -+ { -+ using namespace cute; -+ -+ static_assert(is_rmem::value, "D tensor must be rmem resident."); -+ static_assert(is_gmem::value, "A tensor must be gmem resident."); -+ static_assert(is_gmem::value, "B tensor must be gmem resident."); -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutA{}) == 3, -+ "MainloopSm80CpAsync must have a pipeline mode in the smem layout."); -+ static_assert(rank(SmemLayoutB{}) == 3, -+ "MainloopSm80CpAsync must have a pipeline mode in the smem layout."); -+ -+ // Construct shared memory tiles -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ CUTE_STATIC_ASSERT_V(size<0>(gA) == size<0>(sA)); // BLK_M -+ CUTE_STATIC_ASSERT_V(size<1>(gA) == size<1>(sA)); // BLK_K -+ CUTE_STATIC_ASSERT_V(size<0>(gB) == size<0>(sB)); // BLK_N -+ CUTE_STATIC_ASSERT_V(size<1>(gB) == size<1>(sB)); // BLK_K -+ CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // BLK_K -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE -+ -+ // Partition the copying of A and B tiles across the threads -+ GmemTiledCopyA gmem_tiled_copy_A; -+ GmemTiledCopyB gmem_tiled_copy_B; -+ auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); -+ auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); -+ -+ Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) -+ Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) -+ Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) -+ Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) -+ -+ // -+ // PREDICATES -+ // -+ -+ (void) residue_mnk; -+ //assert(residue_mnk == make_tuple(0,0,0)); -+ -+ // -+ // PREFETCH -+ // -+ -+ // Start async loads for all pipes but the last -+ CUTLASS_PRAGMA_UNROLL -+ for (int k_pipe = 0; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { -+ copy(gmem_tiled_copy_A, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); -+ copy(gmem_tiled_copy_B, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); -+ cp_async_fence(); -+ --k_tile_count; -+ if (k_tile_count > 0) { ++k_tile_iter; } -+ } -+ -+ // -+ // MMA Atom partitioning -+ // -+ -+ // Tile MMA compute thread partitions and allocate accumulators -+ TiledMma tiled_mma; -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) -+ Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K -+ CUTE_STATIC_ASSERT_V(size(gmem_tiled_copy_A) == size(tiled_mma)); -+ CUTE_STATIC_ASSERT_V(size(gmem_tiled_copy_B) == size(tiled_mma)); -+ -+ // -+ // Copy Atom retiling -+ // -+ -+ auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); -+ auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); -+ Tensor tCsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,PIPE) -+ Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M -+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K -+ -+ auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); -+ auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); -+ Tensor tCsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,PIPE) -+ Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N -+ CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K -+ -+ // -+ // PIPELINED MAIN LOOP -+ // -+ -+ // Current pipe index in smem to read from -+ int smem_pipe_read = 0; -+ // Current pipe index in smem to write to -+ int smem_pipe_write = DispatchPolicy::Stages-1; -+ -+ Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); -+ Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); -+ -+ // Size of the register pipeline -+ auto K_BLOCK_MAX = size<2>(tCrA); -+ -+ // PREFETCH register pipeline -+ if (K_BLOCK_MAX > 1) { -+ // Wait until our first prefetched tile is loaded in -+ cp_async_wait(); -+ __syncthreads(); -+ -+ // Prefetch the first rmem from the first k-tile -+ copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); -+ copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); -+ } -+ -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) -+ { -+ // Pipeline the outer products with a static for loop. -+ // -+ // Note, the for_each() function is required here to ensure `k_block` is of type Int. -+ for_each(make_int_sequence{}, [&] (auto k_block) -+ { -+ if (k_block == K_BLOCK_MAX - 1) -+ { -+ // Slice the smem_pipe_read smem -+ tCsA_p = tCsA(_,_,_,smem_pipe_read); -+ tCsB_p = tCsB(_,_,_,smem_pipe_read); -+ -+ // Commit the smem for smem_pipe_read -+ cp_async_wait(); -+ __syncthreads(); -+ } -+ -+ // Load A, B shmem->regs for k_block+1 -+ auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static -+ copy(smem_tiled_copy_A, tCsA_p(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); -+ copy(smem_tiled_copy_B, tCsB_p(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); -+ // Copy gmem to smem before computing gemm on each k-pipe -+ if (k_block == 0) -+ { -+ copy(gmem_tiled_copy_A, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); -+ copy(gmem_tiled_copy_B, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); -+ cp_async_fence(); -+ if (k_tile_count > 0) { ++k_tile_iter; } -+ -+ // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) -+ smem_pipe_write = smem_pipe_read; -+ ++smem_pipe_read; -+ smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? 0 : smem_pipe_read; -+ } -+ -+ // Transform before compute -+ cute::transform(tCrA(_,_,k_block), TransformA{}); -+ cute::transform(tCrB(_,_,k_block), TransformB{}); -+ // Thread-level register gemm for k_block -+ cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); -+ }); -+ -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int Stages, -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm80CpAsync, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm80CpAsync; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ -+ static_assert(DispatchPolicy::Stages >= 2, "CpAsync mainloop must have at least 2 stages in the pipeline."); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_a; -+ cute::array_aligned> smem_b; -+ }; -+ -+ struct Params { -+ ElementA const* ptr_A; -+ StrideA dA; -+ ElementB const* ptr_B; -+ StrideB dB; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CollectiveMma() = default; -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.ptr_A, args.dA, args.ptr_B, args.dB}; -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ template < -+ class FrgTensorD, -+ class TensorA, -+ class TensorB, -+ class FrgTensorC, -+ class KTileIterator, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ FrgTensorD &accum, -+ TensorA gA, // (BLK_M, BLK_K, K_TILES) -+ TensorB gB, // (BLK_N, BLK_K, K_TILES) -+ FrgTensorC const &src_accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char *smem_buf) -+ { -+ using namespace cute; -+ -+ static_assert(is_rmem::value, "D tensor must be rmem resident."); -+ static_assert(is_gmem::value, "A tensor must be gmem resident."); -+ static_assert(is_gmem::value, "B tensor must be gmem resident."); -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); -+ static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); -+ -+ // Construct shared memory tiles -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ CUTE_STATIC_ASSERT_V(size<0>(gA) == size<0>(sA)); // BLK_M -+ CUTE_STATIC_ASSERT_V(size<1>(gA) == size<1>(sA)); // BLK_K -+ CUTE_STATIC_ASSERT_V(size<0>(gB) == size<0>(sB)); // BLK_N -+ CUTE_STATIC_ASSERT_V(size<1>(gB) == size<1>(sB)); // BLK_K -+ CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // BLK_K -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE -+ -+ // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) -+ // This aligns the tensor with BLK_K for all but the 0th k_tile -+ gA.data() = &gA(0, get<2>(residue_mnk), 0); -+ gB.data() = &gB(0, get<2>(residue_mnk), 0); -+ -+ // Partition the copying of A and B tiles across the threads -+ GmemTiledCopyA gmem_tiled_copy_A; -+ GmemTiledCopyB gmem_tiled_copy_B; -+ auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); -+ auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); -+ -+ Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) -+ Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) -+ Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) -+ Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) -+ -+ // -+ // PREDICATES -+ // -+ -+ // Allocate predicate tensors for m and n -+ Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); -+ Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); -+ -+ // Construct identity layout for sA and sB -+ Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) -+ Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) -+ -+ // Repeat the partitioning with identity layouts -+ Tensor tAcA = gmem_thr_copy_A.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) -+ Tensor tBcB = gmem_thr_copy_B.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) -+ -+ // Set predicates for m bounds -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < size<0>(tApA); ++m) { -+ tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m -+ } -+ // Set predicates for n bounds -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < size<0>(tBpB); ++n) { -+ tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n -+ } -+ -+ // -+ // PREFETCH -+ // -+ -+ // Clear the smem tiles to account for predicated off loads -+ clear(tAsA); -+ clear(tBsB); -+ -+ // Start async loads for 0th k-tile, where we take care of the k residue -+ { -+ constexpr int k_pipe = 0; -+ -+ Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < size<2>(tAsA); ++k) { -+ if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) -+ copy_if(gmem_tiled_copy_A, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,k_pipe)); -+ } -+ } -+ Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < size<2>(tBsB); ++k) { -+ if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) -+ copy_if(gmem_tiled_copy_B, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,k_pipe)); -+ } -+ } -+ cp_async_fence(); -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // Start async loads for 1st k-tile onwards, no k-residue handling needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int k_pipe = 1; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { -+ if (k_tile_count <= 0) { -+ clear(tApA); -+ clear(tBpB); -+ } -+ copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); // CpAsync -+ copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); // CpAsync -+ cp_async_fence(); -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // -+ // MMA Atom partitioning -+ // -+ -+ // Tile MMA compute thread partitions and allocate accumulators -+ TiledMma tiled_mma; -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) -+ Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K -+ -+ // -+ // Copy Atom retiling -+ // -+ -+ auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); -+ auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); -+ Tensor tCsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,PIPE) -+ Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M -+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K -+ -+ auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); -+ auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); -+ Tensor tCsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,PIPE) -+ Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N -+ CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K -+ -+ // -+ // PIPELINED MAIN LOOP -+ // -+ -+ // Current pipe index in smem to read from -+ int smem_pipe_read = 0; -+ // Current pipe index in smem to write to -+ int smem_pipe_write = DispatchPolicy::Stages-1; -+ -+ Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); -+ Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); -+ -+ // Size of the register pipeline -+ auto K_BLOCK_MAX = size<2>(tCrA); -+ -+ // PREFETCH register pipeline -+ if (K_BLOCK_MAX > 1) { -+ // Wait until our first prefetched tile is loaded in -+ cp_async_wait(); -+ __syncthreads(); -+ -+ // Prefetch the first rmem from the first k-tile -+ copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); -+ copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); -+ } -+ -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) -+ { -+ // Pipeline the outer products with a static for loop. -+ // -+ // Note, the for_each() function is required here to ensure `k_block` is of type Int. -+ for_each(make_int_sequence{}, [&] (auto k_block) -+ { -+ if (k_block == K_BLOCK_MAX - 1) -+ { -+ // Slice the smem_pipe_read smem -+ tCsA_p = tCsA(_,_,_,smem_pipe_read); -+ tCsB_p = tCsB(_,_,_,smem_pipe_read); -+ -+ // Commit the smem for smem_pipe_read -+ cp_async_wait(); -+ __syncthreads(); -+ } -+ -+ // Load A, B shmem->regs for k_block+1 -+ auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static -+ copy(smem_tiled_copy_A, tCsA_p(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); -+ copy(smem_tiled_copy_B, tCsB_p(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); -+ // Copy gmem to smem before computing gemm on each k-pipe -+ if (k_block == 0) -+ { -+ // Set all predicates to false if we are going to overshoot bounds -+ if (k_tile_count <= 0) { -+ clear(tApA); -+ clear(tBpB); -+ } -+ copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); -+ copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); -+ cp_async_fence(); -+ ++k_tile_iter; -+ -+ // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) -+ smem_pipe_write = smem_pipe_read; -+ ++smem_pipe_read; -+ smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? 0 : smem_pipe_read; -+ } -+ -+ // Transform before compute -+ cute::transform(tCrA(_,_,k_block), TransformA{}); -+ cute::transform(tCrB(_,_,k_block), TransformB{}); -+ // Thread-level register gemm for k_block -+ cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); -+ }); -+ -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp -new file mode 100644 -index 0000000..3b1921b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp -@@ -0,0 +1,596 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+#include "cutlass/pipeline.hpp" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/reg_reconfig.h" -+ -+#include "cute/arch/copy_sm90.hpp" -+#include "cute/atom/mma_atom.hpp" -+#include "cute/algorithm/gemm.hpp" -+ -+#include -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::collective { -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int Stages, -+ class ClusterShape, -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm90CpAsyncGmmaUnpredicated, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm90CpAsyncGmmaUnpredicated; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; -+ using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ -+ static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); -+ static_assert(std::is_base_of::value && -+ std::is_base_of::value, -+ "MMA atom must source both A and B operand from smem_desc for this mainloop."); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_a; -+ cute::array_aligned> smem_b; -+ }; -+ -+ struct Params { -+ ElementA const* ptr_A; -+ StrideA dA; -+ ElementB const* ptr_B; -+ StrideB dB; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CollectiveMma() = default; -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.ptr_A, args.dA, args.ptr_B, args.dB}; -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ template < -+ class TensorA, -+ class TensorB, -+ class FrgTensorC, -+ class KTileIterator, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ TensorA gA, -+ TensorB gB, -+ FrgTensorC& accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char *smem_buf, -+ Params const& mainloop_params) -+ { -+ using namespace cute; -+ -+ (void) residue_mnk; -+ -+ static_assert(is_gmem::value, "A tensor must be gmem resident."); -+ static_assert(is_gmem::value, "B tensor must be gmem resident."); -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); -+ static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); -+ static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); -+ static_assert(std::is_same::value, -+ "SM90 warpgroup MMA must specify transforms through MMA_Atom."); -+ static_assert(std::is_same::value, -+ "SM90 warpgroup MMA must specify transforms through MMA_Atom."); -+ static_assert(std::is_same::value, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ static_assert(std::is_same::value, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // Partition the copying of A and B tiles across the threads -+ GmemTiledCopyA gmem_tiled_copy_a; -+ GmemTiledCopyB gmem_tiled_copy_b; -+ auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); -+ auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); -+ -+ Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) -+ Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) -+ Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) -+ Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) -+ -+ // Tile MMA atom and compute thread partitions across A, B and C -+ TiledMma tiled_mma; -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ -+ // Allocate registers for pipelining -+ Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) -+ Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) -+ -+ Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE) -+ Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N -+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tAsA)); // PIPE -+ CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tBsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE -+ -+ // -+ // Prologue -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k_pipe = 0; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { -+ copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); -+ copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); -+ cp_async_fence(); -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // Current pipe index in smem to read from -+ int smem_pipe_read = 0; -+ // Current pipe index in smem to write to -+ int smem_pipe_write = DispatchPolicy::Stages-1; -+ -+ // -+ // Pipelined Main Loop -+ // -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) -+ { -+ // Copy gmem to smem before computing gemm on each k-pipe -+ // pipe index in smem where the next gmem tile will be read into -+ copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); -+ copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); -+ cp_async_fence(); -+ if (k_tile_count > 0) { ++k_tile_iter; } -+ -+ // -+ // Compute on k_tile -+ // -+ warpgroup_fence_operand(accum); -+ warpgroup_arrive(); -+ -+ cp_async_wait(); -+ cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read), tCrB(_,_,_,smem_pipe_read), accum); -+ warpgroup_commit_batch(); -+ -+ // -+ // Advance the pipe -+ // -+ ++smem_pipe_read; -+ smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? smem_pipe_read = 0 : smem_pipe_read; -+ -+ ++smem_pipe_write; -+ smem_pipe_write = (smem_pipe_write == DispatchPolicy::Stages) ? smem_pipe_write = 0 : smem_pipe_write; -+ -+ // Wait for the pipeline MMAs to drain -+ warpgroup_wait<0>(); -+ warpgroup_fence_operand(accum); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int Stages, -+ class ClusterShape, -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm90CpAsyncGmma, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm90CpAsyncGmma; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ -+ static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); -+ static_assert(std::is_base_of::value && -+ std::is_base_of::value, -+ "MMA atom must source both A and B operand from smem_desc for this mainloop."); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_a; -+ cute::array_aligned> smem_b; -+ }; -+ -+ struct Params { -+ ElementA const* ptr_A; -+ StrideA dA; -+ ElementB const* ptr_B; -+ StrideB dB; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CollectiveMma() = default; -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.ptr_A, args.dA, args.ptr_B, args.dB}; -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ template < -+ class FrgTensorD, -+ class TensorA, -+ class TensorB, -+ class FrgTensorC, -+ class KTileIterator, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ FrgTensorD &accum, -+ TensorA gA, -+ TensorB gB, -+ FrgTensorC const &src_accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char *smem_buf) -+ { -+ using namespace cute; -+ -+ static_assert(is_rmem::value, "D tensor must be rmem resident."); -+ static_assert(is_gmem::value, "A tensor must be gmem resident."); -+ static_assert(is_gmem::value, "B tensor must be gmem resident."); -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); -+ static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); -+ static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); -+ static_assert(std::is_same::value, -+ "SM90 warpgroup MMA must specify transforms through MMA_Atom."); -+ static_assert(std::is_same::value, -+ "SM90 warpgroup MMA must specify transforms through MMA_Atom."); -+ static_assert(std::is_same::value, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ static_assert(std::is_same::value, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) -+ // This aligns the tensor with BLK_K for all but the 0th k_tile -+ gA.data() = &gA(0, get<2>(residue_mnk), 0); -+ gB.data() = &gB(0, get<2>(residue_mnk), 0); -+ -+ // Partition the copying of A and B tiles across the threads -+ GmemTiledCopyA gmem_tiled_copy_a; -+ GmemTiledCopyB gmem_tiled_copy_b; -+ auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); -+ auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); -+ -+ Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) -+ Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) -+ Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) -+ Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) -+ -+ // -+ // PREDICATES -+ // -+ -+ // Allocate predicate tensors for m and n -+ Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); -+ Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); -+ -+ // Construct identity layout for sA and sB -+ Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) -+ Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) -+ -+ // Repeat the partitioning with identity layouts -+ Tensor tAcA = gmem_thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) -+ Tensor tBcB = gmem_thr_copy_b.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) -+ -+ // Set predicates for m bounds -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < size<0>(tApA); ++m) { -+ tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m -+ } -+ // Set predicates for n bounds -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < size<0>(tBpB); ++n) { -+ tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n -+ } -+ -+ // -+ // Prologue/PREFETCH -+ // -+ -+ // Clear the smem tiles to account for predicated off loads -+ clear(tAsA); -+ clear(tBsB); -+ -+ // Start async loads for 0th k-tile, where we take care of the k residue -+ { -+ constexpr int k_pipe = 0; -+ -+ Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < size<2>(tAsA); ++k) { -+ if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) -+ copy_if(gmem_tiled_copy_a, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,k_pipe)); -+ } -+ } -+ Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < size<2>(tBsB); ++k) { -+ if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) -+ copy_if(gmem_tiled_copy_b, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,k_pipe)); -+ } -+ } -+ cp_async_fence(); -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // Start async loads for 1st k-tile onwards, no k-residue handling needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int k_pipe = 1; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { -+ if (k_tile_count <= 0) { -+ clear(tApA); -+ clear(tBpB); -+ } -+ copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); // CpAsync -+ copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); // CpAsync -+ cp_async_fence(); -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // -+ // MMA Atom partitioning -+ // -+ -+ // Tile MMA atom and compute thread partitions across A, B and C -+ TiledMma tiled_mma; -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ -+ // Allocate registers for pipelining -+ Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) -+ Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) -+ -+ Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE) -+ Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(src_accum)); // M -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(src_accum)); // N -+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tAsA)); // PIPE -+ CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tBsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE -+ -+ // Current pipe index in smem to read from -+ int smem_pipe_read = 0; -+ // Current pipe index in smem to write to -+ int smem_pipe_write = DispatchPolicy::Stages-1; -+ -+ // -+ // Pipelined Main Loop -+ // -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) -+ { -+ // -+ // Copy gmem to smem for *k_tile_iter -+ // -+ if (k_tile_count <= 0) { -+ clear(tApA); -+ clear(tBpB); -+ } -+ copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); // CpAsync -+ copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); // CpAsync -+ cp_async_fence(); -+ ++k_tile_iter; -+ -+ // -+ // Compute on k_tile -+ // -+ warpgroup_fence_operand(accum); -+ warpgroup_arrive(); -+ -+ cp_async_wait(); -+ cute::gemm(tiled_mma, accum, tCrA(_,_,_,smem_pipe_read), tCrB(_,_,_,smem_pipe_read), src_accum); -+ warpgroup_commit_batch(); -+ -+ // -+ // Advance the pipe -+ // -+ ++smem_pipe_read; -+ smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? smem_pipe_read = 0 : smem_pipe_read; -+ -+ ++smem_pipe_write; -+ smem_pipe_write = (smem_pipe_write == DispatchPolicy::Stages) ? smem_pipe_write = 0 : smem_pipe_write; -+ -+ // Wait for the pipeline MMAs to drain -+ warpgroup_wait<0>(); -+ warpgroup_fence_operand(accum); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp -new file mode 100644 -index 0000000..25eaffb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp -@@ -0,0 +1,480 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cute/arch/copy_sm90.hpp" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+#include "cute/algorithm/functional.hpp" -+#include "cute/atom/mma_atom.hpp" -+#include "cute/algorithm/gemm.hpp" -+#include "cute/tensor_predicate.hpp" -+#include "cute/numeric/arithmetic_tuple.hpp" -+#include "cutlass/pipeline.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::collective { -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int Stages, -+ class ClusterShape, -+ int PipelineAsyncMmaStages, -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm90TmaGmma, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm90TmaGmma; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; -+ using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ using MainloopPipeline = cutlass::PipelineTmaAsync< -+ DispatchPolicy::Stages, -+ typename DispatchPolicy::ClusterShape>; -+ -+ using PipelineParams = typename MainloopPipeline::Params; -+ using PipelineState = typename cutlass::PipelineState; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ // Tile along K mode first before tiling over MN. PIPE mode last as usual. -+ // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), -+ Step<_2,_1,_3>{})); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), -+ Step<_2,_1,_3>{})); -+ -+ static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); -+ static_assert(std::is_base_of::value && -+ std::is_base_of::value, -+ "MMA atom must source both A and B operand from smem_desc for this mainloop."); -+ static_assert(std::is_same_v || std::is_same_v, -+ "GmemTiledCopy - invalid SM90 TMA copy atom specified."); -+ static_assert(std::is_same_v || std::is_same_v, -+ "GmemTiledCopy - invalid SM90 TMA copy atom specified."); -+ -+ // TMA converts f32 input to tf32 when copying from GMEM to SMEM -+ // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. -+ static constexpr bool ConvertF32toTF32A = std::is_same_v; -+ static constexpr bool ConvertF32toTF32B = std::is_same_v; -+ using InternalElementA = std::conditional_t>>; -+ using InternalElementB = std::conditional_t>>; -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_A; -+ cute::array_aligned> smem_B; -+ -+ using PipelineStorage = typename MainloopPipeline::SharedStorage; -+ alignas(16) PipelineStorage pipeline_storage; -+ }; -+ -+ struct Params { -+ InternalElementA const* ptr_A; -+ StrideA dA; -+ InternalElementB const* ptr_B; -+ StrideB dB; -+ // Assumption: StrideA is congruent with Problem_MK -+ using TMA_A = decltype(make_tma_copy( -+ GmemTiledCopyA{}, -+ make_tensor(ptr_A, repeat_like(StrideA{}, int32_t(0)), dA), -+ SmemLayoutA{}(_,_,0), -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), -+ size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any -+ // Assumption: StrideB is congruent with Problem_NK -+ using TMA_B = decltype(make_tma_copy( -+ GmemTiledCopyB{}, -+ make_tensor(ptr_B, repeat_like(StrideB{}, int32_t(0)), dB), -+ SmemLayoutB{}(_,_,0), -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), -+ size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any -+ TMA_A tma_load_a; -+ TMA_B tma_load_b; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ // Optionally append _1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) -+ auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); -+ auto M = get<0>(problem_shape_MNKL); -+ auto N = get<1>(problem_shape_MNKL); -+ auto K = get<2>(problem_shape_MNKL); -+ auto L = get<3>(problem_shape_MNKL); -+ -+ auto reinterpreted_ptr_A = reinterpret_cast(args.ptr_A); -+ auto reinterpreted_ptr_B = reinterpret_cast(args.ptr_B); -+ -+ Tensor tensor_a = make_tensor(reinterpreted_ptr_A, make_layout(make_shape(M,K,L), args.dA)); -+ Tensor tensor_b = make_tensor(reinterpreted_ptr_B, make_layout(make_shape(N,K,L), args.dB)); -+ typename Params::TMA_A tma_load_a = make_tma_copy( -+ GmemTiledCopyA{}, -+ tensor_a, -+ SmemLayoutA{}(_,_,cute::Int<0>{}), -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), -+ size<1>(ClusterShape{})); // mcast along N mode for this M load, if any -+ typename Params::TMA_B tma_load_b = make_tma_copy( -+ GmemTiledCopyB{}, -+ tensor_b, -+ SmemLayoutB{}(_,_,cute::Int<0>{}), -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), -+ size<0>(ClusterShape{})); // mcast along M mode for this N load, if any -+ return { -+ reinterpreted_ptr_A, -+ args.dA, -+ reinterpreted_ptr_B, -+ args.dB, -+ tma_load_a, -+ tma_load_b -+ }; -+ } -+ -+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance -+ CUTLASS_DEVICE -+ static void prefetch_tma_descriptors(Params const& mainloop_params) -+ { -+ cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); -+ cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ template < -+ class TensorA, class TMA_LOAD_A, -+ class TensorB, class TMA_LOAD_B, -+ class FrgTensorC, -+ class KTileIterator -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ TensorA const& gA, TMA_LOAD_A& tma_load_a, -+ TensorB const& gB, TMA_LOAD_B& tma_load_b, -+ FrgTensorC& accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ int thread_idx, -+ char* shared_memory, -+ Params const& mainloop_params) -+ { -+ using namespace cute; -+ -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); -+ static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); -+ static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); -+ static_assert(std::is_void_v, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ static_assert(std::is_void_v, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ -+ SharedStorage& storage = *reinterpret_cast(shared_memory); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // -+ // Prepare the TMA loads for A and B -+ // -+ dim3 cluster_local_block_id = cute::block_id_in_cluster(); -+ auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); -+ auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); -+ -+ // Applies the mapping from block_tma_a -+ Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) -+ Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) -+ -+ Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) -+ Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) -+ -+ // -+ // Prepare TMA membars and PREFETCH -+ // -+ -+ // Number of pipelined k-tiles in smem -+ constexpr int K_PIPE_MAX = DispatchPolicy::Stages; -+ -+ // NOTE: Another parameter: Partition the pipeline between active MMAs and active TMAs -+ // Tunable via the dispatch policy to tollerate latencies evenly across the math and compute stages -+ // K_PIPE_MMAS: The max number of active MMA pipes at beginning of every loop -+ // K_PIPE_TMAS: The max number of active TMA pipes at beginning of every loop (geq 1) -+ constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; -+ constexpr int K_PIPE_TMAS = K_PIPE_MAX - K_PIPE_MMAS; -+ static_assert(0 <= K_PIPE_MMAS && K_PIPE_MMAS < K_PIPE_MAX); -+ static_assert(0 < K_PIPE_TMAS && K_PIPE_TMAS <= K_PIPE_MAX); -+ -+ static_assert(K_PIPE_MMAS < K_PIPE_MAX - 1); -+ -+ // Set the bytes transferred in this TMA transaction (may involve multiple issues) -+ constexpr uint32_t TmaTransactionBytes = static_cast( -+ (size<0>(sA) * size<1>(sA) * sizeof(InternalElementA)) + -+ (size<0>(sB) * size<1>(sB) * sizeof(InternalElementB))); -+ -+ -+ // Obtain warp index -+ int warp_idx = canonical_warp_idx(); -+ int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; -+ -+ PipelineParams params; -+ params.transaction_bytes = TmaTransactionBytes; -+ params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; -+ params.is_leader = warp_group_thread_idx == 0; -+ params.num_consumers = NumThreadsPerWarpGroup; -+ -+ MainloopPipeline pipeline( -+ storage.pipeline_storage, -+ params); -+ -+ // State variables used for iterating the circular buffer -+ // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA -+ // smem_pipe_write is used by the producer of SMEM data - i.e TMA -+ PipelineState smem_pipe_read; -+ PipelineState smem_pipe_release; -+ PipelineState smem_pipe_write = cutlass::make_producer_start_state(); -+ -+ // We need this to guarantee that the Pipeline init is visible -+ // To all producers and consumer blocks in the Cluster -+ if constexpr (size(ClusterShape{}) > 1) { -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+ } -+ else { -+ __syncthreads(); -+ } -+ -+ // Set predicate for the lowest lane_id in the warp -+ int lane_predicate = cute::elect_one_sync(); -+ -+ uint16_t mcast_mask_a = 0; -+ uint16_t mcast_mask_b = 0; -+ // Keep a copy to know when to stop issuing loads -+ int k_tile_count_tma = k_tile_count; -+ -+ // Issue TmaLoads (Prologue fetches) -+ if (warp_idx == 0 && lane_predicate == 1) { -+ // Maps the tile -> block, value -+ if constexpr (std::is_same_v) { -+ auto block_layout = Layout{}; // (m,n) -> block_id -+ for (int n = 0; n < size<1>(block_layout); ++n) { -+ mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); -+ } -+ } -+ -+ if constexpr (std::is_same_v) { -+ auto block_layout = Layout{}; // (m,n) -> block_id -+ for (int m = 0; m < size<0>(block_layout); ++m) { -+ mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); -+ } -+ } -+ -+ // Issue the prologue loads -+ int prologue_tma_count = min(K_PIPE_MAX, k_tile_count); -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < prologue_tma_count; ++stage) { -+ pipeline.producer_acquire(smem_pipe_write); -+ using BarrierType = typename MainloopPipeline::ValueType; -+ BarrierType* tma_barrier = pipeline.producer_get_barrier(stage); -+ -+ copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,stage)); -+ copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,stage)); -+ ++k_tile_iter; -+ ++smem_pipe_write; -+ } -+ k_tile_count_tma -= prologue_tma_count; -+ } -+ -+ // -+ // Define C accumulators and A/B partitioning -+ // -+ -+ TiledMma tiled_mma; -+ auto thread_mma = tiled_mma.get_thread_slice(thread_idx); -+ -+ Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) -+ Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) -+ -+ // Allocate "fragments/descriptors" -+ Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) -+ Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N -+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tAsA)); // PIPE -+ CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tBsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE -+ -+ __syncthreads(); -+ -+ warpgroup_fence_operand(accum); -+ // Prologue MMAs -+ CUTLASS_PRAGMA_UNROLL -+ for (int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); -+ prologue_mma_count > 0; --prologue_mma_count) -+ { -+ // WAIT on smem_pipe_read until it's data is available -+ pipeline.consumer_wait(smem_pipe_read); -+ warpgroup_arrive(); -+ cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read.index()), tCrB(_,_,_,smem_pipe_read.index()), accum); // (V,M,K) x (V,N,K) => (V,M,N) -+ warpgroup_commit_batch(); -+ ++smem_pipe_read; -+ --k_tile_count; -+ } -+ warpgroup_fence_operand(accum); -+ -+ // -+ // PIPELINED MAIN LOOP -+ // -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > 0; --k_tile_count) -+ { -+ // WAIT on smem_pipe_read until data is available -+ pipeline.consumer_wait(smem_pipe_read); -+ -+ // -+ // Compute on k_tile -+ // -+ -+ warpgroup_fence_operand(accum); -+ warpgroup_arrive(); -+ cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read.index()), tCrB(_,_,_,smem_pipe_read.index()), accum); // (V,M,K) x (V,N,K) => (V,M,N) -+ warpgroup_commit_batch(); -+ -+ /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed -+ warpgroup_wait(); -+ warpgroup_fence_operand(accum); -+ -+ pipeline.consumer_release(smem_pipe_release); // UNLOCK wr stage, done _computing_ on it -+ -+ // -+ // Copy gmem to smem for *k_tile_iter -+ // -+ -+ // Do Acquire & Load only if needed - helps with both performance and also corner case illegal barrier-ops -+ if (warp_idx == 0 && lane_predicate == 1 && (k_tile_count_tma > 0) ) { -+ pipeline.producer_acquire(smem_pipe_write); // LOCK wr stage, for _writing_ -+ -+ using BarrierType = typename MainloopPipeline::ValueType; -+ BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write.index()); -+ -+ copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write.index())); -+ copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write.index())); -+ ++smem_pipe_write; -+ ++k_tile_iter; -+ --k_tile_count_tma; -+ } -+ -+ // Advance consumer pipeline -+ ++smem_pipe_read; -+ ++smem_pipe_release; -+ } -+ -+ // Wait on all GMMAs -+ warpgroup_wait<0>(); -+ warpgroup_fence_operand(accum); -+ -+ // Workaround for ensuring Smem destruction doesn't happen accidentally -+ if constexpr (size(typename DispatchPolicy::ClusterShape{}) > 1) { -+ cute::cluster_arrive(); -+ cute::cluster_wait(); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp -new file mode 100644 -index 0000000..41b0f13 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp -@@ -0,0 +1,494 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cute/arch/copy_sm90.hpp" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+#include "cute/algorithm/functional.hpp" -+#include "cute/atom/mma_atom.hpp" -+#include "cute/algorithm/gemm.hpp" -+#include "cute/tensor_predicate.hpp" -+#include "cute/numeric/arithmetic_tuple.hpp" -+#include "cutlass/pipeline.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::collective { -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// WarpSpecialized Mainloop -+template < -+ int Stages, -+ class ClusterShape, -+ class KernelSchedule, -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm90TmaGmmaWarpSpecialized, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; -+ using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ using MainloopPipeline = cutlass::PipelineTmaAsync< -+ DispatchPolicy::Stages, -+ typename DispatchPolicy::ClusterShape>; -+ using PipelineState = cutlass::PipelineState; -+ -+ using PipelineParams = typename MainloopPipeline::Params; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ // Tile along K mode first before tiling over MN. PIPE mode last as usual. -+ // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), -+ Step<_2,_1,_3>{})); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), -+ Step<_2,_1,_3>{})); -+ -+ static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); -+ static_assert(std::is_base_of::value && -+ std::is_base_of::value, -+ "MMA atom must source both A and B operand from smem_desc for this mainloop."); -+ static_assert(std::is_same_v || std::is_same_v, -+ "GmemTiledCopy - invalid SM90 TMA copy atom specified."); -+ static_assert(std::is_same_v || std::is_same_v, -+ "GmemTiledCopy - invalid SM90 TMA copy atom specified."); -+ -+ // TMA converts f32 input to tf32 when copying from GMEM to SMEM -+ // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. -+ static constexpr bool ConvertF32toTF32A = std::is_same_v; -+ static constexpr bool ConvertF32toTF32B = std::is_same_v; -+ using InternalElementA = std::conditional_t>>; -+ using InternalElementB = std::conditional_t>>; -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_A; -+ cute::array_aligned> smem_B; -+ -+ using PipelineStorage = typename MainloopPipeline::SharedStorage; -+ alignas(16) PipelineStorage pipeline_storage; -+ }; -+ -+ struct Params { -+ InternalElementA const* ptr_A; -+ StrideA dA; -+ InternalElementB const* ptr_B; -+ StrideB dB; -+ // Assumption: StrideA is congruent with Problem_MK -+ using TMA_A = decltype(make_tma_copy( -+ GmemTiledCopyA{}, -+ make_tensor(ptr_A, repeat_like(StrideA{}, int32_t(0)), dA), -+ SmemLayoutA{}(_,_,0), -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), -+ size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any -+ // Assumption: StrideB is congruent with Problem_NK -+ using TMA_B = decltype(make_tma_copy( -+ GmemTiledCopyB{}, -+ make_tensor(ptr_B, repeat_like(StrideB{}, int32_t(0)), dB), -+ SmemLayoutB{}(_,_,0), -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), -+ size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any -+ TMA_A tma_load_a; -+ TMA_B tma_load_b; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ // Optionally append _1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) -+ auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); -+ auto M = get<0>(problem_shape_MNKL); -+ auto N = get<1>(problem_shape_MNKL); -+ auto K = get<2>(problem_shape_MNKL); -+ auto L = get<3>(problem_shape_MNKL); -+ -+ auto reinterpreted_ptr_A = reinterpret_cast(args.ptr_A); -+ auto reinterpreted_ptr_B = reinterpret_cast(args.ptr_B); -+ -+ Tensor tensor_a = make_tensor(reinterpreted_ptr_A, make_layout(make_shape(M,K,L), args.dA)); -+ Tensor tensor_b = make_tensor(reinterpreted_ptr_B, make_layout(make_shape(N,K,L), args.dB)); -+ typename Params::TMA_A tma_load_a = make_tma_copy( -+ GmemTiledCopyA{}, -+ tensor_a, -+ SmemLayoutA{}(_,_,cute::Int<0>{}), -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), -+ size<1>(ClusterShape{})); // mcast along N mode for this M load, if any -+ typename Params::TMA_B tma_load_b = make_tma_copy( -+ GmemTiledCopyB{}, -+ tensor_b, -+ SmemLayoutB{}(_,_,cute::Int<0>{}), -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), -+ size<0>(ClusterShape{})); // mcast along M mode for this N load, if any -+ return { -+ reinterpreted_ptr_A, -+ args.dA, -+ reinterpreted_ptr_B, -+ args.dB, -+ tma_load_a, -+ tma_load_b -+ }; -+ } -+ -+ static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; -+ static constexpr int K_PIPE_MMAS = 1; -+ static constexpr uint32_t TmaTransactionBytes = -+ (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(ElementA)))+ -+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(ElementB))); -+ -+ CUTLASS_DEVICE -+ static MainloopPipeline make_pipeline(char* shared_memory, PipelineParams params){ -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ return {shared_storage.pipeline_storage, params}; -+ } -+ -+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance -+ CUTLASS_DEVICE -+ static void prefetch_tma_descriptors(Params const& mainloop_params) -+ { -+ cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); -+ cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ /// Producer Perspective -+ template < -+ class TensorA, class TMA_LOAD_A, -+ class TensorB, class TMA_LOAD_B, -+ class KTileIterator -+ > -+ CUTLASS_DEVICE void -+ dma(MainloopPipeline pipeline, -+ PipelineState smem_pipe_write, -+ TensorA const& gA, TMA_LOAD_A& tma_load_a, -+ TensorB const& gB, TMA_LOAD_B& tma_load_b, -+ KTileIterator k_tile_iter, int k_tile_count, -+ int thread_idx, -+ char* shared_memory) -+ { -+ -+ using namespace cute; -+ int warp_idx = canonical_warp_idx(); -+ int warp_idx_in_warp_group = warp_idx % 4; -+ int lane_predicate = cute::elect_one_sync(); -+ -+ if (warp_idx_in_warp_group == 0 and lane_predicate) { -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // -+ // Prepare the TMA loads for A and B -+ // -+ -+ dim3 cluster_local_block_id = cute::block_id_in_cluster(); -+ auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); -+ auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); -+ -+ // Applies the mapping from block_tma_a -+ Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) -+ Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) -+ -+ Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) -+ Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) -+ -+ uint16_t mcast_mask_a = 0; -+ uint16_t mcast_mask_b = 0; -+ -+ // Issue TmaLoads -+ // Maps the tile -> block, value -+ if constexpr (std::is_same_v) { -+ auto block_layout = Layout{}; // (m,n) -> block_id -+ for (int n = 0; n < size<1>(block_layout); ++n) { -+ mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); -+ } -+ } -+ -+ if constexpr (std::is_same_v) { -+ auto block_layout = Layout{}; // (m,n) -> block_id -+ for (int m = 0; m < size<0>(block_layout); ++m) { -+ mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); -+ } -+ } -+ -+ // Issue the prologue loads -+ int k_tile_prologue = min(k_tile_count, K_PIPE_MAX); -+ CUTLASS_PRAGMA_UNROLL -+ for (int count = 0; count < k_tile_prologue; ++count) { -+ pipeline.producer_acquire(smem_pipe_write); -+ int write_stage = smem_pipe_write.index(); -+ using BarrierType = typename MainloopPipeline::ValueType; -+ BarrierType* tma_barrier = pipeline.producer_get_barrier(write_stage); -+ -+ copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); -+ copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); -+ ++k_tile_iter; -+ ++smem_pipe_write; -+ } -+ k_tile_count -= k_tile_prologue; -+ -+ // Mainloop -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > 0; --k_tile_count) -+ { -+ // LOCK smem_pipe_write for _writing_ -+ pipeline.producer_acquire(smem_pipe_write); -+ -+ // -+ // Copy gmem to smem for *k_tile_iter -+ // -+ -+ int write_stage = smem_pipe_write.index(); -+ using BarrierType = typename MainloopPipeline::ValueType; -+ BarrierType* tma_barrier = pipeline.producer_get_barrier(write_stage); -+ -+ copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); -+ copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); -+ ++k_tile_iter; -+ -+ // Advance smem_pipe_write -+ ++smem_pipe_write; -+ } -+ } -+ } -+ -+ /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster -+ CUTLASS_DEVICE void -+ dma_epilogue(MainloopPipeline pipeline, -+ PipelineState smem_pipe_write) -+ { -+ int warp_idx = canonical_warp_idx(); -+ int warp_idx_in_warp_group = warp_idx % 4; -+ int lane_predicate = cute::elect_one_sync(); -+ -+ // Issue the epilogue waits -+ if (warp_idx_in_warp_group == 0 and lane_predicate) { -+ /* This helps avoid early exit of blocks in Cluster -+ * Waits for all stages to either be released (all -+ * Consumer UNLOCKs), or if the stage was never used -+ * then would just be acquired since the phase was -+ * still inverted from make_producer_start_state -+ */ -+ for (int count = 0; count < K_PIPE_MAX; ++count) { -+ pipeline.producer_acquire(smem_pipe_write); -+ ++smem_pipe_write; -+ } -+ } -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ /// Consumer Perspective -+ template < -+ class FrgTensorC -+ > -+ CUTLASS_DEVICE void -+ mma(MainloopPipeline pipeline, -+ PipelineState smem_pipe_read, -+ FrgTensorC& accum, -+ int k_tile_count, -+ int thread_idx, -+ char* shared_memory, -+ Params const& mainloop_params -+ ) -+ { -+ using namespace cute; -+ -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); -+ static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); -+ static_assert(std::is_void_v, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ static_assert(std::is_void_v, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // -+ // Define C accumulators and A/B partitioning -+ // -+ -+ TiledMma tiled_mma; -+ auto thread_mma = tiled_mma.get_thread_slice(thread_idx); -+ -+ Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) -+ Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) -+ -+ // Allocate "fragments/descriptors" -+ Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) -+ Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N -+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE -+ -+ // -+ // PIPELINED MAIN LOOP -+ // -+ static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), -+ "ERROR : Incorrect number of MMAs in flight"); -+ -+ // We release buffers to producer warps(dma) with some mmas in flight -+ PipelineState smem_pipe_release = smem_pipe_read; -+ -+ // Prologue GMMAs -+ int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); -+ -+ warpgroup_fence_operand(accum); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) -+ { -+ // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) -+ pipeline.consumer_wait(smem_pipe_read); -+ -+ int read_stage = smem_pipe_read.index(); -+ warpgroup_arrive(); -+ cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); // (V,M,K) x (V,N,K) => (V,M,N) -+ warpgroup_commit_batch(); -+ -+ ++smem_pipe_read; -+ } -+ -+ warpgroup_fence_operand(accum); -+ // Mainloop GMMAs -+ k_tile_count -= prologue_mma_count; -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > 0; --k_tile_count) -+ { -+ // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) -+ pipeline.consumer_wait(smem_pipe_read); -+ -+ // -+ // Compute on k_tile -+ // -+ -+ int read_stage = smem_pipe_read.index(); -+ warpgroup_fence_operand(accum); -+ warpgroup_arrive(); -+ cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); // (V,M,K) x (V,N,K) => (V,M,N) -+ warpgroup_commit_batch(); -+ -+ /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed -+ warpgroup_wait(); -+ warpgroup_fence_operand(accum); -+ -+ pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it -+ -+ // Advance smem_pipe_read and smem_pipe_release -+ ++smem_pipe_read; -+ ++smem_pipe_release; -+ } -+ -+ // Wait on all GMMAs to complete -+ warpgroup_wait<0>(); -+ warpgroup_fence_operand(accum); -+ -+ for (int count = 0; count < prologue_mma_count; ++count) { -+ pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it -+ ++smem_pipe_release; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/base_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/device/base_grouped.h -new file mode 100644 -index 0000000..2e9398a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/base_grouped.h -@@ -0,0 +1,479 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Base device-level grouped kernel. -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+#include "cutlass/trace.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GEMM Grouped -+template -+class BaseGrouped { -+public: -+ -+ using BaseKernel = BaseKernel_; -+ -+ using ElementA = typename BaseKernel::ElementA; -+ using LayoutA = typename BaseKernel::LayoutA; -+ using TensorRefA = TensorRef; -+ static ComplexTransform const kTransformA = BaseKernel::kTransformA; -+ static int const kAlignmentA = BaseKernel::kAlignmentA; -+ -+ using ElementB = typename BaseKernel::ElementB; -+ using LayoutB = typename BaseKernel::LayoutB; -+ using TensorRefB = TensorRef; -+ static ComplexTransform const kTransformB = BaseKernel::kTransformB; -+ static int const kAlignmentB = BaseKernel::kAlignmentB; -+ -+ using ElementC = typename BaseKernel::ElementC; -+ using LayoutC = typename BaseKernel::LayoutC; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ static int const kAlignmentC = BaseKernel::kAlignmentC; -+ -+ using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; -+ -+ using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; -+ using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle; -+ -+ using Operator = typename BaseKernel::Operator; -+ using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename WarpMmaOperator::MathOperator; -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ using ThreadblockShape = typename BaseKernel::Mma::Shape; -+ using WarpShape = typename BaseKernel::WarpShape; -+ using InstructionShape = typename BaseKernel::InstructionShape; -+ static int const kStages = BaseKernel::Mma::kStages; -+ -+ /// Argument structure -+ using Arguments = typename BaseKernel::Arguments; -+ -+ using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; -+ -+protected: -+ -+ /// Kernel parameters object -+ typename BaseKernel::Params params_; -+ -+private: -+ -+ /// Get the number of tiles across all problems in a group -+ static int32_t group_tile_count(const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count) { -+ int32_t tiles = 0; -+ for (int32_t i = 0; i < problem_count; ++i) { -+ cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; -+ BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); -+ tiles += problem_tile_count(problem); -+ } -+ return tiles; -+ } -+ -+ /// Copy from `data` to `workspace` -+ Status copy_to_workspace(void* workspace, void* data, size_t bytes) { -+ cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); -+ if (cuda_error != cudaSuccess) { -+ // Call cudaGetLastError() to clear the error bit -+ cuda_error = cudaGetLastError(); -+ CUTLASS_TRACE_HOST( -+ " cudaMemcpy() returned error " -+ << cudaGetErrorString(cuda_error)); -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Precomputes scheduling information for the grouped GEMM -+ Status precompute(Arguments const &args, int32_t tile_count, void* workspace) { -+ size_t workspace_bytes = get_workspace_size(args); -+ std::vector host_workspace(workspace_bytes); -+ BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes, -+ args.problem_count, -+ args.threadblock_count, -+ (void*)host_workspace.data()); -+ return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); -+ } -+ -+ /// Reorder `data` according to `indices` -+ template -+ static void reorder_array(T* data, const std::vector& indices) { -+ // For now, simply create a copy of the data and then copy over to the original. -+ std::vector copy(indices.size()); -+ for (int i = 0; i < indices.size(); ++i) { -+ copy.at(i) = data[indices[i]]; -+ } -+ -+ memcpy(data, copy.data(), indices.size() * sizeof(T)); -+ } -+ -+public: -+ -+ /// Constructs the GEMM. -+ BaseGrouped() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return BaseKernel::can_implement(args); -+ } -+ -+ /// Get the number of tiles in a problem -+ static int32_t problem_tile_count(cutlass::gemm::GemmCoord const &problem) { -+ auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); -+ return BaseKernel::ProblemVisitor::tile_count(grid); -+ } -+ -+ /// Get the number of tiles across all problems in a group -+ static int32_t group_tile_count(Arguments const &args) { -+ if (args.host_problem_sizes == nullptr) { -+ CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); -+ return -1; -+ } -+ -+ return group_tile_count(args.host_problem_sizes, args.problem_count); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { -+ return BaseKernel::ProblemVisitor::get_workspace_size(args.host_problem_sizes, -+ args.problem_count, -+ args.threadblock_count); -+ } else { -+ return 0; -+ } -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ -+ return dim3(args.threadblock_count, 1, 1); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ -+ CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()"); -+ -+ int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); -+ -+ CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); -+ -+ cudaError_t result; -+ if (smem_size > (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ // Call cudaGetLastError() to clear the error bit -+ result = cudaGetLastError(); -+ CUTLASS_TRACE_HOST( -+ " cudaFuncSetAttribute() returned error " -+ << cudaGetErrorString(result)); -+ return -1; -+ } -+ } -+ -+ int max_active_blocks = -1; -+ result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( -+ &max_active_blocks, -+ Kernel, -+ BaseKernel::kThreadCount, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ // Call cudaGetLastError() to clear the error bit -+ result = cudaGetLastError(); -+ CUTLASS_TRACE_HOST( -+ " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " -+ << cudaGetErrorString(result)); -+ return -1; -+ } -+ -+ CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); -+ return max_active_blocks; -+ } -+ -+ /// Sorts each pointer passed in according to the indices that sort -+ /// `problem_sizes_ptr` in descending order of problem-K dimension. -+ static void sort_problems(int problem_count, -+ cutlass::gemm::GemmCoord* problem_sizes_ptr, -+ int64_t* lda_host_ptr, -+ int64_t* ldb_host_ptr, -+ int64_t* ldc_host_ptr, -+ int64_t* ldd_host_ptr, -+ int64_t* offset_A_ptr, -+ int64_t* offset_B_ptr, -+ int64_t* offset_C_ptr, -+ int64_t* offset_D_ptr) -+ { -+ std::vector indices(problem_count); -+ std::iota(indices.begin(), indices.end(), 0); -+ std::stable_sort(indices.begin(), indices.end(), -+ [&problem_sizes_ptr](size_t i, size_t j) { -+ return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); -+ }); -+ -+ reorder_array(problem_sizes_ptr, indices); -+ reorder_array(lda_host_ptr, indices); -+ reorder_array(ldb_host_ptr, indices); -+ reorder_array(ldc_host_ptr, indices); -+ reorder_array(ldd_host_ptr, indices); -+ reorder_array(offset_A_ptr, indices); -+ reorder_array(offset_B_ptr, indices); -+ reorder_array(offset_C_ptr, indices); -+ reorder_array(offset_D_ptr, indices); -+ } -+ -+ /// Computes the number of threadblocks to launch for the grouped kernel -+ static int sufficient(const cutlass::gemm::GemmCoord* problem_sizes_ptr=nullptr, -+ int problem_count=0, -+ int available_sm_count=-1) { -+ // Determine the number of blocks that would be launched to fill up a single -+ // wave on the GPU with each SM having maximum occupancy. -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ if (result != cudaSuccess) { -+ // Call cudaGetLastError() to clear the error bit -+ result = cudaGetLastError(); -+ CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " -+ << cudaGetErrorString(result)); -+ return 0; -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ if (result != cudaSuccess) { -+ // Call cudaGetLastError() to clear the error bit -+ result = cudaGetLastError(); -+ CUTLASS_TRACE_HOST(" cudaGetDeviceProperties() returned error " -+ << cudaGetErrorString(result)); -+ return 0; -+ } -+ -+ bool override_sm_count = (available_sm_count < 0 || available_sm_count > properties.multiProcessorCount); -+ if (override_sm_count) { -+ available_sm_count = properties.multiProcessorCount; -+ } -+ -+ int max_active_blocks = maximum_active_blocks(); -+ if (max_active_blocks <= 0) { -+ return 0; -+ } -+ -+ int occupancy_based_block_count = available_sm_count * max_active_blocks; -+ -+ if (problem_sizes_ptr == nullptr || problem_count == 0) { -+ return occupancy_based_block_count; -+ } -+ -+ int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); -+ -+ // If the group contains a single problem, launching the exact number of -+ // threadblocks needed to cover the problem minimizes the work performed -+ // per threadblock in finding the next tile to compute. We return total_tiles -+ // unless the user has provided the SM count. -+ if (problem_count == 1 && override_sm_count) { -+ return total_tiles; -+ } -+ -+ // Choose between the full wave of threadblocks and the tile count. If there -+ // are fewer tiles in the group than threadblocks in the full wave, only -+ // some threadblocks will be assigned tiles. Those threadblocks -+ // which are not assigned tiles still need to perform the work of iterating through -+ // problem sizes to determine that they have no work to do. This competes for cycles -+ // with those threadblocks that are assigned tiles to compute. -+ return min(total_tiles, occupancy_based_block_count); -+ } -+ -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ CUTLASS_TRACE_HOST("BaseGrouped::initialize() - workspace " -+ << workspace << ", stream: " << (stream ? "non-null" : "null")); -+ -+ // Workspace -+ size_t workspace_bytes = get_workspace_size(args); -+ -+ if (workspace_bytes && !workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { -+ int32_t tile_count = group_tile_count(args); -+ Status status = precompute(args, tile_count, workspace); -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ params_ = typename BaseKernel::Params(args, workspace, tile_count); -+ } else { -+ params_ = typename BaseKernel::Params(args, workspace); -+ } -+ -+ // Specify shared memory capacity for kernel. -+ int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ size_t workspace_bytes = get_workspace_size(args); -+ -+ if (workspace_bytes && !workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { -+ int32_t tile_count = group_tile_count(args); -+ Status status = precompute(args, tile_count, workspace); -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ params_.update(args, workspace, tile_count); -+ } else { -+ params_.update(args, workspace); -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ // -+ // Configure grid and block dimensions -+ // -+ -+ if (!params_.problem_visitor.problem_count) { -+ return Status::kSuccess; -+ } -+ -+ dim3 grid(params_.threadblock_count, 1, 1); -+ dim3 block(BaseKernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); -+ -+ // -+ // Launch kernel -+ // -+ -+ // Launch -+ cutlass::Kernel<<>>(params_); -+ -+ // -+ // Query for errors -+ // -+ cudaError_t result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ // Call cudaGetLastError() to clear the error bit -+ result = cudaGetLastError(); -+ CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Initializes and runs the kernel. -+ Status operator()( -+ Arguments const &args, -+ void *workspace, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h b/3rdparty/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h -new file mode 100644 -index 0000000..46ef274 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h -@@ -0,0 +1,818 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Definitions for GEMM structures -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename OperatorClass, -+ typename ArchTag, -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename ElementAccumulator -+> -+struct DefaultGemmConfiguration; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ArchTag, -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename ElementAccumulator> -+struct DefaultGemmConfiguration< -+ arch::OpClassSimt, -+ ArchTag, -+ ElementA, -+ ElementB, -+ ElementC, -+ ElementAccumulator> { -+ -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ using ThreadblockShape = GemmShape<128, 128, 8>; -+ using WarpShape = GemmShape<32, 64, 8>; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ArchTag, -+ typename ElementC> -+struct DefaultGemmConfiguration { -+ -+ static int const kAlignmentA = 4; -+ static int const kAlignmentB = 4; -+ using ThreadblockShape = GemmShape<128, 128, 32>; -+ using WarpShape = GemmShape<32, 64, 32>; -+ using InstructionShape = GemmShape<1, 1, 4>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 1, -+ int32_t, -+ float -+ >; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ArchTag, -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename ElementAccumulator> -+struct DefaultGemmConfiguration< -+ arch::OpClassWmmaTensorOp, -+ ArchTag, -+ ElementA, -+ ElementB, -+ ElementC, -+ ElementAccumulator> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename ElementAccumulator> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm70, -+ ElementA, -+ ElementB, -+ ElementC, -+ ElementAccumulator> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 32>; -+ using WarpShape = GemmShape<64, 64, 32>; -+ using InstructionShape = GemmShape<8, 8, 4>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename ElementAccumulator> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ElementA, -+ ElementB, -+ ElementC, -+ ElementAccumulator> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ using ThreadblockShape = GemmShape<128, 256, 32>; -+ using WarpShape = GemmShape<64, 64, 32>; -+ using InstructionShape = GemmShape<16, 8, 8>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using Operator = typename platform::conditional< -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd>::type; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ int8_t, -+ int8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<8, 8, 16>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ int8_t, -+ uint8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<8, 8, 16>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ uint8_t, -+ int8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<8, 8, 16>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ uint8_t, -+ uint8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<8, 8, 16>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ int4b_t, -+ int4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<8, 8, 32>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ int4b_t, -+ uint4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<8, 8, 32>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ uint4b_t, -+ int4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<8, 8, 32>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ uint4b_t, -+ uint4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<8, 8, 32>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ uint1b_t, -+ uint1b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 512>; -+ using WarpShape = GemmShape<64, 64, 512>; -+ using InstructionShape = GemmShape<8, 8, 128>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpXorPopc; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct DefaultGemmConfiguration { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 16>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, 128 / sizeof_bits::value, ElementAccumulator, -+ ElementAccumulator>; -+ -+ using Operator = typename platform::conditional< -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd>::type; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+template -+struct DefaultGemmConfiguration { -+ -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 16>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, 128 / sizeof_bits::value, ElementAccumulator, -+ ElementAccumulator>; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+ -+template <> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ complex, -+ complex, -+ complex, -+ complex -+ > { -+ -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ -+ using ThreadblockShape = GemmShape<64, 64, 16>; -+ using WarpShape = GemmShape<32, 32, 16>; -+ using InstructionShape = GemmShape<8, 8, 4>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ complex, 1, complex, -+ complex>; -+ -+ using Operator = arch::OpMultiplyAddComplex; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ int8_t, -+ int8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 32>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ int8_t, -+ uint8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 32>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ uint8_t, -+ int8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 32>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ uint8_t, -+ uint8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 32>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ int4b_t, -+ int4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<16, 8, 64>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ int4b_t, -+ uint4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<16, 8, 64>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ uint4b_t, -+ int4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<16, 8, 64>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ uint4b_t, -+ uint4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<16, 8, 64>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ uint1b_t, -+ uint1b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 512>; -+ using WarpShape = GemmShape<64, 64, 512>; -+ using InstructionShape = GemmShape<16, 8, 256>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct DefaultGemmConfiguration { -+ -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 4>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, 128 / sizeof_bits::value, ElementAccumulator, -+ ElementAccumulator>; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+template <> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm90, -+ complex, -+ complex, -+ complex, -+ complex -+ > { -+ -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ -+ using ThreadblockShape = GemmShape<64, 64, 16>; -+ using WarpShape = GemmShape<32, 32, 16>; -+ using InstructionShape = GemmShape<16, 8, 4>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ complex, 1, complex, -+ complex>; -+ -+ using Operator = arch::OpMultiplyAddComplex; -+}; -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/ell_gemm.h b/3rdparty/cutlass/include/cutlass/gemm/device/ell_gemm.h -new file mode 100644 -index 0000000..d8698a7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/ell_gemm.h -@@ -0,0 +1,848 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a Block-Ell sparse gemm kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/ell_gemm.h" -+ -+#include "cutlass/gemm/kernel/default_ell_gemm.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! Blocked-Ell sparse gemm device-level operator. This is an interface to efficient CUTLASS -+ Blocked-Ell kernels that may be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters onto -+ specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to Blocked-Ell problems to kernel parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ Example of a CUTLASS EllGemm operator is as follows: -+ -+ // -+ // Instantiate the CUTLASS EllGemm operator. -+ // -+ -+ cutlass::gemm::device::EllGemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ cutlass::half_t, 128 / cutlass::sizeof_bits::value, -+ float, float>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 4, // Stages -+ 128 / cutlass::sizeof_bits::value, // Alignment A -+ 128 / cutlass::sizeof_bits::value // Alignment B -+ > ellgemm_op; -+ -+ // -+ // Launch the EllGemm operation on the device -+ // -+ -+ Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format: -+ a_rows - Rows in the sparse matrix. -+ a_cols - Colums in the sparse matrix. -+ BlockedEllA - Packed matrix (ellValue matrix) that stores non-zero values in -+ consecutive blocks, whose size is (a_rows * a_ell_num_columns) -+ ell_idx - Blocked-ELL Column indices (ellColInd) matrix, whose size is -+ (a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize) -+ a_ell_blocksize - Size of the ELL-Blocks. -+ a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns) -+ B - Input dense matrix whose size is (a_cols * n) -+ C/D - Output dense matrix whose size is (a_rows * n) -+ -+ cutlass::Status status = ellgemm_op({ -+ {a_rows, n, a_cols}, // GemmCoord problem_size -+ {BlockedEllA, lda}, // TensorRef ref_BlockedEllA -+ {B, ldb}, // TensorRef ref_B, -+ {C, ldc}, // TensorRef ref_C, -+ {D, ldd}, // TensorRef ref_D, -+ ell_idx, // Blocked-ELL Column indices or ellColInd matrix (const int*) -+ a_ell_num_columns, // Columns in the Blocked-Ellpack (ellValue) matrix (int) -+ a_ell_blocksize, // Size of the ELL-Blocks (int) -+ a_ell_base, // Base index of ellColInd (int) - Zero or One -+ {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params -+ }); -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages -+ -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ -+ /// Supports split-K with serial reduction -+ bool SplitKSerial, -+ -+ /// Operation performed by GEMM -+ typename Operator, -+ -+ /// Sparse matrix is A or not -+ bool IsASparse -+ > -+ class EllGemm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassTensorOp, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm80, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ typename threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse = true -+ > -+class EllGemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ static bool const kIsASparse = IsASparse; -+ -+ /// Define the kernel -+ using GemmKernel = typename kernel::DefaultEllGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator, -+ kIsASparse -+ >::GemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ const int* ell_idx; -+ int ell_ncol; -+ int ell_blocksize; -+ int ell_base_idx; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): problem_size(0, 0, 0), split_k_slices(1) { -+ -+ } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ const int* ell_idx_, -+ int ell_ncol_, -+ int ell_blocksize_, -+ int ell_base_idx_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1 -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ ell_idx(ell_idx_), -+ ell_ncol(ell_ncol_), -+ ell_blocksize(ell_blocksize_), -+ ell_base_idx(ell_base_idx_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices) { -+ -+ } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ EllGemm() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = GemmKernel::can_implement( -+ args.problem_size, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D -+ ); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {args.ell_blocksize, -+ ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ tiled_shape.m() *= (args.ell_blocksize + ThreadblockShape::kM - 1 ) / ThreadblockShape::kM; -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ Status set(Arguments const &args, cutlass::gemm::GemmCoord const &grid_shape, void *workspace){ -+ // Initialize the Params structure -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D, -+ args.ell_idx, -+ args.ell_ncol, -+ args.ell_blocksize, -+ args.ell_base_idx, -+ args.epilogue, -+ static_cast(workspace) -+ }; -+ return Status::kSuccess; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {args.ell_blocksize, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ grid_shape.m() *= (args.ell_blocksize + ThreadblockShape::kM - 1 ) / ThreadblockShape::kM; -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return set(args, grid_shape, workspace); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ params_.ref_A.reset(args.ref_A.non_const_ref().data()); -+ params_.ref_B.reset(args.ref_B.non_const_ref().data()); -+ params_.ref_C.reset(args.ref_C.non_const_ref().data()); -+ params_.ref_D.reset(args.ref_D.data()); -+ params_.output_op = args.epilogue; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// If true, kernel supports split-K as a serial reduction -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Sparse matrix is A or not -+ bool IsASparse> -+class EllGemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ static bool const kSplitKSerial = SplitKSerial; -+ static bool const kIsASparse = false; -+ -+ using UnderlyingOperator = EllGemm< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA, -+ SplitKSerial, -+ Operator, -+ kIsASparse -+ >; -+ -+ using UnderlyingArguments = typename UnderlyingOperator::Arguments; -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ static int const kAlignmentC = UnderlyingOperator::kAlignmentC; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ const int* ell_idx; -+ int ell_ncol; -+ int ell_blocksize; -+ int ell_base_idx; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ const int* ell_idx_, -+ int ell_ncol_, -+ int ell_blocksize_, -+ int ell_base_idx_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1 -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ ell_idx(ell_idx_), -+ ell_ncol(ell_ncol_), -+ ell_blocksize(ell_blocksize_), -+ ell_base_idx(ell_base_idx_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices) { } -+ }; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ EllGemm() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static UnderlyingArguments to_underlying_arguments(Arguments const &args) { -+ return UnderlyingArguments( -+ {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, -+ {args.ref_B.data(), args.ref_B.stride(0)}, -+ {args.ref_A.data(), args.ref_A.stride(0)}, -+ {args.ref_C.data(), args.ref_C.stride(0)}, -+ {args.ref_D.data(), args.ref_D.stride(0)}, -+ args.ell_idx, -+ args.ell_ncol, -+ args.ell_blocksize, -+ args.ell_base_idx, -+ args.epilogue, -+ args.split_k_slices -+ ); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, args.ell_blocksize, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ tiled_shape.n() *= (args.ell_blocksize + ThreadblockShape::kN - 1 ) / ThreadblockShape::kN; -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ Status set(Arguments const &args, cutlass::gemm::GemmCoord const &grid_shape, void *workspace){ -+ // Initialize the Params structure -+ return underlying_operator_.set(to_underlying_arguments(args), grid_shape, workspace); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, -+ {ThreadblockShape::kM, args.ell_blocksize, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ grid_shape.n() *= (args.ell_blocksize + ThreadblockShape::kN - 1 ) / ThreadblockShape::kN; -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Initialize the Params structure -+ set(args, grid_shape, workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm.h -new file mode 100644 -index 0000000..68fa29b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm.h -@@ -0,0 +1,771 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm.h" -+ -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may -+ be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters onto -+ specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ The intent is to provide a convenient mechanism for interacting with most plausible GEMM -+ configurations for each supported architecture. Consequently, not all parameters are exposed -+ to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy -+ are selected to tradeoff simplicity of the interface with flexibility. We expect -+ most configurations to be specified at this level. Applications with more exotic requirements -+ may construct their kernels of interest using CUTLASS components at the threadblock, warp, -+ and thread levels of abstraction. -+ -+ CUTLASS exposes computations using the functor design pattern in which objects compose some -+ internal state with an overloaded function call operator. This enables decoupling of -+ initialization from execution, possibly reducing overhead during steady state phases of -+ application execution. -+ -+ CUTLASS device-level operators expose an Arguments structure encompassing each logical -+ input to the computation. This is distinct from the kernel-level Params structure pattern -+ which contains application-specific precomputed state needed by the device code. -+ -+ Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN -+ is as follows: -+ -+ // -+ // Instantiate the CUTLASS GEMM operator. -+ // -+ -+ cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ > gemm_op; -+ -+ // -+ // Launch the GEMM operation on the device -+ // -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, // GemmCoord problem_size, -+ {A, lda}, // TensorRef ref_A, -+ {B, ldb}, // TensorRef ref_B, -+ {C, ldc}, // TensorRef ref_C, -+ {D, ldd}, // TensorRef ref_D, -+ {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params -+ }); -+ -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages -+ > -+ class Gemm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ typename threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Gather operand A by using an index array -+ bool GatherA = false, -+ /// Gather operand B by using an index array -+ bool GatherB = false, -+ /// Scatter result D by using an index array -+ bool ScatterD = false, -+ /// Permute result D -+ typename PermuteDLayout = layout::NoPermute> -+class Gemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Define the kernel -+ using GemmKernel = typename kernel::DefaultGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator, -+ SharedMemoryClearOption::kNone, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout -+ >::GemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ // For gather+scatter operations -+ int const *gather_A_indices; -+ int const *gather_B_indices; -+ int const *scatter_D_indices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): problem_size(0, 0, 0), split_k_slices(1) { -+ -+ } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1, -+ int const *gather_A_indices_ = nullptr, -+ int const *gather_B_indices_ = nullptr, -+ int const *scatter_D_indices_ = nullptr -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices), -+ gather_A_indices(gather_A_indices_), -+ gather_B_indices(gather_B_indices_), -+ scatter_D_indices(scatter_D_indices_) { -+ -+ } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ Gemm() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = GemmKernel::can_implement( -+ args.problem_size, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D -+ ); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Initialize the Params structure -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D, -+ args.epilogue, -+ static_cast(workspace), -+ args.gather_A_indices, -+ args.gather_B_indices, -+ args.scatter_D_indices -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ params_.ref_A.reset(args.ref_A.non_const_ref().data()); -+ params_.ref_B.reset(args.ref_B.non_const_ref().data()); -+ params_.ref_C.reset(args.ref_C.non_const_ref().data()); -+ params_.ref_D.reset(args.ref_D.data()); -+ params_.output_op = args.epilogue; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// If true, kernel supports split-K as a serial reduction -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+class Gemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ using UnderlyingOperator = Gemm< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA, -+ SplitKSerial, -+ Operator, -+ GatherB, -+ GatherA, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ using UnderlyingArguments = typename UnderlyingOperator::Arguments; -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ static int const kAlignmentC = UnderlyingOperator::kAlignmentC; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ // For gather+scatter operations -+ int *gather_A_indices; -+ int *gather_B_indices; -+ int *scatter_D_indices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1, -+ int *gather_A_indices_ = nullptr, -+ int *gather_B_indices_ = nullptr, -+ int *scatter_D_indices_ = nullptr -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices), -+ gather_A_indices(gather_A_indices_), -+ gather_B_indices(gather_B_indices_), -+ scatter_D_indices(scatter_D_indices_) { } -+ }; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ Gemm() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static UnderlyingArguments to_underlying_arguments(Arguments const &args) { -+ return UnderlyingArguments( -+ {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, -+ {args.ref_B.data(), args.ref_B.stride(0)}, -+ {args.ref_A.data(), args.ref_A.stride(0)}, -+ {args.ref_C.data(), args.ref_C.stride(0)}, -+ {args.ref_D.data(), args.ref_D.stride(0)}, -+ args.epilogue, -+ args.split_k_slices, -+ args.gather_B_indices, -+ args.gather_A_indices, -+ args.scatter_D_indices -+ ); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_array.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_array.h -new file mode 100644 -index 0000000..dd244f8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_array.h -@@ -0,0 +1,737 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_array.h" -+ -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may -+ be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters onto -+ specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ The intent is to provide a convenient mechanism for interacting with most plausible GEMM -+ configurations for each supported architecture. Consequently, not all parameters are exposed -+ to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy -+ are selected to tradeoff simplicity of the interface with flexibility. We expect -+ most configurations to be specified at this level. Applications with more exotic requirements -+ may construct their kernels of interest using CUTLASS components at the threadblock, warp, -+ and thread levels of abstraction. -+ -+ CUTLASS exposes computations using the functor design pattern in which objects compose some -+ internal state with an overloaded function call operator. This enables decoupling of -+ initialization from execution, possibly reducing overhead during steady state phases of -+ application execution. -+ -+ CUTLASS device-level operators expose an Arguments structure encompassing each logical -+ input to the computation. This is distinct from the kernel-level Params structure pattern -+ which contains application-specific precomputed state needed by the device code. -+ -+ Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN -+ is as follows: -+ -+ // -+ // Instantiate the CUTLASS GEMM operator. -+ // -+ -+ cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ > gemm_op; -+ -+ // -+ // Launch the GEMM operation on the device -+ // -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, // GemmCoord problem_size, -+ {A, lda}, // TensorRef ref_A, -+ {B, ldb}, // TensorRef ref_B, -+ {C, ldc}, // TensorRef ref_C, -+ {D, ldd}, // TensorRef ref_D, -+ {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params -+ }); -+ -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages -+ > -+ class Gemm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator -+> -+class GemmArray { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ using Operator = Operator_; -+ -+ /// Define the kernel -+ using DefaultGemmKernel = typename kernel::DefaultGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ false, -+ Operator -+ >::GemmKernel; -+ -+ using GemmKernel = kernel::GemmArray; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ -+ ElementA const * const *ptr_A; -+ LayoutA layout_A; -+ -+ ElementB const * const *ptr_B; -+ LayoutB layout_B; -+ -+ ElementC const * const *ptr_C; -+ LayoutC layout_C; -+ -+ ElementC * const * ptr_D; -+ LayoutC layout_D; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ int batch_count; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ ElementA const * const *ptr_A_, -+ LayoutA layout_A_, -+ ElementB const * const *ptr_B_, -+ LayoutB layout_B_, -+ ElementC const * const *ptr_C_, -+ LayoutC layout_C_, -+ ElementC * const * ptr_D_, -+ LayoutC layout_D_, -+ typename EpilogueOutputOp::Params epilogue_, -+ int batch_count_ -+ ): -+ problem_size(problem_size_), -+ ptr_A(ptr_A_), -+ layout_A(layout_A_), -+ ptr_B(ptr_B_), -+ layout_B(layout_B_), -+ ptr_C(ptr_C_), -+ layout_C(layout_C_), -+ ptr_D(ptr_D_), -+ layout_D(layout_D_), -+ epilogue(epilogue_), -+ batch_count(batch_count_) { } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmArray() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (args.layout_A.stride(0) % kAlignmentA) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (args.layout_B.stride(0) % kAlignmentB) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (args.layout_C.stride(0) % kAlignmentC) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (args.layout_D.stride(0) % kAlignmentC) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ return 0; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ // Initialize the Params structure -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ptr_A, -+ args.layout_A, -+ args.ptr_B, -+ args.layout_B, -+ args.ptr_C, -+ args.layout_C, -+ args.ptr_D, -+ args.layout_D, -+ args.epilogue, -+ args.batch_count -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ args.batch_count, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}); -+ -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ptr_A, -+ args.layout_A, -+ args.ptr_B, -+ args.layout_B, -+ args.ptr_C, -+ args.layout_C, -+ args.ptr_D, -+ args.layout_D, -+ args.epilogue, -+ args.batch_count -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ typename Operator_ -+> -+class GemmArray< -+ ElementA_, -+ LayoutA_, -+ ElementB_, -+ LayoutB_, -+ ElementC_, -+ layout::ColumnMajor, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ AlignmentA, -+ AlignmentB, -+ Operator_ -+> { -+public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static int const kStages = Stages; -+ -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = false; -+ -+ // -+ using UnderlyingOperator = GemmArray< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA -+ >; -+ -+ using UnderlyingArguments = typename UnderlyingOperator::Arguments; -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ -+ ElementA const * const *ptr_A; -+ LayoutA layout_A; -+ -+ ElementB const * const *ptr_B; -+ LayoutB layout_B; -+ -+ ElementC const * const *ptr_C; -+ LayoutC layout_C; -+ -+ ElementC * const * ptr_D; -+ LayoutC layout_D; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ int batch_count; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ ElementA const * const *ptr_A_, -+ LayoutA layout_A_, -+ ElementB const * const *ptr_B_, -+ LayoutB layout_B_, -+ ElementC const * const *ptr_C_, -+ LayoutC layout_C_, -+ ElementC * const * ptr_D_, -+ LayoutC layout_D_, -+ typename EpilogueOutputOp::Params epilogue_, -+ int batch_count_ -+ ): -+ problem_size(problem_size_), -+ ptr_A(ptr_A_), -+ layout_A(layout_A_), -+ ptr_B(ptr_B_), -+ layout_B(layout_B_), -+ ptr_C(ptr_C_), -+ layout_C(layout_C_), -+ ptr_D(ptr_D_), -+ layout_D(layout_D_), -+ epilogue(epilogue_), -+ batch_count(batch_count_) { } -+ }; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmArray() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static UnderlyingArguments to_underlying_arguments(Arguments const &args) { -+ -+ GemmCoord problem_size{ -+ args.problem_size.n(), -+ args.problem_size.m(), -+ args.problem_size.k() -+ }; -+ -+ return UnderlyingArguments( -+ problem_size, -+ args.ptr_B, -+ args.layout_B.stride(), -+ args.ptr_A, -+ args.layout_A.stride(), -+ args.ptr_C, -+ args.layout_C.stride(), -+ args.ptr_D, -+ args.layout_D.stride(), -+ args.epilogue, -+ args.batch_count -+ ); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_batched.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_batched.h -new file mode 100644 -index 0000000..6f510e9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_batched.h -@@ -0,0 +1,703 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined batch GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_batched.h" -+ -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may -+ be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters onto -+ specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ The intent is to provide a convenient mechanism for interacting with most plausible GEMM -+ configurations for each supported architecture. Consequently, not all parameters are exposed -+ to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy -+ are selected to tradeoff simplicity of the interface with flexibility. We expect -+ most configurations to be specified at this level. Applications with more exotic requirements -+ may construct their kernels of interest using CUTLASS components at the threadblock, warp, -+ and thread levels of abstraction. -+ -+ CUTLASS exposes computations using the functor design pattern in which objects compose some -+ internal state with an overloaded function call operator. This enables decoupling of -+ initialization from execution, possibly reducing overhead during steady state phases of -+ application execution. -+ -+ CUTLASS device-level operators expose an Arguments structure encompassing each logical -+ input to the computation. This is distinct from the kernel-level Params structure pattern -+ which contains application-specific precomputed state needed by the device code. -+ -+ Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN -+ is as follows: -+ -+ // -+ // Instantiate the CUTLASS GEMM operator. -+ // -+ -+ cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ > gemm_op; -+ -+ // -+ // Launch the GEMM operation on the device -+ // -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, // GemmCoord problem_size, -+ {A, lda}, // TensorRef ref_A, -+ {B, ldb}, // TensorRef ref_B, -+ {C, ldc}, // TensorRef ref_C, -+ {D, ldd}, // TensorRef ref_D, -+ {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params -+ }); -+ -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages -+ > -+ class Gemm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator -+> -+class GemmBatched { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ using Operator = Operator_; -+ -+ /// Define the kernel -+ using DefaultGemmKernel = typename kernel::DefaultGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ false, -+ Operator -+ >::GemmKernel; -+ -+ using GemmKernel = kernel::GemmBatched; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ int64_t stride_A; -+ TensorRef ref_B; -+ int64_t stride_B; -+ TensorRef ref_C; -+ int64_t stride_C; -+ TensorRef ref_D; -+ int64_t stride_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int batch_count; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ int64_t stride_A_, -+ TensorRef ref_B_, -+ int64_t stride_B_, -+ TensorRef ref_C_, -+ int64_t stride_C_, -+ TensorRef ref_D_, -+ int64_t stride_D_, -+ typename EpilogueOutputOp::Params epilogue_, -+ int batch_count_ -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ stride_A(stride_A_), -+ ref_B(ref_B_), -+ stride_B(stride_B_), -+ ref_C(ref_C_), -+ stride_C(stride_C_), -+ ref_D(ref_D_), -+ stride_D(stride_D_), -+ epilogue(epilogue_), -+ batch_count(batch_count_) { } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmBatched() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!TensorRef_aligned(args.ref_A, kAlignmentA) || (args.stride_A % kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(args.ref_B, kAlignmentB) || (args.stride_B % kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(args.ref_C, kAlignmentC) || (args.stride_C % kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(args.ref_D, kAlignmentC) || (args.stride_D % kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ return 0; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ // Initialize the Params structure -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A.non_const_ref(), -+ args.stride_A, -+ args.ref_B.non_const_ref(), -+ args.stride_B, -+ args.ref_C.non_const_ref(), -+ args.stride_C, -+ args.ref_D, -+ args.stride_D, -+ args.epilogue, -+ args.batch_count -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ params_.ref_A.reset(args.ref_A.non_const_ref().data()); -+ params_.ref_B.reset(args.ref_B.non_const_ref().data()); -+ params_.ref_C.reset(args.ref_C.non_const_ref().data()); -+ params_.ref_D.reset(args.ref_D.data()); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ typename Operator_ -+> -+class GemmBatched< -+ ElementA_, -+ LayoutA_, -+ ElementB_, -+ LayoutB_, -+ ElementC_, -+ layout::ColumnMajor, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ AlignmentA, -+ AlignmentB, -+ Operator_ -+> { -+public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static int const kStages = Stages; -+ -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = false; -+ -+ // -+ using UnderlyingOperator = GemmBatched< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA -+ >; -+ -+ using UnderlyingArguments = typename UnderlyingOperator::Arguments; -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ int64_t stride_A; -+ TensorRef ref_B; -+ int64_t stride_B; -+ TensorRef ref_C; -+ int64_t stride_C; -+ TensorRef ref_D; -+ int64_t stride_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int batch_count; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ int64_t stride_A_, -+ TensorRef ref_B_, -+ int64_t stride_B_, -+ TensorRef ref_C_, -+ int64_t stride_C_, -+ TensorRef ref_D_, -+ int64_t stride_D_, -+ typename EpilogueOutputOp::Params epilogue_, -+ int batch_count_ -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ stride_A(stride_A_), -+ ref_B(ref_B_), -+ stride_B(stride_B_), -+ ref_C(ref_C_), -+ stride_C(stride_C_), -+ ref_D(ref_D_), -+ stride_D(stride_D_), -+ epilogue(epilogue_), -+ batch_count(batch_count_) { } -+ }; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmBatched() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static UnderlyingArguments to_underlying_arguments(Arguments const &args) { -+ return UnderlyingArguments( -+ {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, -+ {args.ref_B.data(), args.ref_B.stride(0)}, -+ args.stride_B, -+ {args.ref_A.data(), args.ref_A.stride(0)}, -+ args.stride_A, -+ {args.ref_C.data(), args.ref_C.stride(0)}, -+ args.stride_C, -+ {args.ref_D.data(), args.ref_D.stride(0)}, -+ args.stride_D, -+ args.epilogue, -+ args.batch_count -+ ); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_complex.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_complex.h -new file mode 100644 -index 0000000..5bd856f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_complex.h -@@ -0,0 +1,717 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM -+ kernels that may be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters -+ onto specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to GEMM problems to kernel -+ parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ The intent is to provide a convenient mechanism for interacting with most -+ plausible GEMM configurations for each supported architecture. Consequently, -+ not all parameters are exposed to the top-level interface. Rather, sensible -+ defaults at each level of the CUTLASS hierarchy are selected to tradeoff -+ simplicity of the interface with flexibility. We expect most configurations to -+ be specified at this level. Applications with more exotic requirements may -+ construct their kernels of interest using CUTLASS components at the -+ threadblock, warp, and thread levels of abstraction. -+ -+ CUTLASS exposes computations using the functor design pattern in which objects -+ compose some internal state with an overloaded function call operator. This -+ enables decoupling of initialization from execution, possibly reducing -+ overhead during steady state phases of application execution. -+ -+ CUTLASS device-level operators expose an Arguments structure encompassing each -+ logical input to the computation. This is distinct from the kernel-level -+ Params structure pattern which contains application-specific precomputed state -+ needed by the device code. -+ -+ Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's -+ SGEMM NN is as follows: -+ -+ // -+ // Instantiate the CUTLASS GEMM operator. -+ // -+ -+ cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ > gemm_op; -+ -+ // -+ // Launch the GEMM operation on the device -+ // -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, // GemmCoord problem_size, -+ {A, lda}, // TensorRef ref_A, -+ {B, ldb}, // TensorRef ref_B, -+ {C, ldc}, // TensorRef ref_C, -+ {D, ldd}, // TensorRef ref_D, -+ {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params -+ }); -+ -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages -+ > -+ class Gemm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for. -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Multiply-add operator -+ // (selects complex or gaussian complex) -+ typename Operator_ = arch::OpMultiplyAddComplex, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false> -+class GemmComplex { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static bool const kSplitKSerial = SplitKSerial; -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ -+ /// Define the kernel -+ using GemmKernel = typename kernel::DefaultGemmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kTransformA, -+ kTransformB, -+ Operator, -+ kSplitKSerial -+ >::GemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): problem_size(0, 0, 0), split_k_slices(1) { -+ -+ } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1 -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices) { -+ -+ } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmComplex() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return 0; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Initialize the Params structure -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D, -+ args.epilogue, -+ static_cast(workspace) -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ params_.ref_A.reset(args.ref_A.non_const_ref().data()); -+ params_.ref_B.reset(args.ref_B.non_const_ref().data()); -+ params_.ref_C.reset(args.ref_C.non_const_ref().data()); -+ params_.ref_D.reset(args.ref_D.data()); -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (selects complex or gaussian complex) -+ typename Operator_, -+ /// If true, kernel supports split-K as a serial reduction -+ bool SplitKSerial -+> -+class GemmComplex< -+ ElementA_, -+ LayoutA_, -+ ElementB_, -+ LayoutB_, -+ ElementC_, -+ layout::ColumnMajor, // partially specialized on LayoutC -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ TransformA, -+ TransformB, -+ Operator_, -+ SplitKSerial -+> { -+public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static int const kStages = Stages; -+ using Operator = Operator_; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ using UnderlyingOperator = GemmComplex< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ TransformB, -+ TransformA, -+ Operator, -+ SplitKSerial -+ >; -+ -+ static int const kAlignmentA = UnderlyingOperator::kAlignmentB; -+ static int const kAlignmentB = UnderlyingOperator::kAlignmentA; -+ static int const kAlignmentC = UnderlyingOperator::kAlignmentC; -+ static ComplexTransform const kTransformA = UnderlyingOperator::kTransformB; -+ static ComplexTransform const kTransformB = UnderlyingOperator::kTransformA; -+ -+ using UnderlyingArguments = typename UnderlyingOperator::Arguments; -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1 -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices) { } -+ }; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmComplex() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static UnderlyingArguments to_underlying_arguments(Arguments const &args) { -+ return UnderlyingArguments( -+ {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, -+ {args.ref_B.data(), args.ref_B.stride(0)}, -+ {args.ref_A.data(), args.ref_A.stride(0)}, -+ {args.ref_C.data(), args.ref_C.stride(0)}, -+ {args.ref_D.data(), args.ref_D.stride(0)}, -+ args.epilogue, -+ args.split_k_slices -+ ); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_grouped.h -new file mode 100644 -index 0000000..3e932eb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_grouped.h -@@ -0,0 +1,61 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Device-level grouped GEMM. -+*/ -+ -+#pragma once -+ -+#include "cutlass/gemm/device/base_grouped.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GEMM Grouped -+template -+class GemmGrouped : public BaseGrouped { -+public: -+ using GemmKernel = GemmKernel_; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h -new file mode 100644 -index 0000000..3ebb2a7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h -@@ -0,0 +1,385 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Device-level GEMM with layernorm elementwise operations fused in mainloop -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/device/gemm_universal_base.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! -+ The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and -+ batched array variants. -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for Scale/Bias vectors -+ typename ElementScaleBias_, -+ /// Layout type for Scale/Bias vectors -+ typename LayoutScaleBias_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator -+> -+class GemmLayernormMainloopFusion : -+ public GemmUniversalBase< -+ typename kernel::DefaultGemmLayernormMainloopFusion< -+ ElementA_, -+ LayoutA_, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ AlignmentB, -+ ElementScaleBias_, -+ LayoutScaleBias_, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_, -+ SharedMemoryClearOption::kNone -+ >::GemmKernel -+ > { -+ -+ public: -+ -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ -+ using Base = GemmUniversalBase< -+ typename kernel::DefaultGemmLayernormMainloopFusion< -+ ElementA_, -+ LayoutA_, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ AlignmentB, -+ ElementScaleBias_, -+ LayoutScaleBias_, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_, -+ SharedMemoryClearOption::kNone -+ >::GemmKernel -+ >; -+ -+ using Arguments = typename Base::Arguments; -+ using GemmKernel = typename Base::GemmKernel; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for Scale/Bias vectors -+ typename ElementScaleBias_, -+ /// Layout type for Scale/Bias vectors -+ typename LayoutScaleBias_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ -+> -+class GemmLayernormMainloopFusion { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementScaleBias = ElementScaleBias_; -+ using LayoutScaleBias = LayoutScaleBias_; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ -+ using UnderlyingOperator = typename GemmLayernormMainloopFusion< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA, -+ Operator -+ >::Base; -+ -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmLayernormMainloopFusion() { } -+ -+ /// Helper to construct a transposed equivalent for the underlying GEMM operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem(); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_sparse.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_sparse.h -new file mode 100644 -index 0000000..0366b05 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_sparse.h -@@ -0,0 +1,514 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/sparse_gemm.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_sparse.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may -+ be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters onto -+ specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ The intent is to provide a convenient mechanism for interacting with most plausible GEMM -+ configurations for each supported architecture. Consequently, not all parameters are exposed -+ to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy -+ are selected to tradeoff simplicity of the interface with flexibility. We expect -+ most configurations to be specified at this level. Applications with more exotic requirements -+ may construct their kernels of interest using CUTLASS components at the threadblock, warp, -+ and thread levels of abstraction. -+ -+ CUTLASS exposes computations using the functor design pattern in which objects compose some -+ internal state with an overloaded function call operator. This enables decoupling of -+ initialization from execution, possibly reducing overhead during steady state phases of -+ application execution. -+ -+ CUTLASS device-level operators expose an Arguments structure encompassing each logical -+ input to the computation. This is distinct from the kernel-level Params structure pattern -+ which contains application-specific precomputed state needed by the device code. -+ -+ Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN -+ is as follows: -+ -+ // -+ // Instantiate the CUTLASS GEMM operator. -+ // -+ -+ cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ > gemm_op; -+ -+ // -+ // Launch the GEMM operation on the device -+ // -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, // GemmCoord problem_size, -+ {A, lda}, // TensorRef ref_A, -+ {B, ldb}, // TensorRef ref_B, -+ {C, ldc}, // TensorRef ref_C, -+ {D, ldd}, // TensorRef ref_D, -+ {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params -+ }); -+ -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages -+ > -+ class Gemm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ typename threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator> -+class SparseGemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ using MathOperator = Operator; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Define the kernel -+ using GemmKernel = typename kernel::DefaultSparseGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator -+ >::GemmKernel; -+ -+ using ElementE = typename GemmKernel::ElementE; -+ -+ using LayoutE = typename GemmKernel::LayoutE; -+ -+ static int const kAlignmentE = 128 / sizeof_bits::value; -+ -+ static int const kSparse = GemmKernel::kSparse; -+ static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits; -+ static int const kElementsPerElementE = GemmKernel::kElementsPerElementE; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ TensorRef ref_E; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): problem_size(0, 0, 0), split_k_slices(1) { -+ -+ } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ TensorRef ref_E_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1 -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ ref_E(ref_E_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices) { -+ -+ } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ SparseGemm() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = GemmKernel::can_implement( -+ args.problem_size, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D, -+ args.ref_E.non_const_ref() -+ ); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Initialize the Params structure -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D, -+ args.ref_E.non_const_ref(), -+ args.epilogue, -+ static_cast(workspace) -+ }; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ params_.ref_A.reset(args.ref_A.non_const_ref().data()); -+ params_.ref_B.reset(args.ref_B.non_const_ref().data()); -+ params_.ref_C.reset(args.ref_C.non_const_ref().data()); -+ params_.ref_D.reset(args.ref_D.data()); -+ params_.ref_E.reset(args.ref_E.non_const_ref().data()); -+ params_.output_op = args.epilogue; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h -new file mode 100644 -index 0000000..55db955 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h -@@ -0,0 +1,638 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for GEMM performing a reduction over K partitions in parallel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_splitk_parallel.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/reduction/kernel/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/*! -+ Gemm device-level operator performing parallel reduction over the K partition. -+ -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Epilogue output operator -+ typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert< -+ ElementAccumulator_, -+ DefaultGemmConfiguration::EpilogueOutputOp::kCount, -+ ElementAccumulator_>, -+ /// Reduction operator -+ typename ReductionOp_ = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator_, typename EpilogueOutputOp_::ElementAccumulator, -+ EpilogueOutputOp_::kCount>, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ threadblock::GemmSplitKHorizontalThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator> -+class GemmSplitKParallel { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ConvertScaledOp = ConvertScaledOp_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ReductionOp = ReductionOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ -+ /// GEMM kernel -+ using GemmKernel = typename kernel::DefaultGemmSplitKParallel< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ConvertScaledOp, -+ ThreadblockSwizzle, -+ kStages, -+ Operator -+ >::GemmKernel; -+ -+ /// Reduction kernel -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ >; -+ -+ // -+ // -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ typename ConvertScaledOp::Params convert; -+ typename ReductionOp::Params reduction; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1, -+ typename ConvertScaledOp::Params convert_ = -+ typename ConvertScaledOp::Params(), -+ typename ReductionOp::Params reduction_ = -+ typename ReductionOp::Params() -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices), -+ convert(convert_), -+ reduction(reduction_) { } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params gemm_params_; -+ -+ /// Reduction kernel parameters object -+ typename ReductionKernel::Params reduction_params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmSplitKParallel() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ // TODO -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ return sizeof(ElementAccumulator_) * size_t(args.problem_size.m()) * size_t(args.problem_size.n()) * grid_shape.k(); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ // Define a reference to the workspace - this is an aligned region in device memory. -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ TensorRef ref_workspace( -+ static_cast(workspace), -+ args.problem_size.n()); -+ -+ int64_t partition_stride = int64_t(args.problem_size.m()) * int64_t(args.problem_size.n()); -+ -+ // Initialize the Params structure -+ gemm_params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ ref_workspace, -+ args.convert, -+ partition_stride -+ }; -+ -+ reduction_params_ = typename ReductionKernel::Params( -+ args.problem_size.mn(), -+ grid_shape.k(), -+ partition_stride, -+ ref_workspace, -+ args.ref_D, -+ args.ref_C.non_const_ref(), -+ args.epilogue -+ ); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ gemm_params_.ref_A.reset(args.ref_A.data()); -+ gemm_params_.ref_B.reset(args.ref_B.data()); -+ gemm_params_.ref_D.reset(workspace); -+ -+ reduction_params_.ref_D.reset(args.ref_D.data()); -+ reduction_params_.ref_C.reset(args.ref_C.data()); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ // -+ // Launch GEMM kernel -+ // -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(gemm_params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ -+ result = cudaFuncSetAttribute( -+ Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ Kernel<<>>(gemm_params_); -+ -+ result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ // -+ // Launch reduction kernel -+ // -+ -+ block = ReductionKernel::block_shape(); -+ grid = ReductionKernel::grid_shape(gemm_params_.problem_size.mn()); -+ -+ Kernel<<< grid, block, 0, stream >>>(reduction_params_); -+ -+ result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for column-major output -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Epilogue output operator -+ typename ConvertScaledOp_, -+ /// Reduction operator -+ typename ReductionOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, int kAlignmentA, int kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_> -+class GemmSplitKParallel { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ConvertScaledOp = ConvertScaledOp_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ReductionOp = ReductionOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ -+ using UnderlyingOperator = GemmSplitKParallel< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ConvertScaledOp, -+ ReductionOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentA, -+ kAlignmentB, -+ Operator -+ >; -+ -+ using UnderlyingArguments = typename UnderlyingOperator::Arguments; -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ using ReductionKernel = typename UnderlyingOperator::ReductionKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ typename ConvertScaledOp::Params convert; -+ typename ReductionOp::Params reduction; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1, -+ typename ConvertScaledOp::Params convert_ = -+ typename ConvertScaledOp::Params(), -+ typename ReductionOp::Params reduction_ = -+ typename ReductionOp::Params() -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices), -+ convert(convert_), -+ reduction(reduction_) { } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmSplitKParallel() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static UnderlyingArguments to_underlying_arguments(Arguments const &args) { -+ return UnderlyingArguments( -+ {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, -+ {args.ref_B.data(), args.ref_B.stride(0)}, -+ {args.ref_A.data(), args.ref_A.stride(0)}, -+ {args.ref_C.data(), args.ref_C.stride(0)}, -+ {args.ref_D.data(), args.ref_D.stride(0)}, -+ args.epilogue, -+ args.split_k_slices, -+ args.convert, -+ args.reduction -+ ); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal.h -new file mode 100644 -index 0000000..6c19b8a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal.h -@@ -0,0 +1,420 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/device/gemm_universal_base.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! -+ GemmUniversal is a stateful, reusable GEMM handle. Once initialized for a given GEMM computation -+ (problem geometry and data references), it can be reused across different GEMM problems having the -+ geometry. (Once initialized, details regarding problem geometry and references to workspace memory -+ cannot be updated.) -+ -+ The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and -+ batched array variants. -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Gather operand A by using an index array -+ bool GatherA = false, -+ /// Gather operand B by using an index array -+ bool GatherB = false, -+ /// Scatter result D by using an index array -+ bool ScatterD = false, -+ /// Permute result D -+ typename PermuteDLayout = layout::NoPermute -+> -+class GemmUniversal : -+ public GemmUniversalBase< -+ typename kernel::DefaultGemmUniversal< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_, -+ SharedMemoryClearOption::kNone, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout -+ >::GemmKernel -+ > { -+ -+ public: -+ -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using Base = GemmUniversalBase< -+ typename kernel::DefaultGemmUniversal< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_, -+ SharedMemoryClearOption::kNone, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout -+ >::GemmKernel -+ >; -+ -+ using Arguments = typename Base::Arguments; -+ using GemmKernel = typename Base::GemmKernel; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+class GemmUniversal { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using UnderlyingOperator = typename GemmUniversal< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA, -+ Operator, -+ kTransformB, -+ kTransformA, -+ GatherB, -+ GatherA, -+ ScatterD, -+ PermuteDLayout -+ >::Base; -+ -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmUniversal() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem(); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h -new file mode 100644 -index 0000000..66884fb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h -@@ -0,0 +1,549 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and -+ batched array variants. -+*/ -+ -+#pragma once -+ -+// common -+#include "cutlass/cutlass.h" -+#include "cutlass/trace.h" -+#include "cutlass/cluster_launch.hpp" -+#include "cutlass/device_kernel.h" -+#include "cutlass/gemm/gemm.h" -+ -+// 2.x -+#include "cutlass/gemm/device/gemm_universal_base.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+// 3.x -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/*! -+ GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel -+ of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal. -+ -+ It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs -+ to create it from the host facing arguments. For power users, new static methods -+ are exposed in 3.x APIs that bypass the stateful methods or args->params lowering. -+ -+ It supports kernel types that implement both the 2.x and 3.0 APIs, -+ however, this is done by specializing the implementation of GemmUniversalAdapter -+ on the two kernel API types, and thus, GemmUniversalAdapter's behaviour might -+ differ between the two specializations. -+*/ -+template -+class GemmUniversalAdapter; -+ -+//////////////////////////////////////////////////////////////////////////////// -+////////////////////////////// CUTLASS 3.x API ///////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmUniversalAdapter< -+ GemmKernel_, -+ std::enable_if_t::value>> -+{ -+public: -+ using GemmKernel = GemmKernel_; -+ using TileShape = typename GemmKernel::TileShape; -+ using ElementA = typename GemmKernel::ElementA; -+ using ElementB = typename GemmKernel::ElementB; -+ using ElementC = typename GemmKernel::ElementC; -+ using ElementAccumulator = typename GemmKernel::TiledMma::ValTypeC; -+ using DispatchPolicy = typename GemmKernel::DispatchPolicy; -+ using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; -+ using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; -+ -+ // Map back to 2.x type as best as possible -+ using LayoutA = gemm::detail::StrideToLayoutTagA_t; -+ using LayoutB = gemm::detail::StrideToLayoutTagB_t; -+ using LayoutC = gemm::detail::StrideToLayoutTagC_t; -+ using LayoutD = gemm::detail::StrideToLayoutTagC_t; -+ -+ // NOTE: 3.0 kernels do not support complex transforms for now ... -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 -+ using MathOperator = cutlass::arch::OpMultiplyAdd; -+ -+ // If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop! -+ using OperatorClass = std::conditional_t< -+ (cute::size(typename GemmKernel::TiledMma::AtomThrID{}) > 1), -+ cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>; -+ -+ using ArchTag = typename GemmKernel::ArchTag; -+ -+ // NOTE: Assume identity swizzle for now -+ static_assert(std::is_void_v, -+ "CUTLASS 3.x kernel types do not support grid swizzle functors yet."); -+ using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+ // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape -+ using ThreadblockShape = cutlass::gemm::GemmShape< -+ cute::size<0>(TileShape{}), -+ cute::size<1>(TileShape{}), -+ cute::size<2>(TileShape{})>; -+ -+ using ClusterShape = cutlass::gemm::GemmShape< -+ cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), -+ cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), -+ cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; -+ -+ // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape -+ using InstructionShape = cutlass::gemm::GemmShape< -+ cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), -+ cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), -+ cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; -+ -+ // Legacy: provide a correct warp count, but no reliable warp shape -+ static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; -+ -+ // Warp shape is not a primary API type in 3.x -+ // But we can best approximate it by inspecting the TiledMma::TiledShape_MNK -+ // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K -+ // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads -+ static constexpr int WarpsInMma = std::max(4, cute::size(typename GemmKernel::TiledMma{}) / 32); -+ static constexpr int WarpsInMmaM = 4; -+ static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); -+ using WarpCount = cutlass::gemm::GemmShape; -+ using WarpShape = cutlass::gemm::GemmShape< -+ cute::size<0>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaM, -+ cute::size<1>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaN, -+ cute::size<2>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{})>; -+ -+ static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; -+ -+ // Inspect TiledCopy for A and B to compute the alignment size -+ static int constexpr kAlignmentA = gemm::detail::get_alignment_count_from_gmem_tiled_copy< -+ typename CollectiveMainloop::GmemTiledCopyA, ElementA>(); -+ static int constexpr kAlignmentB = gemm::detail::get_alignment_count_from_gmem_tiled_copy< -+ typename CollectiveMainloop::GmemTiledCopyB, ElementB>(); -+ -+ // NOTE: 3.0 DefaultEpilogues don't support vectorized stores (yet) -+ static int constexpr kAlignmentC = 1; -+ static int constexpr kAlignmentD = 1; -+ -+ using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; -+ -+ // Split-K preserves splits that are 128b aligned -+ static int constexpr kSplitKAlignment = std::max( -+ 128 / sizeof_bits::value, 128 / sizeof_bits::value); -+ -+ /// Argument structure: User API -+ using Arguments = typename GemmKernel::Arguments; -+ /// Argument structure: Kernel API -+ using Params = typename GemmKernel::Params; -+ -+private: -+ -+ /// Kernel API parameters object -+ Params params_; -+ -+public: -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status -+ can_implement(Arguments const& args) { -+ if (GemmKernel::can_implement(args)) { -+ return Status::kSuccess; -+ } -+ else { -+ return Status::kInvalid; -+ } -+ } -+ -+ /// Gets the workspace size -+ static size_t -+ get_workspace_size(Arguments const& args) { -+ size_t workspace_bytes = 0; -+ if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); -+ } -+ -+ CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); -+ -+ workspace_bytes += GemmKernel::get_workspace_size(args); -+ return workspace_bytes; -+ } -+ -+ /// Computes the grid shape -+ static dim3 -+ get_grid_shape(Arguments const& args) { -+ auto tmp_params = GemmKernel::to_underlying_arguments(args); -+ return GemmKernel::get_grid_shape(tmp_params); -+ } -+ -+ /// Computes the grid shape -+ static dim3 -+ get_grid_shape(Params const& params) { -+ return GemmKernel::get_grid_shape(params); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int /* smem_capacity */ = -1) { -+ CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); -+ int max_active_blocks = -1; -+ int smem_size = GemmKernel::SharedStorageSize; -+ -+ // first, account for dynamic smem capacity if needed -+ cudaError_t result; -+ if (smem_size >= (48 << 10)) { -+ CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); -+ result = cudaFuncSetAttribute( -+ device_kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ if (cudaSuccess != result) { -+ result = cudaGetLastError(); // to clear the error bit -+ CUTLASS_TRACE_HOST( -+ " cudaFuncSetAttribute() returned error: " -+ << cudaGetErrorString(result)); -+ return -1; -+ } -+ } -+ -+ // query occupancy after setting smem size -+ result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( -+ &max_active_blocks, -+ device_kernel, -+ GemmKernel::MaxThreadsPerBlock, -+ smem_size); -+ -+ if (cudaSuccess != result) { -+ result = cudaGetLastError(); // to clear the error bit -+ CUTLASS_TRACE_HOST( -+ " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " -+ << cudaGetErrorString(result)); -+ return -1; -+ } -+ -+ CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); -+ return max_active_blocks; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status -+ initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { -+ CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " -+ << workspace << ", stream: " << (stream ? "non-null" : "null")); -+ -+ size_t workspace_bytes = GemmKernel::get_workspace_size(args); -+ CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); -+ -+ if (workspace_bytes) { -+ if (!workspace) { -+ CUTLASS_TRACE_HOST(" error: device workspace must not be null"); -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ if (args.mode == GemmUniversalMode::kGemm) { -+ CUTLASS_TRACE_HOST(" clearing device workspace"); -+ cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); -+ if (cudaSuccess != result) { -+ result = cudaGetLastError(); // to clear the error bit -+ CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ -+ // Initialize the Params structure -+ params_ = GemmKernel::to_underlying_arguments(args, workspace); -+ -+ // account for dynamic smem capacity if needed -+ int smem_size = GemmKernel::SharedStorageSize; -+ if (smem_size >= (48 << 10)) { -+ CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); -+ cudaError_t result = cudaFuncSetAttribute( -+ device_kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ if (cudaSuccess != result) { -+ result = cudaGetLastError(); // to clear the error bit -+ CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); -+ return Status::kErrorInternal; -+ } -+ } -+ return Status::kSuccess; -+ } -+ -+ /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. -+ Status -+ update(Arguments const& args, void* workspace = nullptr) { -+ CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); -+ -+ size_t workspace_bytes = get_workspace_size(args); -+ if (workspace_bytes > 0 && nullptr == workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ params_ = GemmKernel::to_underlying_arguments(args, workspace); -+ return Status::kSuccess; -+ } -+ -+ /// Primary run() entry point API that is static allowing users to create and manage their own params. -+ /// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments() -+ static Status -+ run(Params& params, cudaStream_t stream = nullptr) { -+ CUTLASS_TRACE_HOST("GemmUniversal::run()"); -+ dim3 constexpr block = GemmKernel::get_block_shape(); -+ dim3 const grid = get_grid_shape(params); -+ -+ // configure smem size and carveout -+ int smem_size = GemmKernel::SharedStorageSize; -+ -+ Status launch_result; -+ // Use extended launch API only for mainloops that use it -+ if constexpr(GemmKernel::ArchTag::kMinComputeCapability >= 90) { -+ dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), -+ cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), -+ cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); -+ void const* kernel = (void const*) device_kernel; -+ void* kernel_params[] = {¶ms}; -+ launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); -+ } -+ else { -+ launch_result = Status::kSuccess; -+ device_kernel<<>>(params); -+ } -+ -+ cudaError_t result = cudaGetLastError(); -+ if (cudaSuccess == result && Status::kSuccess == launch_result) { -+ return Status::kSuccess; -+ } -+ else { -+ CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // -+ // Non-static launch overloads that first create and set the internal params struct of this kernel handle. -+ // -+ -+ /// Launches the kernel after first constructing Params internal state from supplied arguments. -+ Status -+ run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { -+ Status status = initialize(args, workspace, stream); -+ if (Status::kSuccess == status) { -+ status = run(params_, stream); -+ } -+ return status; -+ } -+ -+ /// Launches the kernel after first constructing Params internal state from supplied arguments. -+ Status -+ operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { -+ return run(args, workspace, stream); -+ } -+ -+ /// Overload that allows a user to re-launch the same kernel without updating internal params struct. -+ Status -+ run(cudaStream_t stream = nullptr) { -+ return run(params_, stream); -+ } -+ -+ /// Overload that allows a user to re-launch the same kernel without updating internal params struct. -+ Status -+ operator()(cudaStream_t stream = nullptr) const { -+ return run(params_, stream); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+////////////////////////////// CUTLASS 2.x API ///////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmUniversalAdapter< -+ GemmKernel_, -+ std::enable_if_t::value>> -+{ -+public: -+ -+ using GemmKernel = GemmKernel_; -+ -+ static bool const kInternalTranspose = -+ platform::is_same::value; -+ -+ using ThreadblockShape = typename GemmKernel::Mma::Shape; -+ using WarpShape = typename GemmKernel::WarpShape; -+ using InstructionShape = typename GemmKernel::InstructionShape; -+ -+ // warp-level, arch-level (instruction), math operator -+ using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator; -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename WarpMmaOperator::MathOperator; -+ -+ // Operator class and arch tag extract bottom-up -+ // set it for top-level gemm device-level template -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ -+ // Type, layout, and complex transform deliberately exchanged with B -+ using MapArguments = kernel::detail::MapArguments< -+ typename GemmKernel::ElementA, -+ typename GemmKernel::LayoutA, -+ GemmKernel::kTransformA, -+ GemmKernel::kAlignmentA, -+ typename GemmKernel::ElementB, -+ typename GemmKernel::LayoutB, -+ GemmKernel::kTransformB, -+ GemmKernel::kAlignmentB, -+ typename GemmKernel::LayoutC, -+ kInternalTranspose -+ >; -+ -+ using ElementA = typename MapArguments::ElementA; -+ using LayoutA = typename MapArguments::LayoutA; -+ static ComplexTransform const kTransformA = MapArguments::kTransformA; -+ static int const kAlignmentA = MapArguments::kAlignmentA; -+ -+ using ElementB = typename MapArguments::ElementB; -+ using LayoutB = typename MapArguments::LayoutB; -+ static ComplexTransform const kTransformB = MapArguments::kTransformB; -+ static int const kAlignmentB = MapArguments::kAlignmentB; -+ -+ using ElementC = typename GemmKernel::ElementC; -+ using LayoutC = typename MapArguments::LayoutC; -+ static int const kAlignmentC = GemmKernel::kAlignmentC; -+ -+ using TensorRefA = TensorRef; -+ using TensorRefB = TensorRef; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ -+ static int const kStages = GemmKernel::Mma::kStages; -+ -+ using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; -+ using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; -+ using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; -+ using UnderlyingOperator = GemmUniversalBase; -+ using Arguments = typename UnderlyingOperator::Arguments; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmUniversalAdapter() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ if (kInternalTranspose) { -+ return args.transposed_problem(); -+ } -+ else { -+ return args; -+ } -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed to -+ /// remain the same. -+ Status update(Arguments const &args) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args)); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::device -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_base.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_base.h -new file mode 100644 -index 0000000..cca768a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_base.h -@@ -0,0 +1,416 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief The universal GEMM accommodates streamk, batched strided, and batched array variants. -+*/ -+ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+template -+class GemmUniversalBase { -+public: -+ -+ using GemmKernel = GemmKernel_; -+ using ThreadblockShape = typename GemmKernel::Mma::Shape; -+ -+ using ElementA = typename GemmKernel::ElementA; -+ using LayoutA = typename GemmKernel::LayoutA; -+ using TensorRefA = TensorRef; -+ static ComplexTransform const kTransformA = GemmKernel::kTransformA; -+ -+ using ElementB = typename GemmKernel::ElementB; -+ using LayoutB = typename GemmKernel::LayoutB; -+ using TensorRefB = TensorRef; -+ static ComplexTransform const kTransformB = GemmKernel::kTransformB; -+ -+ using ElementC = typename GemmKernel::ElementC; -+ using LayoutC = typename GemmKernel::LayoutC; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ -+ /// Numerical accumulation element type -+ using ElementAccumulator = typename GemmKernel::Mma::ElementC; -+ -+ using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; -+ using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; -+ using Operator = typename GemmKernel::Operator; -+ -+ /// Argument structure -+ using Arguments = typename GemmKernel::Arguments; -+ -+protected: -+ -+ // -+ // Device properties (uniform across all instances of the current thread) -+ // -+ -+ // Device ordinal -+ thread_local static int device_ordinal_; -+ -+ /// Device SM count -+ thread_local static int device_sms_; -+ -+ /// Kernel SM occupancy (in thread blocks) -+ thread_local static int sm_occupancy_; -+ -+ /// Kernel dynamic shared memory allocation requirement -+ thread_local static int smem_size_; -+ -+ /// Initialize static thread-local members for the thread's current device, -+ /// if necessary. -+ static Status init_device_props() -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::init_device_props()"); -+ -+ cudaError_t cudart_result; -+ -+ // Get current device ordinal -+ int current_ordinal; -+ cudart_result = cudaGetDevice(¤t_ordinal); -+ if (cudart_result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(cudart_result)); -+ return Status::kErrorInternal; -+ } -+ -+ // Done if matches the current static member -+ if (current_ordinal == device_ordinal_) { -+ // Already initialized -+ return Status::kSuccess; -+ } -+ -+ // Update SM count member -+ cudart_result = cudaDeviceGetAttribute (&device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); -+ if (cudart_result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(cudart_result)); -+ return Status::kErrorInternal; -+ } -+ -+ // Update the kernel function's shared memory configuration for the current device -+ smem_size_ = int(sizeof(typename GemmKernel::SharedStorage)); -+ -+ // If requires more than 48KB: configure for extended, dynamic shared memory -+ if (smem_size_ >= (48 << 10)) -+ { -+ cudart_result = cudaFuncSetAttribute( -+ Kernel2, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size_); -+ if (cudart_result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); -+ return Status::kErrorInternal; -+ } -+ -+ cudart_result = cudaFuncSetAttribute( -+ Kernel2, -+ cudaFuncAttributePreferredSharedMemoryCarveout, 100); // 100% shared memory -+ if (cudart_result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // Update SM occupancy member -+ cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( -+ &sm_occupancy_, -+ Kernel2, -+ GemmKernel::kThreadCount, -+ smem_size_, -+ cudaOccupancyDisableCachingOverride); -+ if (cudart_result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result)); -+ return Status::kErrorInternal; -+ } -+ -+ // Update device ordinal member on success -+ device_ordinal_ = current_ordinal; -+ -+ CUTLASS_TRACE_HOST(" " -+ "device_ordinal: (" << device_ordinal_ << "), " -+ "device_sms: (" << device_sms_ << "), " -+ "sm_occupancy: (" << sm_occupancy_ << ") " -+ "smem_size: (" << smem_size_ << ") " -+ "GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")"); -+ -+ return Status::kSuccess; -+ } -+ -+ -+protected: -+ -+ // -+ // Instance data members -+ // -+ -+ /// Kernel parameters -+ typename GemmKernel::Params params_; -+ -+ -+ /// Initialize params member -+ Status init_params(Arguments const &args) -+ { -+ // Initialize static device properties, if necessary -+ Status result = init_device_props(); -+ if (result != Status::kSuccess) { -+ return result; -+ } -+ -+ // Initialize params member -+ params_ = typename GemmKernel::Params(args, device_sms_, sm_occupancy_); -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ //--------------------------------------------------------------------------------------------- -+ // Stateless API -+ //--------------------------------------------------------------------------------------------- -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()"); -+ -+ // Initialize static kernel and device properties, if necessary. -+ Status result = init_device_props(); -+ if (result != Status::kSuccess) { -+ return result; -+ } -+ -+ dim3 grid = get_grid_shape(args); -+ -+ if (!(grid.y <= std::numeric_limits::max() && -+ grid.z <= std::numeric_limits::max())) -+ { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return GemmKernel::can_implement(args); -+ } -+ -+ -+ /// Returns the workspace size (in bytes) needed for the problem -+ /// geometry expressed by these arguments -+ static size_t get_workspace_size(Arguments const &args) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()"); -+ -+ // Initialize parameters from args -+ GemmUniversalBase base; -+ if (base.init_params(args) != Status::kSuccess) { -+ return 0; -+ } -+ -+ // Get size from parameters -+ size_t workspace_bytes = base.params_.get_workspace_size(); -+ -+ CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); -+ return workspace_bytes; -+ } -+ -+ -+ /// Returns the grid extents in thread blocks to launch -+ static dim3 get_grid_shape(Arguments const &args) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()"); -+ -+ // Initialize parameters from args -+ GemmUniversalBase base; -+ if (base.init_params(args) != Status::kSuccess) { -+ return dim3(0,0,0); -+ } -+ -+ // Get dims from parameters -+ dim3 grid_dims = base.params_.get_grid_dims(); -+ -+ CUTLASS_TRACE_HOST( -+ " tiled_shape: " << base.params_.get_tiled_shape() << "\n" -+ << " grid_dims: {" << grid_dims << "}"); -+ -+ return grid_dims; -+ } -+ -+ -+ /// Returns the maximum number of active thread blocks per multiprocessor -+ static int maximum_active_blocks() -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); -+ -+ // Initialize static device properties, if necessary -+ if (init_device_props() != Status::kSuccess) { -+ return -1; -+ } -+ -+ CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_); -+ return sm_occupancy_; -+ } -+ -+ -+ //--------------------------------------------------------------------------------------------- -+ // Stateful API -+ //--------------------------------------------------------------------------------------------- -+ -+ /// Initializes GEMM state from arguments and workspace memory -+ Status initialize( -+ Arguments const &args, -+ void *workspace, -+ cudaStream_t stream = nullptr) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " -+ << workspace << ", stream: " << (stream ? "non-null" : "null")); -+ -+ // Initialize parameters from args -+ Status result = init_params(args); -+ if (result != Status::kSuccess) { -+ return result; -+ } -+ -+ // Assign and prepare workspace memory -+ return params_.init_workspace(workspace, stream); -+ } -+ -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed to -+ /// remain the same. -+ Status update(Arguments const &args) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase()::update()"); -+ params_.update(args); -+ return Status::kSuccess; -+ } -+ -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::run()"); -+ -+ // Configure grid and block dimensions -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ dim3 grid = params_.get_grid_dims(); -+ -+ // Launch kernel -+ CUTLASS_TRACE_HOST(" " -+ "grid: (" << grid << "), " -+ "block: (" << block << "), " -+ "SMEM: (" << smem_size_ << ")"); -+ -+ Kernel2<<>>(params_); -+ -+ // Query for errors -+ cudaError_t result = cudaGetLastError(); -+ if (result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) -+ { -+ return run(stream); -+ } -+ -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) -+ { -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Static initializers -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Device ordinal -+template -+thread_local int GemmUniversalBase::device_ordinal_ = -1; -+ -+/// Device SM count -+template -+thread_local int GemmUniversalBase::device_sms_ = -1; -+ -+/// Kernel SM occupancy (in thread blocks) -+template -+thread_local int GemmUniversalBase::sm_occupancy_ = -1; -+ -+/// Kernel dynamic shared memory allocation requirement -+template -+thread_local int GemmUniversalBase::smem_size_ = -1; -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_with_broadcast.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_with_broadcast.h -new file mode 100644 -index 0000000..34b3f6c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_with_broadcast.h -@@ -0,0 +1,386 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a GEMM kernel that can broadcast bias vector in the -+ epigloue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/device/gemm_universal_base.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! -+ The universal GEMM with a broadcast epilogue. -+ Supports -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' -+ typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ ElementC_, ElementAccumulator_, ElementAccumulator_, -+ ElementC_, ElementC_, 128 / cutlass::sizeof_bits::value>, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone -+> -+class GemmUniversalWithBroadcast : -+ public GemmUniversalBase< -+ typename kernel::DefaultGemmWithBroadcast< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_ -+ >::GemmKernel -+ > { -+ -+ public: -+ -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using Base = GemmUniversalBase< -+ typename kernel::DefaultGemmWithBroadcast< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_ -+ >::GemmKernel -+ >; -+ -+ using Arguments = typename Base::Arguments; -+ using GemmKernel = typename Base::GemmKernel; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB> -+class GemmUniversalWithBroadcast { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using UnderlyingOperator = typename GemmUniversalWithBroadcast< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA, -+ Operator, -+ kTransformB, -+ kTransformA -+ >::Base; -+ -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmUniversalWithBroadcast() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem(); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_with_k_reduction.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_with_k_reduction.h -new file mode 100644 -index 0000000..c671d7c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_with_k_reduction.h -@@ -0,0 +1,415 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a GEMM kernel that can reduce one of the input matrix -+ into a vector along the K dimension. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_with_k_reduction.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/device/gemm_universal_base.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! -+ The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and -+ batched array variants. -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Reduce A or B operand along the K dimension -+ bool ReduceKForA_ = true, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Gather operand A by using an index array -+ bool GatherA = false, -+ /// Gather operand B by using an index array -+ bool GatherB = false, -+ /// Scatter result D by using an index array -+ bool ScatterD = false, -+ /// Permute result D -+ typename PermuteDLayout = layout::NoPermute -+> -+class GemmWithKReduction : -+ public GemmUniversalBase< -+ typename kernel::DefaultGemmWithKReduction< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ReduceKForA_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_, -+ SharedMemoryClearOption::kNone -+ >::GemmKernel -+ > { -+ -+ public: -+ -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static constexpr int kStages = Stages; -+ static constexpr int kAlignmentA = AlignmentA; -+ static constexpr int kAlignmentB = AlignmentB; -+ static constexpr int kAlignmentC = EpilogueOutputOp::kCount; -+ static constexpr ComplexTransform kTransformA = TransformA; -+ static constexpr ComplexTransform kTransformB = TransformB; -+ -+ using Base = GemmUniversalBase< -+ typename kernel::DefaultGemmWithKReduction< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ReduceKForA_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_, -+ SharedMemoryClearOption::kNone -+ >::GemmKernel -+ >; -+ -+ using Arguments = typename Base::Arguments; -+ using GemmKernel = typename Base::GemmKernel; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Reduce A or B operand along the K dimension -+ bool ReduceKForA_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+class GemmWithKReduction { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using UnderlyingOperator = typename GemmWithKReduction< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ !ReduceKForA_, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA, -+ Operator, -+ kTransformB, -+ kTransformA, -+ GatherB, -+ GatherA, -+ ScatterD, -+ PermuteDLayout -+ >::Base; -+ -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmWithKReduction() = default; -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem(); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemv.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemv.h -new file mode 100644 -index 0000000..c62168f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemv.h -@@ -0,0 +1,174 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/device/gemm_universal_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Gemv { -+public: -+ -+ using GemvKernel = GemvKernel_; -+ -+ -+ using ElementA = typename GemvKernel::ElementA; -+ using LayoutA = typename GemvKernel::LayoutA; -+ using ElementB = typename GemvKernel::ElementB; -+ using ElementC = typename GemvKernel::ElementC; -+ -+ using ElementAccumulator = typename GemvKernel::ElementAccumulator; -+ using EpilogueOutputOp = typename GemvKernel::EpilogueOutputOp; -+ -+ static ComplexTransform const kTransformA = GemvKernel::kTransformA; -+ static ComplexTransform const kTransformB = GemvKernel::kTransformB; -+ -+ static int const kThreadCount = GemvKernel::kThreadCount; -+ static int const kStages = GemvKernel::kStages; -+ -+ static int const kAlignmentA = GemvKernel::kAlignmentA; -+ static int const kAlignmentB = GemvKernel::kAlignmentB; -+ static int const kAlignmentC = GemvKernel::kAlignmentC; -+ -+ using Arguments = typename GemvKernel::Arguments; -+ using Params = typename GemvKernel::Params; -+ -+private: -+ -+ Params params_; -+ -+public: -+ -+ /// Constructs the Gemv. -+ Gemv() { } -+ -+ /// Determines whether the Gemv can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return GemvKernel::can_implement(args); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return 0; -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return dim3((args.problem_size.row() + (kThreadCount - 1)) / kThreadCount, 1, args.batch_count % 65565); -+ } -+ -+ /// Initializes Gemv state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ params_ = Params(args); -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ return params_.update(args); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ dim3 grid = get_grid_shape(params_); -+ dim3 block(GemvKernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename GemvKernel::SharedStorage)); -+ -+ // Launch -+ cutlass::Kernel<<>>(params_); -+ -+ // -+ // Query for errors -+ // -+ cudaError_t result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/rank_2k.h b/3rdparty/cutlass/include/cutlass/gemm/device/rank_2k.h -new file mode 100644 -index 0000000..d333ffa ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/rank_2k.h -@@ -0,0 +1,547 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined Rank2K kernel. Does not compute batching or support split-K. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/rank_2k_universal.h" -+ -+#include "cutlass/gemm/kernel/default_rank_2k_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassTensorOp, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm80, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ typename threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by SYRK -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Complex elementwise transformation -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex elementwise transformation -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+class Rank2K { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static FillMode const kFillModeC = FillModeC; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ static BlasMode const kBlasMode = BlasMode_; -+ static int const kUpdateRank = 2; -+ -+ // static asserts for rank 2k update kernel -+ static_assert(platform::is_same::value, -+ "Rank 2K update operator support same layouts for operandA and B"); -+ -+ /// Define the kernel -+ using Rank2Kkernel = typename kernel::DefaultRank2KUniversal< -+ ElementA, -+ LayoutA, -+ kTransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kTransformB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ kFillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator, -+ kBlasMode -+ >::Rank2Kkernel; -+ -+ using Arguments = typename Rank2Kkernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename Rank2Kkernel::Params params_; -+public: -+ -+ /// Constructs the SYRK. -+ Rank2K() { } -+ -+ /// Determines whether the SYRK can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = Rank2Kkernel::can_implement(args); -+ -+ if (FillModeC != FillMode::kLower && FillModeC != FillMode::kUpper) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes SYRK state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial) { -+ if (args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ int gemm_k_size = args.problem_size.k(); -+ -+ // Initialize the Params structure -+ params_ = typename Rank2Kkernel::Params{ -+ args, -+ grid_tiled_shape, -+ gemm_k_size, -+ static_cast(workspace) -+ }; -+ -+ int smem_size = int(sizeof(typename Rank2Kkernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ size_t workspace_bytes = get_workspace_size(args); -+ -+ if (workspace_bytes && !workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ params_.update(args, workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(Rank2Kkernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename Rank2Kkernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchange operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial, -+ /// Operation performed by Rank2K update kernel -+ typename Operator_, -+ /// Complex elementwise transformation -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation -+ ComplexTransform TransformB, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ -+ > -+class Rank2K { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static FillMode const kFillModeC = FillModeC; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static BlasMode const kBlasMode = BlasMode_; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ static int const kUpdateRank = 2; -+ -+ /// Define the kernel -+ using UnderlyingOperator = typename cutlass::gemm::device::Rank2K< -+ ElementB, -+ LayoutB, -+ ElementA, -+ LayoutA, -+ ElementC, -+ layout::RowMajor, -+ InvertFillMode::mode, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kAlignmentB, -+ kAlignmentA, -+ kSplitKSerial, -+ Operator, -+ kTransformA, -+ kTransformB, -+ kBlasMode -+ >; -+ -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ using Rank2Kkernel = typename UnderlyingOperator::Rank2Kkernel; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the Rank2K. -+ Rank2K() { } -+ -+ /// Helper to construct a transposed equivalent for the underying Rank2K operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem(); -+ } -+ -+ /// Determines whether the Rank2K can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes Rank2K state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace Rank2K -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/rank_2k_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/device/rank_2k_grouped.h -new file mode 100644 -index 0000000..f38b07a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/rank_2k_grouped.h -@@ -0,0 +1,63 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Device-level grouped Rank2K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/gemm/device/base_grouped.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Rank2K Grouped -+template -+class Rank2KGrouped : public BaseGrouped { -+public: -+ using Rank2Kkernel = Rank2Kkernel_; -+ static const cutlass::FillMode kFillModeC = Rank2Kkernel::kFillModeC; -+ static const cutlass::BlasMode kBlasMode = Rank2Kkernel::kBlasMode; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/rank_k.h b/3rdparty/cutlass/include/cutlass/gemm/device/rank_k.h -new file mode 100644 -index 0000000..a2101a7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/rank_k.h -@@ -0,0 +1,509 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined RankK kernel. Does not compute batching or support split-K. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/rank_k_universal.h" -+ -+#include "cutlass/gemm/kernel/default_rank_k_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassTensorOp, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm80, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ typename threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by SYRK -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Complex elementwise transformation -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+class RankK { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static FillMode const kFillModeC = FillModeC; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static ComplexTransform const kTransformA = TransformA; -+ static BlasMode const kBlasMode = BlasMode_; -+ static int const kUpdateRank = 1; -+ -+ /// Define the kernel -+ using RankKkernel = typename kernel::DefaultRankKUniversal< -+ ElementA, -+ LayoutA, -+ kTransformA, -+ kAlignmentA, -+ ElementC, -+ LayoutC, -+ kFillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator, -+ kBlasMode -+ >::RankKkernel; -+ -+ using Arguments = typename RankKkernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename RankKkernel::Params params_; -+public: -+ -+ /// Constructs the SYRK. -+ RankK() { } -+ -+ /// Determines whether the SYRK can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = RankKkernel::can_implement(args); -+ -+ if (FillModeC != FillMode::kLower && FillModeC != FillMode::kUpper) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes SYRK state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial) { -+ if (args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ int gemm_k_size = args.problem_size.k(); -+ -+ // Initialize the Params structure -+ params_ = typename RankKkernel::Params{ -+ args, -+ grid_tiled_shape, -+ gemm_k_size, -+ static_cast(workspace) -+ }; -+ -+ int smem_size = int(sizeof(typename RankKkernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ size_t workspace_bytes = get_workspace_size(args); -+ -+ if (workspace_bytes && !workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ params_.update(args, workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(RankKkernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename RankKkernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchange operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial, -+ /// Operation performed by RankK update kernel -+ typename Operator_, -+ /// Complex elementwise transformation -+ ComplexTransform TransformA, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ -+ > -+class RankK { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static FillMode const kFillModeC = FillModeC; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static BlasMode const kBlasMode = BlasMode_; -+ static int const kUpdateRank = 1; -+ -+ // Complex transform for input A matrices (function on input layout) -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Define the kernel -+ using UnderlyingOperator = typename cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ layout::RowMajor, -+ InvertFillMode::mode, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kAlignmentA, -+ kSplitKSerial, -+ Operator, -+ kTransformA, -+ kBlasMode -+ >; -+ -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ using RankKkernel = typename UnderlyingOperator::RankKkernel; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the RankK. -+ RankK() { } -+ -+ /// Helper to construct a transposed equivalent for the underying RankK operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args; -+ } -+ -+ /// Determines whether the RankK can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes RankK state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace RankK -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/symm.h b/3rdparty/cutlass/include/cutlass/gemm/device/symm.h -new file mode 100755 -index 0000000..57bfeec ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/symm.h -@@ -0,0 +1,602 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined SYMM and HEMM kernels. Does not compute batching or support split-K. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/symm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_symm_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode SideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode FillModeA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassTensorOp, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm80, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = epilogue::thread::LinearCombination< -+ ElementC_, -+ 128 / sizeof_bits::value, -+ ElementAccumulator_, -+ ElementAccumulator_, -+ epilogue::thread::ScaleType::OnlyAlphaScaling -+ >, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by SYMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+class Symm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementAKernel = typename platform::conditional<(SideModeA == SideMode::kRight), ElementB_, ElementA_>::type; -+ using LayoutAKernel = typename platform::conditional<(SideModeA == SideMode::kRight), LayoutB_, LayoutA_>::type; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementBKernel = typename platform::conditional<(SideModeA == SideMode::kRight), ElementA_, ElementB_>::type; -+ using LayoutBKernel = typename platform::conditional<(SideModeA == SideMode::kRight), LayoutA_, LayoutB_>::type; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static SideMode const kSideModeA = SideModeA; -+ static FillMode const kFillModeA = FillModeA; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentAKernel = (SideModeA == SideMode::kRight) ? AlignmentB : AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentBKernel = (SideModeA == SideMode::kRight) ? AlignmentA : AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ // static asserts for symm update kernel -+ static_assert(platform::is_same::value, -+ "SYMM update operator support same layouts for operand A and B"); -+ -+ /// Define the kernel -+ using SymmKernel = typename kernel::DefaultSymmUniversal< -+ ElementAKernel, -+ LayoutAKernel, -+ kSideModeA, -+ kFillModeA, -+ kAlignmentAKernel, -+ ElementBKernel, -+ LayoutBKernel, -+ kAlignmentBKernel, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator, -+ kBlasMode -+ >::SymmKernel; -+ -+ using Arguments = typename SymmKernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename SymmKernel::Params params_; -+public: -+ -+ /// Constructs the SYMM. -+ Symm() { } -+ -+ /// Determines whether the SYMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = SymmKernel::can_implement(args); -+ -+ if (SideModeA == SideMode::kInvalid) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (FillModeA != FillMode::kLower && FillModeA != FillMode::kUpper) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes SYMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial) { -+ if (args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ int gemm_k_size = args.problem_size.k(); -+ -+ // Swapping argument for A and B, if A was on the right side (problem size doesn't need to change here). -+ if (kSideModeA == SideMode::kRight) { -+ // Initialize the Params structure -+ params_ = typename SymmKernel::Params{ -+ args.swapped_matrices(), -+ grid_tiled_shape, -+ gemm_k_size, -+ static_cast(workspace) -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ // Initialize the Params structure -+ params_ = typename SymmKernel::Params{ -+ args, -+ grid_tiled_shape, -+ gemm_k_size, -+ static_cast(workspace) -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ size_t workspace_bytes = get_workspace_size(args); -+ -+ if (workspace_bytes && !workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ params_.update(args, workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(SymmKernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename SymmKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+/******************************************************************************************************** -+ SYMM/HEMM has 4 combinations based on Layouts {RowMajor, ColumnMajor} x Side mode {LeftSide, RightSide} -+ In templates and arguments to cutlass kernel, `matrix A` is always symmetric/hermitian, and `matrix B` is rectangular. -+ (adhering to the cuBLAS convention) -+ -+ Although, cuBLAS SYMM/HEMM only supports ColumnMajor layouts for all matrices (A, B, C/D). -+ -+ For the mainloop and symm kernel, `A` and `B` points to left-side and right-side matrices, respectively. -+ -+ Thus, for LeftSide mode `A` and `B` points to `matrix A` and `matrix B`, respectively. While for -+ the RightSide mode `A` and `B` points to `matrix B` and `matrix A`, respectively. -+ -+ Additionally, CUTLASS GEMM epilogue is always RowMajor, and ColumnMajor output is achieved by -+ transposing the GEMM problem. Thus, ColumnMajor output layout for SYMM/HEMM requires: -+ - Transposing `matrix A` and `matrix B` layouts -+ - Swapping problem size m and n values -+ - Swapping LeftSide and RightSide mode -+ -+ RowMajor output: D = matrix A x matrix B -+ ColumnMajor output: D = matrix A x matrix B -> Transpose (D) = Transpose(matrix B) x Transpose(matrix A) -+ -+ {RowMajor, ColumnMajor} x Side Mode {LeftSide, RightSide} 4 cases: -+ 1. LeftSide mode and RowMajor output (default template) -+ 2. LeftSide mode and ColumnMajor output -+ 3. RightSide mode and RowMajor output -+ 4. RightSide mode and ColumnMajor output -+ -+ Mapping ColumnMajor output layout cases 2 and 4 to RowMajor efficient epilogue implementation: -+ -+ Case 2 -> Case 3: -+ D_col = matrix A x matrix B (LeftSide mode) -+ => Transpose(D_col) = Transpose(matrix B) x Transpose(matrix A) (RightSide mode) -+ -+ swap pointers for `A` and `B` call GEMM mainloop with RowMajor efficient-epilogue -+ -+ Case 4 -> Case 1: -+ D_col = matrix B x matrix A (RightSide mode) -+ => Transpose(D_col) = Transpose(matrix A) x Transpose(matrix B) (LeftSide mode) -+ -+ call GEMM mainloop for with RowMajor efficient-epilogue -+********************************************************************************************************/ -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode SideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode FillModeA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial, -+ /// Operation performed by Symm update kernel -+ typename Operator_, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ -+ > -+class Symm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static SideMode const kSideModeA = SideModeA; -+ static FillMode const kFillModeA = FillModeA; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ /// Define the kernel -+ using UnderlyingOperator = typename cutlass::gemm::device::Symm< -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ InvertSideMode::mode, -+ InvertFillMode::mode, -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kAlignmentA, -+ kAlignmentB, -+ kSplitKSerial, -+ Operator, -+ kBlasMode -+ >; -+ -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ using SymmKernel = typename UnderlyingOperator::SymmKernel; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the Symm. -+ Symm() { } -+ -+ /// Helper to construct a transposed equivalent for the underying SYMM operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem_size(); -+ } -+ -+ /// Determines whether the Symm can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes Symm state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace Symm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/trmm.h b/3rdparty/cutlass/include/cutlass/gemm/device/trmm.h -new file mode 100644 -index 0000000..34816db ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/trmm.h -@@ -0,0 +1,758 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a TRMM kernel. Does not compute batching or support split-K. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/trmm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_trmm_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! Trmm device-level operator. This is an interface to efficient CUTLASS TRMM kernels that may -+ be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters onto -+ specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to TRMM problems to kernel parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ The intent is to provide a convenient mechanism for interacting with most plausible TRMM -+ configurations for each supported architecture. Consequently, not all parameters are exposed -+ to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy -+ are selected to tradeoff simplicity of the interface with flexibility. We expect -+ most configurations to be specified at this level. Applications with more exotic requirements -+ may construct their kernels of interest using CUTLASS components at the threadblock, warp, -+ and thread levels of abstraction. -+ -+ CUTLASS exposes computations using the functor design pattern in which objects compose some -+ internal state with an overloaded function call operator. This enables decoupling of -+ initialization from execution, possibly reducing overhead during steady state phases of -+ application execution. -+ -+ CUTLASS device-level operators expose an Arguments structure encompassing each logical -+ input to the computation. This is distinct from the kernel-level Params structure pattern -+ which contains application-specific precomputed state needed by the device code. -+ -+ Example of a CUTLASS TRMM operator implementing the functionality of cuBLAS's STRMM NN -+ is as follows: -+ -+ // -+ // Instantiate the CUTLASS TRMM operator. -+ // -+ -+ cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ > trmm_op; -+ -+ // -+ // Launch the TRMM operation on the device -+ // -+ -+ cutlass::Status status = trmm_op({ -+ cutlass::gemm::GemmUniversalMode, // Trmm Problem Mode -+ {m, n, m/n}, // GemmCoord problem_size (k is based on left- or right-side mode) -+ batch_count, -+ {alpha}, // EpilogueOutputOp::Params epilogue_op_params -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int lda, -+ int ldb, -+ int ldc -+ }); -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Side Mode for A (kLeft or kRight) -+ SideMode SideModeA, -+ -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode FillModeA, -+ -+ /// DiagType for A (kNonUnit or kUnit) -+ DiagType DiagTypeA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial, -+ -+ /// Operation performed by TRMM -+ typename Operator, -+ -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA -+ > -+ class Trmm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A -+ SideMode SideModeA, -+ /// Fill Mode for A -+ FillMode FillModeA, -+ /// DiagType for A -+ DiagType DiagTypeA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassTensorOp, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm80, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = epilogue::thread::LinearCombination< -+ ElementC_, -+ 128 / sizeof_bits::value, -+ ElementAccumulator_, -+ ElementAccumulator_, -+ epilogue::thread::ScaleType::OnlyAlphaScaling -+ >, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by TRMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone> -+class Trmm { -+ public: -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementAKernel = typename platform::conditional<(SideModeA == SideMode::kRight), ElementB_, ElementA_>::type; -+ using LayoutAKernel = typename platform::conditional<(SideModeA == SideMode::kRight), LayoutB_, LayoutA_>::type; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementBKernel = typename platform::conditional<(SideModeA == SideMode::kRight), ElementA_, ElementB_>::type; -+ using LayoutBKernel = typename platform::conditional<(SideModeA == SideMode::kRight), LayoutA_, LayoutB_>::type; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static SideMode const kSideMode = SideModeA; -+ static FillMode const kFillMode = FillModeA; -+ static DiagType const kDiagType = DiagTypeA; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentAKernel = (SideModeA == SideMode::kRight) ? AlignmentB : AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentBKernel = (SideModeA == SideMode::kRight) ? AlignmentA : AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ // Complex Transform don't appply to B -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ static ComplexTransform const kTransformAKernel = (SideModeA == SideMode::kRight) ? -+ ComplexTransform::kNone : TransformA; -+ static ComplexTransform const kTransformBKernel = (SideModeA == SideMode::kRight) ? -+ TransformA : ComplexTransform::kNone; -+ -+ /// Define the kernel -+ using TrmmKernel = typename kernel::DefaultTrmmUniversal< -+ ElementAKernel, -+ LayoutAKernel, -+ kTransformAKernel, -+ kAlignmentAKernel, -+ ElementBKernel, -+ LayoutBKernel, -+ kTransformBKernel, -+ kAlignmentBKernel, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator -+ >::TrmmKernel; -+ -+ using Arguments = typename TrmmKernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename TrmmKernel::Params params_; -+public: -+ -+ /// Constructs the TRMM. -+ Trmm() { } -+ -+ /// Determines whether the TRMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = TrmmKernel::can_implement(args); -+ -+ if (SideModeA == SideMode::kInvalid) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (FillModeA == FillMode::kInvalid) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (DiagTypeA == DiagType::kInvalid) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes TRMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial) { -+ if (args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ int gemm_k_size = args.problem_size.k(); -+ -+ // Swapping argument for A and B, if A was on the right side (problem size doesn't need to change here). -+ if (kSideMode == SideMode::kRight) { -+ // Initialize the Params structure -+ params_ = typename TrmmKernel::Params{ -+ args.swapped_matrices(), -+ grid_tiled_shape, -+ gemm_k_size, -+ static_cast(workspace) -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ // Initialize the Params structure -+ params_ = typename TrmmKernel::Params{ -+ args, -+ grid_tiled_shape, -+ gemm_k_size, -+ static_cast(workspace) -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ size_t workspace_bytes = get_workspace_size(args); -+ -+ if (workspace_bytes && !workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ params_.update(args, workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(TrmmKernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename TrmmKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+/******************************************************************************************************** -+ TRMM has 4 combinations based on Layouts {RowMajor, ColumnMajor} x Side mode {LeftSide, RightSide} -+ In templates and arguments to cutlass kernel, `matrix A` is always triangular, and `matrix B` is rectangular. -+ (adhering to the cuBLAS convention) -+ -+For the mainloop and trmm kernel, `A` and `B` points to left-side and right-side matrices, respectively. -+ -+ Thus, for LeftSide mode `A` and `B` points to `matrix A` and `matrix B`, respectively. While for -+ the RightSide mode `A` and `B` points to `matrix B` and `matrix A`, respectively. -+ -+ Additionally, CUTLASS GEMM epilogue is always RowMajor, and ColumnMajor output is achieved by -+ transposing the GEMM problem. Thus, ColumnMajor output layout for TRMM requires: -+ - Transposing `matrix A` and `matrix B` layouts -+ - Swapping problem size m and n values -+ - Swapping LeftSide and RightSide mode -+ -+ RowMajor output: D = matrix A x matrix B -+ ColumnMajor output: D = matrix A x matrix B -> Transpose (D) = Transpose(matrix B) x Transpose(matrix A) -+ -+ {RowMajor, ColumnMajor} x Side Mode {LeftSide, RightSide} 4 cases: -+ 1. LeftSide mode and RowMajor output (default template) -+ 2. LeftSide mode and ColumnMajor output -+ 3. RightSide mode and RowMajor output -+ 4. RightSide mode and ColumnMajor output -+ -+ Mapping ColumnMajor output layout cases 2 and 4 to RowMajor efficient epilogue implementation: -+ -+ Case 2 -> Case 3: -+ D_col = matrix A x matrix B (LeftSide mode) -+ => Transpose(D_col) = Transpose(matrix B) x Transpose(matrix A) (RightSide mode) -+ -+ swap pointers for `A` and `B` call GEMM mainloop with RowMajor efficient-epilogue -+ -+ Case 4 -> Case 1: -+ D_col = matrix B x matrix A (RightSide mode) -+ => Transpose(D_col) = Transpose(matrix A) x Transpose(matrix B) (LeftSide mode) -+ -+ call GEMM mainloop for with RowMajor efficient-epilogue -+********************************************************************************************************/ -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A -+ SideMode SideModeA, -+ /// Fill Mode for A -+ FillMode FillModeA, -+ /// DiagType for A -+ DiagType DiagTypeA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// If true, kernel supports split-K as a serial reduction -+ bool SplitKSerial, -+ /// Operation performed by TRMM -+ typename Operator_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA> -+class Trmm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static SideMode const kSideMode = SideModeA; -+ static FillMode const kFillMode = FillModeA; -+ static DiagType const kDiagType = DiagTypeA; -+ // Changing SideMode as we change the layout -+ static SideMode const kSideModeT = (SideModeA == SideMode::kLeft) ? -+ SideMode::kRight : SideMode::kLeft; -+ // Changing FillMode as we change the layout -+ static FillMode const kFillModeT = (FillModeA == FillMode::kLower) ? -+ FillMode::kUpper : FillMode::kLower; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static ComplexTransform const kTransformA = TransformA; -+ // Complex Transform don't appply to B -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ using UnderlyingOperator = Trmm< -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ kSideModeT, -+ kFillModeT, -+ kDiagType, -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kAlignmentA, -+ kAlignmentB, -+ kSplitKSerial, -+ Operator, -+ TransformA -+ >; -+ -+ using Arguments = typename UnderlyingOperator::Arguments; -+ using TrmmKernel = typename UnderlyingOperator::TrmmKernel; -+ static int const kAlignmentC = UnderlyingOperator::kAlignmentC; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the TRMM. -+ Trmm() { } -+ -+ /// Helper to construct a transposed equivalent for the underying TRMM operator which is identical -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem_size(); -+ } -+ -+ /// Determines whether the TRMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Initializes TRMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/dispatch_policy.hpp b/3rdparty/cutlass/include/cutlass/gemm/dispatch_policy.hpp -new file mode 100644 -index 0000000..a2cd9a1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/dispatch_policy.hpp -@@ -0,0 +1,144 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/arch/arch.h" -+ -+#include "cute/layout.hpp" -+#include "cute/numeric/integral_constant.hpp" -+ -+////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm { -+using namespace cute; -+ -+////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Policies for categorical dispatch of mainloop against kernel grid schedules -+// -+struct KernelMultistage { }; -+struct KernelTma { }; -+struct KernelTmaWarpSpecialized { }; -+struct KernelTmaWarpSpecializedPersistent { }; -+ -+// -+// Collective Mainloop Policies -+// -+ -+// 2 stage pipeline through 1 stage in smem, 1 in rmem, WITHOUT predicated gmem loads -+struct MainloopSm70TwoStageUnpredicated { -+ constexpr static int Stages = 2; -+ using ArchTag = arch::Sm70; -+ using Schedule = KernelMultistage; -+ using ClusterShape = Shape<_1,_1,_1>; -+}; -+ -+// 2 stage pipeline through 1 stage in smem, 1 in rmem, with predicated gmem loads -+struct MainloopSm70TwoStage { -+ constexpr static int Stages = 2; -+ using ArchTag = arch::Sm70; -+ using Schedule = KernelMultistage; -+ using ClusterShape = Shape<_1,_1,_1>; -+}; -+ -+// n-buffer in smem (cp.async), pipelined with registers, WITHOUT predicated gmem loads -+template -+struct MainloopSm80CpAsyncUnpredicated { -+ constexpr static int Stages = Stages_; -+ using ArchTag = arch::Sm80; -+ using Schedule = KernelMultistage; -+ using ClusterShape = Shape<_1,_1,_1>; -+}; -+ -+// n-buffer in smem (cp.async), pipelined with registers, with predicated gmem loads -+template -+struct MainloopSm80CpAsync { -+ constexpr static int Stages = Stages_; -+ using ArchTag = arch::Sm80; -+ using Schedule = KernelMultistage; -+ using ClusterShape = Shape<_1,_1,_1>; -+}; -+ -+// n-buffer in smem (cp.async), pipelined with Hopper GMMA, WITHOUT predicated gmem loads -+template< -+ int Stages_, -+ class ClusterShape_ = Shape<_1,_1,_1> -+> -+struct MainloopSm90CpAsyncGmmaUnpredicated { -+ constexpr static int Stages = Stages_; -+ using ClusterShape = ClusterShape_; -+ using ArchTag = arch::Sm90; -+ using Schedule = KernelMultistage; -+}; -+ -+// n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads -+template< -+ int Stages_, -+ class ClusterShape_ = Shape<_1,_1,_1> -+> -+struct MainloopSm90CpAsyncGmma { -+ constexpr static int Stages = Stages_; -+ using ClusterShape = ClusterShape_; -+ using ArchTag = arch::Sm90; -+ using Schedule = KernelMultistage; -+}; -+ -+// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA -+template< -+ int Stages_, -+ class ClusterShape_ = Shape<_1,_1,_1>, -+ int PipelineAsyncMmaStages_ = 1 -+> -+struct MainloopSm90TmaGmma { -+ constexpr static int Stages = Stages_; -+ using ClusterShape = ClusterShape_; -+ constexpr static int PipelineAsyncMmaStages = PipelineAsyncMmaStages_; -+ using ArchTag = arch::Sm90; -+ using Schedule = KernelTma; -+}; -+ -+// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule -+template< -+ int Stages_, -+ class ClusterShape_ = Shape<_1,_1,_1>, -+ class KernelSchedule = KernelTmaWarpSpecialized -+> -+struct MainloopSm90TmaGmmaWarpSpecialized { -+ constexpr static int Stages = Stages_; -+ using ClusterShape = ClusterShape_; -+ using ArchTag = arch::Sm90; -+ using Schedule = KernelSchedule; -+}; -+ -+////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm -diff --git a/3rdparty/cutlass/include/cutlass/gemm/gemm.h b/3rdparty/cutlass/include/cutlass/gemm/gemm.h -new file mode 100644 -index 0000000..4b76101 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/gemm.h -@@ -0,0 +1,574 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines common types used for all GEMM-like operators. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+#include "cutlass/layout/matrix.h" -+#include "cute/layout.hpp" -+#include "cute/arch/copy_sm90.hpp" -+ -+namespace cutlass { -+namespace gemm { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GEMM operand enumeration: D = A * B + C -+enum class Operand { -+ kA, /// A multiplicand -+ kB, /// B multiplicand -+ kC, /// Source accumulator -+ kD /// Destination accumulator -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Shape of a matrix multiply-add operation -+template < -+ /// Rows of matrix product -+ int M = 1, -+ /// Columns of matrix product -+ int N = 1, -+ /// Inner dimension of matrix product -+ int K = 1 -+> -+struct GemmShape { -+ static int const kM = M; -+ static int const kN = N; -+ static int const kK = K; -+ -+ static int const kMN = M * N; -+ static int const kMK = M * K; -+ static int const kKN = N * K; -+ static int const kMNK = M * N * K; -+ -+ static int const kCount = kMNK; -+ -+ // -+ // Static member functions -+ // -+ -+ /// Returns a Coord object -+ CUTLASS_HOST_DEVICE -+ static Coord<3> toCoord() { -+ return make_Coord(kM, kN, kK); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Type alias of the transpose of a GemmShape -+template < -+ /// concept: GemmShape -+ typename Shape -+> -+using GemmShapeTranspose = GemmShape; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GemmCoord is a structure derived from Coord<3> that specifies a location within the -+/// coordinate space of a GEMM problem. -+struct GemmCoord : public Coord<3, int> { -+ -+ /// Integer-valued index -+ typedef int Index; -+ -+ /// Base type is a Coord of rank=3 -+ typedef Coord<3, Index> Base; -+ -+ /// GEMM M dimension - rows of the output C matrix -+ static int const kM = 0; -+ -+ /// GEMM N dimension - columns of the output C matrix -+ static int const kN = 1; -+ -+ /// GEMM K dimension - inner dimension of the GEMM problem -+ static int const kK = 2; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ GemmCoord() { } -+ -+ /// Constructs from Coord<3> and a batch -+ CUTLASS_HOST_DEVICE -+ GemmCoord(Coord<3, Index> const &coord): Base(make_Coord(coord[0], coord[1], coord[2])) { } -+ -+ /// Helper to construct from a K, N, M, batch variables -+ CUTLASS_HOST_DEVICE -+ GemmCoord(Index m, Index n, Index k): Base(make_Coord(m, n, k)) { } -+ -+ /// Returns the GEMM M coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & m() const { return this->at(kM); } -+ -+ /// Returns reference to the GEMM M coordinate -+ CUTLASS_HOST_DEVICE -+ Index & m() { return this->at(kM); } -+ -+ /// Returns the GEMM N coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & n() const { return this->at(kN); } -+ -+ /// Returns reference to the GEMM N coordinate -+ CUTLASS_HOST_DEVICE -+ Index & n() { return this->at(kN); } -+ -+ /// Returns the GEMM K coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & k() const { return this->at(kK); } -+ -+ /// Returns reference to the GEMM K coordinate -+ CUTLASS_HOST_DEVICE -+ Index & k() { return this->at(kK); } -+ -+ /// Obtains a Coord<3> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<3> mnk() const { -+ return make_Coord(m(), n(), k()); -+ } -+ -+ /// Obtains a Coord<3> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<3> knm() const { -+ return make_Coord(k(), n(), m()); -+ } -+ -+ /// Obtains a Coord<2> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<2> nm() const { -+ return make_Coord(n(), m()); -+ } -+ -+ /// Obtains a Coord<2> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<2> mn() const { -+ return make_Coord(m(), n()); -+ } -+ -+ /// Obtains a Coord<2> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<2> mk() const { -+ return make_Coord(m(), k()); -+ } -+ -+ /// Obtains a Coord<2> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<2> km() const { -+ return make_Coord(k(), m()); -+ } -+ -+ /// Obtains a Coord<2> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<2> nk() const { -+ return make_Coord(n(), k()); -+ } -+ -+ /// Obtains a Coord<2> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<2> kn() const { -+ return make_Coord(k(), n()); -+ } -+ -+ // -+ // Coord operators -+ // -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ GemmCoord operator+(Base const& b) const { -+ return GemmCoord(Base::operator+(b)); -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ GemmCoord operator-(Base const& b) const { -+ return GemmCoord(Base::operator-(b)); -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ GemmCoord operator*(Base const& b) const { -+ return GemmCoord(Base::operator*(b)); -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ GemmCoord operator/(Base const& b) const { -+ return GemmCoord(Base::operator/(b)); -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ GemmCoord& operator+=(Base const& b) { -+ Base::operator+=(b); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ GemmCoord& operator-=(Base const& b) { -+ Base::operator-=(b); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ GemmCoord& operator*=(Base const& b) { -+ Base::operator*=(b); -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ GemmCoord& operator/=(Base const& b) { -+ Base::operator/=(b); -+ return *this; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// BatchedGemmCoord is a structure derived from Coord<4> that specifies a location within the -+/// coordinate space of a batched GEMM problem. -+struct BatchedGemmCoord : public Coord<4, int> { -+ -+ /// Integer-valued index -+ typedef int Index; -+ -+ /// Base type is a Coord of rank=4 -+ typedef Coord<4, Index> Base; -+ -+ /// GEMM M dimension - rows of the output C matrix -+ static int const kM = 0; -+ -+ /// GEMM N dimension - columns of the output C matrix -+ static int const kN = 1; -+ -+ /// GEMM K dimension - inner dimension of the GEMM problem -+ static int const kK = 2; -+ -+ /// GEMM Batch dimension - inner dimension of the GEMM problem -+ static int const kBatch = 3; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord() { } -+ -+ /// Constructs from Coord<4> -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord(Base const &coord): Base(coord) { } -+ -+ /// Helper to construct from a K, N, M, and batch variables -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord(Index m, Index n, Index k, Index b): Base(make_Coord(m, n, k, b)) { } -+ -+ /// Returns the GEMM M coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & m() const { return this->at(kM); } -+ -+ /// Returns reference to the GEMM M coordinate -+ CUTLASS_HOST_DEVICE -+ Index & m() { return this->at(kM); } -+ -+ /// Returns the GEMM N coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & n() const { return this->at(kN); } -+ -+ /// Returns reference to the GEMM N coordinate -+ CUTLASS_HOST_DEVICE -+ Index & n() { return this->at(kN); } -+ -+ /// Returns the GEMM K coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & k() const { return this->at(kK); } -+ -+ /// Returns reference to the GEMM K coordinate -+ CUTLASS_HOST_DEVICE -+ Index & k() { return this->at(kK); } -+ -+ /// Returns the GEMM batch coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & batch() const { return this->at(kBatch); } -+ -+ /// Returns reference to the GEMM batch coordinate -+ CUTLASS_HOST_DEVICE -+ Index & batch() { return this->at(kBatch); } -+ -+ /// Obtains a GemmCoord from BatchedGemmCoord -+ CUTLASS_HOST_DEVICE -+ GemmCoord mnk() const { -+ return GemmCoord(m(), n(), k()); -+ } -+ -+ /// Obtains a Coord<4> from BatchedGemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<4> mnkb() const { -+ return make_Coord(m(), n(), k(), batch()); -+ } -+ -+ // -+ // Coord operators -+ // -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord operator+(Base const& b) const { -+ return BatchedGemmCoord(Base::operator+(b)); -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord operator-(Base const& b) const { -+ return BatchedGemmCoord(Base::operator-(b)); -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord operator*(Base const& b) const { -+ return BatchedGemmCoord(Base::operator*(b)); -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord operator/(Base const& b) const { -+ return BatchedGemmCoord(Base::operator/(b)); -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord& operator+=(Base const& b) { -+ Base::operator+=(b); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord& operator-=(Base const& b) { -+ Base::operator-=(b); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord& operator*=(Base const& b) { -+ Base::operator*=(b); -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord& operator/=(Base const& b) { -+ Base::operator/=(b); -+ return *this; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+enum class GemmUniversalMode { -+ kGemm, -+ kGemmSplitKParallel, -+ kBatched, -+ kArray, -+ kInvalid -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Some options for clearing shared memory -+enum class SharedMemoryClearOption { -+ kNone, ///< SMEM is in don't-care state -+ kZfill, ///< Kernels fill out of bounds accesses with zeros -+ kClearLastStage ///< Last SMEM stage is explicitly cleared. Mainloop uses 'kNone' -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// For each cutlass::layout, provides its corresponding cute stride types, 64b by default -+ -+template -+struct TagToStrideA {}; -+ -+// Maps to modes [M, K, L] -+template <> -+struct TagToStrideA { -+ using type = cute::Stride, int64_t>; -+ using tag = layout::RowMajor; -+}; -+ -+// Maps to modes [M, K, L] -+template <> -+struct TagToStrideA { -+ using type = cute::Stride, int64_t, int64_t>; -+ using tag = layout::ColumnMajor; -+}; -+ -+template -+struct TagToStrideB {}; -+ -+// Maps to modes [N, K, L] -+template <> -+struct TagToStrideB { -+ using type = cute::Stride, int64_t, int64_t>; -+ using tag = layout::RowMajor; -+}; -+ -+// Maps to modes [N, K, L] -+template <> -+struct TagToStrideB { -+ using type = cute::Stride, int64_t>; -+ using tag = layout::ColumnMajor; -+}; -+ -+ -+// Maps to modes [N, N, L] -+template -+struct TagToStrideC : TagToStrideA { }; -+ -+// Convenience aliases -+template -+using TagToStrideA_t = typename TagToStrideA::type; -+ -+template -+using TagToStrideB_t = typename TagToStrideB::type; -+ -+template -+using TagToStrideC_t = typename TagToStrideC::type; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// For 2.x compatibility APIs, provide stride->layout tag mappers -+ -+namespace detail { -+ -+// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices -+template -+constexpr -+auto -+stride_to_layout_tag_A() { -+ // Account for stride types with and without batch mode and batch modes with static zero stride -+ if constexpr (cute::size<0>(StrideAC{}) == 1) { // M major -+ return layout::ColumnMajor{}; -+ } -+ else { // K major -+ return layout::RowMajor{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+constexpr -+auto -+stride_to_layout_tag_B() { -+ // Account for stride types with and without batch mode and batch modes with static zero stride -+ if constexpr (cute::size<0>(StrideB{}) == 1) { // N major -+ return layout::RowMajor{}; -+ } -+ else { // K major -+ return layout::ColumnMajor{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Inspects a TiledCopy and returns its alignment in terms of element count -+template -+constexpr int -+get_alignment_count_from_gmem_tiled_copy() { -+ // For TMA tiled copies, we know the alignment has to be 128 bits -+ if constexpr (std::is_base_of_v || -+ std::is_base_of_v) { -+ return 128 / sizeof_bits::value; -+ } -+ else -+ { -+ // For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN -+ return GmemTiledCopy::NumValSrc; -+ } -+} -+ -+// Utilities to map Stride back on to their corresponding layout tags -+template -+struct StrideToLayoutTagA { -+ using type = decltype(detail::stride_to_layout_tag_A()); -+}; -+ -+template -+struct StrideToLayoutTagB { -+ using type = decltype(detail::stride_to_layout_tag_B()); -+}; -+ -+// Maps to modes [N, N, L] -+template -+struct StrideToLayoutTagC : StrideToLayoutTagA { }; -+ -+// Convenience aliases -+template -+using StrideToLayoutTagA_t = typename StrideToLayoutTagA::type; -+ -+template -+using StrideToLayoutTagB_t = typename StrideToLayoutTagB::type; -+ -+template -+using StrideToLayoutTagC_t = typename StrideToLayoutTagC::type; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// The following two metafunctions are used to detect whether a `kernel::Gemm` or `kernel::GemmUniversal` -+// is implementing the CUTLASS 3.x API or not, by checking if the problem shape type is aliased within or not. -+template -+struct IsCutlass3GemmKernel : std::false_type { }; -+ -+template -+struct IsCutlass3GemmKernel> -+ : std::true_type { }; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_ell_gemm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_ell_gemm.h -new file mode 100644 -index 0000000..04b14a4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_ell_gemm.h -@@ -0,0 +1,837 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Default kernel-level Blocked-Ell sparse gemm operators. -+ This operator combines threadblock-scoped ELL MMA -+ with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm.h" -+#include "cutlass/gemm/kernel/gemm_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+#include "cutlass/gemm/kernel/ell_gemm.h" -+#include "cutlass/gemm/threadblock/default_ell_mma.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse> -+struct DefaultEllGemm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+> -+struct DefaultEllGemm { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Turing Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+> -+struct DefaultEllGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ IsASparse -+> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ 2, -+ Operator -+ >::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ typename Mma::Operator, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse> -+struct DefaultEllGemm< -+ ElementA, layout::ColumnMajorInterleaved, kAlignmentA, -+ ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, -+ layout::ColumnMajorInterleaved, int32_t, -+ arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, -+ InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ SplitKSerial, Operator, IsASparse> { -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, Operator, -+ true>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Turing Integer Matrix Multiply Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse> -+struct DefaultEllGemm, -+ kAlignmentA, ElementB, -+ layout::RowMajorInterleaved, kAlignmentB, -+ ElementC, layout::ColumnMajorInterleaved, -+ int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, -+ WarpShape, InstructionShape, EpilogueOutputOp, -+ ThreadblockSwizzle, 2, SplitKSerial, Operator, IsASparse> { -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, -+ arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, -+ InstructionShape, 2, Operator, true>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Partial specialization for Volta architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+> -+struct DefaultEllGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ IsASparse -+> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<8, 8, 4>, -+ 2, -+ Operator -+ >::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ ThreadblockShape, -+ typename Mma::Operator, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for SIMT -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+ > -+struct DefaultEllGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<1, 1, 1>, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ IsASparse> { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassSimt, -+ arch::Sm50, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<1, 1, 1>, -+ 2, -+ Operator>::ThreadblockMma; -+ -+ static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; -+ static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+ > -+struct DefaultEllGemm, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ IsASparse> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassSimt, arch::Sm80, -+ ThreadblockShape, WarpShape, GemmShape<1, 1, 1>, Stages, -+ Operator>::ThreadblockMma; -+ -+ static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; -+ static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization for SIMT DP4A -+ -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Layout type for C matrix operand -+ typename LayoutC, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+ > -+struct DefaultEllGemm, -+ EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, -+ Operator, IsASparse> { -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ -+ using OperatorClass = arch::OpClassSimt; -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma::ThreadblockMma; -+ -+ static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; -+ static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+//////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization for Wmma Gemm Kernel -+template < -+ ///< Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+ > -+struct DefaultEllGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ arch::OpClassWmmaTensorOp, -+ ArchTag, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ IsASparse> { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, -+ arch::OpClassWmmaTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp< -+ ThreadblockShape, -+ typename Mma::Operator, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+//////////////////////////////////////////////////////////////////////////////// -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm.h -new file mode 100644 -index 0000000..4432008 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm.h -@@ -0,0 +1,1060 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm.h" -+#include "cutlass/gemm/kernel/gemm_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#include "cutlass/layout/permute.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Gather operand A by using an index array -+ bool GatherA = false, -+ /// Gather operand B by using an index array -+ bool GatherB = false, -+ /// Scatter result D by using an index array -+ bool ScatterD = false, -+ /// Permute result D -+ typename PermuteDLayout = layout::NoPermute, -+ /// -+ typename Enable = void -+> -+struct DefaultGemm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemm { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemm { -+ -+ static_assert((platform::is_same::value -+ || platform::is_same>::value), -+ "Epilogue in the kernel level must be row major"); -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using RegularEpilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; -+ -+ using Affine2Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpAffineRankN< -+ 2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ using Epilogue = typename platform::conditional::value, -+ RegularEpilogue, -+ Affine2Epilogue>::type; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Turing Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ SharedMemoryClear, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout -+> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ 2, -+ Operator, -+ false, -+ SharedMemoryClear, -+ GatherA, -+ GatherB -+ >::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ typename Mma::Operator, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ ScatterD, -+ PermuteDLayout -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear> -+struct DefaultGemm< -+ ElementA, layout::ColumnMajorInterleaved, kAlignmentA, -+ ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, -+ layout::ColumnMajorInterleaved, int32_t, -+ arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, -+ InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ SplitKSerial, Operator, SharedMemoryClear, false, false, false> { -+ -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, Operator, -+ true, SharedMemoryClear>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Turing Integer Matrix Multiply Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear> -+struct DefaultGemm, -+ kAlignmentA, ElementB, -+ layout::RowMajorInterleaved, kAlignmentB, -+ ElementC, layout::ColumnMajorInterleaved, -+ int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, -+ WarpShape, InstructionShape, EpilogueOutputOp, -+ ThreadblockSwizzle, 2, SplitKSerial, Operator, SharedMemoryClear, -+ false, false, false> { -+ -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, -+ arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, -+ InstructionShape, 2, Operator, true>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Volta architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ SharedMemoryClear, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout -+> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<8, 8, 4>, -+ 2, -+ Operator, -+ false, -+ SharedMemoryClear, -+ GatherA, -+ GatherB -+ >::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ ThreadblockShape, -+ typename Mma::Operator, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ ScatterD, -+ PermuteDLayout -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for SIMT -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+ > -+struct DefaultGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<1, 1, 1>, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ SharedMemoryClear, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout, -+ typename platform::enable_if< ! platform::is_same::value >::type > { -+ -+ static_assert((platform::is_same::value -+ || platform::is_same>::value), -+ "Epilogue in the kernel level must be row major"); -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ arch::OpClassSimt, -+ arch::Sm50, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<1, 1, 1>, -+ 2, -+ Operator, -+ false, -+ SharedMemoryClear, -+ GatherA, -+ GatherB>::ThreadblockMma; -+ -+ static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; -+ static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); -+ -+ /// Define the epilogue -+ using RegularEpilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess, -+ ScatterD, -+ PermuteDLayout -+ >::Epilogue; -+ -+ using Affine2Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimtAffineRankN< -+ 2, -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess -+ >::Epilogue; -+ -+ using Epilogue = typename platform::conditional::value, -+ RegularEpilogue, -+ Affine2Epilogue>::type; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemm, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ SharedMemoryClear, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout> { -+ -+ static_assert((platform::is_same::value -+ || platform::is_same>::value), -+ "Epilogue in the kernel level must be row major"); -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassSimt, arch::Sm80, -+ ThreadblockShape, WarpShape, GemmShape<1, 1, 1>, Stages, -+ Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma; -+ -+ static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; -+ static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); -+ -+ /// Define the epilogue -+ using RegularEpilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess, -+ ScatterD, -+ PermuteDLayout -+ >::Epilogue; -+ -+ using Affine2Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimtAffineRankN< -+ 2, -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess -+ >::Epilogue; -+ -+ using Epilogue = typename platform::conditional::value, -+ RegularEpilogue, -+ Affine2Epilogue>::type; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization for SIMT DP4A -+ -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Layout type for C matrix operand -+ typename LayoutC, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear -+> -+struct DefaultGemm, -+ EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, -+ Operator, SharedMemoryClear, false, false, false> { -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ -+ using OperatorClass = arch::OpClassSimt; -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma::ThreadblockMma; -+ -+ static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; -+ static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+//////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization for Wmma Gemm Kernel -+template < -+ ///< Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear -+> -+struct DefaultGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ arch::OpClassWmmaTensorOp, -+ ArchTag, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ SharedMemoryClear, -+ false, -+ false, -+ false -+> { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, -+ arch::OpClassWmmaTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp< -+ ThreadblockShape, -+ typename Mma::Operator, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_complex.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_complex.h -new file mode 100644 -index 0000000..956068b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_complex.h -@@ -0,0 +1,404 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm.h" -+#include "cutlass/gemm/kernel/gemm_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+> -+struct DefaultGemmComplex; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+ > -+struct DefaultGemmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, -+ layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, ThreadblockShape, -+ WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+ > -+struct DefaultGemmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassSimt, -+ arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassSimt, -+ Stages, -+ Operator, -+ false, -+ cutlass::arch::CacheOperation::Global, -+ cutlass::arch::CacheOperation::Global, -+ TransformA, -+ TransformB -+ >; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, -+ typename MmaCore::IteratorThreadMapA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, -+ typename MmaCore::IteratorThreadMapB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using Mma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+ > -+struct DefaultGemmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, -+ layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, -+ WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+ > -+struct DefaultGemmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassSimt, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, -+ layout::RowMajor, arch::OpClassSimt, arch::Sm80, ThreadblockShape, -+ WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped.h -new file mode 100644 -index 0000000..c44f060 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped.h -@@ -0,0 +1,384 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+#include "cutlass/layout/permute.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, -+ /// Operation performed by GEMM -+ typename Operator = typename device::DefaultGemmConfiguration< -+ OperatorClass, ArchTag, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator>::Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Permute result D -+ typename PermuteDLayout = layout::NoPermute, -+ /// -+ typename Enable = void -+ > -+struct DefaultGemmGrouped; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued GEMM kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemmGrouped< -+ ElementA, -+ LayoutA, -+ ComplexTransform::kNone, // transform A -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ ComplexTransform::kNone, // transform B -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ GroupScheduleMode_, -+ Operator, -+ SharedMemoryClear, -+ PermuteDLayout, -+ typename platform::enable_if< ! cutlass::is_complex::value>::type -+> { -+ -+ // If true, we must construct a 'transposed-and-exchanged' Mma operator. -+ static bool const kInternalTranspose = platform::is_same::value; -+ -+ using MapArguments = kernel::detail::MapArguments< -+ ElementA, -+ LayoutA, -+ ComplexTransform::kNone, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ ComplexTransform::kNone, -+ kAlignmentB, -+ LayoutC, -+ kInternalTranspose -+ >; -+ -+ // Define the default GEMM kernel -+ using DefaultGemmKernel = typename kernel::DefaultGemm< -+ typename MapArguments::ElementA, -+ typename MapArguments::LayoutA, -+ MapArguments::kAlignmentA, -+ typename MapArguments::ElementB, -+ typename MapArguments::LayoutB, -+ MapArguments::kAlignmentB, -+ ElementC, -+ typename MapArguments::LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ true, -+ Operator, -+ SharedMemoryClear, -+ false, /*GatherA*/ -+ false, /*GatherB*/ -+ false, /*ScatterD*/ -+ PermuteDLayout -+ >::GemmKernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using GemmKernel = kernel::GemmGrouped< -+ typename DefaultGemmKernel::Mma, -+ typename DefaultGemmKernel::Epilogue, -+ ThreadblockSwizzle, -+ GroupScheduleMode_, -+ kInternalTranspose -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Complex-valued GEMM kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear -+ > -+struct DefaultGemmGrouped< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ GroupScheduleMode_, -+ Operator, -+ SharedMemoryClear, -+ layout::NoPermute, /*PermuteDLayout*/ -+ typename platform::enable_if::value>::type -+> { -+ -+ // If true, we must construct a 'transposed-and-exchanged' Mma operator. -+ static bool const kInternalTranspose = platform::is_same::value; -+ -+ using MapArguments = kernel::detail::MapArguments< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ LayoutC, -+ kInternalTranspose -+ >; -+ -+ using DefaultGemmKernel = typename kernel::DefaultGemmComplex< -+ typename MapArguments::ElementA, -+ typename MapArguments::LayoutA, -+ typename MapArguments::ElementB, -+ typename MapArguments::LayoutB, -+ ElementC, -+ typename MapArguments::LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MapArguments::kTransformA, -+ MapArguments::kTransformB, -+ Operator, -+ false -+ >::GemmKernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using GemmKernel = kernel::GemmGrouped< -+ typename DefaultGemmKernel::Mma, -+ typename DefaultGemmKernel::Epilogue, -+ ThreadblockSwizzle, -+ GroupScheduleMode_, -+ kInternalTranspose -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h -new file mode 100644 -index 0000000..323ae5d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h -@@ -0,0 +1,164 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level softmax-grouped-GEMM -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h" -+ -+#include "cutlass/layout/permute.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for Scale/Bias vectors -+ typename ElementScaleBias_, -+ /// Layout type for Scale/Bias vectors -+ typename LayoutScaleBias_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, -+ /// Operation performed by GEMM -+ typename Operator = typename device::DefaultGemmConfiguration< -+ OperatorClass, ArchTag, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator>::Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone -+ > -+struct DefaultGemmGroupedSoftmaxMainloopFusion { -+ // If true, we must construct a 'transposed-and-exchanged' Mma operator. -+ static bool const kInternalTranspose = platform::is_same::value; -+ -+ using MapArguments = kernel::detail::MapArguments< -+ ElementA_, -+ LayoutA_, -+ ComplexTransform::kNone, -+ kAlignmentA, -+ ElementB_, -+ LayoutB_, -+ ComplexTransform::kNone, -+ kAlignmentB, -+ LayoutC_, -+ kInternalTranspose -+ >; -+ -+private: -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMmaSoftmaxMainloopFusion< -+ typename MapArguments::ElementA, typename MapArguments::LayoutA, MapArguments::kAlignmentA, -+ typename MapArguments::ElementB, typename MapArguments::LayoutB, MapArguments::kAlignmentB, -+ ElementScaleBias_, LayoutScaleBias_, ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, kInternalTranspose, -+ Operator, false, SharedMemoryClear>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+public: -+ using GemmKernel = kernel::GemmGroupedSoftmaxMainloopFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ GroupScheduleMode_, -+ kInternalTranspose -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h -new file mode 100644 -index 0000000..76a405a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h" -+#include "cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for Scale/Bias vectors -+ typename ElementScaleBias, -+ /// Layout type for Scale/Bias vectors -+ typename LayoutScaleBias, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> -+struct DefaultGemmLayernormMainloopFusion { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMmaLayernormMainloopFusion< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementScaleBias, LayoutScaleBias, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator, false, SharedMemoryClear>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::GemmLayernormMainloopFusion; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h -new file mode 100644 -index 0000000..e3b58cb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h -@@ -0,0 +1,352 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/kernel/gemm_planar_complex.h" -+#include "cutlass/gemm/kernel/gemm_planar_complex_array.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_planar_complex.h" -+#include "cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Math operation performed by GEMM (e.g. arch::OpMultiplyAdd) -+ typename Operator, -+ /// Conditional enabling to switch between stages -+ typename Enable = void -+ > -+struct DefaultGemmPlanarComplexUniversal; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for pipelined mainloop -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator -+ > -+struct DefaultGemmPlanarComplexUniversal< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ typename platform::enable_if<(Stages <= 2)>::type -+> { -+ -+ /// Define planar complex valued variants instead -+ using Mma = typename gemm::threadblock::DefaultMmaPlanarComplexPipelined< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ Stages, -+ TransformA, -+ TransformB, -+ Operator -+ >::ThreadblockMma; -+ -+ /// Planar complex epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ ThreadblockShape, -+ typename Mma::Policy::Operator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape::kK / WarpShape::kK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel in terms of the default kernel -+ using GemmKernel = kernel::GemmPlanarComplex< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+ -+ // Array variant -+ using GemmArrayKernel = kernel::GemmPlanarComplexArray< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiple pipeline stages. -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator -+ > -+struct DefaultGemmPlanarComplexUniversal< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ typename platform::enable_if<(Stages > 2)>::type -+> { -+ -+ /// Define planar complex valued variants instead -+ using Mma = typename gemm::threadblock::DefaultMmaPlanarComplexMultistage< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ Stages, -+ TransformA, -+ TransformB, -+ Operator -+ >::ThreadblockMma; -+ -+ /// Planar complex epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ ThreadblockShape, -+ typename Mma::Policy::Operator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape::kK / WarpShape::kK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel in terms of the default kernel -+ using GemmKernel = kernel::GemmPlanarComplex< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+ -+ // Array variant -+ using GemmArrayKernel = kernel::GemmPlanarComplexArray< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse.h -new file mode 100644 -index 0000000..7303e01 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse.h -@@ -0,0 +1,191 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm.h" -+#include "cutlass/gemm/kernel/sparse_gemm.h" -+#include "cutlass/gemm/kernel/gemm_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" -+#include "cutlass/gemm/threadblock/default_sparse_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultSparseGemm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultSparseGemm { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultSparseMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::SparseGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h -new file mode 100644 -index 0000000..7fc9da3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/gemm_splitk_parallel.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator -+> -+struct DefaultGemmSplitKParallel { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate using the basic GEMM's -+ /// mainloop. -+ using Default = DefaultGemm< -+ ElementA_, -+ LayoutA_, -+ kAlignmentA, -+ ElementB_, -+ LayoutB_, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC_, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ false, -+ Operator -+ >; -+ -+ /// Define the matrix multiply operator -+ using Mma = typename Default::Mma; -+ -+ /// Define the epilogue -+ using Epilogue = typename Default::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::GemmSplitKParallel; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_universal.h -new file mode 100644 -index 0000000..45a825d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_universal.h -@@ -0,0 +1,382 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/kernel/gemm_universal.h" -+#include "cutlass/gemm/kernel/gemm_universal_streamk.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+ -+#include "cutlass/layout/permute.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Gather operand A by using an index array -+ bool GatherA = false, -+ /// Gather operand B by using an index array -+ bool GatherB = false, -+ /// Scatter result D by using an index array -+ bool ScatterD = false, -+ /// Permute result D -+ typename PermuteDLayout = layout::NoPermute, -+ /// -+ typename Enable = void -+ > -+struct DefaultGemmUniversal; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued GEMM kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemmUniversal< -+ ElementA, -+ LayoutA, -+ ComplexTransform::kNone, // transform A -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ ComplexTransform::kNone, // transform B -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ SharedMemoryClear, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout, -+ typename platform::enable_if< ! cutlass::is_complex::value>::type -+> { -+ -+ using DefaultGemmKernel = typename kernel::DefaultGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ true, -+ Operator, -+ SharedMemoryClear, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout -+ >::GemmKernel; -+ -+ /// Universal kernel without StreamkFeature member type -+ template -+ class SelectBase : -+ public kernel::GemmUniversal< -+ typename DefaultGemmKernel::Mma, -+ typename DefaultGemmKernel::Epilogue, -+ SwizzleT> -+ {}; -+ -+ /// Universal kernel with StreamkFeature member type -+ template -+ class SelectBase : -+ public kernel::GemmUniversalStreamk< -+ typename DefaultGemmKernel::Mma, -+ typename DefaultGemmKernel::Epilogue, -+ SwizzleT> -+ {}; -+ -+ /// Select kernel by ThreadblockSwizzle's support for StreamkFeature -+ using GemmKernel = SelectBase; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Complex-valued GEMM kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear -+ > -+struct DefaultGemmUniversal< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ SharedMemoryClear, -+ false, -+ false, -+ false, -+ layout::NoPermute, -+ typename platform::enable_if::value>::type -+> { -+ -+ using DefaultGemmKernel = typename kernel::DefaultGemmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ TransformA, -+ TransformB, -+ Operator, -+ false -+ >::GemmKernel; -+ -+ /// Universal kernel without StreamkFeature member type -+ template -+ class SelectBase : -+ public kernel::GemmUniversal< -+ typename DefaultGemmKernel::Mma, -+ typename DefaultGemmKernel::Epilogue, -+ SwizzleT> -+ {}; -+ -+ /// Universal kernel with StreamkFeature member type -+ template -+ class SelectBase : -+ public kernel::GemmUniversalStreamk< -+ typename DefaultGemmKernel::Mma, -+ typename DefaultGemmKernel::Epilogue, -+ SwizzleT> -+ {}; -+ -+ /// Select kernel by ThreadblockSwizzle's support for StreamkFeature -+ using GemmKernel = SelectBase; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h -new file mode 100644 -index 0000000..1356b49 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Defines a GEMM with Reduction based on an existing UniversalGemm kernel. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/kernel/gemm_with_fused_epilogue.h" -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// -+ typename Enable = void -+> -+struct DefaultGemmWithBroadcast { -+ -+ using GemmBase = typename DefaultGemmUniversal< -+ ElementA_, LayoutA_, TransformA, kAlignmentA, -+ ElementB_, LayoutB_, TransformB, kAlignmentB, -+ ElementC_, LayoutC_, ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator -+ >::GemmKernel; -+ -+ // Replace epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp< -+ typename GemmBase::Epilogue::Shape, -+ typename GemmBase::Epilogue::WarpMmaOperator, -+ GemmBase::Epilogue::kPartitionsK, -+ ElementC_, -+ typename EpilogueOutputOp::ElementT, -+ ElementC_, -+ EpilogueOutputOp, -+ GemmBase::Epilogue::kElementsPerAccess -+ >::Epilogue; -+ -+ // Compose the GEMM kernel -+ using GemmKernel = GemmWithFusedEpilogue< -+ typename GemmBase::Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization: ArchTag = cutlass::arch::Sm70 -+/// -+/// -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// -+ typename Enable -+> -+struct DefaultGemmWithBroadcast< -+ ElementA_, LayoutA_, TransformA, kAlignmentA, -+ ElementB_, LayoutB_, TransformB, kAlignmentB, -+ ElementC_, LayoutC_, -+ ElementAccumulator, -+ OperatorClass, -+ cutlass::arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ Enable -+ > { -+ -+ using GemmBase = typename DefaultGemmUniversal< -+ ElementA_, LayoutA_, TransformA, kAlignmentA, -+ ElementB_, LayoutB_, TransformB, kAlignmentB, -+ ElementC_, LayoutC_, ElementAccumulator, -+ OperatorClass, -+ cutlass::arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator -+ >::GemmKernel; -+ -+ // Replace epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOp< -+ typename GemmBase::Epilogue::Shape, -+ typename GemmBase::Epilogue::WarpMmaOperator, -+ GemmBase::Epilogue::kPartitionsK, -+ ElementC_, -+ typename EpilogueOutputOp::ElementT, -+ ElementC_, -+ EpilogueOutputOp, -+ GemmBase::Epilogue::kElementsPerAccess -+ >::Epilogue; -+ -+ // Compose the GEMM kernel -+ using GemmKernel = GemmWithFusedEpilogue< -+ typename GemmBase::Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h -new file mode 100644 -index 0000000..422db5c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_with_k_reduction.h" -+#include "cutlass/gemm/threadblock/default_mma_with_reduction.h" -+#include "cutlass/gemm/threadblock/default_mma_core_with_reduction.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Reduce A or B along the K dimension -+ bool ReduceKForA_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// -+ typename Enable = void> -+struct DefaultGemmWithKReduction { -+ -+ static const bool kReduceKForA = (platform::is_same::value) ? ReduceKForA_ : !ReduceKForA_; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMmaWithReduction< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kReduceKForA, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator, false, SharedMemoryClear>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the epilogue of the reduction vector -+ using EpilogueGemmKReduction = -+ typename cutlass::epilogue::threadblock::EpilogueGemmKReduction< -+ ElementAccumulator, ElementC, ThreadblockShape, typename Mma::Operator, kReduceKForA>; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::GemmWithKReduction; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_reduction.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_reduction.h -new file mode 100644 -index 0000000..6e9e647 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_reduction.h -@@ -0,0 +1,246 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Defines a GEMM with Reduction based on an existing UniversalGemm kernel. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/kernel/gemm_with_fused_epilogue.h" -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Epilogue reduction operator -+ typename EpilogueReductionOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// -+ typename Enable = void -+> -+struct DefaultGemmWithReduction { -+ -+ using GemmBase = typename DefaultGemmUniversal< -+ ElementA_, LayoutA_, TransformA, kAlignmentA, -+ ElementB_, LayoutB_, TransformB, kAlignmentB, -+ ElementC_, LayoutC_, ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ SharedMemoryClearOption::kClearLastStage -+ >::GemmKernel; -+ -+ // Replace epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ typename GemmBase::Epilogue::Shape, -+ typename GemmBase::Epilogue::WarpMmaOperator, -+ GemmBase::Epilogue::kPartitionsK, -+ ElementC_, -+ EpilogueOutputOp, -+ EpilogueReductionOp, -+ GemmBase::Epilogue::kElementsPerAccess -+ >::Epilogue; -+ -+ // Compose the GEMM kernel -+ using GemmKernel = GemmWithFusedEpilogue< -+ typename GemmBase::Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization: ArchTag = cutlass::arch::Sm70 -+/// -+/// -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Epilogue reduction operator -+ typename EpilogueReductionOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// -+ typename Enable -+> -+struct DefaultGemmWithReduction< -+ ElementA_, LayoutA_, TransformA, kAlignmentA, -+ ElementB_, LayoutB_, TransformB, kAlignmentB, -+ ElementC_, LayoutC_, -+ ElementAccumulator, -+ OperatorClass, -+ cutlass::arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ EpilogueReductionOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ Enable -+ > { -+ -+ using GemmBase = typename DefaultGemmUniversal< -+ ElementA_, LayoutA_, TransformA, kAlignmentA, -+ ElementB_, LayoutB_, TransformB, kAlignmentB, -+ ElementC_, LayoutC_, ElementAccumulator, -+ OperatorClass, -+ cutlass::arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator -+ >::GemmKernel; -+ -+ // Replace epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionVoltaTensorOp< -+ typename GemmBase::Epilogue::Shape, -+ typename GemmBase::Epilogue::WarpMmaOperator, -+ GemmBase::Epilogue::kPartitionsK, -+ ElementC_, -+ EpilogueOutputOp, -+ EpilogueReductionOp, -+ GemmBase::Epilogue::kElementsPerAccess -+ >::Epilogue; -+ -+ // Compose the GEMM kernel -+ using GemmKernel = GemmWithFusedEpilogue< -+ typename GemmBase::Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemv.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemv.h -new file mode 100755 -index 0000000..263930c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemv.h -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/gemm/threadblock/gemv.h" -+#include "cutlass/gemm/threadblock/default_gemv_core.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the ThreadBlock tile - concept: gemm::GemmShape<> -+ typename ThreadBlockShape_, -+ /// Size of the per-thread shape - concept: gemm::GemmShape<> -+ typename ThreadShape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C/D matrix -+ typename ElementCD_, -+ /// Layout of C/D matrix (concept: MatrixLayout) -+ typename LayoutCD_, -+ /// Data type of the accumulator -+ typename ElementAccumulator_ = ElementCD_> -+struct DefaultGemv { -+ -+ /// Shape of Threadblock-level matrix operation (concept: GemmShape) -+ using ThreadBlockShape = ThreadBlockShape_; -+ -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using ThreadShape = ThreadShape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulators -+ using ElementAccumulator = ElementAccumulator_; -+ -+ /// Data type of accumulators (same as C/D) -+ using LayoutAccumulator = LayoutCD_; -+ -+ /// Data type of input/output matrix C/D -+ using ElementCD = ElementCD_; -+ -+ /// Layout of input/output matrix C/D -+ using LayoutCD = LayoutCD_; -+ -+ // Define the core components -+ using Core = typename cutlass::gemm::threadblock::DefaultGemvCore< -+ ThreadBlockShape, ThreadShape, ElementA, LayoutA, ElementB, LayoutB, -+ ElementAccumulator, LayoutAccumulator>; -+ -+ // Define the threadblock-scoped gemv -+ using ThreadBlockGemv = cutlass::gemm::threadblock::Gemv; -+ -+ // Iterator for multiplicand A -+ using IteratorA = typename ThreadBlockGemv::IteratorA; -+ -+ // Iterator for multiplicand B -+ using IteratorB = typename ThreadBlockGemv::IteratorB; -+ -+ /// Policy for the iterator that reads/writes C/D -+ using IteratorPolicyCD = typename platform::conditional< -+ platform::is_same::value, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< -+ layout::PitchLinearShape, Core::kThreadsPerN, ThreadShape::kN>, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< -+ layout::PitchLinearShape, Core::kThreadsPerN, ThreadShape::kM>>::type; -+ -+ /// Iterator that reads/writes C/D -+ using IteratorCD = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementCD, LayoutCD, 0, IteratorPolicyCD>; -+ -+ /// Fragment storage for C/D -+ using FragmentCD = typename IteratorCD::Fragment; -+ -+ // Define the threadblock swizzle -+ using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k.h -new file mode 100644 -index 0000000..4573a3a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k.h -@@ -0,0 +1,285 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level Rank2K definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+struct DefaultRank2K; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultRank2K< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC,layout::RowMajor, FillModeC, -+ ElementAccumulator, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, -+ Operator> { -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x BT) -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, -+ kAlignmentA, -+ ElementB, typename layout::LayoutTranspose::type, -+ kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (B x AT) -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementB, LayoutB, -+ kAlignmentB, -+ ElementA, typename layout::LayoutTranspose::type, -+ kAlignmentA, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3< -+ ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue; -+ -+ /// Define the kernel-level Rank2K operator. -+ using Rank2Kkernel = kernel::Rank2KUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultRank2K< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC,layout::RowMajor, FillModeC, -+ ElementAccumulator, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, -+ Operator> { -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x BT) -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, -+ kAlignmentA, -+ ElementB, typename layout::LayoutTranspose::type, -+ kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (B x AT) -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementB, LayoutB, -+ kAlignmentB, -+ ElementA, typename layout::LayoutTranspose::type, -+ kAlignmentA, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3< -+ ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue; -+ -+ /// Define the kernel-level Rank2K operator. -+ using Rank2Kkernel = kernel::Rank2KUniversal; -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_complex.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_complex.h -new file mode 100644 -index 0000000..dc34fe9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_complex.h -@@ -0,0 +1,498 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level Rank2K definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+struct DefaultRank2KComplex; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+namespace detail { -+ -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation -+ ComplexTransform TransformB, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ -+ > struct Rank2KTransposedComplexTransform { -+ -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+}; -+ -+ // partial specializations for HER2K CUBLAS_OP_N layout (ColumMajor) -+template <> -+ struct Rank2KTransposedComplexTransform < -+ layout::ColumnMajor, layout::ColumnMajor, -+ ComplexTransform::kNone, ComplexTransform::kNone, -+ BlasMode::kHermitian> { -+ -+ static ComplexTransform const kTransformA = ComplexTransform::kConjugate; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+}; -+ -+ // partial specializations for HER2K CUBLAS_OP_C layout (RowMajor + Complex conjugate) -+template <> -+ struct Rank2KTransposedComplexTransform < -+ layout::RowMajor, layout::RowMajor, -+ ComplexTransform::kConjugate, ComplexTransform::kConjugate, -+ BlasMode::kHermitian> { -+ -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kConjugate; -+ -+}; -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture complex datatype (symmetric) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRank2KComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, TransformB, Operator, SplitKSerial, BlasMode::kSymmetric> { -+ -+ static BlasMode const kBlasMode = BlasMode::kSymmetric; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x B^T) -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementB, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (B x A^T) -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementB, LayoutB, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level Rank2K operator. -+ using Rank2Kkernel = kernel::Rank2KUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture complex datatype (hermitian) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRank2KComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, TransformB, Operator, SplitKSerial, BlasMode::kHermitian> { -+ -+ static BlasMode const kBlasMode = BlasMode::kHermitian; -+ -+ // Complex transform for input A and B matrices (function on input layout) -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using TransposedComplexTransform = detail::Rank2KTransposedComplexTransform< -+ LayoutA, LayoutB, -+ TransformA, TransformB, -+ kBlasMode>; -+ -+ // Complex transform on operandA and operandB (function of blas3 computation) -+ static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA; -+ static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x B^H) -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementB, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (B x A^H) -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementB, LayoutB, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level Rank2K operator. -+ using Rank2Kkernel = kernel::Rank2KUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture complex datatype (symmetric) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRank2KComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, TransformB, Operator, SplitKSerial, BlasMode::kSymmetric> { -+ -+ static BlasMode const kBlasMode = BlasMode::kSymmetric; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x B^T) -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementB, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (B x A^T) -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementB, LayoutB, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level Rank2K operator. -+ using Rank2Kkernel = kernel::Rank2KUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture complex datatype (hermitian) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRank2KComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, TransformB, Operator, SplitKSerial, BlasMode::kHermitian> { -+ -+ static BlasMode const kBlasMode = BlasMode::kHermitian; -+ -+ // Complex transform for input A and B matrices (function on input layout) -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using TransposedComplexTransform = detail::Rank2KTransposedComplexTransform< -+ LayoutA, LayoutB, -+ TransformA, TransformB, -+ kBlasMode>; -+ -+ // Complex transform on operandA and operandB (function of blas3 computation) -+ static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA; -+ static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x B^H) -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementB, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (B x A^H) -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementB, LayoutB, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level Rank2K operator. -+ using Rank2Kkernel = kernel::Rank2KUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_grouped.h -new file mode 100644 -index 0000000..a237125 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_grouped.h -@@ -0,0 +1,355 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level grouped Rank2K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/kernel/rank_2k_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_rank_2k.h" -+#include "cutlass/gemm/kernel/default_rank_2k_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, -+ /// -+ typename Enable = void -+ > -+struct DefaultRank2KGrouped; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued grouped Rank2K -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Blas3 computation mode -+ BlasMode BlasMode_, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_ -+ > -+struct DefaultRank2KGrouped::value>::type -+> { -+ // If true, we must construct a 'transposed-and-exchanged' Rank2K operator. -+ static bool const kInternalTranspose = platform::is_same::value; -+ -+ using MapArguments = kernel::detail::Rank2KMapArguments< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ LayoutC, -+ FillModeC, -+ kInternalTranspose -+ >; -+ -+ // Define the default grouped Rank2K kernel -+ using DefaultRank2Kkernel = typename kernel::DefaultRank2K< -+ typename MapArguments::ElementA, -+ typename MapArguments::LayoutA, -+ MapArguments::kAlignmentA, -+ typename MapArguments::ElementB, -+ typename MapArguments::LayoutB, -+ MapArguments::kAlignmentB, -+ ElementC, -+ typename MapArguments::LayoutC, -+ MapArguments::kFillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ false, // SplitKSerial -+ Operator, -+ BlasMode_ -+ >::Rank2Kkernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using Rank2Kkernel = kernel::Rank2KGrouped< -+ typename DefaultRank2Kkernel::Mma1, -+ typename DefaultRank2Kkernel::Mma2, -+ typename DefaultRank2Kkernel::Epilogue, -+ ThreadblockSwizzle, -+ TransformA, -+ TransformB, -+ DefaultRank2Kkernel::kFillModeC, -+ DefaultRank2Kkernel::kBlasMode, -+ GroupScheduleMode_, -+ kInternalTranspose -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Complex-valued grouped Rank2K -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Blas3 computation mode -+ BlasMode BlasMode_, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_ -+ > -+struct DefaultRank2KGrouped::value>::type -+> { -+ // If true, we must construct a 'transposed-and-exchanged' Rank2K operator. -+ static bool const kInternalTranspose = platform::is_same::value; -+ -+ using MapArguments = kernel::detail::Rank2KMapArguments< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ LayoutC, -+ FillModeC, -+ kInternalTranspose -+ >; -+ -+ // Define the default grouped Rank2K kernel -+ using DefaultRank2Kkernel = typename kernel::DefaultRank2KComplex< -+ typename MapArguments::ElementA, -+ typename MapArguments::LayoutA, -+ typename MapArguments::ElementB, -+ typename MapArguments::LayoutB, -+ ElementC, -+ typename MapArguments::LayoutC, -+ MapArguments::kFillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MapArguments::kTransformA, -+ MapArguments::kTransformB, -+ Operator, -+ false, // SplitKSerial -+ BlasMode_ -+ >::Rank2Kkernel; -+ -+ /// Define the kernel in terms of the default kernel -+ /// Pass through the user-provided TransformA and TransformB so as to -+ /// correctly set public-facing TransformA and TransformB in kernel::Rank2KGrouped. -+ /// This is needed because kernel::DefaultRank2KComplex may change TransformA and -+ /// TransformB that become template arguments to Mma1 and Mma2. -+ using Rank2Kkernel = kernel::Rank2KGrouped< -+ typename DefaultRank2Kkernel::Mma1, -+ typename DefaultRank2Kkernel::Mma2, -+ typename DefaultRank2Kkernel::Epilogue, -+ ThreadblockSwizzle, -+ TransformA, -+ TransformB, -+ DefaultRank2Kkernel::kFillModeC, -+ DefaultRank2Kkernel::kBlasMode, -+ GroupScheduleMode_, -+ kInternalTranspose -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_universal.h -new file mode 100644 -index 0000000..9651300 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_universal.h -@@ -0,0 +1,346 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level Rank 2k definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/kernel/rank_2k_universal.h" -+#include "cutlass/gemm/kernel/default_rank_2k.h" -+#include "cutlass/gemm/kernel/default_rank_2k_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYRK -+ typename Operator, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ = BlasMode::kSymmetric, -+ /// -+ typename Enable = void -+ > -+struct DefaultRank2KUniversal; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued Rank 2k update kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by Rank2k -+ typename Operator> -+struct DefaultRank2KUniversal< -+ ElementA, -+ LayoutA, -+ ComplexTransform::kNone, // transform A -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ ComplexTransform::kNone, // transform B -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ BlasMode::kSymmetric, -+ typename std::enable_if< ! cutlass::is_complex::value>::type -+> { -+ -+ using DefaultRank2Kkernel = typename kernel::DefaultRank2K< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ BlasMode::kSymmetric -+ >::Rank2Kkernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using Rank2Kkernel = kernel::Rank2KUniversal< -+ typename DefaultRank2Kkernel::Mma1, -+ typename DefaultRank2Kkernel::Mma2, -+ typename DefaultRank2Kkernel::Epilogue, -+ ThreadblockSwizzle, -+ FillModeC, -+ BlasMode::kSymmetric -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Complex-valued Rank 2K update kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYRK -+ typename Operator, -+ // BlasMode -+ BlasMode kBlasMode -+ > -+ -+struct DefaultRank2KUniversal< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ kBlasMode, -+ typename std::enable_if::value>::type -+> { -+ -+ using DefaultRank2Kkernel = typename kernel::DefaultRank2KComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ TransformA, -+ TransformB, -+ Operator, -+ SplitKSerial, -+ kBlasMode -+ >::Rank2Kkernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using Rank2Kkernel = kernel::Rank2KUniversal< -+ typename DefaultRank2Kkernel::Mma1, -+ typename DefaultRank2Kkernel::Mma2, -+ typename DefaultRank2Kkernel::Epilogue, -+ ThreadblockSwizzle, -+ FillModeC, -+ kBlasMode -+ >; -+}; -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k.h -new file mode 100644 -index 0000000..2c0c7a8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k.h -@@ -0,0 +1,247 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level RankK definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_k_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+struct DefaultRankK; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultRankK< -+ ElementA, LayoutA, kAlignmentA, -+ ElementC,layout::RowMajor, FillModeC, -+ ElementAccumulator, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, -+ Operator> { -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x AT) -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, -+ kAlignmentA, -+ ElementA, typename layout::LayoutTranspose::type, -+ kAlignmentA, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue; -+ -+ /// Define the kernel-level Rank2 operator. -+ using RankKkernel = kernel::RankKUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultRankK< -+ ElementA, LayoutA, kAlignmentA, -+ ElementC,layout::RowMajor, FillModeC, -+ ElementAccumulator, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, -+ Operator> { -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x AT) -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, -+ kAlignmentA, -+ ElementA, typename layout::LayoutTranspose::type, -+ kAlignmentA, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue; -+ -+ /// Define the kernel-level Rank2 operator. -+ using RankKkernel = kernel::RankKUniversal; -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k_complex.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k_complex.h -new file mode 100644 -index 0000000..d7569a9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k_complex.h -@@ -0,0 +1,429 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level RankK definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_k_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+struct DefaultRankKComplex; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+namespace detail { -+ -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation -+ ComplexTransform TransformA, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ -+ > struct RankKTransposedComplexTransform { -+ -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformA; -+ -+}; -+ -+ // partial specializations for HERK CUBLAS_OP_N layout (ColumMajor) -+template <> -+ struct RankKTransposedComplexTransform < -+ layout::ColumnMajor, -+ ComplexTransform::kNone, -+ BlasMode::kHermitian> { -+ -+ static ComplexTransform const kTransformA = ComplexTransform::kConjugate; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+}; -+ -+ // partial specializations for HERK CUBLAS_OP_C layout (RowMajor + Complex conjugate) -+template <> -+ struct RankKTransposedComplexTransform < -+ layout::RowMajor, -+ ComplexTransform::kConjugate, -+ BlasMode::kHermitian> { -+ -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kConjugate; -+ -+}; -+ -+} -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture complex datatype (symmetric) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRankKComplex< -+ ElementA, LayoutA, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, Operator, SplitKSerial, BlasMode::kSymmetric> { -+ -+ static BlasMode const kBlasMode = BlasMode::kSymmetric; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x B^T) -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ TransformA, TransformA, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level RankK operator. -+ using RankKkernel = kernel::RankKUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture complex datatype (hermitian) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRankKComplex< -+ ElementA, LayoutA, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, Operator, SplitKSerial, BlasMode::kHermitian> { -+ -+ static BlasMode const kBlasMode = BlasMode::kHermitian; -+ -+ // Complex transform for input A and B matrices (function on input layout) -+ static ComplexTransform const kTransformA = TransformA; -+ -+ using TransposedComplexTransform = detail::RankKTransposedComplexTransform< -+ LayoutA, -+ TransformA, -+ kBlasMode>; -+ -+ // Complex transform on operandA and operandB (function of blas3 computation) -+ static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA; -+ static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x A^H) -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level RankK operator. -+ using RankKkernel = kernel::RankKUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture complex datatype (symmetric) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRankKComplex< -+ ElementA, LayoutA, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, Operator, SplitKSerial, BlasMode::kSymmetric> { -+ -+ static BlasMode const kBlasMode = BlasMode::kSymmetric; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x B^T) -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ TransformA, TransformA, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level RankK operator. -+ using RankKkernel = kernel::RankKUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture complex datatype (hermitian) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRankKComplex< -+ ElementA, LayoutA, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, Operator, SplitKSerial, BlasMode::kHermitian> { -+ -+ static BlasMode const kBlasMode = BlasMode::kHermitian; -+ -+ // Complex transform for input A and B matrices (function on input layout) -+ static ComplexTransform const kTransformA = TransformA; -+ -+ using TransposedComplexTransform = detail::RankKTransposedComplexTransform< -+ LayoutA, -+ TransformA, -+ kBlasMode>; -+ -+ // Complex transform on operandA and operandB (function of blas3 computation) -+ static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA; -+ static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x A^H) -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level RankK operator. -+ using RankKkernel = kernel::RankKUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k_universal.h -new file mode 100644 -index 0000000..b8ce45c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k_universal.h -@@ -0,0 +1,305 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level Rank k definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/kernel/rank_k_universal.h" -+#include "cutlass/gemm/kernel/default_rank_k.h" -+#include "cutlass/gemm/kernel/default_rank_k_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYRK -+ typename Operator, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ = BlasMode::kSymmetric, -+ /// -+ typename Enable = void -+ > -+struct DefaultRankKUniversal; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued Rank k update kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by Rank2k -+ typename Operator> -+struct DefaultRankKUniversal< -+ ElementA, -+ LayoutA, -+ ComplexTransform::kNone, // transform A -+ kAlignmentA, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ BlasMode::kSymmetric, -+ typename std::enable_if< ! cutlass::is_complex::value>::type -+> { -+ -+ using DefaultRankKkernel = typename kernel::DefaultRankK< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ BlasMode::kSymmetric -+ >::RankKkernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using RankKkernel = kernel::RankKUniversal< -+ typename DefaultRankKkernel::Mma, -+ typename DefaultRankKkernel::Epilogue, -+ ThreadblockSwizzle, -+ FillModeC -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Complex-valued Rank 2K update kernels -+// -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYRK -+ typename Operator, -+ // BlasMode -+ BlasMode kBlasMode -+ > -+ -+struct DefaultRankKUniversal< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ kBlasMode, -+ typename std::enable_if::value>::type -+> { -+ -+ using DefaultRankKkernel = typename kernel::DefaultRankKComplex< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ TransformA, -+ Operator, -+ SplitKSerial, -+ kBlasMode -+ >::RankKkernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using RankKkernel = kernel::RankKUniversal< -+ typename DefaultRankKkernel::Mma, -+ typename DefaultRankKkernel::Epilogue, -+ ThreadblockSwizzle, -+ FillModeC -+ >; -+}; -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm.h -new file mode 100755 -index 0000000..1faf25d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm.h -@@ -0,0 +1,321 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level SYMM/HEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/symm_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_trmm.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+struct DefaultSymm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultSymm< -+ ElementA, LayoutA, kSideModeA, kFillModeA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC,layout::RowMajor, -+ ElementAccumulator, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, -+ Operator> { -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - with diagonal: alpha * A * B or alpha * B * A -+ static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultTrmm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ kSideModeA, kFillModeA, kDiagTypeMma1, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT -+ static const DiagType kDiagTypeMma2 = DiagType::kZero; -+ using LayoutAMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ typename layout::LayoutTranspose::type, -+ LayoutA -+ >::type; -+ using LayoutBMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ LayoutB, -+ typename layout::LayoutTranspose::type -+ >::type; -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultTrmm< -+ ElementA, LayoutAMma2, kAlignmentA, -+ ElementB, LayoutBMma2, kAlignmentB, -+ kSideModeA, InvertFillMode::mode, kDiagTypeMma2, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level SYMM/HEMM operator. -+ using SymmKernel = kernel::SymmUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultSymm< -+ ElementA, LayoutA, kSideModeA, kFillModeA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC,layout::RowMajor, -+ ElementAccumulator, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, -+ Operator> { -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - with diagonal: alpha * A * B or alpha * B * A -+ static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultTrmm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ kSideModeA, kFillModeA, kDiagTypeMma1, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT -+ static const DiagType kDiagTypeMma2 = DiagType::kZero; -+ using LayoutAMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ typename layout::LayoutTranspose::type, -+ LayoutA -+ >::type; -+ using LayoutBMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ LayoutB, -+ typename layout::LayoutTranspose::type -+ >::type; -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultTrmm< -+ ElementA, LayoutAMma2, kAlignmentA, -+ ElementB, LayoutBMma2, kAlignmentB, -+ kSideModeA, InvertFillMode::mode, kDiagTypeMma2, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level SYMM/HEMM operator. -+ using SymmKernel = kernel::SymmUniversal; -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm_complex.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm_complex.h -new file mode 100755 -index 0000000..09cb7e5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm_complex.h -@@ -0,0 +1,508 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level SYMM/HEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/symm_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_multistage_trmm_complex.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+struct DefaultSymmComplex; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture complex datatype (symmetric) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultSymmComplex< -+ ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ Operator, SplitKSerial, BlasMode::kSymmetric> { -+ -+ static BlasMode const kBlasMode = BlasMode::kSymmetric; -+ // Complex Transform don't appply to A or B for SYMM -+ static ComplexTransform const TransformA = ComplexTransform::kNone; -+ static ComplexTransform const TransformB = ComplexTransform::kNone; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - with diagonal: alpha * A * B or alpha * B * A -+ static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ kSideModeA, kFillModeA, kDiagTypeMma1, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT -+ static const DiagType kDiagTypeMma2 = DiagType::kZero; -+ using LayoutAMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ typename layout::LayoutTranspose::type, -+ LayoutA -+ >::type; -+ using LayoutBMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ LayoutB, -+ typename layout::LayoutTranspose::type -+ >::type; -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutAMma2, -+ ElementB, LayoutBMma2, -+ kSideModeA, InvertFillMode::mode, kDiagTypeMma2, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level Symm operator. -+ using SymmKernel = kernel::SymmUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture complex datatype (hermitian) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultSymmComplex< -+ ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ Operator, SplitKSerial, BlasMode::kHermitian> { -+ -+ static BlasMode const kBlasMode = BlasMode::kHermitian; -+ -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - with diagonal: alpha * A * B or alpha * B * A -+ static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; -+ static ComplexTransform const TransformAMma1 = ComplexTransform::kNone; -+ static ComplexTransform const TransformBMma1 = ComplexTransform::kNone; -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ kSideModeA, kFillModeA, kDiagTypeMma1, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformAMma1, TransformBMma1, Operator, BlasMode::kHermitian>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - withOUT diagonal - with conjugate transpose: alpha * AT * B or alpha * B * AT -+ static const DiagType kDiagTypeMma2 = DiagType::kZero; -+ using LayoutAMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ typename layout::LayoutTranspose::type, -+ LayoutA -+ >::type; -+ using LayoutBMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ LayoutB, -+ typename layout::LayoutTranspose::type -+ >::type; -+ static ComplexTransform const TransformAMma2 = (kSideModeA == SideMode::kLeft) ? -+ ComplexTransform::kConjugate : ComplexTransform::kNone; -+ static ComplexTransform const TransformBMma2 = (kSideModeA == SideMode::kLeft) ? -+ ComplexTransform::kNone : ComplexTransform::kConjugate; -+ -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutAMma2, -+ ElementB, LayoutBMma2, -+ kSideModeA, InvertFillMode::mode, kDiagTypeMma2, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformAMma2, TransformBMma2, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level Symm operator. -+ using SymmKernel = kernel::SymmUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture complex datatype (symmetric) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultSymmComplex< -+ ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ Operator, SplitKSerial, BlasMode::kSymmetric> { -+ -+ static BlasMode const kBlasMode = BlasMode::kSymmetric; -+ // Complex Transform don't appply to A or B for SYMM -+ static ComplexTransform const TransformA = ComplexTransform::kNone; -+ static ComplexTransform const TransformB = ComplexTransform::kNone; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - with diagonal: alpha * A * B or alpha * B * A -+ static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ kSideModeA, kFillModeA, kDiagTypeMma1, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT -+ static const DiagType kDiagTypeMma2 = DiagType::kZero; -+ using LayoutAMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ typename layout::LayoutTranspose::type, -+ LayoutA -+ >::type; -+ using LayoutBMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ LayoutB, -+ typename layout::LayoutTranspose::type -+ >::type; -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutAMma2, -+ ElementB, LayoutBMma2, -+ kSideModeA, InvertFillMode::mode, kDiagTypeMma2, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level Symm operator. -+ using SymmKernel = kernel::SymmUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture complex datatype (hermitian) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultSymmComplex< -+ ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ Operator, SplitKSerial, BlasMode::kHermitian> { -+ -+ static BlasMode const kBlasMode = BlasMode::kHermitian; -+ -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - with diagonal: alpha * A * B or alpha * B * A -+ static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; -+ static ComplexTransform const TransformAMma1 = ComplexTransform::kNone; -+ static ComplexTransform const TransformBMma1 = ComplexTransform::kNone; -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ kSideModeA, kFillModeA, kDiagTypeMma1, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformAMma1, TransformBMma1, Operator, BlasMode::kHermitian>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - withOUT diagonal - with conjugate transpose: alpha * AT * B or alpha * B * AT -+ static const DiagType kDiagTypeMma2 = DiagType::kZero; -+ using LayoutAMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ typename layout::LayoutTranspose::type, -+ LayoutA -+ >::type; -+ using LayoutBMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ LayoutB, -+ typename layout::LayoutTranspose::type -+ >::type; -+ static ComplexTransform const TransformAMma2 = (kSideModeA == SideMode::kLeft) ? -+ ComplexTransform::kConjugate : ComplexTransform::kNone; -+ static ComplexTransform const TransformBMma2 = (kSideModeA == SideMode::kLeft) ? -+ ComplexTransform::kNone : ComplexTransform::kConjugate; -+ -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutAMma2, -+ ElementB, LayoutBMma2, -+ kSideModeA, InvertFillMode::mode, kDiagTypeMma2, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformAMma2, TransformBMma2, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level Symm operator. -+ using SymmKernel = kernel::SymmUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm_universal.h -new file mode 100755 -index 0000000..adcf1ff ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm_universal.h -@@ -0,0 +1,342 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level SYMM/HEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/kernel/symm_universal.h" -+#include "cutlass/gemm/kernel/default_symm.h" -+#include "cutlass/gemm/kernel/default_symm_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode SideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode FillModeA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYRK -+ typename Operator, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ = BlasMode::kSymmetric, -+ /// -+ typename Enable = void -+ > -+struct DefaultSymmUniversal; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued SYMM/HEMM update kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode SideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode FillModeA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYMM/HEMM -+ typename Operator> -+struct DefaultSymmUniversal< -+ ElementA, -+ LayoutA, -+ SideModeA, -+ FillModeA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ BlasMode::kSymmetric, -+ typename std::enable_if< ! cutlass::is_complex::value>::type -+> { -+ -+ using DefaultSymmkernel = typename kernel::DefaultSymm< -+ ElementA, -+ LayoutA, -+ SideModeA, -+ FillModeA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ BlasMode::kSymmetric -+ >::SymmKernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using SymmKernel = kernel::SymmUniversal< -+ typename DefaultSymmkernel::Mma1, -+ typename DefaultSymmkernel::Mma2, -+ typename DefaultSymmkernel::Epilogue, -+ ThreadblockSwizzle, -+ SideModeA, -+ FillModeA -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Complex-valued SYMM/HEMM update kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode SideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode FillModeA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYRK -+ typename Operator, -+ // BlasMode -+ BlasMode kBlasMode -+ > -+ -+struct DefaultSymmUniversal< -+ ElementA, -+ LayoutA, -+ SideModeA, -+ FillModeA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ kBlasMode, -+ typename std::enable_if::value>::type -+> { -+ -+ using DefaultSymmkernel = typename kernel::DefaultSymmComplex< -+ ElementA, -+ LayoutA, -+ SideModeA, -+ FillModeA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ SplitKSerial, -+ kBlasMode -+ >::SymmKernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using SymmKernel = kernel::SymmUniversal< -+ typename DefaultSymmkernel::Mma1, -+ typename DefaultSymmkernel::Mma2, -+ typename DefaultSymmkernel::Epilogue, -+ ThreadblockSwizzle, -+ SideModeA, -+ FillModeA -+ >; -+}; -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm.h -new file mode 100644 -index 0000000..cf2896a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm.h -@@ -0,0 +1,269 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+// -+/*! \file -+ \brief -+ Default kernel-level TRMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/trmm_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_trmm.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode SideMode_, -+ /// Fill Mode for the triangular matrix -+ FillMode FillMode_, -+ /// Diag Type for the triangular matrix -+ DiagType DiagType_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultTrmm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultTrmm { -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultTrmm< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ kSideMode, kFillMode, kDiagType, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level TRMM operator. -+ using TrmmKernel = kernel::TrmmUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultTrmm { -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultTrmm< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ kSideMode, kFillMode, kDiagType, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level TRMM operator. -+ using TrmmKernel = kernel::TrmmUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm_complex.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm_complex.h -new file mode 100644 -index 0000000..4909396 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm_complex.h -@@ -0,0 +1,265 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level TRMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/trmm_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_multistage_trmm_complex.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Side Mode for the kernel -+ SideMode SideMode_, -+ /// Fill Mode for the triangular matrix -+ FillMode FillMode_, -+ /// Diag Type for the triangular matrix -+ DiagType DiagType_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+> -+struct DefaultTrmmComplex; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+ > -+struct DefaultTrmmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, -+ kSideMode, kFillMode, kDiagType, -+ ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, -+ kSideMode, kFillMode, kDiagType, -+ ElementAccumulator,layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, ThreadblockShape, -+ WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level TRMM operator. -+ using TrmmKernel = kernel::TrmmUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+ > -+struct DefaultTrmmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, -+ kSideMode, kFillMode, kDiagType, -+ ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, -+ kSideMode, kFillMode, kDiagType, -+ ElementAccumulator,layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, -+ WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level TRMM operator. -+ using TrmmKernel = kernel::TrmmUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm_universal.h -new file mode 100644 -index 0000000..50e8d8d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm_universal.h -@@ -0,0 +1,359 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level TRMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/kernel/trmm_universal.h" -+#include "cutlass/gemm/kernel/default_trmm.h" -+#include "cutlass/gemm/kernel/default_trmm_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by TRMM -+ typename Operator, -+ /// -+ typename Enable = void -+ > -+struct DefaultTrmmUniversal; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued TRMM kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by TRMM -+ typename Operator> -+struct DefaultTrmmUniversal< -+ ElementA, -+ LayoutA, -+ ComplexTransform::kNone, // transform A -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ ComplexTransform::kNone, // transform B -+ kAlignmentB, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ typename std::enable_if< ! cutlass::is_complex::value>::type -+> { -+ -+ using DefaultTrmmKernel = typename kernel::DefaultTrmm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator -+ >::TrmmKernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using TrmmKernel = kernel::TrmmUniversal< -+ typename DefaultTrmmKernel::Mma, -+ typename DefaultTrmmKernel::Epilogue, -+ ThreadblockSwizzle, -+ kSideMode, -+ kFillMode, -+ kDiagType -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Complex-valued TRMM kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by TRMM -+ typename Operator -+ > -+struct DefaultTrmmUniversal< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ typename std::enable_if::value>::type -+> { -+ -+ using DefaultTrmmKernel = typename kernel::DefaultTrmmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ TransformA, -+ TransformB, -+ Operator, -+ SplitKSerial -+ >::TrmmKernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using TrmmKernel = kernel::TrmmUniversal< -+ typename DefaultTrmmKernel::Mma, -+ typename DefaultTrmmKernel::Epilogue, -+ ThreadblockSwizzle, -+ kSideMode, -+ kFillMode, -+ kDiagType -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/ell_gemm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/ell_gemm.h -new file mode 100644 -index 0000000..88a1bd3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/ell_gemm.h -@@ -0,0 +1,830 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a Block-Ell sparse gemm kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/transform/threadblock/ell_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled. -+ bool IsASparse ///! If true, A is sparse matrix -+> -+struct EllGemm { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::TensorRef ref_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::TensorRef ref_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D; -+ typename OutputOp::Params output_op; -+ int *semaphore; -+ int gemm_k_iterations; -+ int gemm_k_size; -+ const int* ell_idx; -+ int ell_ncol; -+ int ell_blocksize; -+ int ell_base_idx; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D, -+ const int* ell_idx, -+ int ell_ncol, -+ int ell_blocksize, -+ int ell_base_idx, -+ typename OutputOp::Params output_op = typename OutputOp::Params(), -+ int *workspace = nullptr -+ ): -+ problem_size(problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(ref_A.layout()), -+ ref_A(ref_A), -+ params_B(ref_B.layout()), -+ ref_B(ref_B), -+ params_C(ref_C.layout()), -+ ref_C(ref_C), -+ params_D(ref_D.layout()), -+ ref_D(ref_D), -+ output_op(output_op), -+ ell_idx(ell_idx), -+ ell_ncol(ell_ncol), -+ ell_blocksize(ell_blocksize), -+ ell_base_idx(ell_base_idx) -+ { -+ -+ int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ -+ gemm_k_size = gemm_k_iterations * Mma::Shape::kK; -+ -+ semaphore = workspace; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ struct SharedStorage { -+ union{ -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ typename cutlass::transform::threadblock::ell::SharedStorage ell; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ EllGemm() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D) { -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if (!TensorRef_aligned(ref_A, kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int tile_in_ell_block = (params.ell_blocksize + Mma::Shape::kM - 1 ) / Mma::Shape::kM; -+ int ell_block_offset_m = threadblock_tile_offset.m() / tile_in_ell_block; -+ int tile_offset_m = threadblock_tile_offset.m() % tile_in_ell_block; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ int lane_idx = threadIdx.x % 32; -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // skip computation if matrix is 0 -+ if (params.ell_ncol > 0) { -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ ell_block_offset_m * params.ell_blocksize -+ + tile_offset_m * Mma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ int ell_idx_start = -+ (threadblock_tile_offset.m() / tile_in_ell_block) * -+ (params.ell_ncol / params.ell_blocksize); -+ const int* ell_idx_ptr = &(params.ell_idx[ell_idx_start]); -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k = min( -+ params.problem_size.k(), -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size); -+ problem_size_k = min(problem_size_k, params.ell_ncol); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = -+ (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ params.ref_A.data(), -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ params.ref_B.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Define coef for ELL index depending on LayoutB -+ int ell_stride = iterator_B.get_stride(); -+ -+ typename cutlass::transform::threadblock::ell::Iterator ell_iterator( -+ shared_storage.ell, -+ ell_idx_ptr, -+ params.ell_blocksize, -+ params.ell_base_idx, -+ Mma::Shape::kK, -+ problem_size_k, -+ ell_stride, -+ thread_idx -+ ); -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ if (!kSplitKSerial || gemm_k_iterations > 0) { -+ // check if index computations can be skipped -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8); -+ constexpr bool is_multiple_alignment = -+ (kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1); -+ const bool is_specialized_blocksize = -+ ((params.ell_blocksize) & (params.ell_blocksize-1)) == 0 -+ && params.ell_blocksize >= Mma::Shape::kK; -+ // Compute threadblock-scoped matrix multiply-add -+ if ((is_double || is_multiple_alignment) && is_specialized_blocksize) { -+ mma.operator()( -+ gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); -+ } -+ else { -+ mma.operator()( -+ gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); -+ } -+ } -+ } // if (params.ell_ncols > 0) -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ ell_block_offset_m = threadblock_tile_offset.m() / tile_in_ell_block; -+ tile_offset_m = threadblock_tile_offset.m() % tile_in_ell_block; -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ ell_block_offset_m * params.ell_blocksize -+ + tile_offset_m * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ //avoid out of bounds -+ MatrixCoord threadblock_extent( -+ min(params.problem_size.m(), -+ ell_block_offset_m * params.ell_blocksize -+ + min((tile_offset_m + 1) * Mma::Shape::kM, params.ell_blocksize)), -+ min(params.problem_size.n(), -+ (threadblock_tile_offset.n()+1) * Mma::Shape::kN) -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ params.ref_C.data(), -+ threadblock_extent, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ref_D.data(), -+ threadblock_extent, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+// B is Sparse -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. -+> -+struct EllGemm { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::TensorRef ref_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::TensorRef ref_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D; -+ typename OutputOp::Params output_op; -+ int *semaphore; -+ int gemm_k_iterations; -+ int gemm_k_size; -+ const int* ell_idx; -+ int ell_ncol; -+ int ell_blocksize; -+ int ell_base_idx; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D, -+ const int* ell_idx, -+ int ell_ncol, -+ int ell_blocksize, -+ int ell_base_idx, -+ typename OutputOp::Params output_op = typename OutputOp::Params(), -+ int *workspace = nullptr -+ ): -+ problem_size(problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(ref_A.layout()), -+ ref_A(ref_A), -+ params_B(ref_B.layout()), -+ ref_B(ref_B), -+ params_C(ref_C.layout()), -+ ref_C(ref_C), -+ params_D(ref_D.layout()), -+ ref_D(ref_D), -+ output_op(output_op), -+ ell_idx(ell_idx), -+ ell_ncol(ell_ncol), -+ ell_blocksize(ell_blocksize), -+ ell_base_idx(ell_base_idx) -+ { -+ -+ int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ -+ gemm_k_size = gemm_k_iterations * Mma::Shape::kK; -+ -+ semaphore = workspace; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ struct SharedStorage { -+ union{ -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ typename cutlass::transform::threadblock::ell::SharedStorage ell; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ EllGemm() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D) { -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if (!TensorRef_aligned(ref_A, kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int tile_in_ell_block = (params.ell_blocksize + Mma::Shape::kN - 1 ) / Mma::Shape::kN; -+ int ell_block_offset_n = threadblock_tile_offset.n() / tile_in_ell_block; -+ int tile_offset_n = threadblock_tile_offset.n() % tile_in_ell_block; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ int lane_idx = threadIdx.x % 32; -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // skip computation if matrix is 0 -+ if (params.ell_ncol > 0) { -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ ell_block_offset_n * params.ell_blocksize -+ + tile_offset_n * Mma::Shape::kN, -+ }; -+ -+ int ell_idx_start = -+ (threadblock_tile_offset.n() / tile_in_ell_block) * -+ (params.ell_ncol / params.ell_blocksize); -+ const int* ell_idx_ptr = &(params.ell_idx[ell_idx_start]); -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k = min( -+ params.problem_size.k(), -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size); -+ problem_size_k = min(problem_size_k, params.ell_ncol); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = -+ (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ params.ref_A.data(), -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ params.ref_B.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Define coef for ELL index depending on LayoutA -+ int ell_stride = iterator_A.get_stride(); -+ -+ typename cutlass::transform::threadblock::ell::Iterator ell_iterator( -+ shared_storage.ell, -+ ell_idx_ptr, -+ params.ell_blocksize, -+ params.ell_base_idx, -+ Mma::Shape::kK, -+ problem_size_k, -+ ell_stride, -+ thread_idx -+ ); -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ if (!kSplitKSerial || gemm_k_iterations > 0) { -+ // check if index computations can be skipped -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8); -+ constexpr bool is_multiple_alignment = -+ (kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1); -+ const bool is_specialized_blocksize = -+ ((params.ell_blocksize) & (params.ell_blocksize-1)) == 0 -+ && params.ell_blocksize >= Mma::Shape::kK; -+ // Compute threadblock-scoped matrix multiply-add -+ if ((is_double || is_multiple_alignment) && is_specialized_blocksize) { -+ mma.operator()( -+ gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); -+ } -+ else { -+ mma.operator()( -+ gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); -+ } -+ } -+ } // if (params.ell_ncols > 0) -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ ell_block_offset_n = threadblock_tile_offset.n() / tile_in_ell_block; -+ tile_offset_n = threadblock_tile_offset.n() % tile_in_ell_block; -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ ell_block_offset_n * params.ell_blocksize -+ + tile_offset_n * Mma::Shape::kN -+ ); -+ -+ //avoid out of bounds -+ MatrixCoord threadblock_extent( -+ min(params.problem_size.m(), -+ (threadblock_tile_offset.m()+1) * Mma::Shape::kM), -+ min(params.problem_size.n(), -+ ell_block_offset_n * params.ell_blocksize -+ + min((tile_offset_n + 1) * Mma::Shape::kN, params.ell_blocksize)) -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ params.ref_C.data(), -+ threadblock_extent, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ref_D.data(), -+ threadblock_extent, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm.h -new file mode 100644 -index 0000000..b5064ec ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm.h -@@ -0,0 +1,380 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/arch/arch.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. -+> -+struct Gemm { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::TensorRef ref_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::TensorRef ref_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D; -+ typename OutputOp::Params output_op; -+ int *semaphore; -+ int gemm_k_size; -+ // For gather+scatter operations -+ int const *gather_A_indices; -+ int const *gather_B_indices; -+ int const *scatter_D_indices; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D, -+ typename OutputOp::Params output_op = typename OutputOp::Params(), -+ int *workspace = nullptr, -+ int const *gather_A_indices = nullptr, -+ int const *gather_B_indices = nullptr, -+ int const *scatter_D_indices = nullptr -+ ): -+ problem_size(problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(ref_A.layout()), -+ ref_A(ref_A), -+ params_B(ref_B.layout()), -+ ref_B(ref_B), -+ params_C(ref_C.layout()), -+ ref_C(ref_C), -+ params_D(ref_D.layout()), -+ ref_D(ref_D), -+ output_op(output_op), -+ gather_A_indices(gather_A_indices), -+ gather_B_indices(gather_B_indices), -+ scatter_D_indices(scatter_D_indices) { -+ -+ int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ -+ gemm_k_size = gemm_k_iterations * Mma::Shape::kK; -+ -+ semaphore = workspace; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Gemm() { } -+ -+ /// Determines whether kernel satisfies alignment -+ CUTLASS_HOST_DEVICE -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D) { -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if (!TensorRef_aligned(ref_A, kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k = min( -+ params.problem_size.k(), -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ params.ref_A.data(), -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A, -+ params.gather_A_indices); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ params.ref_B.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B, -+ params.gather_B_indices); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ if (!kSplitKSerial || gemm_k_iterations > 0) { -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); -+ } -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ params.ref_C.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ params.scatter_D_indices -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ref_D.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ params.scatter_D_indices -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_array.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_array.h -new file mode 100644 -index 0000000..1862e20 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_array.h -@@ -0,0 +1,264 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmArray { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::Element const * const * ptr_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::Element const * const * ptr_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Element const * const * ptr_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::Element * const * ptr_D; -+ int64_t stride_D; -+ typename OutputOp::Params epilogue; -+ int batch_count; -+ int gemm_k_iterations; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params() : -+ swizzle_log_tile(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size_, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape_, -+ typename Mma::IteratorA::Element const * const * ptr_A_, -+ typename Mma::IteratorA::Layout layout_A, -+ typename Mma::IteratorB::Element const * const * ptr_B_, -+ typename Mma::IteratorB::Layout layout_B, -+ typename Epilogue::OutputTileIterator::Element const * const * ptr_C_, -+ typename Epilogue::OutputTileIterator::Layout layout_C, -+ typename Epilogue::OutputTileIterator::Element * const * ptr_D_, -+ typename Epilogue::OutputTileIterator::Layout layout_D, -+ typename OutputOp::Params epilogue_, -+ int batch_count_ -+ ): -+ problem_size(problem_size_), -+ grid_tiled_shape(grid_tiled_shape_), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(layout_A), -+ ptr_A(ptr_A_), -+ params_B(layout_B), -+ ptr_B(ptr_B_), -+ params_C(layout_C), -+ ptr_C(ptr_C_), -+ params_D(layout_D), -+ ptr_D(ptr_D_), -+ epilogue(epilogue_), -+ batch_count(batch_count_), -+ gemm_k_iterations((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) { -+ -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ GemmArray() { } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ -+ // Each CTA handles multiple batch indices to accommodate limited range of CUDA grid's Z dimension -+ for (int batch_idx = threadblock_swizzle.get_batch_idx(); -+ batch_idx < params.batch_count; -+ batch_idx += gridDim.z) { -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ 0 -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ 0, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ const_cast(params.ptr_A[batch_idx]), -+ params.problem_size.mk(), -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ const_cast(params.ptr_B[batch_idx]), -+ params.problem_size.kn(), -+ thread_idx, -+ tb_offset_B); -+ -+ // -+ // Main loop -+ // -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.epilogue); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ const_cast(params.ptr_C[batch_idx]), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ptr_D[batch_idx], -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // run efficient epilogue -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_batched.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_batched.h -new file mode 100644 -index 0000000..464aeef ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_batched.h -@@ -0,0 +1,279 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmBatched { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::TensorRef ref_A; -+ int64_t stride_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::TensorRef ref_B; -+ int64_t stride_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C; -+ int64_t stride_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D; -+ int64_t stride_D; -+ typename OutputOp::Params epilogue; -+ int batch_count; -+ int gemm_k_iterations; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params() : swizzle_log_tile(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size_, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape_, -+ typename Mma::IteratorA::TensorRef ref_A_, -+ int64_t stride_A_, -+ typename Mma::IteratorB::TensorRef ref_B_, -+ int64_t stride_B_, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C_, -+ int64_t stride_C_, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D_, -+ int64_t stride_D_, -+ typename OutputOp::Params epilogue_, -+ int batch_count_ -+ ): -+ problem_size(problem_size_), -+ grid_tiled_shape(grid_tiled_shape_), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(ref_A_.layout()), -+ ref_A(ref_A_), -+ stride_A(stride_A_), -+ params_B(ref_B_.layout()), -+ ref_B(ref_B_), -+ stride_B(stride_B_), -+ params_C(ref_C_.layout()), -+ ref_C(ref_C_), -+ stride_C(stride_C_), -+ params_D(ref_D_.layout()), -+ ref_D(ref_D_), -+ stride_D(stride_D_), -+ epilogue(epilogue_), -+ batch_count(batch_count_), -+ gemm_k_iterations((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) { -+ -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ GemmBatched() { } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ -+ // Each CTA handles multiple batch indices to accommodate limited range of CUDA grid's Z dimension -+ for (int batch_idx = threadblock_swizzle.get_batch_idx(); -+ batch_idx < params.batch_count; -+ batch_idx += gridDim.z) { -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ 0 -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ 0, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ params.ref_A.data(), -+ params.problem_size.mk(), -+ thread_idx, -+ tb_offset_A); -+ -+ iterator_A.add_pointer_offset(params.stride_A * batch_idx); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ params.ref_B.data(), -+ params.problem_size.kn(), -+ thread_idx, -+ tb_offset_B); -+ -+ iterator_B.add_pointer_offset(params.stride_B * batch_idx); -+ -+ -+ // -+ // Main loop -+ // -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.epilogue); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ params.ref_C.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ iterator_C.add_pointer_offset(params.stride_C * batch_idx); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ref_D.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ iterator_D.add_pointer_offset(params.stride_D * batch_idx); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // run efficient epilogue -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped.h -new file mode 100644 -index 0000000..84dc4ae ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped.h -@@ -0,0 +1,481 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Problem visitor for grouped GEMMs -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/trace.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform -+ bool Transposed = false -+> -+struct GemmGrouped { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; -+ static bool const kTransposed = Transposed; -+ -+ // Optional transpose -+ using MapArguments = kernel::detail::MapArguments< -+ typename Mma::IteratorA::Element, -+ typename Mma::IteratorA::Layout, -+ Mma::kTransformA, -+ Mma::IteratorA::AccessType::kElements, -+ typename Mma::IteratorB::Element, -+ typename Mma::IteratorB::Layout, -+ Mma::kTransformB, -+ Mma::IteratorB::AccessType::kElements, -+ typename Mma::LayoutC, -+ kTransposed -+ >; -+ -+ // Public-facing type definitions related to operand element type, layout, and complex conjugate -+ // operation. Must interact with the 'kTransposed' notion. -+ using ElementA = typename MapArguments::ElementA; -+ using LayoutA = typename MapArguments::LayoutA; -+ using ElementB = typename MapArguments::ElementB; -+ using LayoutB = typename MapArguments::LayoutB; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename MapArguments::LayoutC; -+ -+ static ComplexTransform const kTransformA = MapArguments::kTransformA; -+ static ComplexTransform const kTransformB = MapArguments::kTransformB; -+ -+ // Type definitions about the mainloop. -+ using Operator = typename Mma::Operator; -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = MapArguments::kAlignmentA; -+ static int const kAlignmentB = MapArguments::kAlignmentB; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using ProblemVisitor = GemmGroupedProblemVisitor< -+ ThreadblockShape, -+ kGroupScheduleMode, -+ kThreadCount, -+ kThreadCount, -+ kTransposed>; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord *problem_sizes; -+ int problem_count; -+ int threadblock_count; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ ElementA ** ptr_A; -+ ElementB ** ptr_B; -+ ElementC ** ptr_C; -+ ElementC ** ptr_D; -+ -+ typename LayoutA::Stride::LongIndex *lda; -+ typename LayoutB::Stride::LongIndex *ldb; -+ typename LayoutC::Stride::LongIndex *ldc; -+ typename LayoutC::Stride::LongIndex *ldd; -+ -+ // Only used by device-level operator -+ GemmCoord *host_problem_sizes; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ problem_count(0), -+ threadblock_count(0), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ lda(nullptr), -+ ldb(nullptr), -+ ldc(nullptr), -+ ldd(nullptr), -+ host_problem_sizes(nullptr) -+ { -+ -+ } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord *problem_sizes, -+ int problem_count, -+ int threadblock_count, -+ typename EpilogueOutputOp::Params output_op, -+ ElementA ** ptr_A, -+ ElementB ** ptr_B, -+ ElementC ** ptr_C, -+ ElementC ** ptr_D, -+ typename LayoutA::Stride::LongIndex *lda, -+ typename LayoutB::Stride::LongIndex *ldb, -+ typename LayoutC::Stride::LongIndex *ldc, -+ typename LayoutC::Stride::LongIndex *ldd, -+ GemmCoord *host_problem_sizes=nullptr -+ ): -+ problem_sizes(problem_sizes), -+ problem_count(problem_count), -+ threadblock_count(threadblock_count), -+ output_op(output_op), -+ ptr_A(ptr_A), -+ ptr_B(ptr_B), -+ ptr_C(ptr_C), -+ ptr_D(ptr_D), -+ lda(lda), -+ ldb(ldb), -+ ldc(ldc), -+ ldd(ldd), -+ host_problem_sizes(host_problem_sizes) -+ { -+ -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ typename ProblemVisitor::Params problem_visitor; -+ int threadblock_count; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ ElementA ** ptr_A; -+ ElementB ** ptr_B; -+ ElementC ** ptr_C; -+ ElementC ** ptr_D; -+ -+ typename LayoutA::Stride::LongIndex *lda; -+ typename LayoutB::Stride::LongIndex *ldb; -+ typename LayoutC::Stride::LongIndex *ldc; -+ typename LayoutC::Stride::LongIndex *ldd; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ lda(nullptr), -+ ldb(nullptr), -+ ldc(nullptr), -+ ldd(nullptr) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0): -+ problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), -+ threadblock_count(args.threadblock_count), -+ output_op(args.output_op), -+ ptr_A(args.ptr_A), -+ ptr_B(args.ptr_B), -+ ptr_C(args.ptr_C), -+ ptr_D(args.ptr_D), -+ lda(args.lda), -+ ldb(args.ldb), -+ ldc(args.ldc), -+ ldd(args.ldd) -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0) { -+ -+ problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, -+ workspace, tile_count); -+ threadblock_count = args.threadblock_count; -+ output_op = args.output_op; -+ ptr_A = args.ptr_A; -+ ptr_B = args.ptr_B; -+ ptr_C = args.ptr_C; -+ ptr_D = args.ptr_D; -+ lda = args.lda; -+ ldb = args.ldb; -+ ldc = args.ldc; -+ ldd = args.ldd; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ struct SharedStorage { -+ union { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ } kernel; -+ -+ // ProblemVisitor shared storage can't be overlapped with others -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ GemmGrouped() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // -+ // These types shadow the type-level definitions and support the ability to implement -+ // a 'transposed' GEMM that computes the transposed problems. -+ // -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ // -+ // Problem visitor. -+ // -+ ProblemVisitor problem_visitor( -+ params.problem_visitor, -+ shared_storage.problem_visitor, -+ blockIdx.x); -+ -+ // Outer 'persistent' loop to iterate over tiles -+ while (problem_visitor.next_tile()) { -+ -+ GemmCoord problem_size = problem_visitor.problem_size(); -+ int32_t problem_idx = problem_visitor.problem_index(); -+ int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); -+ -+ GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); -+ -+ cutlass::gemm::GemmCoord threadblock_offset( -+ int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, -+ int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, -+ 0); -+ -+ // Load element pointers. Exchange pointers and strides if working on the transpose -+ ElementA *ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); -+ typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); -+ -+ ElementB *ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); -+ typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_offset.m(), -+ 0, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ 0, -+ threadblock_offset.n() -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ LayoutA(ldm_A), -+ ptr_A, -+ {problem_size.m(), problem_size.k()}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ LayoutB(ldm_B), -+ ptr_B, -+ {problem_size.k(), problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Matrix multiply phase -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Wait for all threads to finish their epilogue phases from the previous tile. -+ __syncthreads(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ ElementC *ptr_C = params.ptr_C[problem_idx]; -+ ElementC *ptr_D = params.ptr_D[problem_idx]; -+ -+ LayoutC layout_C(params.ldc[problem_idx]); -+ LayoutC layout_D(params.ldd[problem_idx]); -+ -+ typename Epilogue::OutputTileIterator::Params params_C(layout_C); -+ typename Epilogue::OutputTileIterator::Params params_D(layout_D); -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params_C, -+ ptr_C, -+ problem_size.mn(), -+ thread_idx, -+ threadblock_offset.mn() -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params_D, -+ ptr_D, -+ problem_size.mn(), -+ thread_idx, -+ threadblock_offset.mn() -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.kernel.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h -new file mode 100644 -index 0000000..9df78c9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h -@@ -0,0 +1,122 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Scheduler for grouped GEMM -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/gemm/kernel/grouped_problem_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+// Helper for correctly representing problem sizes in grouped kernels -+template < -+ typename ThreadblockShape, -+ bool Transposed -+> -+struct GemmGroupedProblemSizeHelper { -+ -+ static bool const kTransposed = Transposed; -+ -+ CUTLASS_HOST_DEVICE -+ static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { -+ return cutlass::gemm::GemmCoord( -+ ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), -+ ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), -+ 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { -+ if (kTransposed) { -+ swap(problem.m(), problem.n()); -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { -+ return grid.m() * grid.n(); -+ } -+}; -+ -+} // namespace detail -+ -+/// Visitor class to abstract away the algorithm for iterating over tiles -+template -+struct GemmGroupedProblemVisitor : public GroupedProblemVisitor< -+ detail::GemmGroupedProblemSizeHelper, -+ ThreadblockShape, -+ GroupScheduleMode_, -+ PrefetchTileCount, -+ ThreadCount> { -+ -+ static bool const kTransposed = Transposed; -+ -+ using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; -+ using Base = GroupedProblemVisitor; -+ using Params = typename Base::Params; -+ using SharedStorage = typename Base::SharedStorage; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ GemmGroupedProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base (params_, shared_storage_, block_idx) -+ {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h -new file mode 100644 -index 0000000..cac99f5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h -@@ -0,0 +1,510 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Problem visitor for grouped GEMMs with a softmax fused beforehand -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/trace.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform -+ bool Transposed = false -+> -+struct GemmGroupedSoftmaxMainloopFusion { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; -+ static bool const kTransposed = Transposed; -+ -+ // Optional transpose -+ using MapArguments = kernel::detail::MapArguments< -+ typename Mma::IteratorA::Element, -+ typename Mma::IteratorA::Layout, -+ Mma::kTransformA, -+ Mma::IteratorA::AccessType::kElements, -+ typename Mma::IteratorB::Element, -+ typename Mma::IteratorB::Layout, -+ Mma::kTransformB, -+ Mma::IteratorB::AccessType::kElements, -+ typename Mma::LayoutC, -+ kTransposed -+ >; -+ -+ // Public-facing type definitions related to operand element type, layout, and complex conjugate -+ // operation. Must interact with the 'kTransposed' notion. -+ using ElementA = typename MapArguments::ElementA; -+ using LayoutA = typename MapArguments::LayoutA; -+ using ElementB = typename MapArguments::ElementB; -+ using LayoutB = typename MapArguments::LayoutB; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename MapArguments::LayoutC; -+ -+ using ElementScaleBias = typename Mma::IteratorNormSum::Element; -+ -+ static ComplexTransform const kTransformA = MapArguments::kTransformA; -+ static ComplexTransform const kTransformB = MapArguments::kTransformB; -+ -+ // Type definitions about the mainloop. -+ using Operator = typename Mma::Operator; -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = MapArguments::kAlignmentA; -+ static int const kAlignmentB = MapArguments::kAlignmentB; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using ProblemVisitor = GemmGroupedProblemVisitor< -+ ThreadblockShape, -+ kGroupScheduleMode, -+ kThreadCount, -+ kThreadCount, -+ kTransposed>; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord *problem_sizes; -+ int problem_count; -+ int threadblock_count; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ ElementA ** ptr_A; -+ ElementB ** ptr_B; -+ ElementC ** ptr_C; -+ ElementC ** ptr_D; -+ void ** ptr_norm; -+ void ** ptr_sum; -+ -+ typename LayoutA::Stride::LongIndex *lda; -+ typename LayoutB::Stride::LongIndex *ldb; -+ typename LayoutC::Stride::LongIndex *ldc; -+ typename LayoutC::Stride::LongIndex *ldd; -+ -+ // Only used by device-level operator -+ GemmCoord *host_problem_sizes; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ problem_count(0), -+ threadblock_count(0), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ ptr_norm(nullptr), -+ ptr_sum(nullptr), -+ lda(nullptr), -+ ldb(nullptr), -+ ldc(nullptr), -+ ldd(nullptr), -+ host_problem_sizes(nullptr) -+ { -+ -+ } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord *problem_sizes, -+ int problem_count, -+ int threadblock_count, -+ typename EpilogueOutputOp::Params output_op, -+ ElementA ** ptr_A, -+ ElementB ** ptr_B, -+ ElementC ** ptr_C, -+ ElementC ** ptr_D, -+ void ** ptr_norm, -+ void ** ptr_sum, -+ typename LayoutA::Stride::LongIndex *lda, -+ typename LayoutB::Stride::LongIndex *ldb, -+ typename LayoutC::Stride::LongIndex *ldc, -+ typename LayoutC::Stride::LongIndex *ldd, -+ GemmCoord *host_problem_sizes=nullptr -+ ): -+ problem_sizes(problem_sizes), -+ problem_count(problem_count), -+ threadblock_count(threadblock_count), -+ output_op(output_op), -+ ptr_A(ptr_A), -+ ptr_B(ptr_B), -+ ptr_C(ptr_C), -+ ptr_D(ptr_D), -+ ptr_norm(ptr_norm), -+ ptr_sum(ptr_sum), -+ lda(lda), -+ ldb(ldb), -+ ldc(ldc), -+ ldd(ldd), -+ host_problem_sizes(host_problem_sizes) -+ { -+ -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ typename ProblemVisitor::Params problem_visitor; -+ int threadblock_count; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ ElementA ** ptr_A; -+ ElementB ** ptr_B; -+ ElementC ** ptr_C; -+ ElementC ** ptr_D; -+ -+ void ** ptr_norm; -+ void ** ptr_sum; -+ -+ typename LayoutA::Stride::LongIndex *lda; -+ typename LayoutB::Stride::LongIndex *ldb; -+ typename LayoutC::Stride::LongIndex *ldc; -+ typename LayoutC::Stride::LongIndex *ldd; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ ptr_norm(nullptr), -+ ptr_sum(nullptr), -+ lda(nullptr), -+ ldb(nullptr), -+ ldc(nullptr), -+ ldd(nullptr) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0): -+ problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), -+ threadblock_count(args.threadblock_count), -+ output_op(args.output_op), -+ ptr_A(args.ptr_A), -+ ptr_B(args.ptr_B), -+ ptr_C(args.ptr_C), -+ ptr_D(args.ptr_D), -+ ptr_norm(args.ptr_norm), -+ ptr_sum(args.ptr_sum), -+ lda(args.lda), -+ ldb(args.ldb), -+ ldc(args.ldc), -+ ldd(args.ldd) -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0) { -+ -+ problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, -+ workspace, tile_count); -+ threadblock_count = args.threadblock_count; -+ output_op = args.output_op; -+ ptr_A = args.ptr_A; -+ ptr_B = args.ptr_B; -+ ptr_C = args.ptr_C; -+ ptr_D = args.ptr_D; -+ ptr_norm = args.ptr_norm; -+ ptr_sum = args.ptr_sum; -+ lda = args.lda; -+ ldb = args.ldb; -+ ldc = args.ldc; -+ ldd = args.ldd; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ struct SharedStorage { -+ union { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ } kernel; -+ -+ // ProblemVisitor shared storage can't be overlapped with others -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ GemmGroupedSoftmaxMainloopFusion() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // -+ // These types shadow the type-level definitions and support the ability to implement -+ // a 'transposed' GEMM that computes the transposed problems. -+ // -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ // -+ // Problem visitor. -+ // -+ ProblemVisitor problem_visitor( -+ params.problem_visitor, -+ shared_storage.problem_visitor, -+ blockIdx.x); -+ -+ // Outer 'persistent' loop to iterate over tiles -+ while (problem_visitor.next_tile()) { -+ -+ GemmCoord problem_size = problem_visitor.problem_size(); -+ int32_t problem_idx = problem_visitor.problem_index(); -+ int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); -+ -+ GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); -+ -+ cutlass::gemm::GemmCoord threadblock_offset( -+ int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, -+ int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, -+ 0); -+ -+ // Load element pointers. Exchange pointers and strides if working on the transpose -+ ElementA *ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); -+ typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); -+ -+ ElementB *ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); -+ typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_offset.m(), -+ 0, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ 0, -+ threadblock_offset.n() -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ LayoutA(ldm_A), -+ ptr_A, -+ {problem_size.m(), problem_size.k()}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ LayoutB(ldm_B), -+ ptr_B, -+ {problem_size.k(), problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Construct iterator to the softmax norm/sum vector -+ typename Mma::IteratorNormSum iterator_norm_sum( -+ problem_size.m(), -+ static_cast(params.ptr_norm[problem_idx]), -+ static_cast(params.ptr_sum[problem_idx]), -+ thread_idx, -+ MatrixCoord(0, threadblock_offset.m()) -+ ); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Matrix multiply phase -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Wait for all threads to finish their epilogue phases from the previous tile. -+ __syncthreads(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ iterator_norm_sum, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ ElementC *ptr_C = params.ptr_C[problem_idx]; -+ ElementC *ptr_D = params.ptr_D[problem_idx]; -+ -+ LayoutC layout_C(params.ldc[problem_idx]); -+ LayoutC layout_D(params.ldd[problem_idx]); -+ -+ typename Epilogue::OutputTileIterator::Params params_C(layout_C); -+ typename Epilogue::OutputTileIterator::Params params_D(layout_D); -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params_C, -+ ptr_C, -+ problem_size.mn(), -+ thread_idx, -+ threadblock_offset.mn() -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params_D, -+ ptr_D, -+ problem_size.mn(), -+ thread_idx, -+ threadblock_offset.mn() -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.kernel.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h -new file mode 100644 -index 0000000..94e2f1d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h -@@ -0,0 +1,777 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a multistage GEMM kernel with layernorm operations fused in mainloop. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmLayernormMainloopFusion { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ using ElementScaleBias = typename Mma::IteratorVarMean::Element; -+ using LayoutScaleBias = typename Mma::IteratorVarMean::Layout; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase -+ { -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_var; -+ void const * ptr_mean; -+ void const * ptr_gamma; -+ void const * ptr_beta; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_var; -+ int64_t batch_stride_mean; -+ int64_t batch_stride_gamma; -+ int64_t batch_stride_beta; -+ int64_t batch_stride_C; -+ -+ typename LayoutA::Stride stride_a; -+ typename LayoutB::Stride stride_b; -+ typename LayoutScaleBias::Stride stride_var; -+ typename LayoutScaleBias::Stride stride_mean; -+ typename LayoutScaleBias::Stride stride_gamma; -+ typename LayoutScaleBias::Stride stride_beta; -+ typename LayoutC::Stride stride_c; -+ typename LayoutC::Stride stride_d; -+ -+ typename LayoutA::Stride::LongIndex lda; -+ typename LayoutB::Stride::LongIndex ldb; -+ typename LayoutScaleBias::Stride::LongIndex ld_var; -+ typename LayoutScaleBias::Stride::LongIndex ld_mean; -+ typename LayoutScaleBias::Stride::LongIndex ld_gamma; -+ typename LayoutScaleBias::Stride::LongIndex ld_beta; -+ typename LayoutC::Stride::LongIndex ldc; -+ typename LayoutC::Stride::LongIndex ldd; -+ -+ int const * ptr_gather_A_indices; -+ int const * ptr_gather_B_indices; -+ int const * ptr_scatter_D_indices; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), -+ ptr_var(nullptr), ptr_mean(nullptr), -+ ptr_gamma(nullptr), ptr_beta(nullptr), -+ ptr_gather_A_indices(nullptr), -+ ptr_gather_B_indices(nullptr), -+ ptr_scatter_D_indices(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_var, -+ void const * ptr_mean, -+ void const * ptr_gamma, -+ void const * ptr_beta, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_var, -+ int64_t batch_stride_mean, -+ int64_t batch_stride_gamma, -+ int64_t batch_stride_beta, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride stride_a, -+ typename LayoutB::Stride stride_b, -+ typename LayoutScaleBias::Stride stride_var, -+ typename LayoutScaleBias::Stride stride_mean, -+ typename LayoutScaleBias::Stride stride_gamma, -+ typename LayoutScaleBias::Stride stride_beta, -+ typename LayoutC::Stride stride_c, -+ typename LayoutC::Stride stride_d, -+ int const *ptr_gather_A_indices = nullptr, -+ int const *ptr_gather_B_indices = nullptr, -+ int const *ptr_scatter_D_indices = nullptr) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ ptr_var(ptr_var), ptr_mean(ptr_mean), -+ ptr_gamma(ptr_gamma), ptr_beta(ptr_beta), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), -+ batch_stride_var(batch_stride_var), batch_stride_mean(batch_stride_mean), -+ batch_stride_gamma(batch_stride_gamma), batch_stride_beta(batch_stride_beta), -+ lda(0), ldb(0), ldc(0), ldd(0), -+ ld_var(0), ld_mean(0), -+ ld_gamma(0), ld_beta(0), -+ stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), -+ stride_var(stride_var), stride_mean(stride_mean), -+ stride_gamma(stride_gamma), stride_beta(stride_beta), -+ ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), -+ ptr_scatter_D_indices(ptr_scatter_D_indices) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_var, -+ void const * ptr_mean, -+ void const * ptr_gamma, -+ void const * ptr_beta, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_var, -+ int64_t batch_stride_mean, -+ int64_t batch_stride_gamma, -+ int64_t batch_stride_beta, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::LongIndex lda, -+ typename LayoutB::Stride::LongIndex ldb, -+ typename LayoutScaleBias::Stride::LongIndex ld_var, -+ typename LayoutScaleBias::Stride::LongIndex ld_mean, -+ typename LayoutScaleBias::Stride::LongIndex ld_gamma, -+ typename LayoutScaleBias::Stride::LongIndex ld_beta, -+ typename LayoutC::Stride::LongIndex ldc, -+ typename LayoutC::Stride::LongIndex ldd, -+ int const *ptr_gather_A_indices = nullptr, -+ int const *ptr_gather_B_indices = nullptr, -+ int const *ptr_scatter_D_indices = nullptr) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ ptr_var(ptr_var), ptr_mean(ptr_mean), -+ ptr_gamma(ptr_gamma), ptr_beta(ptr_beta), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), -+ batch_stride_var(batch_stride_var), batch_stride_mean(batch_stride_mean), -+ batch_stride_gamma(batch_stride_gamma), batch_stride_beta(batch_stride_beta), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), -+ ld_var(ld_var), ld_mean(ld_mean), -+ ld_gamma(ld_gamma), ld_beta(ld_beta), -+ ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), -+ ptr_scatter_D_indices(ptr_scatter_D_indices) -+ { -+ stride_a = make_Coord(lda); -+ stride_b = make_Coord(ldb); -+ stride_c = make_Coord(ldc); -+ stride_d = make_Coord(ldd); -+ stride_var = make_Coord(ld_var); -+ stride_mean = make_Coord(ld_mean); -+ stride_gamma = make_Coord(ld_gamma); -+ stride_beta = make_Coord(ld_beta); -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.stride_a, args.stride_b); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_var; -+ void * ptr_mean; -+ void * ptr_gamma; -+ void * ptr_beta; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_var; -+ int64_t batch_stride_mean; -+ int64_t batch_stride_gamma; -+ int64_t batch_stride_beta; -+ int64_t batch_stride_C; -+ -+ int * ptr_gather_A_indices; -+ int * ptr_gather_B_indices; -+ int * ptr_scatter_D_indices; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), -+ params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), -+ params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), -+ params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), -+ output_op(args.epilogue), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_var(const_cast(args.ptr_var)), -+ ptr_mean(const_cast(args.ptr_mean)), -+ ptr_gamma(const_cast(args.ptr_gamma)), -+ ptr_beta(const_cast(args.ptr_beta)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(args.ptr_D), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_var(args.batch_stride_var), -+ batch_stride_mean(args.batch_stride_mean), -+ batch_stride_gamma(args.batch_stride_gamma), -+ batch_stride_beta(args.batch_stride_beta), -+ batch_stride_C(args.batch_stride_C), -+ ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), -+ ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), -+ ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) -+ {} -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ void update(Arguments const &args) -+ { -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_var = const_cast(args.ptr_var); -+ ptr_mean = const_cast(args.ptr_mean); -+ ptr_gamma = const_cast(args.ptr_gamma); -+ ptr_beta = const_cast(args.ptr_beta); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); -+ ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); -+ ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); -+ -+ output_op = args.epilogue; -+ -+ CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmLayernormMainloopFusion op; -+ op(params, shared_storage); -+ } -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A, -+ params.ptr_gather_A_indices); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B, -+ params.ptr_gather_B_indices); -+ -+ // Construct iterators to A var/mean vector -+ typename Mma::IteratorVarMean iterator_var_mean( -+ params.problem_size.m(), -+ static_cast(params.ptr_var), -+ static_cast(params.ptr_mean), -+ thread_idx, -+ MatrixCoord(0, (threadblock_tile_offset.m() * Mma::Shape::kM)) -+ ); -+ -+ // Construct iterators to A scale/bias vector -+ typename Mma::IteratorGammaBeta iterator_gamma_beta( -+ problem_size_k, -+ static_cast(params.ptr_gamma), -+ static_cast(params.ptr_beta), -+ thread_idx, -+ MatrixCoord( -+ 0, (threadblock_tile_offset.k() * Mma::Shape::kK) -+ ) -+ ); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ iterator_var_mean, -+ iterator_gamma_beta, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ params.ptr_scatter_D_indices -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ params.ptr_scatter_D_indices -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_params.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_params.h -new file mode 100755 -index 0000000..046ad75 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_params.h -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct GemmParams { -+ -+ // -+ // Type definitions -+ // -+ using Index = int32_t; -+ using LongIndex = int64_t; -+ -+ using MmaIteratorParams = typename cutlass::transform::threadblock::PredicatedTileAccessIteratorParams; -+ using EpilogueIteratorParams = typename cutlass::epilogue::threadblock::PredicatedTileIteratorParams; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ // Data members for Mma::Iterator::Params -+ MmaIteratorParams params_itr_a; -+ MmaIteratorParams params_itr_b; -+ -+ // Data member for Epilogue::OutputTileIterator::Params -+ EpilogueIteratorParams params_itr_c; -+ EpilogueIteratorParams params_itr_d; -+ -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ LongIndex lda; -+ LongIndex ldb; -+ LongIndex ldc; -+ LongIndex ldd; -+ -+ LongIndex batch_stride_A; -+ LongIndex batch_stride_B; -+ LongIndex batch_stride_C; -+ LongIndex batch_stride_D; -+ -+ int *semaphore; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ GemmParams() {} -+ -+ CUTLASS_HOST_DEVICE -+ GemmParams( -+ cutlass::gemm::GemmCoord problem_size_, -+ cutlass::gemm::GemmCoord grid_tiled_shape_, -+ int swizzle_log_tile_, -+ GemmUniversalMode mode_, -+ int batch_count_, -+ int gemm_k_size_, -+ void const * ptr_A_, -+ void const * ptr_B_, -+ void const * ptr_C_, -+ void * ptr_D_, -+ LongIndex lda_, -+ LongIndex ldb_, -+ LongIndex ldc_, -+ LongIndex ldd_, -+ int64_t batch_stride_A_, -+ int64_t batch_stride_B_, -+ int64_t batch_stride_C_, -+ int64_t batch_stride_D_, -+ MmaIteratorParams const & params_itr_a_, -+ MmaIteratorParams const & params_itr_b_, -+ EpilogueIteratorParams const & params_itr_c_, -+ EpilogueIteratorParams const & params_itr_d_, -+ void *workspace_ = nullptr) : -+ problem_size(problem_size_), -+ grid_tiled_shape(grid_tiled_shape_), -+ swizzle_log_tile(swizzle_log_tile_), -+ mode(mode_), -+ batch_count(batch_count_), -+ gemm_k_size(gemm_k_size_), -+ ptr_A(const_cast(ptr_A_)), -+ ptr_B(const_cast(ptr_B_)), -+ ptr_C(const_cast(ptr_C_)), -+ ptr_D(ptr_D_), -+ lda(lda_), -+ ldb(ldb_), -+ ldc(ldc_), -+ ldd(ldd_), -+ batch_stride_A(batch_stride_A_), -+ batch_stride_B(batch_stride_B_), -+ batch_stride_C(batch_stride_C_), -+ batch_stride_D(batch_stride_D_), -+ params_itr_a(params_itr_a_), -+ params_itr_b(params_itr_b_), -+ params_itr_c(params_itr_c_), -+ params_itr_d(params_itr_d_), -+ semaphore(static_cast(workspace_) -+ ) { } -+ -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ void const * ptr_A_, -+ void const * ptr_B_, -+ void const * ptr_C_, -+ void * ptr_D_, -+ int64_t batch_stride_A_, -+ int64_t batch_stride_B_, -+ int64_t batch_stride_C_, -+ int64_t batch_stride_D_, -+ void *workspace_ = nullptr) { -+ -+ ptr_A = const_cast(ptr_A_); -+ ptr_B = const_cast(ptr_B_); -+ ptr_C = const_cast(ptr_C_); -+ ptr_D = ptr_D_; -+ -+ batch_stride_A = batch_stride_A_; -+ batch_stride_B = batch_stride_B_; -+ batch_stride_C = batch_stride_C_; -+ batch_stride_D = batch_stride_D_; -+ -+ -+ semaphore = static_cast(workspace_); -+ CUTLASS_TRACE_HOST("GemmParams::update()"); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_pipelined.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_pipelined.h -new file mode 100644 -index 0000000..df450d0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_pipelined.h -@@ -0,0 +1,158 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void GemmPipelined( -+ cutlass::gemm::GemmCoord problem_size, -+ cutlass::gemm::GemmCoord grid_tiled_shape, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::Params params_epilogue -+ ) { -+ -+ // Shared storage needed by threadblock-scoped matrix multiply-accumulate -+ __shared__ union { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ } shared_storage; -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ int swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); -+ -+ cutlass::gemm::GemmCoord tb_tile_offset = threadblock_swizzle.get_tile_offset(swizzle_log_tile); -+ -+ if (grid_tiled_shape.m() <= tb_tile_offset.m() || -+ grid_tiled_shape.n() <= tb_tile_offset.n()) { -+ -+ return; -+ } -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k() -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params_A, -+ ref_A.data(), -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params_B, -+ ref_B.data(), -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, -+ tb_offset_B); -+ -+ int warp_id = canonical_warp_idx(); -+ int lane_id = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, tb_thread_id, warp_id, lane_id); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(problem_size, accumulators, iterator_A, iterator_B, accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ Epilogue epilogue( -+ params_epilogue, -+ shared_storage.epilogue, -+ tb_thread_id, -+ warp_id, -+ lane_id); -+ -+ tb_tile_offset = threadblock_swizzle.get_tile_offset(swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ // run efficient epilogue -+ epilogue({problem_size.m(), problem_size.n()}, accumulators, threadblock_offset); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex.h -new file mode 100644 -index 0000000..7dbc592 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex.h -@@ -0,0 +1,715 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmPlanarComplex { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ using Operator = typename Mma::Operator; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value); -+ -+ // -+ // Additional types needed for reflection -+ // -+ -+ using ElementAccumulator = typename Mma::Policy::Operator::ElementC; -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::Shape; -+ -+ static int const kStages = Mma::kStages; -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ // -+ // Arguments structure -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase -+ { -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A_real; -+ void const * ptr_A_imag; -+ -+ void const * ptr_B_real; -+ void const * ptr_B_imag; -+ -+ void const * ptr_C_real; -+ void const * ptr_C_imag; -+ -+ void * ptr_D_real; -+ void * ptr_D_imag; -+ -+ typename LayoutA::Stride::Index lda_real; -+ typename LayoutA::Stride::Index lda_imag; -+ typename LayoutB::Stride::Index ldb_real; -+ typename LayoutB::Stride::Index ldb_imag; -+ typename LayoutC::Stride::Index ldc_real; -+ typename LayoutC::Stride::Index ldc_imag; -+ typename LayoutC::Stride::Index ldd_real; -+ typename LayoutC::Stride::Index ldd_imag; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_A_imag; -+ int64_t batch_stride_B; -+ int64_t batch_stride_B_imag; -+ int64_t batch_stride_C; -+ int64_t batch_stride_C_imag; -+ int64_t batch_stride_D_imag; -+ -+ // -+ // Methods -+ // -+ -+ Arguments() : -+ ptr_A_real(nullptr), -+ ptr_A_imag(nullptr), -+ ptr_B_real(nullptr), -+ ptr_B_imag(nullptr), -+ ptr_C_real(nullptr), -+ ptr_C_imag(nullptr), -+ ptr_D_real(nullptr), -+ ptr_D_imag(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A_real, -+ void const * ptr_A_imag, -+ void const * ptr_B_real, -+ void const * ptr_B_imag, -+ void const * ptr_C_real, -+ void const * ptr_C_imag, -+ void * ptr_D_real, -+ void * ptr_D_imag, -+ typename LayoutA::Stride::Index lda_real, -+ typename LayoutA::Stride::Index lda_imag, -+ typename LayoutB::Stride::Index ldb_real, -+ typename LayoutB::Stride::Index ldb_imag, -+ typename LayoutC::Stride::Index ldc_real, -+ typename LayoutC::Stride::Index ldc_imag, -+ typename LayoutC::Stride::Index ldd_real, -+ typename LayoutC::Stride::Index ldd_imag, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_A_imag = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_B_imag = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_C_imag = 0, -+ int64_t batch_stride_D = 0, -+ int64_t batch_stride_D_imag = 0) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A_real(ptr_A_real), -+ ptr_A_imag(ptr_A_imag), -+ ptr_B_real(ptr_B_real), -+ ptr_B_imag(ptr_B_imag), -+ ptr_C_real(ptr_C_real), -+ ptr_C_imag(ptr_C_imag), -+ ptr_D_real(ptr_D_real), -+ ptr_D_imag(ptr_D_imag), -+ lda_real(lda_real), -+ lda_imag(lda_imag), -+ ldb_real(ldb_real), -+ ldb_imag(ldb_imag), -+ ldc_real(ldc_real), -+ ldc_imag(ldc_imag), -+ ldd_real(ldd_real), -+ ldd_imag(ldd_imag), -+ batch_stride_A(batch_stride_A), -+ batch_stride_A_imag(batch_stride_A_imag), -+ batch_stride_B(batch_stride_B), -+ batch_stride_B_imag(batch_stride_B_imag), -+ batch_stride_C(batch_stride_C), -+ batch_stride_C_imag(batch_stride_C_imag), -+ batch_stride_D_imag(batch_stride_D_imag) -+ {} -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A_real, args.ptr_B_real); -+ std::swap(args.ptr_A_imag, args.ptr_B_imag); -+ std::swap(args.lda_real, args.ldb_real); -+ std::swap(args.lda_imag, args.ldb_imag); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ std::swap(args.batch_stride_A_imag, args.batch_stride_B_imag); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A_real; -+ typename Mma::IteratorA::Params params_A_imag; -+ typename Mma::IteratorB::Params params_B_real; -+ typename Mma::IteratorB::Params params_B_imag; -+ typename Epilogue::OutputTileIterator::Params params_C_real; -+ typename Epilogue::OutputTileIterator::Params params_C_imag; -+ typename Epilogue::OutputTileIterator::Params params_D_real; -+ typename Epilogue::OutputTileIterator::Params params_D_imag; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_A_real; -+ void * ptr_A_imag; -+ void * ptr_B_real; -+ void * ptr_B_imag; -+ void * ptr_C_real; -+ void * ptr_C_imag; -+ void * ptr_D_real; -+ void * ptr_D_imag; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ -+ int64_t batch_stride_A_imag; -+ int64_t batch_stride_B_imag; -+ int64_t batch_stride_C_imag; -+ int64_t batch_stride_D_imag; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A_real(args.lda_real), -+ params_A_imag(args.lda_imag), -+ params_B_real(args.ldb_real), -+ params_B_imag(args.ldb_imag), -+ params_C_real(args.ldc_real), -+ params_C_imag(args.ldc_imag), -+ params_D_real(args.ldd_real), -+ params_D_imag(args.ldd_imag), -+ output_op(args.epilogue), -+ ptr_A_real(const_cast(args.ptr_A_real)), -+ ptr_A_imag(const_cast(args.ptr_A_imag)), -+ ptr_B_real(const_cast(args.ptr_B_real)), -+ ptr_B_imag(const_cast(args.ptr_B_imag)), -+ ptr_C_real(const_cast(args.ptr_C_real)), -+ ptr_C_imag(const_cast(args.ptr_C_imag)), -+ ptr_D_real(args.ptr_D_real), -+ ptr_D_imag(args.ptr_D_imag), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_A_imag(args.batch_stride_A_imag), -+ batch_stride_B_imag(args.batch_stride_B_imag), -+ batch_stride_C_imag(args.batch_stride_C_imag), -+ batch_stride_D_imag(args.batch_stride_D_imag) -+ {} -+ -+ /// Returns the workspace size (in bytes) needed for this problem geometry -+ size_t get_workspace_size() const -+ { -+ size_t workspace_bytes = ParamsBase::get_workspace_size(); -+ if (this->mode == GemmUniversalMode::kGemmSplitKParallel) -+ { -+ // Double the size returned by the base class because we need to -+ // accumulate two ElementC components -+ workspace_bytes *= 2; -+ } -+ -+ return workspace_bytes; -+ } -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ void update(Arguments const &args) -+ { -+ ptr_A_real = const_cast(args.ptr_A_real); -+ ptr_A_imag = const_cast(args.ptr_A_imag); -+ -+ ptr_B_real = const_cast(args.ptr_B_real); -+ ptr_B_imag = const_cast(args.ptr_B_imag); -+ -+ ptr_C_real = const_cast(args.ptr_C_real); -+ ptr_C_imag = const_cast(args.ptr_C_imag); -+ -+ ptr_D_real = const_cast(args.ptr_D_real); -+ ptr_D_imag = const_cast(args.ptr_D_imag); -+ -+ output_op = args.epilogue; -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(Arguments const &args) -+ { -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = args.problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = args.problem_size.m() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = args.problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = args.problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = args.problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = args.problem_size.m() % kAlignmentC; -+ } -+ -+ if (isAMisaligned || isBMisaligned || isCMisaligned) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmPlanarComplex op; -+ op(params, shared_storage); -+ } -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A_real = static_cast(params.ptr_A_real); -+ ElementA *ptr_A_imag = static_cast(params.ptr_A_imag); -+ -+ ElementB *ptr_B_real = static_cast(params.ptr_B_real); -+ ElementB *ptr_B_imag = static_cast(params.ptr_B_imag); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_A; -+ ptr_A_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_A_imag; -+ ptr_B_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_B; -+ ptr_B_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_B_imag; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A_real = static_cast(params.ptr_A_real)[threadblock_tile_offset.k()]; -+ ptr_A_imag = static_cast(params.ptr_A_imag)[threadblock_tile_offset.k()]; -+ ptr_B_real = static_cast(params.ptr_B_real)[threadblock_tile_offset.k()]; -+ ptr_B_imag = static_cast(params.ptr_B_imag)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A_real( -+ params.params_A_real, -+ ptr_A_real, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorA iterator_A_imag( -+ params.params_A_imag, -+ ptr_A_imag, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B_real( -+ params.params_B_real, -+ ptr_B_real, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ typename Mma::IteratorB iterator_B_imag( -+ params.params_B_imag, -+ ptr_B_imag, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A_real, -+ iterator_A_imag, -+ iterator_B_real, -+ iterator_B_imag, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C_real = static_cast(params.ptr_C_real); -+ ElementC *ptr_C_imag = static_cast(params.ptr_C_imag); -+ ElementC *ptr_D_real = static_cast(params.ptr_D_real); -+ ElementC *ptr_D_imag = static_cast(params.ptr_D_imag); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D_real += threadblock_tile_offset.k() * params.batch_stride_D; -+ ptr_D_imag += threadblock_tile_offset.k() * params.batch_stride_D_imag; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_C; -+ ptr_C_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_C_imag; -+ ptr_D_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_D; -+ ptr_D_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_D_imag; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C_real = static_cast(params.ptr_C_real)[threadblock_tile_offset.k()]; -+ ptr_C_imag = static_cast(params.ptr_C_imag)[threadblock_tile_offset.k()]; -+ ptr_D_real = static_cast(params.ptr_D_real)[threadblock_tile_offset.k()]; -+ ptr_D_imag = static_cast(params.ptr_D_imag)[threadblock_tile_offset.k()]; -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C_real( -+ params.params_C_real, -+ ptr_C_real, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_C_imag( -+ params.params_C_imag, -+ ptr_C_imag, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D_real( -+ params.params_D_real, -+ ptr_D_real, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_D_imag( -+ params.params_D_imag, -+ ptr_D_imag, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // -+ // Construct epilogue -+ // -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C_real = iterator_D_real; -+ iterator_C_imag = iterator_D_imag; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D_real, -+ iterator_D_imag, -+ accumulators, -+ iterator_C_real, -+ iterator_C_imag); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex_array.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex_array.h -new file mode 100644 -index 0000000..21b8011 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex_array.h -@@ -0,0 +1,618 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmPlanarComplexArray { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ using Operator = typename Mma::Operator; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value); -+ -+ // -+ // Additional types needed for reflection -+ // -+ -+ using ElementAccumulator = typename Mma::Policy::Operator::ElementC; -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::Shape; -+ -+ static int const kStages = Mma::kStages; -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ // -+ // Arguments structure -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase -+ { -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ int const *ptr_M; -+ int const *ptr_N; -+ int const *ptr_K; -+ -+ void const * const * ptr_A_real; -+ void const * const * ptr_A_imag; -+ -+ void const * const * ptr_B_real; -+ void const * const * ptr_B_imag; -+ -+ void const * const * ptr_C_real; -+ void const * const * ptr_C_imag; -+ -+ void * const * ptr_D_real; -+ void * const * ptr_D_imag; -+ -+ typename LayoutA::Stride::Index lda_real; -+ typename LayoutA::Stride::Index lda_imag; -+ typename LayoutB::Stride::Index ldb_real; -+ typename LayoutB::Stride::Index ldb_imag; -+ typename LayoutC::Stride::Index ldc_real; -+ typename LayoutC::Stride::Index ldc_imag; -+ typename LayoutC::Stride::Index ldd_real; -+ typename LayoutC::Stride::Index ldd_imag; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ ptr_M(nullptr), -+ ptr_N(nullptr), -+ ptr_K(nullptr), -+ ptr_A_real(nullptr), -+ ptr_A_imag(nullptr), -+ ptr_B_real(nullptr), -+ ptr_B_imag(nullptr), -+ ptr_C_real(nullptr), -+ ptr_C_imag(nullptr), -+ ptr_D_real(nullptr), -+ ptr_D_imag(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ int const *ptr_M, -+ int const *ptr_N, -+ int const *ptr_K, -+ void const * const * ptr_A_real, -+ void const * const * ptr_A_imag, -+ void const * const * ptr_B_real, -+ void const * const * ptr_B_imag, -+ void const * const * ptr_C_real, -+ void const * const * ptr_C_imag, -+ void * const * ptr_D_real, -+ void * const * ptr_D_imag, -+ typename LayoutA::Stride::Index lda_real, -+ typename LayoutA::Stride::Index lda_imag, -+ typename LayoutB::Stride::Index ldb_real, -+ typename LayoutB::Stride::Index ldb_imag, -+ typename LayoutC::Stride::Index ldc_real, -+ typename LayoutC::Stride::Index ldc_imag, -+ typename LayoutC::Stride::Index ldd_real, -+ typename LayoutC::Stride::Index ldd_imag) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_M(ptr_M), -+ ptr_N(ptr_N), -+ ptr_K(ptr_K), -+ ptr_A_real(ptr_A_real), -+ ptr_A_imag(ptr_A_imag), -+ ptr_B_real(ptr_B_real), -+ ptr_B_imag(ptr_B_imag), -+ ptr_C_real(ptr_C_real), -+ ptr_C_imag(ptr_C_imag), -+ ptr_D_real(ptr_D_real), -+ ptr_D_imag(ptr_D_imag), -+ lda_real(lda_real), -+ lda_imag(lda_imag), -+ ldb_real(ldb_real), -+ ldb_imag(ldb_imag), -+ ldc_real(ldc_real), -+ ldc_imag(ldc_imag), -+ ldd_real(ldd_real), -+ ldd_imag(ldd_imag) -+ {} -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_M, args.ptr_N); -+ std::swap(args.ptr_A_real, args.ptr_B_real); -+ std::swap(args.ptr_A_imag, args.ptr_B_imag); -+ std::swap(args.lda_real, args.ldb_real); -+ std::swap(args.lda_imag, args.ldb_imag); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A_real; -+ typename Mma::IteratorA::Params params_A_imag; -+ typename Mma::IteratorB::Params params_B_real; -+ typename Mma::IteratorB::Params params_B_imag; -+ typename Epilogue::OutputTileIterator::Params params_C_real; -+ typename Epilogue::OutputTileIterator::Params params_C_imag; -+ typename Epilogue::OutputTileIterator::Params params_D_real; -+ typename Epilogue::OutputTileIterator::Params params_D_imag; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ int const *ptr_M; -+ int const *ptr_N; -+ int const *ptr_K; -+ -+ void const * const * ptr_A_real; -+ void const * const * ptr_A_imag; -+ void const * const * ptr_B_real; -+ void const * const * ptr_B_imag; -+ void const * const * ptr_C_real; -+ void const * const * ptr_C_imag; -+ void * const * ptr_D_real; -+ void * const * ptr_D_imag; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ ptr_M(args.ptr_M), -+ ptr_N(args.ptr_N), -+ ptr_K(args.ptr_K), -+ params_A_real(args.lda_real), -+ params_A_imag(args.lda_imag), -+ params_B_real(args.ldb_real), -+ params_B_imag(args.ldb_imag), -+ params_C_real(args.ldc_real), -+ params_C_imag(args.ldc_imag), -+ params_D_real(args.ldd_real), -+ params_D_imag(args.ldd_imag), -+ output_op(args.epilogue), -+ ptr_A_real(args.ptr_A_real), -+ ptr_A_imag(args.ptr_A_imag), -+ ptr_B_real(args.ptr_B_real), -+ ptr_B_imag(args.ptr_B_imag), -+ ptr_C_real(args.ptr_C_real), -+ ptr_C_imag(args.ptr_C_imag), -+ ptr_D_real(args.ptr_D_real), -+ ptr_D_imag(args.ptr_D_imag) -+ {} -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ void update(Arguments const &args) -+ { -+ ptr_M = args.ptr_M; -+ ptr_N = args.ptr_N; -+ ptr_K = args.ptr_K; -+ -+ ptr_A_real = args.ptr_A_real; -+ ptr_A_imag = args.ptr_A_imag; -+ -+ ptr_B_real = args.ptr_B_real; -+ ptr_B_imag = args.ptr_B_imag; -+ -+ ptr_C_real = args.ptr_C_real; -+ ptr_C_imag = args.ptr_C_imag; -+ -+ ptr_D_real = args.ptr_D_real; -+ ptr_D_imag = args.ptr_D_imag; -+ -+ output_op = args.epilogue; -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(Arguments const &args) { -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = args.problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = args.problem_size.m() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = args.problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = args.problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = args.problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = args.problem_size.m() % kAlignmentC; -+ } -+ -+ if (isAMisaligned || isBMisaligned || isCMisaligned) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmPlanarComplexArray op; -+ op(params, shared_storage); -+ } -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int batch_idx = threadblock_tile_offset.k(); -+ -+ int problem_size_m = params.problem_size.m(); -+ int problem_size_n = params.problem_size.n(); -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A_real = static_cast(const_cast(params.ptr_A_real[batch_idx])); -+ ElementA *ptr_A_imag = static_cast(const_cast(params.ptr_A_imag[batch_idx])); -+ -+ ElementB *ptr_B_real = static_cast(const_cast(params.ptr_B_real[batch_idx])); -+ ElementB *ptr_B_imag = static_cast(const_cast(params.ptr_B_imag[batch_idx])); -+ -+ // -+ // If pointers for problem sizes are specified, these are loaded from global memory -+ // -+ -+ if (params.ptr_M) { -+ problem_size_m = params.ptr_M[batch_idx]; -+ } -+ -+ if (params.ptr_N) { -+ problem_size_n = params.ptr_N[batch_idx]; -+ } -+ -+ if (params.ptr_K) { -+ problem_size_k = params.ptr_K[batch_idx]; -+ } -+ -+ int const kBlockCountM = (problem_size_m + Mma::Shape::kM - 1) / Mma::Shape::kM; -+ int const kBlockCountN = (problem_size_n + Mma::Shape::kN - 1) / Mma::Shape::kN; -+ -+ int const kGemmKIterations = (problem_size_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // -+ // Each threadblock loops over the logical problem size which the kernel may have discovered -+ // after the grid is launched. -+ // -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int block_m = threadblock_tile_offset.m(); -+ block_m < kBlockCountM; -+ block_m += params.grid_tiled_shape.m()) { -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int block_n = threadblock_tile_offset.n(); -+ block_n < kBlockCountN; -+ block_n += params.grid_tiled_shape.n()) { -+ -+ // -+ // Compute indices within threadblock and warp. -+ // -+ int thread_idx = threadIdx.x; -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Proceed with regular GEMM logic. -+ // -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ block_m * Mma::Shape::kM, 0}; -+ cutlass::MatrixCoord tb_offset_B{ 0, block_n * Mma::Shape::kN }; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A_real( -+ params.params_A_real, -+ ptr_A_real, -+ {problem_size_m, problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorA iterator_A_imag( -+ params.params_A_imag, -+ ptr_A_imag, -+ {problem_size_m, problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B_real( -+ params.params_B_real, -+ ptr_B_real, -+ {problem_size_k, problem_size_n}, -+ thread_idx, -+ tb_offset_B); -+ -+ typename Mma::IteratorB iterator_B_imag( -+ params.params_B_imag, -+ ptr_B_imag, -+ {problem_size_k, problem_size_n}, -+ thread_idx, -+ tb_offset_B); -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ kGemmKIterations, -+ accumulators, -+ iterator_A_real, -+ iterator_A_imag, -+ iterator_B_real, -+ iterator_B_imag, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ block_m * Mma::Shape::kM, -+ block_n * Mma::Shape::kN -+ ); -+ -+ ElementC *ptr_C_real = static_cast(const_cast(params.ptr_C_real[batch_idx])); -+ ElementC *ptr_C_imag = static_cast(const_cast(params.ptr_C_imag[batch_idx])); -+ ElementC *ptr_D_real = static_cast(params.ptr_D_real[batch_idx]); -+ ElementC *ptr_D_imag = static_cast(params.ptr_D_imag[batch_idx]); -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C_real( -+ params.params_C_real, -+ ptr_C_real, -+ {problem_size_m, problem_size_n}, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_C_imag( -+ params.params_C_imag, -+ ptr_C_imag, -+ {problem_size_m, problem_size_n}, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D_real( -+ params.params_D_real, -+ ptr_D_real, -+ {problem_size_m, problem_size_n}, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_D_imag( -+ params.params_D_imag, -+ ptr_D_imag, -+ {problem_size_m, problem_size_n}, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // -+ // Construct epilogue -+ // -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D_real, -+ iterator_D_imag, -+ accumulators, -+ iterator_C_real, -+ iterator_C_imag); -+ -+ -+ } // for block_n -+ } // for block_m -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_splitk_parallel.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_splitk_parallel.h -new file mode 100644 -index 0000000..ffb928c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_splitk_parallel.h -@@ -0,0 +1,253 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for GEMM performing a reduction over K partitions in parallel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmSplitKParallel { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ static int const kAlignmentK = Mma::Operator::Shape::kK; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::TensorRef ref_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::TensorRef ref_B; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D; -+ typename OutputOp::Params output_op; -+ int64_t splitk_slice_stride; -+ int gemm_k_size; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D, -+ typename OutputOp::Params output_op, -+ int64_t splitk_slice_stride -+ ): -+ problem_size(problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(ref_A.layout()), -+ ref_A(ref_A), -+ params_B(ref_B.layout()), -+ ref_B(ref_B), -+ params_D(ref_D.layout()), -+ ref_D(ref_D), -+ output_op(output_op), -+ splitk_slice_stride(splitk_slice_stride) { -+ -+ int full_gemm_k_iterations = problem_size.k() / Mma::Shape::kK; -+ int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); -+ -+ gemm_k_size = gemm_k_iterations * Mma::Shape::kK; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ GemmSplitKParallel() { } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k; -+ if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) { -+ problem_size_k = params.problem_size.k(); -+ } -+ else { -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ params.ref_A.data(), -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ params.ref_B.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ int warp_idx = threadIdx.x / 32; -+ int lane_idx = threadIdx.x % 32; -+ -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ref_D.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ iterator_D.add_pointer_offset(params.splitk_slice_stride * threadblock_tile_offset.k()); -+ -+ // Execute the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Run efficient epilogue -+ epilogue(output_op, iterator_D, accumulators, iterator_D); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_transpose_operands.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_transpose_operands.h -new file mode 100644 -index 0000000..dec9935 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_transpose_operands.h -@@ -0,0 +1,124 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and -+ batched array variants. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ ComplexTransform TransformA, -+ int AlignmentA, -+ typename ElementB_, -+ typename LayoutB_, -+ ComplexTransform TransformB, -+ int AlignmentB, -+ typename LayoutC_, -+ bool Transpose -+> -+struct MapArguments { -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ static ComplexTransform const kTransformA = TransformA; -+ static int const kAlignmentA = AlignmentA; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ static ComplexTransform const kTransformB = TransformB; -+ static int const kAlignmentB = AlignmentB; -+ using LayoutC = LayoutC_; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ ComplexTransform TransformA, -+ int AlignmentA, -+ typename ElementB_, -+ typename LayoutB_, -+ ComplexTransform TransformB, -+ int AlignmentB, -+ typename LayoutC_ -+> -+struct MapArguments< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ LayoutC_, -+ true -+> { -+ using ElementA = ElementB_; -+ using LayoutA = typename layout::LayoutTranspose::type; -+ static ComplexTransform const kTransformA = TransformB; -+ static int const kAlignmentA = AlignmentB; -+ using ElementB = ElementA_; -+ using LayoutB = typename layout::LayoutTranspose::type; -+ static ComplexTransform const kTransformB = TransformA; -+ static int const kAlignmentB = AlignmentA; -+ using LayoutC = typename layout::LayoutTranspose::type; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal.h -new file mode 100644 -index 0000000..fc62c01 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal.h -@@ -0,0 +1,694 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/arch/arch.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+class GemmUniversal< -+ Mma_, -+ Epilogue_, -+ ThreadblockSwizzle_, -+ void, -+ // 3.x kernels use the first template argument to define the ProblemShape tuple -+ // We use this invariant to SFINAE dispatch against either the 2.x API or the 3.x API -+ std::enable_if_t::value> -+> { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase -+ { -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ -+ typename LayoutA::Stride stride_a; -+ typename LayoutB::Stride stride_b; -+ typename LayoutC::Stride stride_c; -+ typename LayoutC::Stride stride_d; -+ -+ typename LayoutA::Stride::LongIndex lda; -+ typename LayoutB::Stride::LongIndex ldb; -+ typename LayoutC::Stride::LongIndex ldc; -+ typename LayoutC::Stride::LongIndex ldd; -+ -+ int const * ptr_gather_A_indices; -+ int const * ptr_gather_B_indices; -+ int const * ptr_scatter_D_indices; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), -+ ptr_gather_A_indices(nullptr), -+ ptr_gather_B_indices(nullptr), -+ ptr_scatter_D_indices(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride stride_a, -+ typename LayoutB::Stride stride_b, -+ typename LayoutC::Stride stride_c, -+ typename LayoutC::Stride stride_d, -+ int const *ptr_gather_A_indices = nullptr, -+ int const *ptr_gather_B_indices = nullptr, -+ int const *ptr_scatter_D_indices = nullptr) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), -+ stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), -+ ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), -+ ptr_scatter_D_indices(ptr_scatter_D_indices) -+ { -+ lda = 0; -+ ldb = 0; -+ ldc = 0; -+ ldd = 0; -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::LongIndex lda, -+ typename LayoutB::Stride::LongIndex ldb, -+ typename LayoutC::Stride::LongIndex ldc, -+ typename LayoutC::Stride::LongIndex ldd, -+ int const *ptr_gather_A_indices = nullptr, -+ int const *ptr_gather_B_indices = nullptr, -+ int const *ptr_scatter_D_indices = nullptr -+ ): -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), -+ ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), -+ ptr_scatter_D_indices(ptr_scatter_D_indices) -+ { -+ stride_a = make_Coord(lda); -+ stride_b = make_Coord(ldb); -+ stride_c = make_Coord(ldc); -+ stride_d = make_Coord(ldd); -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const -+ { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.stride_a, args.stride_b); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ -+ int * ptr_gather_A_indices; -+ int * ptr_gather_B_indices; -+ int * ptr_scatter_D_indices; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), -+ params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), -+ params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), -+ params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), -+ output_op(args.epilogue), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(args.ptr_D), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), -+ ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), -+ ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) -+ {} -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ void update(Arguments const &args) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); -+ -+ // Update input/output pointers -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); -+ ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); -+ ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); -+ -+ output_op = args.epilogue; -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmUniversal op; -+ op(params, shared_storage); -+ } -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A, -+ params.ptr_gather_A_indices); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B, -+ params.ptr_gather_B_indices); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ params.ptr_scatter_D_indices -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ params.ptr_scatter_D_indices -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ } -+ -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal.hpp b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal.hpp -new file mode 100644 -index 0000000..cdac6ca ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal.hpp -@@ -0,0 +1,72 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/* -+ * Stateless universal device GEMM kernel type that treats GEMM as -+ * a composition of a collective mainloop and a collective epilogue. -+ * -+ * Supports both the 2.x and 3.x APIs based on whether the first type is -+ * a cute::tuple<> or not. -+ * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h -+ * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp -+ * -+ * In the following declaration, the name preceding the 'Or' refers to -+ * 3.x API type argument order, and the name succeeding the 'Or' refers to -+ * 2.x API type argument order. Template arguments without two names -+ * belong to the 3.x API only. -+**/ -+template < -+ class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l) -+ class CollectiveMainloopOrEpilogue_, -+ class CollectiveEpilogueOrThreadblockSwizzle_, -+ class GridSwizzle_ = void, -+ class Enable = void -+> -+class GemmUniversal; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::kernel -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/gemm/kernel/sm70_gemm.hpp" -+#include "cutlass/gemm/kernel/sm90_gemm_tma.hpp" -+#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp" -+#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp" -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal_streamk.h -new file mode 100644 -index 0000000..27da66f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal_streamk.h -@@ -0,0 +1,1249 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/barrier.h" -+#include "cutlass/block_striped.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock mapping function -+> -+struct GemmUniversalStreamk { -+public: -+ -+ -+ // -+ // Types and constants -+ // -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ /// The per-thread tile of raw accumulators -+ using AccumulatorTile = typename Mma::FragmentC; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Workspace bytes per thread block -+ static size_t const kWorkspaceBytesPerBlock = -+ __NV_STD_MAX( -+ kThreadCount * sizeof(AccumulatorTile), -+ Epilogue::kWorkspaceBytesPerBlock); -+ -+ /// Block-striped reduction utility -+ using BlockStripedReduceT = BlockStripedReduce; -+ -+ -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; // Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ typename LayoutA::Stride stride_a; -+ typename LayoutB::Stride stride_b; -+ typename LayoutC::Stride stride_c; -+ typename LayoutC::Stride stride_d; -+ -+ typename LayoutA::Stride::LongIndex lda; -+ typename LayoutB::Stride::LongIndex ldb; -+ typename LayoutC::Stride::LongIndex ldc; -+ typename LayoutC::Stride::LongIndex ldd; -+ -+ int avail_sms; /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) -+ -+ -+ // -+ // Methods -+ // -+ -+ /// Default Constructor -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ avail_sms(-1) -+ {} -+ -+ /// Constructor -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride stride_a, -+ typename LayoutB::Stride stride_b, -+ typename LayoutC::Stride stride_c, -+ typename LayoutC::Stride stride_d, -+ int avail_sms = -1 /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) -+ ): -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_split), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), -+ stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), avail_sms(avail_sms) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// Constructor -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::LongIndex lda, -+ typename LayoutB::Stride::LongIndex ldb, -+ typename LayoutC::Stride::LongIndex ldc, -+ typename LayoutC::Stride::LongIndex ldd, -+ int avail_sms = -1 /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) -+ ): -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_split), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), avail_sms(avail_sms) -+ { -+ stride_a = make_Coord(lda); -+ stride_b = make_Coord(ldb); -+ stride_c = make_Coord(ldc); -+ stride_d = make_Coord(ldd); -+ CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const -+ { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.stride_a, args.stride_b); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ }; -+ -+ -+ /// Parameters structure -+ struct Params -+ { -+ public: -+ -+ // -+ // Data members -+ // -+ -+ void * ptr_A; -+ void * ptr_B; -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ -+ GemmUniversalMode mode; -+ -+ ThreadblockSwizzle block_mapping; -+ -+ bool quick_dp; -+ -+ void *barrier_workspace; -+ void *partials_workspace; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_D; -+ void * ptr_C; -+ -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ -+ int64_t batch_stride_D; -+ int64_t batch_stride_C; -+ -+ -+ protected: -+ -+ // -+ // Host-only dispatch-utilities -+ // -+ -+ /// Pad the given allocation size up to the nearest cache line -+ static size_t cacheline_align_up(size_t size) -+ { -+ static const int CACHELINE_SIZE = 128; -+ return (size + CACHELINE_SIZE - 1) / CACHELINE_SIZE * CACHELINE_SIZE; -+ } -+ -+ /// Get the workspace size needed for barrier -+ size_t get_barrier_workspace_size() const -+ { -+ // For atomic reduction, each SK-block needs a synchronization flag. For parallel reduction, -+ // each reduction block needs its own synchronization flag. -+ int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region(); -+ int num_flags = fast_max(sk_blocks, block_mapping.reduction_blocks); -+ -+ return cacheline_align_up(sizeof(typename Barrier::T) * num_flags); -+ } -+ -+ /// Get the workspace size needed for intermediate partial sums -+ size_t get_partials_workspace_size() const -+ { -+ int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region(); -+ return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks); -+ } -+ -+ -+ public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), -+ params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), -+ params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), -+ params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), -+ output_op(args.epilogue), -+ mode(args.mode), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(args.ptr_D), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_D(args.batch_stride_D), -+ barrier_workspace(nullptr), -+ partials_workspace(nullptr) -+ { -+ // Number of SMs to make available for StreamK decomposition -+ int avail_sms = (args.avail_sms == -1) ? -+ device_sms : -+ fast_min(args.avail_sms, device_sms); -+ -+ // Initialize the block mapping structure -+ block_mapping = ThreadblockSwizzle( -+ typename ThreadblockSwizzle::template KernelTraits(), -+ args.mode, -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count, -+ sm_occupancy, -+ device_sms, -+ avail_sms); -+ -+ quick_dp = -+ (block_mapping.sk_waves == 0) && -+ (mode == GemmUniversalMode::kGemm) && -+ !block_mapping.cohort_raster && -+ !EpilogueOutputOp(output_op).is_source_needed(); -+ -+ } -+ -+ -+ /// Returns the workspace size (in bytes) needed for these parameters -+ size_t get_workspace_size() const -+ { -+ return -+ get_barrier_workspace_size() + -+ get_partials_workspace_size(); -+ } -+ -+ -+ /// Assign and initialize the specified workspace buffer. Assumes -+ /// the memory allocated to workspace is at least as large as get_workspace_size(). -+ Status init_workspace( -+ void *workspace, -+ cudaStream_t stream = nullptr) -+ { -+ uint8_t *ptr = static_cast(workspace); -+ -+ // Establish partials workspace -+ partials_workspace = nullptr; -+ size_t partials_workspace_bytes = get_partials_workspace_size(); -+ if (partials_workspace_bytes > 0) -+ { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ partials_workspace = ptr; -+ ptr += partials_workspace_bytes; -+ } -+ -+ // Establish barrier workspace -+ barrier_workspace = nullptr; -+ size_t barrier_workspace_bytes = get_barrier_workspace_size(); -+ if (barrier_workspace_bytes > 0) -+ { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ barrier_workspace = ptr; -+ ptr += barrier_workspace_bytes; -+ } -+ -+ // Zero-initialize barrier workspace -+ if (barrier_workspace) -+ { -+ size_t barrier_workspace_bytes = get_barrier_workspace_size(); -+ -+ CUTLASS_TRACE_HOST(" Initialize " << barrier_workspace_bytes << " barrier bytes"); -+ -+ cudaError_t result = cudaMemsetAsync( -+ barrier_workspace, -+ 0, -+ barrier_workspace_bytes, -+ stream); -+ -+ if (result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ -+ /// Returns the GEMM volume in thread block tiles -+ cutlass::gemm::GemmCoord get_tiled_shape() const -+ { -+ return block_mapping.tiled_shape(); -+ } -+ -+ -+ /// Returns the total number of thread blocks to launch -+ int get_grid_blocks() const -+ { -+ dim3 grid_dims = get_grid_dims(); -+ return grid_dims.x * grid_dims.y * grid_dims.z; -+ } -+ -+ -+ /// Returns the grid extents in thread blocks to launch -+ dim3 get_grid_dims() const -+ { -+ return block_mapping.get_grid_dims(); -+ } -+ -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ void update(Arguments const &args) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalStreamK::Params::update()"); -+ -+ // Update input/output pointers -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ batch_stride_A = args.batch_stride_A; -+ batch_stride_B = args.batch_stride_B; -+ batch_stride_C = args.batch_stride_C; -+ batch_stride_D = args.batch_stride_D; -+ -+ output_op = args.epilogue; -+ } -+ -+ }; -+ -+ /// Tile work descriptor -+ struct TileWorkDesc -+ { -+ /// The linear tile index -+ int tile_idx; -+ -+ /// The location of this tile (in threadblock-tile coordinates) in the output matrix -+ cutlass::gemm::GemmCoord tiled_coord; -+ -+ // The first global-scoped MAC-iteration this threadblock will perform for this tile -+ int iter_begin; -+ -+ // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile -+ int k_begin; -+ -+ // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile -+ int k_end; -+ -+ /// The number of remaining MAC-iterations this threadblock will perform for this tile -+ int k_iters_remaining; -+ -+ // Whether this block will perform the first iteration of this tile -+ CUTLASS_DEVICE -+ bool tile_started() -+ { -+ return (k_begin == 0); -+ } -+ -+ // Whether this block will perform the last iteration of this tile -+ CUTLASS_DEVICE -+ bool tile_finished(Params const ¶ms) -+ { -+ return (k_end == params.block_mapping.problem_size.k()); -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage -+ { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// GEMM problem parameters -+ Params const ¶ms; -+ -+ /// Shared storage reference -+ SharedStorage &shared_storage; -+ -+ /// ID within the threadblock -+ int thread_idx; -+ -+ /// ID of warp -+ int warp_idx; -+ -+ /// ID of each thread within a warp -+ int lane_idx; -+ -+ /// Threadblock scoped epilogue -+ Epilogue epilogue; -+ -+ -+public: -+ -+ // -+ // Host-only dispatch API -+ // -+ -+ /// Determines whether the GEMM problem size satisfies this kernel's -+ /// alignment requirements -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()"); -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Determines whether the GEMM problem satisfies this kernel's -+ /// alignment requirements -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+protected: -+ -+ // -+ // Device-only utility methods -+ // -+ -+ /// Iterator for fetching tile fragments from A -+ CUTLASS_DEVICE -+ typename Mma::IteratorA init_iterator_A( -+ TileWorkDesc &tile_work, -+ GemmUniversalMode mode) -+ { -+ // The input A matrix -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ -+ // Update input pointers based on batched/array mode -+ if (mode == GemmUniversalMode::kBatched) { -+ ptr_A += tile_work.tiled_coord.k() * params.batch_stride_A; -+ } -+ if (mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[tile_work.tiled_coord.k()]; -+ } -+ -+ int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM; -+ int m_end = params.block_mapping.problem_size.m(); -+ return Mma::IteratorA( -+ params.params_A, -+ ptr_A, -+ { m_end, tile_work.k_end }, -+ threadIdx.x, -+ { m_begin, tile_work.k_begin }); -+ -+ } -+ -+ -+ /// Iterator for fetching tile fragments from B -+ CUTLASS_DEVICE -+ typename Mma::IteratorB init_iterator_B( -+ TileWorkDesc &tile_work, -+ GemmUniversalMode mode) -+ { -+ // The input B matrix -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // Update input pointers based on batched/array mode -+ if (mode == GemmUniversalMode::kBatched) { -+ ptr_B += tile_work.tiled_coord.k() * params.batch_stride_B; -+ } -+ if (mode == GemmUniversalMode::kArray) { -+ ptr_B = static_cast(params.ptr_B)[tile_work.tiled_coord.k()]; -+ } -+ -+ int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN; -+ int n_end = params.block_mapping.problem_size.n(); -+ return Mma::IteratorB( -+ params.params_B, -+ ptr_B, -+ { tile_work.k_end, n_end }, -+ threadIdx.x, -+ { tile_work.k_begin, n_begin }); -+ } -+ -+ -+ CUTLASS_DEVICE -+ void init_dp_tile_work( -+ TileWorkDesc &tile_work, -+ int tile_idx) -+ { -+ // The linear tile index -+ tile_work.tile_idx = tile_idx; -+ -+ // The first global-scoped MAC-iteration this threadblock will perform for this tile -+ tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile(); -+ -+ // The number of MAC-iterations this threadblock will perform for this tile -+ tile_work.k_iters_remaining = params.block_mapping.iters_per_tile(); -+ -+ // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile -+ tile_work.k_begin = 0; -+ -+ // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile -+ tile_work.k_end = params.block_mapping.problem_size.k(); -+ -+ // The location of this tile (in threadblock-tile coordinates) in the output matrix -+ tile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx); -+ } -+ -+ -+ CUTLASS_DEVICE -+ void init_sk_tile_work( -+ TileWorkDesc &tile_work, -+ int tile_idx, -+ int block_iter_begin, -+ int block_iter_end) -+ { -+ // The linear tile index -+ tile_work.tile_idx = tile_idx; -+ -+ // The first global-scoped MAC-iteration for this tile -+ int tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile(); -+ -+ // The first global-scoped MAC-iteration this threadblock will perform for this tile -+ tile_work.iter_begin = max(block_iter_begin, tile_iter_begin); -+ -+ // The first tile-scoped MAC-iteration this threadblock will perform for this tile -+ int k_iter_begin = tile_work.iter_begin - tile_iter_begin; -+ -+ // The last (one past) tile-scoped MAC-iteration this threadblock will perform for this tile -+ int k_iter_end = block_iter_end - tile_iter_begin; -+ -+ // The number of MAC-iterations this threadblock will perform for this tile -+ tile_work.k_iters_remaining = k_iter_end - k_iter_begin; -+ -+ // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile -+ tile_work.k_begin = k_iter_begin * Mma::Shape::kK; -+ -+ // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile -+ tile_work.k_end = min( -+ params.block_mapping.problem_size.k(), // extent of k domain -+ (k_iter_end * Mma::Shape::kK)); // extent of the threadblock's global iteration assignment -+ -+ // The location of this tile (in threadblock-tile coordinates) in the output matrix -+ tile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx); -+ } -+ -+ -+ /// Share accumulators with peers -+ CUTLASS_DEVICE -+ void share_accumulators( -+ AccumulatorTile const &accumulator_tile, -+ int block_idx, -+ int first_block_idx) -+ { -+ AccumulatorTile *accum_tile_workspace = reinterpret_cast(params.partials_workspace); -+ -+ int accum_tile_offset = first_block_idx * kThreadCount; -+ -+ if (block_idx == first_block_idx) -+ { -+ // First peer initializes the workspace partials -+ BlockStripedReduceT::store(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx); -+ } -+ else -+ { -+ // Subsequent peers atomically accumulate into the workspace partials -+ if (ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) -+ { -+ // Non-deterministic reduction order: wait for the first peer to have initialized the partials before we add to them -+ Barrier::wait_lt(params.barrier_workspace, thread_idx, first_block_idx, 1); -+ } -+ else -+ { -+ // Turnstile reduction order: wait until the previous peer has written -+ int wait_count = block_idx - first_block_idx; -+ Barrier::wait_eq(params.barrier_workspace, thread_idx, first_block_idx, wait_count); -+ } -+ -+ // Perform reduction in workspace -+ BlockStripedReduceT::reduce(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx); -+ } -+ -+ // Signal our arrival -+ Barrier::arrive_inc(params.barrier_workspace, thread_idx, first_block_idx); -+ } -+ -+ -+ /// Acquire accumulators from peers -+ CUTLASS_DEVICE -+ void acquire_accumulators( -+ AccumulatorTile &accumulator_tile, -+ int block_idx, -+ int first_block_idx) -+ { -+ AccumulatorTile *accum_tile_workspace = reinterpret_cast(params.partials_workspace); -+ -+ // Wait for arrival -+ int num_carry_in = block_idx - first_block_idx; -+ Barrier::wait_eq_reset(params.barrier_workspace, thread_idx, first_block_idx, num_carry_in); -+ -+ // Load and add peer-partials accumulator tile to local accumulator tile -+ int accum_tile_offset = first_block_idx * kThreadCount; -+ BlockStripedReduceT::load_add(accumulator_tile, accum_tile_workspace + accum_tile_offset, thread_idx); -+ } -+ -+ -+ /// Perform epilogue computations and output -+ CUTLASS_DEVICE -+ void do_epilogue( -+ TileWorkDesc &tile_work, -+ AccumulatorTile &accumulator_tile) -+ { -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // Update pointers for batched/array mode(s) -+ if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += tile_work.tiled_coord.k() * params.batch_stride_C; -+ ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D; -+ } -+ if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[tile_work.tiled_coord.k()]; -+ ptr_D = static_cast(params.ptr_D)[tile_work.tiled_coord.k()]; -+ } -+ -+ // Location of this tile in item-coords -+ MatrixCoord threadblock_item_begin( -+ tile_work.tiled_coord.m() * Mma::Shape::kM, -+ tile_work.tiled_coord.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.block_mapping.problem_size.mn(), -+ thread_idx, -+ threadblock_item_begin); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.block_mapping.problem_size.mn(), -+ thread_idx, -+ threadblock_item_begin); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue.unified( -+ EpilogueOutputOp(params.output_op), -+ iterator_D, -+ accumulator_tile, -+ iterator_C); -+ } -+ -+ -+ CUTLASS_DEVICE -+ void separate_reduction(int reduce_idx) -+ { -+ int peer_idx_begin, peer_idx_last, reduce_tile_idx, reduce_fragment_idx; -+ -+ // Reduce by sk-tile (every tile contributed to by one or more blocks) -+ reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments; -+ reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments; -+ -+ int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile(); -+ int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile() - 1; -+ -+ peer_idx_begin = params.block_mapping.get_sk_block_idx(iter_tile_first); -+ peer_idx_last = params.block_mapping.get_sk_block_idx(iter_tile_last); -+ -+ // Wait for peers to complete -+ int peer_idx_end = peer_idx_last + 1; -+ int num_peers = peer_idx_end - peer_idx_begin; -+ Barrier::wait_eq_reset( -+ params.barrier_workspace, -+ thread_idx, -+ (reduce_tile_idx * Epilogue::kAccumulatorFragments) + reduce_fragment_idx, -+ num_peers); -+ -+ /// The location of this tile (in threadblock-tile coordinates) in the output matrix -+ GemmCoord tiled_coord = params.block_mapping.get_tile_offset(reduce_tile_idx); -+ -+ // Location of this tile in item-coords -+ MatrixCoord threadblock_item_begin( -+ tiled_coord.m() * Mma::Shape::kM, -+ tiled_coord.n() * Mma::Shape::kN -+ ); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.block_mapping.problem_size.mn(), -+ thread_idx, -+ threadblock_item_begin); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.block_mapping.problem_size.mn(), -+ thread_idx, -+ threadblock_item_begin); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue.reduce( -+ peer_idx_begin, -+ peer_idx_end, -+ reduce_fragment_idx, -+ params.partials_workspace, -+ EpilogueOutputOp(params.output_op), -+ iterator_D, -+ iterator_C); -+ } -+ -+ -+ CUTLASS_DEVICE -+ void process_tile( -+ TileWorkDesc tile_work, -+ int block_idx, -+ int dp_start_block_idx, -+ int block_iter_begin) -+ { -+ // Initialize input iterators -+ typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode); -+ typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode); -+ -+ // Initialize accumulators -+ AccumulatorTile accumulator_tile; -+ accumulator_tile.clear(); -+ -+ // Perform this tile's range of multiply-accumulate (MAC) iterations -+ Mma mma( -+ shared_storage.main_loop, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile); -+ -+ if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) || -+ (params.block_mapping.reduction_blocks == 0) || -+ (block_idx >= dp_start_block_idx)) -+ { -+ // -+ // Cooperative SK peer reduction or DP block -+ // -+ -+ int first_block_idx = params.block_mapping.get_first_block_idx(tile_work.tile_idx, block_idx); -+ -+ if (!tile_work.tile_finished(params)) { -+ // Non "finishing" SK blocks must share their partial accumulator sums through global scratch workspace -+ share_accumulators(accumulator_tile, block_idx, first_block_idx); -+ } -+ else -+ { -+ // DP blocks and "finishing" SK blocks must perform epilogue operations and write the output tile -+ if (!tile_work.tile_started()) -+ { -+ // A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks -+ acquire_accumulators(accumulator_tile, block_idx, first_block_idx); -+ } -+ -+ do_epilogue(tile_work, accumulator_tile); -+ } -+ } -+ else -+ { -+ // -+ // Separate peer reduction -+ // -+ -+ // Share accumulator partial sums with peer threadblock(s) through scratch workspace -+ epilogue.share(block_idx, params.partials_workspace, accumulator_tile, tile_work.tile_started()); -+ -+ // Signal arrival -+ Barrier::arrive_range_inc( -+ params.barrier_workspace, -+ thread_idx, -+ tile_work.tile_idx * Epilogue::kAccumulatorFragments, -+ Epilogue::kAccumulatorFragments); -+ } -+ } -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void gemm() -+ { -+ // Initialize block's iteration range -+ int tile_idx, block_iter_begin, block_iters_remaining; -+ -+ int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region(); -+ int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms; -+ int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks; -+ int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks; -+ -+ int block_idx = params.block_mapping.get_block_idx(); -+ if (block_idx < sk_padding_start_block_idx) -+ { -+ // This is a SK block -+ int block_iter_end; -+ params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end); -+ block_iters_remaining = block_iter_end - block_iter_begin; -+ -+ tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1); -+ } -+ else if (block_idx < dp_start_block_idx) -+ { -+ // This is a filler block -+ return; -+ } -+ else if (block_idx < reduce_start_block_idx) -+ { -+ // This is a DP block -+ int dp_block_idx = block_idx - dp_start_block_idx; -+ int first_dp_tile = (params.block_mapping.cohort_raster) ? 0 : params.block_mapping.sk_tiles; -+ -+ // Blocks in first DP wave get configured number of tiles -+ tile_idx = first_dp_tile + dp_block_idx; -+ int tile_allottment = params.block_mapping.dp_first_wave_tiles; -+ -+ // Blocks in subsequent DP waves get 1 tile -+ if (dp_block_idx >= params.block_mapping.avail_sms) { -+ tile_allottment = 1; -+ tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms; -+ } -+ -+ block_iter_begin = 0; -+ block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment; -+ } -+ -+ else if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed) && -+ (block_idx < grid_padding_start_block_idx)) -+ { -+ // This is a reduction threadblock -+ int reduce_block_idx = block_idx - reduce_start_block_idx; -+ separate_reduction(reduce_block_idx); -+ return; -+ } -+ else -+ { -+ // This is a filler block -+ return; -+ } -+ -+ // Iteration-processing loop body -+ CUTLASS_PRAGMA_NO_UNROLL -+ while (true) -+ { -+ // Initialize tile work descriptor -+ TileWorkDesc tile_work; -+ if (block_idx >= dp_start_block_idx) -+ { -+ init_dp_tile_work(tile_work, tile_idx); -+ -+ // DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1) -+ if ((tile_idx < params.block_mapping.sk_tiles) || -+ (tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) || -+ (tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n())) -+ { -+ break; -+ } -+ } -+ else -+ { -+ init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining); -+ } -+ -+ // Perform this block's share of work for this tile -+ process_tile(tile_work, block_idx, dp_start_block_idx, block_iter_begin); -+ -+ // Update remaining work for this block -+ block_iters_remaining -= tile_work.k_iters_remaining; -+ if (block_iters_remaining == 0) { -+ // Done -+ break; -+ } -+ -+ // Continue to next tile -+ __syncthreads(); -+ -+ if (block_idx >= dp_start_block_idx) -+ { -+ // DP block consume their tiles at stride -+ tile_idx += params.block_mapping.avail_sms; -+ } -+ else -+ { -+ // SK blocks consume their tiles in backwards order -+ tile_idx--; -+ } -+ } -+ -+ } -+ -+ -+ /// Executes one DP-only GEMM -+ CUTLASS_DEVICE -+ void gemm_dp() -+ { -+ int block_idx = blockIdx.x; -+ int tile_idx = block_idx; -+ -+ TileWorkDesc tile_work; -+ tile_work.tile_idx = tile_idx; -+ tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile(); -+ tile_work.k_iters_remaining = params.block_mapping.iters_per_tile(); -+ tile_work.k_begin = 0; -+ tile_work.k_end = params.block_mapping.problem_size.k(); -+ tile_work.tiled_coord = params.block_mapping.get_tile_offset_row_major(tile_work.tile_idx); -+ -+ // Initialize input iterators -+ typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode); -+ typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode); -+ -+ // Initialize accumulators -+ AccumulatorTile accumulator_tile; -+ accumulator_tile.clear(); -+ -+ // Perform this tile's range of multiply-accumulate (MAC) iterations -+ Mma mma( -+ shared_storage.main_loop, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile); -+ -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // Location of this tile in item-coords -+ MatrixCoord threadblock_item_begin( -+ tile_work.tiled_coord.m() * Mma::Shape::kM, -+ tile_work.tiled_coord.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.block_mapping.problem_size.mn(), -+ thread_idx, -+ threadblock_item_begin); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ EpilogueOutputOp(params.output_op), -+ iterator_D, -+ accumulator_tile); -+ } -+ -+ -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmUniversalStreamk op(params, shared_storage); -+ op(); -+ } -+ -+ -+ // Constructor -+ CUTLASS_DEVICE -+ GemmUniversalStreamk( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ : -+ params(params), -+ shared_storage(shared_storage), -+ thread_idx(threadIdx.x), -+ warp_idx(__shfl_sync(0xffffffff, threadIdx.x / 32, 0)), // broadcast the warp_id computed by lane 0 to ensure dependent code -+ lane_idx(threadIdx.x % 32), -+ epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx) -+ {} -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()() -+ { -+#if (__CUDACC_VER_MAJOR__ > 10) -+ if (params.quick_dp) -+ { -+ // Simple (low-bootstrap latency) GEMM code path for data-parallel only. (kBatched and kArray -+ // modes will only be launched using a data-parallel configurations) -+ gemm_dp(); -+ return; -+ } -+#endif -+ -+ // Generic SK code path -+ gemm(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h -new file mode 100644 -index 0000000..8f67bd4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h -@@ -0,0 +1,1487 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Gemm kernel with fused reduction operation. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/layout.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool IsSingleSource = Epilogue_::kIsSingleSource -+> -+struct GemmWithFusedEpilogue; -+ -+// GemmWithFusedEpilogue with two sources -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmWithFusedEpilogue { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value -+ ); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase{ -+ -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C1; -+ void const * ptr_C2; -+ void * ptr_D; -+ -+ void * ptr_Vector; -+ void * ptr_Tensor; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C1; -+ int64_t batch_stride_C2; -+ int64_t batch_stride_Vector; -+ int64_t batch_stride_Tensor; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldc1; -+ typename LayoutC::Stride::Index ldc2; -+ typename LayoutC::Stride::Index ldd; -+ typename LayoutC::Stride::Index ldr; -+ typename LayoutC::Stride::Index ldt; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C1(nullptr), -+ ptr_C2(nullptr), -+ ptr_D(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C1, -+ void const * ptr_C2, -+ void * ptr_D, -+ void * ptr_Vector, -+ void * ptr_Tensor, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C1, -+ int64_t batch_stride_C2, -+ int64_t batch_stride_D, -+ int64_t batch_stride_Vector, -+ int64_t batch_stride_Tensor, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutB::Stride::Index ldb, -+ typename LayoutC::Stride::Index ldc1, -+ typename LayoutC::Stride::Index ldc2, -+ typename LayoutC::Stride::Index ldd, -+ typename LayoutC::Stride::Index ldr, -+ typename LayoutC::Stride::Index ldt) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C1(ptr_C1), ptr_C2(ptr_C2), ptr_D(ptr_D), -+ ptr_Vector(ptr_Vector), -+ ptr_Tensor(ptr_Tensor), -+ batch_stride_A(batch_stride_A), -+ batch_stride_B(batch_stride_B), -+ batch_stride_C1(batch_stride_C1), -+ batch_stride_C2(batch_stride_C2), -+ batch_stride_Vector(batch_stride_Vector), -+ batch_stride_Tensor(batch_stride_Tensor), -+ lda(lda), ldb(ldb), ldc1(ldc1), ldc2(ldc2), ldd(ldd), ldr(ldr), ldt(ldt) -+ { -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); -+ CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); -+ CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); -+ CUTLASS_TRACE_HOST(" ldr: " << this->ldr); -+ CUTLASS_TRACE_HOST(" ldt: " << this->ldt); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_C1; -+ typename Epilogue::OutputTileIterator::Params params_C2; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::TensorTileIterator::Params params_Tensor; -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C1; -+ void * ptr_C2; -+ void * ptr_D; -+ -+ void * ptr_Vector; -+ typename LayoutC::Stride::Index ldr; -+ -+ void * ptr_Tensor; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C1; -+ int64_t batch_stride_C2; -+ int64_t batch_stride_Vector; -+ int64_t batch_stride_Tensor; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A(args.lda), -+ params_B(args.ldb), -+ params_C1(args.ldc1), -+ params_C2(args.ldc2), -+ params_D(args.ldd), -+ params_Tensor(args.ldt), -+ output_op(args.epilogue), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C1(const_cast(args.ptr_C1)), -+ ptr_C2(const_cast(args.ptr_C2)), -+ ptr_D(args.ptr_D), -+ ptr_Vector(args.ptr_Vector), -+ ldr(args.ldr), -+ ptr_Tensor(args.ptr_Tensor), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C1(args.batch_stride_C1), -+ batch_stride_C2(args.batch_stride_C2), -+ batch_stride_Vector(args.batch_stride_Vector), -+ batch_stride_Tensor(args.batch_stride_Tensor) -+ { -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); -+ CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); -+ CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); -+ CUTLASS_TRACE_HOST(" ldr: " << this->ldr); -+ CUTLASS_TRACE_HOST(" ldt: " << args.ldt); -+ } -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ CUTLASS_HOST_DEVICE -+ void update(Arguments const &args) -+ { -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C1 = const_cast(args.ptr_C1); -+ ptr_C2 = const_cast(args.ptr_C2); -+ ptr_D = args.ptr_D; -+ -+ ptr_Vector = args.ptr_Vector; -+ ldr = args.ldr; -+ ptr_Tensor = args.ptr_Tensor; -+ -+ output_op = args.epilogue; -+ -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()"); -+ CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); -+ CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); -+ CUTLASS_TRACE_HOST(" ldr: " << this->ldr); -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::can_implement()"); -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmWithFusedEpilogue op; -+ op(params, shared_storage); -+ } -+ -+ #define SPLIT_K_ENABLED 1 -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ -+ #if SPLIT_K_ENABLED -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ #endif -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C1 = static_cast(params.ptr_C1); -+ ElementC *ptr_C2 = static_cast(params.ptr_C2); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); -+ -+ // Define the reduction output pointer and move to the appropriate place -+ typename Epilogue::ElementVector *ptr_Vector = -+ static_cast(params.ptr_Vector); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // -+ // Special path when split-K not enabled. -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) { -+ -+ // Tile iterators loading from source tensors. -+ typename Epilogue::OutputTileIterator iterator_C1( -+ params.params_C1, -+ ptr_C1, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_C2( -+ params.params_C2, -+ ptr_C2, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Additional tensor to load from -+ typename Epilogue::TensorTileIterator tensor_iterator( -+ params.params_Tensor, -+ // Only the final block outputs Tensor -+ ptr_Tensor, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Move to appropriate location for this output tile -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, -+ ptr_Vector, -+ iterator_D, -+ accumulators, -+ iterator_C1, -+ iterator_C2, -+ tensor_iterator, -+ params.problem_size.mn(), -+ threadblock_offset); -+ -+ return; -+ } -+ -+ // -+ // Slower path when split-K or batching is needed -+ // -+ -+ -+ #if SPLIT_K_ENABLED -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1; -+ if (ptr_C2) { -+ ptr_C2 += threadblock_tile_offset.k() * params.batch_stride_C2; -+ } -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ if (ptr_Tensor) { -+ ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; -+ } -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C1 = static_cast(params.ptr_C1)[threadblock_tile_offset.k()]; -+ if (ptr_C2) { -+ ptr_C2 = static_cast(params.ptr_C2)[threadblock_tile_offset.k()]; -+ } -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ if (ptr_Tensor) { -+ ptr_Tensor = static_cast(params.ptr_Tensor)[threadblock_tile_offset.k()]; -+ } -+ if (ptr_Vector) { -+ ptr_Vector = static_cast(params.ptr_Vector)[threadblock_tile_offset.k()]; -+ } -+ } -+ #endif -+ -+ // Tile iterators loading from source tensors. -+ typename Epilogue::OutputTileIterator iterator_C1( -+ params.params_C1, -+ ptr_C1, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_C2( -+ params.params_C2, -+ ptr_C2, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Additional tensor to load from -+ typename Epilogue::TensorTileIterator tensor_iterator( -+ params.params_Tensor, -+ // Only the final block outputs Tensor -+ ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && -+ (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) -+ ? nullptr -+ : ptr_Tensor, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ #if SPLIT_K_ENABLED -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C1 = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ } -+ #endif -+ -+ // Move to appropriate location for this output tile -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, -+ // Only the final block uses Vector -+ ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && -+ (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) -+ ? nullptr -+ : ptr_Vector, -+ iterator_D, -+ accumulators, -+ iterator_C1, -+ iterator_C2, -+ tensor_iterator, -+ params.problem_size.mn(), -+ threadblock_offset); -+ -+ // -+ // Release the semaphore -+ // -+ -+ #if SPLIT_K_ENABLED -+ if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ #endif -+ } -+}; -+ -+// GemmWithFusedEpilogue with one source -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmWithFusedEpilogue { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value -+ ); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase -+ { -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ void * ptr_Vector; -+ void * ptr_Tensor; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_Vector; -+ int64_t batch_stride_Tensor; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldc; -+ typename LayoutC::Stride::Index ldd; -+ typename LayoutC::Stride::Index ldr; -+ typename LayoutC::Stride::Index ldt; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ void * ptr_Vector, -+ void * ptr_Tensor, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ int64_t batch_stride_Vector, -+ int64_t batch_stride_Tensor, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutB::Stride::Index ldb, -+ typename LayoutC::Stride::Index ldc, -+ typename LayoutC::Stride::Index ldd, -+ typename LayoutC::Stride::Index ldr, -+ typename LayoutC::Stride::Index ldt) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ ptr_Vector(ptr_Vector), -+ ptr_Tensor(ptr_Tensor), -+ batch_stride_A(batch_stride_A), -+ batch_stride_B(batch_stride_B), -+ batch_stride_C(batch_stride_C), -+ batch_stride_Vector(batch_stride_Vector), -+ batch_stride_Tensor(batch_stride_Tensor), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt) -+ { -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); -+ CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); -+ CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); -+ CUTLASS_TRACE_HOST(" ldr: " << this->ldr); -+ CUTLASS_TRACE_HOST(" ldt: " << this->ldt); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::TensorTileIterator::Params params_Tensor; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ void * ptr_Vector; -+ typename LayoutC::Stride::Index ldr; -+ -+ void * ptr_Tensor; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_Vector; -+ int64_t batch_stride_Tensor; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A(args.lda), -+ params_B(args.ldb), -+ params_C(args.ldc), -+ params_D(args.ldd), -+ params_Tensor(args.ldt), -+ output_op(args.epilogue), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(args.ptr_D), -+ ptr_Vector(args.ptr_Vector), -+ ldr(args.ldr), -+ ptr_Tensor(args.ptr_Tensor), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_Vector(args.batch_stride_Vector), -+ batch_stride_Tensor(args.batch_stride_Tensor) -+ { -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); -+ CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); -+ CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); -+ CUTLASS_TRACE_HOST(" ldr: " << this->ldr); -+ CUTLASS_TRACE_HOST(" ldt: " << args.ldt); -+ } -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ CUTLASS_HOST_DEVICE -+ void update(Arguments const &args) -+ { -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ ptr_Vector = args.ptr_Vector; -+ ldr = args.ldr; -+ ptr_Tensor = args.ptr_Tensor; -+ -+ output_op = args.epilogue; -+ -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()"); -+ CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); -+ CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); -+ CUTLASS_TRACE_HOST(" ldr: " << this->ldr); -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::can_implement()"); -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmWithFusedEpilogue op; -+ op(params, shared_storage); -+ } -+ -+ #define SPLIT_K_ENABLED 1 -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ -+ #if SPLIT_K_ENABLED -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ #endif -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); -+ -+ // Define the reduction output pointer and move to the appropriate place -+ typename Epilogue::ElementVector *ptr_Vector = -+ static_cast(params.ptr_Vector); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // -+ // Special path when split-K not enabled. -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) { -+ -+ // Tile iterators loading from source tensors. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Additional tensor to load from -+ typename Epilogue::TensorTileIterator tensor_iterator( -+ params.params_Tensor, -+ // Only the final block outputs Tensor -+ ptr_Tensor, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Move to appropriate location for this output tile -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, -+ ptr_Vector, -+ iterator_D, -+ accumulators, -+ iterator_C, -+ tensor_iterator, -+ params.problem_size.mn(), -+ threadblock_offset); -+ -+ return; -+ } -+ -+ // -+ // Slower path when split-K or batching is needed -+ // -+ -+ -+ #if SPLIT_K_ENABLED -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ if (ptr_Tensor) { -+ ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; -+ } -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ if (ptr_Tensor) { -+ ptr_Tensor = static_cast(params.ptr_Tensor)[threadblock_tile_offset.k()]; -+ } -+ if (ptr_Vector) { -+ ptr_Vector = static_cast(params.ptr_Vector)[threadblock_tile_offset.k()]; -+ } -+ } -+ #endif -+ -+ // Tile iterators loading from source tensors. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Additional tensor to load from -+ typename Epilogue::TensorTileIterator tensor_iterator( -+ params.params_Tensor, -+ // Only the final block outputs Tensor -+ ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && -+ (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) -+ ? nullptr -+ : ptr_Tensor, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ #if SPLIT_K_ENABLED -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ } -+ #endif -+ -+ // Move to appropriate location for this output tile -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, -+ // Only the final block uses Vector -+ ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && -+ (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) -+ ? nullptr -+ : ptr_Vector, -+ iterator_D, -+ accumulators, -+ iterator_C, -+ tensor_iterator, -+ params.problem_size.mn(), -+ threadblock_offset); -+ -+ // -+ // Release the semaphore -+ // -+ -+ #if SPLIT_K_ENABLED -+ if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_with_k_reduction.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_with_k_reduction.h -new file mode 100644 -index 0000000..8e00e18 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_with_k_reduction.h -@@ -0,0 +1,695 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename EpilogueGemmKReduction_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmWithKReduction { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using EpilogueGemmKReduction = EpilogueGemmKReduction_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ using LayoutGemmKReduction = cutlass::layout::PitchLinear; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); -+ -+ static int const kReduceKForA = Mma::kReduceKForA; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase -+ { -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ void * ptr_gemm_k_reduction; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_gemm_k_reduction; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldc; -+ typename LayoutC::Stride::Index ldd; -+ typename LayoutGemmKReduction::Stride::Index ld_gemm_k_reduction; -+ -+ // -+ // Methods -+ // -+ -+ Arguments() : -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ ptr_gemm_k_reduction(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ void * ptr_gemm_k_reduction, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ int64_t batch_stride_gemm_k_reduction, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutB::Stride::Index ldb, -+ typename LayoutC::Stride::Index ldc, -+ typename LayoutC::Stride::Index ldd, -+ typename LayoutGemmKReduction::Stride::Index ld_gemm_k_reduction) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), ptr_gemm_k_reduction(ptr_gemm_k_reduction), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_gemm_k_reduction(batch_stride_gemm_k_reduction), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ld_gemm_k_reduction(ld_gemm_k_reduction) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ void * ptr_gemm_k_reduction; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_gemm_k_reduction; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A(args.lda), -+ params_B(args.ldb), -+ params_C(args.ldc), -+ params_D(args.ldd), -+ output_op(args.epilogue), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_gemm_k_reduction(args.batch_stride_gemm_k_reduction), -+ ptr_D(args.ptr_D), -+ ptr_gemm_k_reduction(args.ptr_gemm_k_reduction) -+ {} -+ -+ /// Assign and initialize the specified workspace buffer. Assumes -+ /// the memory allocated to workspace is at least as large as get_workspace_size(). -+ Status init_workspace( -+ void *workspace, -+ cudaStream_t stream = nullptr) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversal::Params::Params() - problem_size: " << this->problem_size); -+ -+ if (this->mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D = workspace; -+ ptr_gemm_k_reduction = static_cast(workspace) -+ + sizeof(ElementC) * size_t(this->batch_stride_D) * size_t(this->grid_tiled_shape.k()); -+ -+ return Status::kSuccess; -+ } -+ -+ return ParamsBase::init_workspace(workspace, stream); -+ } -+ -+ /// Returns the workspace size (in bytes) needed for this problem geometry -+ size_t get_workspace_size() const -+ { -+ size_t workspace_bytes = ParamsBase::get_workspace_size(); -+ -+ if (this->mode == GemmUniversalMode::kGemmSplitKParallel) -+ { -+ // Split-K parallel always requires a temporary workspace -+ workspace_bytes += -+ sizeof(ElementC) * -+ size_t(batch_stride_gemm_k_reduction) * -+ size_t(this->grid_tiled_shape.k()); -+ } -+ -+ return workspace_bytes; -+ } -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ void update(Arguments const &args) -+ { -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ ptr_gemm_k_reduction = args.ptr_gemm_k_reduction; -+ -+ output_op = args.epilogue; -+ -+ CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for operand A"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for operand B"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for operand C"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmWithKReduction op; -+ op(params, shared_storage); -+ } -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ typename Mma::FragmentReduction gemm_k_accumulators; -+ -+ gemm_k_accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators, -+ gemm_k_accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ ElementC *ptr_gemm_k_reduction = static_cast(params.ptr_gemm_k_reduction); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ ptr_gemm_k_reduction += threadblock_tile_offset.k() * params.batch_stride_gemm_k_reduction; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ if ((kReduceKForA && threadblock_tile_offset.n() == 0) -+ || (!kReduceKForA && threadblock_tile_offset.m() == 0)) { -+ -+ int warp_idx_mn = warp_idx % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); -+ int warp_idx_m = warp_idx_mn % Mma::Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Mma::Base::WarpCount::kM; -+ -+ if ((kReduceKForA && warp_idx_n == 0) -+ || (!kReduceKForA && warp_idx_m == 0)) { -+ -+ int reduction_warp_idx = kReduceKForA ? warp_idx_m : warp_idx_n; -+ int reduction_threadblock_offset = kReduceKForA ? threadblock_tile_offset.m() : -+ threadblock_tile_offset.n(); -+ int reduction_vector_size = kReduceKForA ? params.problem_size.m() -+ : params.problem_size.n(); -+ EpilogueGemmKReduction epilogue_gemm_k_reduction(thread_idx, -+ reduction_warp_idx, -+ lane_idx, -+ reduction_threadblock_offset, -+ ptr_gemm_k_reduction); -+ epilogue_gemm_k_reduction( -+ reduction_vector_size, -+ gemm_k_accumulators, -+ params.mode == GemmUniversalMode::kGemm -+ && (params.grid_tiled_shape.k() > 1) -+ && (threadblock_tile_offset.k() > 0)); -+ } -+ } -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemv.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemv.h -new file mode 100644 -index 0000000..acde3d5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemv.h -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename ElementC_, -+ typename ElementAccumulator_, -+ typename EpilogueOutputOp_ -+> -+struct Gemv { -+public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using TensorRefA = TensorRef; -+ -+ static_assert(platform::is_same::value, -+ "Only supported for column-major A matrix"); -+ -+ using ElementB = ElementB_; -+ using ElementC = ElementC_; -+ -+ using ElementAccumulator = ElementAccumulator_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ static int const kThreadCount = 32; -+ static int const kStages = 1; -+ -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ static int const kAlignmentC = 1; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ MatrixCoord problem_size; -+ int32_t batch_count; -+ typename EpilogueOutputOp::Params output_op; -+ -+ TensorRefA ref_A; -+ -+ ElementB const *ptr_B; -+ ElementC const *ptr_C; -+ ElementC *ptr_D; -+ -+ int64_t inc_B; -+ int64_t inc_C; -+ int64_t inc_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): batch_count(0) { } -+ -+ Arguments( -+ MatrixCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params output_op, -+ TensorRefA ref_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t inc_B, -+ int64_t inc_C, -+ int64_t inc_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D -+ ): -+ problem_size(problem_size), -+ batch_count(batch_count), -+ output_op(output_op), -+ ref_A(ref_A), -+ ptr_B(static_cast(ptr_B)), -+ ptr_C(static_cast(ptr_C)), -+ ptr_D(static_cast(ptr_D)), -+ inc_B(inc_B), -+ inc_C(inc_C), -+ inc_D(inc_D), -+ batch_stride_A(batch_stride_A), -+ batch_stride_B(batch_stride_B), -+ batch_stride_C(batch_stride_C), -+ batch_stride_D(batch_stride_D) -+ { } -+ -+ Arguments( -+ MatrixCoord problem_size, -+ typename EpilogueOutputOp::Params output_op, -+ TensorRefA ref_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t inc_B, -+ int64_t inc_C, -+ int64_t inc_D -+ ): -+ Arguments( -+ problem_size, -+ 1, -+ output_op, -+ ref_A, -+ ptr_B, -+ ptr_C, -+ ptr_D, -+ inc_B, -+ inc_C, -+ inc_D, -+ 1, -+ 1, -+ 1, -+ 1) -+ { } -+ -+ Status update(Arguments const &args) { -+ output_op = args.output_op; -+ ref_A = ref_A; -+ ptr_B = args.ptr_B; -+ ptr_C = args.ptr_C; -+ ptr_D = args.ptr_D; -+ -+ return Status::kSuccess; -+ } -+ }; -+ -+ using Params = Arguments; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ Gemv() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(cutlass::MatrixCoord const & problem_size) { -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Loop over batch indices -+ for (int batch_idx = blockIdx.z; batch_idx < params.batch_count; batch_idx += gridDim.z) { -+ -+ int i = blockIdx.x * kThreadCount + threadIdx.x; -+ -+ ElementA const *ptr_A = params.ref_A.data() + i; -+ ElementB const *ptr_B = params.ptr_B; -+ -+ ptr_A += batch_idx * params.batch_stride_A; -+ ptr_B += batch_idx * params.batch_stride_B; -+ -+ ElementAccumulator accum = ElementAccumulator(); -+ -+ // Compute inner product -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int k = 0; k < params.problem_size.column(); ++k) { -+ -+ // Fetch from A -+ ElementA a = ElementA(); -+ if (i < params.problem_size.row()) { -+ a = *ptr_A; -+ } -+ ptr_A += params.ref_A.stride(0); -+ -+ // Fetch from B -+ ElementB b = *ptr_B; -+ ptr_B += params.inc_B; -+ -+ // Math -+ accum += ElementAccumulator(a) * ElementAccumulator(b); -+ } -+ -+ // -+ // Epilogue phase -+ // -+ -+ ElementC const *ptr_C = params.ptr_C + i * params.inc_C + batch_idx * params.batch_stride_C; -+ ElementC *ptr_D = params.ptr_D + i * params.inc_D + batch_idx * params.batch_stride_D; -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ typename EpilogueOutputOp::FragmentAccumulator accum_fragment; -+ typename EpilogueOutputOp::FragmentOutput source_fragment; -+ typename EpilogueOutputOp::FragmentOutput output_fragment; -+ -+ accum_fragment[0] = accum; -+ -+ if (i < params.problem_size.row()) { -+ if (output_op.is_source_needed()) { -+ source_fragment[0] = *ptr_C; -+ output_fragment = output_op(accum_fragment, source_fragment); -+ } -+ else { -+ output_fragment = output_op(accum_fragment); -+ } -+ -+ *ptr_D = output_fragment[0]; -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemv_batched_strided.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemv_batched_strided.h -new file mode 100755 -index 0000000..613a279 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemv_batched_strided.h -@@ -0,0 +1,244 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+namespace detail -+{ -+ template -+ struct GemvBatchedStridedEpilogueScaling -+ { -+ ElementAlphaBeta const & alpha; -+ ElementAlphaBeta const & beta; -+ -+ CUTLASS_DEVICE -+ GemvBatchedStridedEpilogueScaling(ElementAlphaBeta& alpha_, ElementAlphaBeta& beta_) : -+ alpha(alpha_), beta(beta_) -+ { } -+ -+ template -+ CUTLASS_DEVICE -+ void operator()(FragmentAccumulator& accumulators, -+ FragmentCD const& fragment_C, -+ FragmentCD& fragment_D) const -+ { -+ using AccType = typename FragmentAccumulator::value_type; -+ using CDType = typename FragmentCD::value_type; -+ -+ static_assert(FragmentCD::kElements == FragmentAccumulator::kElements, -+ "Mistmatch in fragment sizes."); -+ -+ for (int i = 0; i < FragmentCD::kElements; ++i) -+ { -+ if (BetaIsZero) -+ { -+ fragment_D[i] = CDType(accumulators[i] * AccType(alpha)); -+ } -+ else -+ { -+ fragment_D[i] = CDType(accumulators[i] * AccType(alpha) -+ + AccType(fragment_C[i]) * AccType(beta)); -+ } -+ } -+ } -+ }; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+CUTLASS_DEVICE void GemvBatchedStridedDevice( -+ cutlass::gemm::BatchedGemmCoord problem_size, -+ ElementAlphaBeta alpha, -+ ElementAlphaBeta beta, -+ typename GemvKernel::IteratorA::TensorRef ref_A, -+ typename GemvKernel::IteratorA::TensorRef::LongIndex lda, -+ typename GemvKernel::IteratorB::TensorRef ref_B, -+ typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, -+ typename GemvKernel::IteratorCD::TensorRef ref_C, -+ typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, -+ typename GemvKernel::IteratorCD::TensorRef ref_D, -+ typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd) -+{ -+ using ThreadBlockGemv = typename GemvKernel::ThreadBlockGemv; -+ using ThreadBlockSwizzle = typename GemvKernel::ThreadBlockSwizzle; -+ using EpilogueScale = detail::GemvBatchedStridedEpilogueScaling; -+ -+ ThreadBlockSwizzle swizzler; -+ -+ // Compute initial location in logical coordinates -+ BatchedGemmCoord tb_offset = swizzler.get_tile_offset(); -+ int const batch_idx = swizzler.get_batch_idx(); -+ -+ // Offset to the batch -+ ref_A.add_pointer_offset(batch_idx*lda); -+ ref_B.add_pointer_offset(batch_idx*ldb); -+ -+ // Construct iterators to A and B operands -+ typename GemvKernel::IteratorA::Params params_A(ref_A.layout()); -+ typename GemvKernel::IteratorA iterator_A( -+ params_A, -+ ref_A.data(), -+ { 1, problem_size.k() }, -+ 0, -+ { 0, 0 }); -+ -+ typename GemvKernel::IteratorB::Params params_B(ref_B.layout()); -+ typename GemvKernel::IteratorB iterator_B( -+ params_B, -+ ref_B.data(), -+ { problem_size.k(), problem_size.n() }, -+ threadIdx.x, -+ { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN }); -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ ThreadBlockGemv mma; -+ -+ typename ThreadBlockGemv::FragmentC accumulators; -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped gemv -+ mma(problem_size.mnk(), accumulators, iterator_A, iterator_B, accumulators); -+ -+ // -+ // Epilogue (TODO: Epiloge as template argument) -+ // -+ typename GemvKernel::FragmentCD fragment_CD; -+ -+ // Load C (skip if beta is zero) -+ if (!BetaIsZero) -+ { -+ tb_offset = swizzler.get_tile_offset(); -+ ref_C.add_pointer_offset(batch_idx*ldc); -+ typename GemvKernel::IteratorCD::Params params_C(ref_C.layout()); -+ typename GemvKernel::IteratorCD iterator_C( -+ params_C, -+ ref_C.data(), -+ { 1, problem_size.n() }, -+ threadIdx.x, -+ { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN }); -+ iterator_C.load(fragment_CD); -+ } -+ -+ // Apply alpha/beta scaling -+ EpilogueScale epilogue_scale(alpha, beta); -+ epilogue_scale(accumulators, fragment_CD, fragment_CD); -+ -+ // Store D -+ tb_offset = swizzler.get_tile_offset(); -+ ref_D.add_pointer_offset(batch_idx*ldd); -+ typename GemvKernel::IteratorCD::Params params_D(ref_D.layout()); -+ typename GemvKernel::IteratorCD iterator_D( -+ params_D, -+ ref_D.data(), -+ { 1, problem_size.n() }, -+ threadIdx.x, -+ { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN }); -+ iterator_D.store(fragment_CD); -+} -+ -+template -+__global__ void GemvBatchedStrided( -+ cutlass::gemm::BatchedGemmCoord problem_size, -+ ElementAlphaBeta alpha, -+ ElementAlphaBeta beta, -+ typename GemvKernel::IteratorA::TensorRef ref_A, -+ typename GemvKernel::IteratorA::TensorRef::LongIndex lda, -+ typename GemvKernel::IteratorB::TensorRef ref_B, -+ typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, -+ typename GemvKernel::IteratorCD::TensorRef ref_C, -+ typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, -+ typename GemvKernel::IteratorCD::TensorRef ref_D, -+ typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd) -+{ -+ GemvBatchedStridedDevice( -+ problem_size, alpha, beta, ref_A, lda, ref_B, ldb, ref_C, ldc, ref_D, ldd -+ ); -+} -+ -+template -+__global__ void GemvBatchedStrided( -+ cutlass::gemm::BatchedGemmCoord problem_size, -+ ElementAlphaBeta alpha, -+ typename GemvKernel::IteratorA::TensorRef ref_A, -+ typename GemvKernel::IteratorA::TensorRef::LongIndex lda, -+ typename GemvKernel::IteratorB::TensorRef ref_B, -+ typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, -+ typename GemvKernel::IteratorCD::TensorRef ref_D, -+ typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd) -+{ -+ GemvBatchedStridedDevice( -+ problem_size, alpha, ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd -+ ); -+} -+ -+template -+__global__ void GemvBatchedStrided( -+ cutlass::gemm::BatchedGemmCoord problem_size, -+ typename GemvKernel::IteratorA::TensorRef ref_A, -+ typename GemvKernel::IteratorA::TensorRef::LongIndex lda, -+ typename GemvKernel::IteratorB::TensorRef ref_B, -+ typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, -+ typename GemvKernel::IteratorCD::TensorRef ref_D, -+ typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd) -+{ -+ using ElementAlphaBeta = typename GemvKernel::IteratorCD::Element; -+ GemvBatchedStridedDevice( -+ problem_size, ElementAlphaBeta(1), ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/grouped_problem_visitor.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/grouped_problem_visitor.h -new file mode 100644 -index 0000000..d9f0249 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/grouped_problem_visitor.h -@@ -0,0 +1,464 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Base scheduler for grouped problems -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Enumerated type describing the type of scheduling to perform for the ProblemVisitor -+enum class GroupScheduleMode { -+ // Perform all scheduling on device -+ kDeviceOnly, -+ // Precompute on the host the full sequence of problems to access -+ kHostPrecompute -+}; -+ -+/// Visitor class to abstract away the algorithm for iterating over tiles -+template -+struct BaseGroupedProblemVisitor { -+ using ThreadblockShape = ThreadblockShape_; -+ -+ struct ProblemInfo { -+ static int32_t const kNoPrefetchEntry = -1; -+ int32_t problem_idx; -+ int32_t problem_start; -+ -+ CUTLASS_DEVICE -+ ProblemInfo() : problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {} -+ -+ CUTLASS_DEVICE -+ ProblemInfo(int32_t problem_idx_, int32_t problem_start_) : -+ problem_idx(problem_idx_), problem_start(problem_start_) {} -+ }; -+ -+ struct Params { -+ cutlass::gemm::GemmCoord const *problem_sizes; -+ int32_t problem_count; -+ void const *workspace; -+ int32_t tile_count; -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Params(): problem_sizes(nullptr), problem_count(0), workspace(nullptr), tile_count(0) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const *problem_sizes, -+ int32_t problem_count, -+ void const *workspace = nullptr, -+ int32_t tile_count = 0 -+ ): -+ problem_sizes(problem_sizes), -+ problem_count(problem_count), -+ workspace(workspace), -+ tile_count(tile_count) -+ {} -+ -+ }; -+ -+ Params const ¶ms; -+ int32_t tile_idx; -+ int32_t problem_tile_start; -+ int32_t problem_idx; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ BaseGroupedProblemVisitor( -+ Params const ¶ms_, -+ int32_t block_idx -+ ): -+ params(params_), -+ tile_idx(block_idx), -+ problem_tile_start(0), -+ problem_idx(0) -+ {} -+ -+ /// Get the grid shape -+ CUTLASS_HOST_DEVICE -+ static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { -+ return ProblemSizeHelper::grid_shape(problem); -+ } -+ -+ /// Gets the global tile index -+ CUTLASS_HOST_DEVICE -+ int32_t tile_index() const { -+ return tile_idx; -+ } -+ -+ /// Gets the index of the problem -+ CUTLASS_HOST_DEVICE -+ int32_t problem_index() const { -+ return problem_idx; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int32_t threadblock_idx() const { -+ return tile_idx - problem_tile_start; -+ } -+ -+ CUTLASS_DEVICE -+ void advance(int32_t grid_size) { -+ tile_idx += grid_size; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { -+ ProblemSizeHelper::possibly_transpose_problem(problem); -+ } -+ -+ /// Returns the problem size for the current problem -+ CUTLASS_HOST_DEVICE -+ cutlass::gemm::GemmCoord problem_size() const { -+ GemmCoord problem = params.problem_sizes[problem_idx]; -+ ProblemSizeHelper::possibly_transpose_problem(problem); -+ return problem; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { -+ return ProblemSizeHelper::tile_count(grid); -+ } -+ -+ static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) { -+ int32_t total_tiles = 0; -+ for (int32_t i = 0; i < problem_count; ++i) { -+ auto problem = host_problem_sizes_ptr[i]; -+ possibly_transpose_problem(problem); -+ auto grid = grid_shape(problem); -+ total_tiles += tile_count(grid); -+ } -+ -+ return total_tiles; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ProblemSizeHelper, -+ typename ThreadblockShape, -+ GroupScheduleMode GroupScheduleMode_, -+ int PrefetchTileCount, -+ int ThreadCount -+> -+struct GroupedProblemVisitor; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// ProblemVisitor that performs all scheduling on device -+// -+template -+struct GroupedProblemVisitor: public BaseGroupedProblemVisitor { -+ using Base = BaseGroupedProblemVisitor; -+ using Params = typename Base::Params; -+ static int const kThreadCount = ThreadCount; -+ static bool const kRequiresPrecomputation = false; -+ static int const kThreadsPerWarp = 32; -+ -+ struct SharedStorage {}; -+ -+ // Final tile of the problem loaded by this thread. Each thread will hold -+ // a separate value. -+ int32_t problem_ending_tile; -+ -+ SharedStorage &shared_storage; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ GroupedProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base(params_, block_idx), -+ problem_ending_tile(0), -+ shared_storage(shared_storage_) -+ { -+ this->problem_idx = -1 * kThreadsPerWarp; -+ this->problem_tile_start = 0; -+ } -+ -+ CUTLASS_DEVICE -+ bool next_tile() { -+ // Check whether the tile to compute is within the range of the current problem. -+ int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); -+ if (this->tile_idx < problem_tile_end) { -+ return true; -+ } -+ -+ // Check whether the tile to compute is within the current group of problems fetched by the warp. -+ // The last tile for this group is the final tile of the problem held by the final thread in the warp. -+ int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp-1); -+ -+ // Keep the starting problem for this group in `problem_idx`. This is done to reduce -+ // register pressure. The starting problem for this group is simply the first problem -+ // in the group most recently fetched by the warp. -+ int32_t &group_problem_start = this->problem_idx; -+ group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; -+ -+ // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce -+ // register pressure. -+ int32_t &group_tile_start = this->problem_tile_start; -+ -+ // Each thread in the warp processes a separate problem to advance until -+ // reaching a problem whose starting tile is less less than tile_idx. -+ while (group_tile_end <= this->tile_idx) { -+ group_problem_start += kThreadsPerWarp; -+ if (group_problem_start > this->params.problem_count) { -+ return false; -+ } -+ -+ // Since `group_tile_start` is a reference to `this->problem_tile_start`, this -+ // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` -+ // is also set here is used later in `next_tile`. -+ group_tile_start = group_tile_end; -+ -+ int lane_idx = threadIdx.x % kThreadsPerWarp; -+ int32_t lane_problem = group_problem_start + lane_idx; -+ -+ // Compute the number of tiles in the problem assigned to each thread. -+ problem_ending_tile = 0; -+ if (lane_problem < this->params.problem_count) { -+ cutlass::gemm::GemmCoord problem = this->params.problem_sizes[lane_problem]; -+ this->possibly_transpose_problem(problem); -+ cutlass::gemm::GemmCoord grid = this->grid_shape(problem); -+ problem_ending_tile = this->tile_count(grid); -+ } -+ -+ // Compute a warp-wide inclusive prefix sum to compute the ending tile index of -+ // each thread's problem. -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < kThreadsPerWarp; i <<= 1) { -+ int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); -+ if (lane_idx >= i) { -+ problem_ending_tile += val; -+ } -+ } -+ -+ // The total tile count for this group is now in the final position of the prefix sum -+ int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp-1); -+ -+ problem_ending_tile += group_tile_start; -+ group_tile_end += tiles_in_group; -+ } -+ -+ // The next problem to process is the first one that does not have ending tile position -+ // that is greater than or equal to tile index. -+ int32_t problem_idx_in_group = -+ __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); -+ -+ this->problem_idx = group_problem_start + problem_idx_in_group; -+ -+ // The starting tile for this problem is the ending tile of the previous problem. In cases -+ // where `problem_idx_in_group` is the first problem in the group, we do not need to reset -+ // `problem_tile_start`, because it is set to the previous group's ending tile in the while -+ // loop above. -+ if (problem_idx_in_group > 0) { -+ this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); -+ } -+ -+ return true; -+ } -+ -+ static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count) { -+ return 0; -+ } -+ -+ static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count, -+ void* host_workspace_ptr) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Precomputes schedule on host and prefetches into shared memory -+// -+template -+struct GroupedProblemVisitor : public BaseGroupedProblemVisitor { -+ static_assert(PrefetchTileCount > 0, -+ "GroupedProblemVisitor with GroupScheduleMode `kHostPrecompute` currently requires prefetching to shared memory"); -+ -+ using Base = BaseGroupedProblemVisitor; -+ using Params = typename Base::Params; -+ using ProblemInfo = typename Base::ProblemInfo; -+ static bool const kRequiresPrecomputation = true; -+ -+ static int const kPrefetchTileCount = PrefetchTileCount; -+ static int const kThreadCount = ThreadCount; -+ -+ struct SharedStorage { -+ // Sequence of problem IDs and starting tiles to compute -+ cutlass::Array prefetched_problems; -+ }; -+ -+ int32_t tiles_computed; -+ int32_t iterations_per_block; -+ int32_t block_load_start; -+ SharedStorage &shared_storage; -+ ProblemInfo const *problem_info_ptr; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ GroupedProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base(params_, block_idx), -+ tiles_computed(0), -+ shared_storage(shared_storage_), -+ problem_info_ptr(reinterpret_cast(params_.workspace)) -+ { -+ iterations_per_block = (params_.tile_count - 1 + gridDim.x) / gridDim.x; -+ block_load_start = iterations_per_block * block_idx; -+ // Start prefetching the first set of tiles to compute -+ prefetch_tiles(); -+ } -+ -+ CUTLASS_DEVICE -+ bool next_tile() { -+ if (this->tile_idx >= this->params.tile_count) { -+ return false; -+ } -+ -+ int32_t prefetch_idx = (tiles_computed % kPrefetchTileCount); -+ if (prefetch_idx == 0) { -+ // Ensure all previous stores to shared memory have been completed -+ __syncthreads(); -+ } -+ -+ auto problem_info = shared_storage.prefetched_problems[prefetch_idx]; -+ ++tiles_computed; -+ -+ if ((tiles_computed % kPrefetchTileCount) == 0) { -+ // Begin prefetching next set of tiles. Synchronize first to ensure that -+ // we don't overwrite the current buffer while someone else is using it. -+ __syncthreads(); -+ prefetch_tiles(); -+ } -+ -+ this->problem_idx = problem_info.problem_idx; -+ this->problem_tile_start = problem_info.problem_start; -+ -+ return true; -+ } -+ -+ static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count) { -+ int32_t total_tiles = Base::group_tile_count(host_problem_sizes_ptr, problem_count); -+ int32_t entries_per_block = ((total_tiles - 1 + block_count) / block_count); -+ return sizeof(ProblemInfo) * entries_per_block * block_count; -+ } -+#if !defined(__CUDACC_RTC__) -+ static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count, -+ void* host_workspace_ptr) { -+ ProblemInfo* host_problem_info_ptr = reinterpret_cast(host_workspace_ptr); -+ int32_t total_tiles = Base::group_tile_count(host_problem_sizes_ptr, problem_count); -+ int32_t entries_per_block = (total_tiles - 1 + block_count) / block_count; -+ -+ int tile = 0; -+ int start_tile = 0; -+ for (int p_idx = 0; p_idx < problem_count; ++p_idx) { -+ auto problem = host_problem_sizes_ptr[p_idx]; -+ Base::possibly_transpose_problem(problem); -+ auto grid = Base::grid_shape(problem); -+ int tiles = Base::tile_count(grid); -+ ProblemInfo problem_info(p_idx, start_tile); -+ for (int i = 0; i < tiles; ++i, ++tile) { -+ host_problem_info_ptr[(entries_per_block * (tile % block_count)) + (tile / block_count)] = problem_info; -+ } -+ start_tile += tiles; -+ } -+ } -+#endif -+private: -+ CUTLASS_DEVICE -+ void prefetch_tiles() { -+ // TODO: Consider changing to use async copies from global to shared mem -+ CUTLASS_PRAGMA_UNROLL -+ for (int32_t i = 0; i < kPrefetchTileCount; i += kThreadCount) { -+ int32_t offset = threadIdx.x + i; -+ if (offset < kPrefetchTileCount && (tiles_computed + offset < iterations_per_block)) { -+ shared_storage.prefetched_problems[offset] = problem_info_ptr[block_load_start + tiles_computed + offset]; -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/params_universal_base.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/params_universal_base.h -new file mode 100644 -index 0000000..453379d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/params_universal_base.h -@@ -0,0 +1,245 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Base functionality for common types of universal GEMM kernel parameters -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/trace.h" -+#include "cutlass/gemm/gemm.h" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Argument structure -+struct UniversalArgumentsBase -+{ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; -+ -+ int64_t batch_stride_D; -+ -+ // -+ // Methods -+ // -+ -+ UniversalArgumentsBase() : -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1), -+ batch_stride_D(0) -+ {} -+ -+ /// constructs an arguments structure -+ UniversalArgumentsBase( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ int64_t batch_stride_D) -+ : -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_count), -+ batch_stride_D(batch_stride_D) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+}; -+ -+ -+/// Parameters structure -+template < -+ typename ThreadblockSwizzle, -+ typename ThreadblockShape, -+ typename ElementA, -+ typename ElementB, -+ typename ElementC> -+struct UniversalParamsBase -+{ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ int64_t batch_stride_D; -+ -+ int *semaphore; -+ -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ UniversalParamsBase() = default; -+ -+ -+ /// Constructor -+ UniversalParamsBase( -+ UniversalArgumentsBase const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ problem_size(args.problem_size), -+ mode(args.mode), -+ batch_count(args.batch_count), -+ batch_stride_D(args.batch_stride_D), -+ semaphore(nullptr) -+ { -+ ThreadblockSwizzle swizzle; -+ -+ // Get GEMM volume in thread block tiles -+ grid_tiled_shape = swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ swizzle_log_tile = swizzle.get_log_tile(grid_tiled_shape); -+ -+ // Determine extent of K-dimension assigned to each block -+ gemm_k_size = args.problem_size.k(); -+ -+ if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) -+ { -+ int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); -+ -+ gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); -+ if (gemm_k_size) { -+ grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); -+ } -+ } -+ } -+ -+ -+ /// Returns the workspace size (in bytes) needed for this problem geometry -+ size_t get_workspace_size() const -+ { -+ size_t workspace_bytes = 0; -+ if (mode == GemmUniversalMode::kGemmSplitKParallel) -+ { -+ // Split-K parallel always requires a temporary workspace -+ workspace_bytes = -+ sizeof(ElementC) * -+ size_t(batch_stride_D) * -+ size_t(grid_tiled_shape.k()); -+ } -+ else if (mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) -+ { -+ // Serial split-K only requires a temporary workspace if the number of partitions along the -+ // GEMM K dimension is greater than one. -+ workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); -+ } -+ -+ return workspace_bytes; -+ } -+ -+ -+ /// Assign and initialize the specified workspace buffer. Assumes -+ /// the memory allocated to workspace is at least as large as get_workspace_size(). -+ Status init_workspace( -+ void *workspace, -+ cudaStream_t stream = nullptr) -+ { -+ semaphore = static_cast(workspace); -+ // Zero-initialize entire workspace -+ if (semaphore) -+ { -+ size_t workspace_bytes = get_workspace_size(); -+ -+ CUTLASS_TRACE_HOST(" Initialize " << workspace_bytes << " workspace bytes"); -+ -+ cudaError_t result = cudaMemsetAsync( -+ semaphore, -+ 0, -+ workspace_bytes, -+ stream); -+ -+ if (result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ -+ /// Returns the GEMM volume in thread block tiles -+ GemmCoord get_tiled_shape() const -+ { -+ return grid_tiled_shape; -+ } -+ -+ -+ /// Returns the total number of thread blocks to launch -+ int get_grid_blocks() const -+ { -+ dim3 grid_dims = get_grid_dims(); -+ return grid_dims.x * grid_dims.y * grid_dims.z; -+ } -+ -+ -+ /// Returns the grid extents in thread blocks to launch -+ dim3 get_grid_dims() const -+ { -+ return ThreadblockSwizzle().get_grid_shape(grid_tiled_shape); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped.h -new file mode 100644 -index 0000000..1c840e7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped.h -@@ -0,0 +1,704 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Grouped Rank2K kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/trace.h" -+#include "cutlass/gemm/kernel/rank_2k_transpose_operands.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma1_, ///! Threadblock-scoped matrix multiply-accumulate (A*B^T) -+ typename Mma2_, ///! Threadblock-scoped matrix multiply-accumulate (B*A^T) -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ ComplexTransform OriginalTransformA_, ///! Public-facing transformation on A -+ ComplexTransform OriginalTransformB_, ///! Public-facing transformation on B -+ FillMode FillModeC_, ///! Fill Mode for C (kLower or kUpper) -+ BlasMode BlasMode_, ///! Blas3 computation mode -+ GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform -+ bool Transposed = false -+> -+struct Rank2KGrouped { -+public: -+ -+ using Mma1 = Mma1_; -+ using Mma2 = Mma2_; -+ -+ static_assert(platform::is_same::value && -+ platform::is_same::value, -+ "Kernel-level grouped Rank2K requires that LayoutC be row major."); -+ -+ // Define generic Mma for usecases that use Kernel::Mma -+ using Mma = Mma1_; -+ -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; -+ static bool const kTransposed = Transposed; -+ -+ // Public-facing type definitions related to operand element type, layout, and complex conjugate -+ // operation. Must interact with the 'kTransposed' notion to reflect the original layout, -+ // fill mode, etc. passed in. -+ // -+ // Recall that a Rank2K operation performs (A x BT) + (B x AT) -+ // This is performed via: -+ // Mma1 = (A x BT) -+ // Mma2 = (B x AT) -+ // -+ // However, if C needs to be transposed, then this is changed to the following: -+ // Mma1 = (B x AT) -+ // Mma2 = (A x BT) -+ // -+ // The transformation above is achieved by swapping the Layouts/Elements/Transforms/etc. -+ // of A and B as they are passed into the instantiations of Mma1 and Mma2. -+ // -+ // Now, given access to only Mma1 and Mma2, as well as whether a transposition has occurred, -+ // we wish to retrieve the original Layouts/Elements/etc. for A and B that were passed into -+ // the device-level call. -+ // -+ // The logic to do this (which is made clearer by referencing the above instantiations) is as follows: -+ // LayoutA = kTransposed ? Mma2::LayoutA : Mma1::LayoutA -+ // LayoutB = kTransposed ? Mma1::LayoutA : Mma2::LayoutA -+ // -+ // We achieve this swapping by passing Mma1::*A and Mma2::*B to Rank2KMapArguments: -+ using MapArgumentsA = kernel::detail::Rank2KMapArguments< -+ typename Mma1::IteratorA::Element, -+ typename Mma1::IteratorA::Layout, -+ Mma1::kTransformA, -+ Mma1::IteratorA::AccessType::kElements, -+ typename Mma2::IteratorA::Element, -+ typename Mma2::IteratorA::Layout, -+ Mma2::kTransformA, -+ Mma2::IteratorA::AccessType::kElements, -+ typename Mma1::LayoutC, -+ FillModeC_, -+ kTransposed -+ >; -+ -+ using ElementA = typename MapArgumentsA::ElementA; -+ using LayoutA = typename MapArgumentsA::LayoutA; -+ static int const kAlignmentA = MapArgumentsA::kAlignmentA; -+ -+ using MapArgumentsB = kernel::detail::Rank2KMapArguments< -+ typename Mma2::IteratorA::Element, -+ typename Mma2::IteratorA::Layout, -+ Mma2::kTransformA, -+ Mma2::IteratorA::AccessType::kElements, -+ typename Mma1::IteratorA::Element, -+ typename Mma1::IteratorA::Layout, -+ Mma1::kTransformA, -+ Mma1::IteratorA::AccessType::kElements, -+ typename Mma2::LayoutC, -+ FillModeC_, -+ kTransposed -+ >; -+ -+ using ElementB = typename MapArgumentsB::ElementA; -+ using LayoutB = typename MapArgumentsB::LayoutA; -+ static int const kAlignmentB = MapArgumentsB::kAlignmentA; -+ -+ // Use the user-provided TransformA and TransformB, rather than those -+ // resulting from MapArguments, because Mma1 and Mma2 may have different -+ // complex transforms than those passed in by the user. -+ // (See kernel/rank_2k_complex.h for an example of this) -+ static cutlass::ComplexTransform const kTransformA = OriginalTransformA_; -+ static cutlass::ComplexTransform const kTransformB = OriginalTransformB_; -+ -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename MapArgumentsA::LayoutC; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ static FillMode const kFillModeC = MapArgumentsA::kFillModeC; -+ -+ // Common type definitions for Mma1 and Mma2 -+ using Operator = typename Mma1::Operator; -+ using OperatorClass = typename Mma1::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma1::Shape; -+ using WarpShape = typename Mma1::Operator::Shape; -+ using InstructionShape = typename Mma1::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma1::ArchTag; -+ -+ static int const kStages = Mma1::kStages; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+private: -+ static FillMode const kInternalFillModeC = FillModeC_; -+ -+public: -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma1::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using ProblemVisitor = Rank2KGroupedProblemVisitor< -+ ThreadblockShape, -+ kGroupScheduleMode, -+ kThreadCount, -+ kThreadCount, -+ kInternalFillModeC>; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord *problem_sizes; -+ int problem_count; -+ int threadblock_count; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ ElementA ** ptr_A; -+ ElementB ** ptr_B; -+ ElementC ** ptr_C; -+ ElementC ** ptr_D; -+ -+ typename LayoutA::Stride::LongIndex *lda; -+ typename LayoutB::Stride::LongIndex *ldb; -+ typename LayoutC::Stride::LongIndex *ldc; -+ typename LayoutC::Stride::LongIndex *ldd; -+ -+ // Only used by device-level operator -+ GemmCoord *host_problem_sizes; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ problem_count(0), -+ threadblock_count(0), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ lda(nullptr), -+ ldb(nullptr), -+ ldc(nullptr), -+ ldd(nullptr), -+ host_problem_sizes(nullptr) -+ { -+ -+ } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord *problem_sizes, -+ int problem_count, -+ int threadblock_count, -+ typename EpilogueOutputOp::Params epilogue, -+ ElementA ** ptr_A, -+ ElementB ** ptr_B, -+ ElementC ** ptr_C, -+ ElementC ** ptr_D, -+ typename LayoutA::Stride::LongIndex *lda, -+ typename LayoutB::Stride::LongIndex *ldb, -+ typename LayoutC::Stride::LongIndex *ldc, -+ typename LayoutC::Stride::LongIndex *ldd, -+ GemmCoord *host_problem_sizes=nullptr -+ ): -+ mode(mode), -+ problem_sizes(problem_sizes), -+ problem_count(problem_count), -+ threadblock_count(threadblock_count), -+ epilogue(epilogue), -+ ptr_A(ptr_A), -+ ptr_B(ptr_B), -+ ptr_C(ptr_C), -+ ptr_D(ptr_D), -+ lda(lda), -+ ldb(ldb), -+ ldc(ldc), -+ ldd(ldd), -+ host_problem_sizes(host_problem_sizes) -+ { -+ -+ } -+ -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ typename ProblemVisitor::Params problem_visitor; -+ int threadblock_count; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ -+ ElementA ** ptr_A; -+ ElementB ** ptr_B; -+ ElementC ** ptr_C; -+ ElementC ** ptr_D; -+ -+ typename LayoutA::Stride::LongIndex *lda; -+ typename LayoutB::Stride::LongIndex *ldb; -+ typename LayoutC::Stride::LongIndex *ldc; -+ typename LayoutC::Stride::LongIndex *ldd; -+ -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ lda(nullptr), -+ ldb(nullptr), -+ ldc(nullptr), -+ ldd(nullptr) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args, void *workspace = nullptr, int tile_count = 0): -+ problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), -+ threadblock_count(args.threadblock_count), -+ output_op(args.epilogue), -+ ptr_A(args.ptr_A), -+ ptr_B(args.ptr_B), -+ ptr_C(args.ptr_C), -+ ptr_D(args.ptr_D), -+ lda(args.lda), -+ ldb(args.ldb), -+ ldc(args.ldc), -+ ldd(args.ldd) -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0) { -+ -+ problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); -+ threadblock_count = args.threadblock_count; -+ output_op = args.output_op; -+ ptr_A = args.ptr_A; -+ ptr_B = args.ptr_B; -+ ptr_C = args.ptr_C; -+ ptr_D = args.ptr_D; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ struct SharedStorage { -+ union { -+ typename Mma1::SharedStorage mma1_main_loop; -+ typename Mma2::SharedStorage mma2_main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ } kernel; -+ -+ // ProblemVisitor shared storage can't be overlapped with others -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ Rank2KGrouped() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // -+ // Problem visitor. -+ // -+ -+ ProblemVisitor problem_visitor( -+ params.problem_visitor, -+ shared_storage.problem_visitor, -+ blockIdx.x); -+ -+ // Outer 'persistent' loop to iterate over tiles -+ while (problem_visitor.next_tile()) { -+ -+ GemmCoord problem_size = problem_visitor.problem_size(); -+ int32_t problem_idx = problem_visitor.problem_index(); -+ int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); -+ -+ GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = problem_visitor.threadblock_offset(threadblock_idx); -+ -+ // -+ // Perform checks to determine whether the results of this threadblock will be needed. -+ // An example of an unneeded threadblock is one that is assigned to compute in the upper -+ // portion of a Rank2K kernel filled with mode kLower. -+ // -+ // TODO: Consider pushing these checks into ProblemVisitor to avoid spuriously -+ // returning from `next_tile()`. -+ // -+ -+ // Early exit if threadblock is out of range -+ if (grid_shape.m() <= threadblock_tile_offset.m() || -+ grid_shape.n() <= threadblock_tile_offset.n()) { -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ continue; -+ } -+ -+ // Skip this tile if Fill Mode is Lower and -+ // if the entire tile is above the main diagonal (bottom-left corner is at or above the diagonal) -+ if (kInternalFillModeC == cutlass::FillMode::kLower && -+ (threadblock_tile_offset.m() + 1) * Mma1::Shape::kM <= threadblock_tile_offset.n() * Mma1::Shape::kN) { -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ continue; -+ } -+ -+ // Skip this tile if Fill Mode is Upper and -+ // if the entire tile is below the main diagonal (top-right corner is at or below the diagonal) -+ if (kInternalFillModeC == cutlass::FillMode::kUpper && -+ threadblock_tile_offset.m() * Mma1::Shape::kM >= (threadblock_tile_offset.n() + 1) * Mma1::Shape::kN) { -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ continue; -+ } -+ -+ bool tile_on_diagonal = false; -+ // Mark tiles that are being crossed by the main diagonal -+ // (top-right and bottom-left corners are on either side of the diagonal) -+ if ((threadblock_tile_offset.m() + 1) * Mma1::Shape::kM > threadblock_tile_offset.n() * Mma1::Shape::kN -+ && threadblock_tile_offset.m() * Mma1::Shape::kM < (threadblock_tile_offset.n() + 1) * Mma1::Shape::kN) { -+ tile_on_diagonal = true; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = problem_size.k(); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < grid_shape.k()) { -+ problem_size_k = (threadblock_tile_offset.k() + 1) * problem_size.k(); -+ } -+ -+ offset_k = threadblock_tile_offset.k() * problem_size.k(); -+ } -+ -+ ElementA *ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); -+ typename LayoutA::Stride::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); -+ -+ ElementB *ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); -+ typename LayoutB::Stride::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_MxK{ -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_KxN{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ }; -+ -+ // Assume identity swizzle -+ MatrixCoord tb_offset( -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ ); -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands for Mma1 -+ typename Mma1::IteratorA iterator_A( -+ Mma1::IteratorA::Params(ldm_A), -+ ptr_A, -+ {problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_MxK); -+ -+ typename Mma1::IteratorB iterator_BT( -+ Mma1::IteratorB::Params(ldm_B), -+ ptr_B, -+ {problem_size_k, problem_size.n()}, -+ thread_idx, -+ tb_offset_KxN); -+ -+ // Construct iterators to A and B operands for Mma2 -+ typename Mma2::IteratorA iterator_B( -+ Mma2::IteratorA::Params(ldm_B), -+ ptr_B, -+ {problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_MxK); -+ -+ typename Mma2::IteratorB iterator_AT( -+ Mma2::IteratorB::Params(ldm_A), -+ ptr_A, -+ {problem_size_k, problem_size.n()}, -+ thread_idx, -+ tb_offset_KxN); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply for Mma1 (A x BT) -+ Mma1 mma1(shared_storage.kernel.mma1_main_loop, thread_idx, warp_idx, lane_idx); -+ -+ // Construct thread-scoped matrix multiply for Mma2 (B x AT) -+ Mma2 mma2(shared_storage.kernel.mma2_main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma1::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ -+ // Wait for all threads to finish their epilogue phases from the previous tile. -+ __syncthreads(); -+ -+ // Compute threadblock-scoped matrix multiply-add (A x BT) -+ mma1( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_BT, -+ accumulators); -+ -+ // HER2K kernel needs Alpha to be complex and is conj(Alpha) is applied to the second HERK. -+ if (kBlasMode == BlasMode::kHermitian) { -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C[problem_idx]); -+ ElementC *ptr_D = static_cast(params.ptr_D[problem_idx]); -+ -+ // If TB not on diagonal, FillMode doesn't apply. -+ FillMode kFillModeTB = tile_on_diagonal ? kInternalFillModeC : FillMode::kNone; -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), -+ ptr_C, -+ problem_size.mn(), -+ thread_idx, -+ tb_offset, -+ kFillModeTB -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), -+ ptr_D, -+ problem_size.mn(), -+ thread_idx, -+ tb_offset, -+ kFillModeTB -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.kernel.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ __syncthreads(); -+ -+ accumulators.clear(); -+ } -+ -+ // Compute threadblock-scoped matrix multiply-add (B x AT) -+ mma2( -+ gemm_k_iterations, -+ accumulators, -+ iterator_B, -+ iterator_AT, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ /* Needed for HER2K where the second HERK is multiplied by conj(alpha) */ -+ typename EpilogueOutputOp::Params second_her2k_params(conj(params.output_op.alpha), 1); -+ EpilogueOutputOp output_op_her2k(second_her2k_params); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C[problem_idx]); -+ -+ // HER2K kernel needs Alpha to be complex and is conj(Alpha) is applied to the second HERK. -+ if (kBlasMode == BlasMode::kHermitian) { -+ ptr_C = static_cast(params.ptr_D[problem_idx]); -+ } -+ -+ ElementC *ptr_D = static_cast(params.ptr_D[problem_idx]); -+ -+ // If TB not on diagonal, FillMode doesn't apply. -+ FillMode kFillModeTB = tile_on_diagonal ? kInternalFillModeC : FillMode::kNone; -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), -+ ptr_C, -+ problem_size.mn(), -+ thread_idx, -+ tb_offset, -+ kFillModeTB -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), -+ ptr_D, -+ problem_size.mn(), -+ thread_idx, -+ tb_offset, -+ kFillModeTB -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.kernel.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ if (kBlasMode == BlasMode::kSymmetric) { -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ } else { -+ epilogue( -+ output_op_her2k, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ } -+ -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h -new file mode 100644 -index 0000000..92cc2a7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h -@@ -0,0 +1,376 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Problem visitor for grouped Rank2K operations. -+ -+ This problem visitor is specialized for Rank2K operations, for which matrix C is upper/lower -+ triangular. Using a problem visitor designed for GEMMs for Rank2K problems is inefficient -+ because threadblocks will be frequently assigned to tiles that exit early (e.g., due to -+ being assigned to a tile in the upper-triangular portion of a lower-triangular problem). -+ This can lead to load imbalance among threadblocks, as the GEMM-based scheduler -+ assigns all threadblocks to nearly the same number of tiles, regardless of whether -+ those tiles exit early. -+ -+ Consider an example of a group of four Rank2Ks with matrix C consisting of a grid of 2x2 tiles. -+ Consider a grid of 8 threadblocks. The default GEMM scheduler will assign threadblocks to -+ tiles in the following order: -+ Rank2K 0 Rank2K 1 Rank2K 2 Rank2K 3 -+ 0 1 4 5 0 1 4 5 -+ 2 3 6 7 2 3 6 7 -+ Assuming that the problems are lower triangular, blocks 1 and 5 are continuously assigned -+ to inactive tiles. -+ -+ This problem visitor aims to assign threadblocks to only those tiles which are in the -+ upper/lower triangular portion of a given problem. Using the example above, the resulting -+ assignment would be: -+ Rank2K 0 Rank2K 1 Rank2K 2 Rank2K 3 -+ 0 - 3 - 6 - 1 - -+ 1 2 4 5 7 0 2 3 -+ -+ Achieving the schedule above requires a mapping from threadblock ID to tile coordinates (i, j). -+ We will illustrate this by mapping on a lower-triangular matrix with a 3x3 grid. We first -+ calculate row and column indices assuming one-indexed rows, tiles, and threadblock IDs, and -+ then subtract one to convert to zero-indexed. -+ Col 1 Col 2 Col 3 -+ ---------------------- -+ Row 1 | 1 - - -+ Row 2 | 2 3 - -+ Row 3 | 4 5 6 -+ -+ We next outline this mapping, borrowing from: https://stackoverflow.com/a/40954159 -+ -+ Calculating row i given threadblock ID t -+ ---------------------------------------- -+ For a given row i, all threadblock IDs t in that row satisfy the following: -+ t <= 1 + 2 + 3 + ... + (i-1) + i -+ -+ The closed-form equation for the right-hand side is: i(i+1)/2. -+ Using this, we can solve for i given t: -+ t <= i(i+1)/2 -+ 2t <= i^2 + i -+ 2t <= i^2 + i + 0.25 - 0.25 -+ 2t + 0.25 <= i^2 + i + 0.25 -+ 2t + 0.25 <= (i + 0.5)^2 -+ sqrt(2t + 0.25) - 0.5 <= i -+ -+ To account for fractional values, we set: -+ i = ceil(sqrt(2t + 0.25) - 0.5) -+ -+ To turn this into a zero-indexed row and work with zero-indexed t, we perform: -+ i = ceil(sqrt(2(t+1) + 0.25) - 0.5) - 1 -+ = ceil(sqrt(2t + 2.25) - 0.5) - 1 -+ -+ Calculating column j given threadblock ID t and row i -+ ----------------------------------------------------- -+ For a given row i, all threadblock IDs t in that row also satisfy the following: -+ t > 1 + 2 + 3 + ... + (i-2) + (i-1) -+ --> t > i(i-1)/2 -+ -+ Threadblock IDs within a given row are sequential, so the one-indexed column ID -+ for one-indexed threadblock ID t and row i is: -+ j = t - (i(i-1)/2) -+ -+ The zero-indexed version becomes: -+ j = (t+1) - (i(i+1)/2) -1 -+ = t - (i(i+1)/2) -+ -+ Accounting for non-square grids -+ ------------------------------- -+ Though the overall output problem size for Rank2K problems is guranteed to be square, the -+ grids used in computing may not be square due to using non-square threadblock shapes. For -+ example, a threadblock shape of 64x32 operating on a problem of output size 128x128 would -+ result in a grid of 2x4 tiles. -+ -+ This case can be handled by noting that the output resembles a square grid of 2x2 "macro tiles" -+ each of which contains 2 "true tiles." We can thus first map a threadblock ID to its "macro tile" -+ using the equations above, and then map it to the "true tile" within its "macro tile." In the example -+ of a 2x4 grid, this mapping would look as follows: -+ "Macro grid" "True grid" -+ {0, 1} - 0 1 - - -+ {2, 3} {4, 5} 2 3 4 5 -+ -+ A zero-indexed threadblock ID t is mapped to its "macro tile ID" t_macro as: -+ t_macro = t // r -+ Where r is the ratio of the maximum dimension of the grid to the minimum dimension of the grid -+ (i.e., r = 4 / 2 = 2 in the previous example). -+ -+ One uses t_macro and the calculations above to find the row and column in the square matrix to -+ obtain i_macro and j_macro (zero-indexed). The mapping from (i_macro, j_macro) --> (i, j) -+ is simply the following: -+ if (ThreadblockShape::M > ThreadblockShape::N): -+ r = ThreadblockShape::M / ThreadblockShape::N -+ i = i_macro -+ j = (j_macro * r) + (t % r) -+ elif (ThreadblockShape::M < ThreadblockShape::N): -+ r = ThreadblockShape::N / ThreadblockShape::M -+ i = (i_macro * r) + (t % r) -+ j = j_macro -+ else: -+ i = i_macro -+ j = j_macro -+ -+ Handling cases with grid dimensions that aren't multiples of eachother -+ ---------------------------------------------------------------------- -+ Even though threadblock shapes M and N are typically multiples of one another, the grid -+ for a given problem may not have dimensions of the same ratio as that of the threadblock. -+ For example, a problem of size 132x132 using a threadblock of shape 64x32 will result -+ in a grid of 3x5 tiles. In this case, there is not an integer number of "true tiles" -+ per "macro tile." -+ -+ When this scenario arises, we simply pad the larger dimension of the grid such that -+ there are an integer number of "true tiles" per "macro tile." Thus, the 3x5 grid in -+ the example above will be treated as a 3x6 grid. Row and column positions for each -+ tile are calculated as above. Any threadblocks that map to tiles that are outside the -+ problem range or upper/lower triangular portion (e.g., (2, 5)) will exit early from -+ this problem and may proceed to the next problem in the group. -+ -+ Handling upper-triangular matrices -+ ---------------------------------- -+ The only modification needed for upper-triangular matrices is to swap i_macro and j_macro -+ in the calculations above. -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+ -+#include "cutlass/gemm/kernel/grouped_problem_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+namespace detail { -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Helpers for calculating offsets for Rank2K problem visitor. These helpers specifically pertain -+// to the conversion from "macro tiles" to "true tiles" in the description above. -+// -+template < -+ typename ThreadblockShape, -+ typename Enable = void -+> -+struct Rank2KGroupedProblemVisitorOffsetHelper; -+ -+// Partial specialization for the case where threadblock shape M > threadblock shape N -+template < -+ typename ThreadblockShape -+> -+struct Rank2KGroupedProblemVisitorOffsetHelper< -+ ThreadblockShape, -+ typename platform::enable_if< (ThreadblockShape::kM > ThreadblockShape::kN) >::type -+> { -+ static_assert(ThreadblockShape::kM % ThreadblockShape::kN == 0, -+ "Rank2KGroupedProblemVisitor with threadblock shape M > threadblock shape N " -+ "requires that threadblock shape M be a multiple of threadblock shape N."); -+ -+ static int32_t const kThreadblockSkewRatio = ThreadblockShape::kM / ThreadblockShape::kN; -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t min_dim(cutlass::gemm::GemmCoord grid) { -+ return grid.m(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t macro_row_to_row(int32_t row, int32_t threadblock_id) { -+ return row; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t macro_col_to_col(int32_t col, int32_t threadblock_id) { -+ return (col * kThreadblockSkewRatio) + (threadblock_id % kThreadblockSkewRatio); -+ } -+}; -+ -+// Partial specialization for the case where threadblock shape M < threadblock shape N -+template < -+ typename ThreadblockShape -+> -+struct Rank2KGroupedProblemVisitorOffsetHelper< -+ ThreadblockShape, -+ typename platform::enable_if< (ThreadblockShape::kM < ThreadblockShape::kN) >::type -+> { -+ -+ static_assert(ThreadblockShape::kN % ThreadblockShape::kM == 0, -+ "Rank2KGroupedProblemVisitor with threadblock shape M < threadblock shape N " -+ "requires that threadblock shape N be a multiple of threadblock shape M."); -+ -+ static int32_t const kThreadblockSkewRatio = ThreadblockShape::kN / ThreadblockShape::kM; -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t min_dim(cutlass::gemm::GemmCoord grid) { -+ return grid.n(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t macro_row_to_row(int32_t row, int32_t threadblock_id) { -+ return (row * kThreadblockSkewRatio) + (threadblock_id % kThreadblockSkewRatio); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t macro_col_to_col(int32_t col, int32_t threadblock_id) { -+ return col; -+ } -+}; -+ -+// Partial specialization for the case where threadblock shape M == threadblock shape N -+// In this case, macro tiles are equivalent to true tiles, so the conversions are -+// identity functions. -+template < -+ typename ThreadblockShape -+> -+struct Rank2KGroupedProblemVisitorOffsetHelper< -+ ThreadblockShape, -+ typename platform::enable_if< (ThreadblockShape::kM == ThreadblockShape::kN) >::type -+> { -+ -+ static int32_t const kThreadblockSkewRatio = 1; -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t min_dim(cutlass::gemm::GemmCoord grid) { -+ return grid.m(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t macro_row_to_row(int32_t row, int32_t threadblock_id) { -+ return row; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t macro_col_to_col(int32_t col, int32_t threadblock_id) { -+ return col; -+ } -+}; -+ -+// Helper for correctly representing problem sizes in grouped kernels -+template -+struct Rank2KGroupedProblemSizeHelper { -+ using OffsetHelper = Rank2KGroupedProblemVisitorOffsetHelper; -+ -+ CUTLASS_HOST_DEVICE -+ static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { -+ return cutlass::gemm::GemmCoord( -+ ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), -+ ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), -+ 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { -+ // Return the number of tiles at or below the diagonal (or at and above -+ // for mode kUpper). We do this by first calculating this value assuming -+ // we have a square matrix of tiles of size `dim x dim` where `dim` is the -+ // minimum among {grid.m(), grid.n()}. We then multiply the resulting value -+ // by OffsetHelper::kThreadblockSkewRatio to account for cases in which there -+ // are more tiles in one dimension than the other. -+ int32_t dim = OffsetHelper::min_dim(grid); -+ int32_t tiles_on_diagonal = dim; -+ int32_t tiles_below_diagonal = ((dim * (dim - 1)) / 2); -+ return (tiles_on_diagonal + tiles_below_diagonal) * OffsetHelper::kThreadblockSkewRatio; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {} -+}; -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Default problem visitor for fill modes kUpper and kLower. -+// -+template -+struct Rank2KGroupedProblemVisitor : public GroupedProblemVisitor< -+ detail::Rank2KGroupedProblemSizeHelper, -+ ThreadblockShape, -+ GroupScheduleMode_, -+ PrefetchTileCount, -+ ThreadCount> { -+ -+ static cutlass::FillMode const kFillModeC = FillModeC; -+ -+ static_assert(kFillModeC == cutlass::FillMode::kLower || kFillModeC == cutlass::FillMode::kUpper, -+ "Default Rank2KGroupedProblemVisitor requires fill mode of kLower or kUpper."); -+ -+ using ProblemSizeHelper = detail::Rank2KGroupedProblemSizeHelper; -+ using Base = GroupedProblemVisitor; -+ using OffsetHelper = typename ProblemSizeHelper::OffsetHelper; -+ using Params = typename Base::Params; -+ using SharedStorage = typename Base::SharedStorage; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ Rank2KGroupedProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base(params_, shared_storage_, block_idx) -+ {} -+ -+ CUTLASS_DEVICE -+ cutlass::gemm::GemmCoord threadblock_offset(int32_t threadblock_id) const { -+ int32_t macro_id = threadblock_id / OffsetHelper::kThreadblockSkewRatio; -+ int32_t macro_row = ceil(cutlass::fast_sqrt((2*macro_id) + 2.25) - 0.5) - 1; -+ int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); -+ -+ if (kFillModeC == cutlass::FillMode::kUpper) { -+ swap(macro_row, macro_col); -+ } -+ -+ int32_t row = OffsetHelper::macro_row_to_row(macro_row, threadblock_id); -+ int32_t col = OffsetHelper::macro_col_to_col(macro_col, threadblock_id); -+ -+ return cutlass::gemm::GemmCoord(row, col, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h -new file mode 100644 -index 0000000..0837a9d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h -@@ -0,0 +1,129 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Transpositions for Rank2K problems. -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ ComplexTransform TransformA, -+ int AlignmentA, -+ typename ElementB_, -+ typename LayoutB_, -+ ComplexTransform TransformB, -+ int AlignmentB, -+ typename LayoutC_, -+ FillMode FillModeC_, -+ bool Transpose -+> -+struct Rank2KMapArguments { -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ static ComplexTransform const kTransformA = TransformA; -+ static int const kAlignmentA = AlignmentA; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ static ComplexTransform const kTransformB = TransformB; -+ static int const kAlignmentB = AlignmentB; -+ using LayoutC = LayoutC_; -+ static FillMode const kFillModeC = FillModeC_; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ ComplexTransform TransformA, -+ int AlignmentA, -+ typename ElementB_, -+ typename LayoutB_, -+ ComplexTransform TransformB, -+ int AlignmentB, -+ typename LayoutC_, -+ FillMode FillModeC_ -+> -+struct Rank2KMapArguments< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ LayoutC_, -+ FillModeC_, -+ true -+> { -+ using ElementA = ElementB_; -+ using LayoutA = LayoutB_; -+ static ComplexTransform const kTransformA = TransformB; -+ static int const kAlignmentA = AlignmentB; -+ using ElementB = ElementA_; -+ using LayoutB = LayoutA_; -+ static ComplexTransform const kTransformB = TransformA; -+ static int const kAlignmentB = AlignmentA; -+ using LayoutC = typename layout::LayoutTranspose::type; -+ static FillMode const kFillModeC = InvertFillMode::mode; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_universal.h -new file mode 100644 -index 0000000..6d1f4ac ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_universal.h -@@ -0,0 +1,778 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma1_, ///! Threadblock-scoped matrix multiply-accumulate (A*B^T) -+ typename Mma2_, ///! Threadblock-scoped matrix multiply-accumulate (B*A^T) -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ FillMode FillModeC_, ///! Fill Mode for C (kLower or kUpper) -+ BlasMode BlasMode_ ///! Blas3 computation mode -+> -+struct Rank2KUniversal { -+public: -+ -+ using Mma1 = Mma1_; -+ using Mma2 = Mma2_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma1::IteratorA::Element; -+ using ElementB = typename Mma1::IteratorB::Element; -+ -+ // Mma1 (A x B^T) -+ using LayoutA = typename Mma1::IteratorA::Layout; -+ using LayoutBT = typename Mma1::IteratorB::Layout; -+ static ComplexTransform const kMma1TransformA = Mma1::kTransformA; -+ static ComplexTransform const kMma1TransformB = Mma1::kTransformB; -+ -+ // Mma2 (B x A^T) -+ using LayoutB = typename Mma2::IteratorA::Layout; -+ using LayoutAT = typename Mma2::IteratorB::Layout; -+ static ComplexTransform const kMma2TransformA = Mma2::kTransformA; -+ static ComplexTransform const kMma2TransformB = Mma2::kTransformB; -+ -+ // Common type definitions for Mma1 and Mma2 -+ using Operator = typename Mma1::Operator; -+ using OperatorClass = typename Mma1::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma1::Shape; -+ using WarpShape = typename Mma1::Operator::Shape; -+ using InstructionShape = typename Mma1::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma1::ArchTag; -+ -+ static int const kStages = Mma1::kStages; -+ static int const kAlignmentA = Mma1::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma1::IteratorB::AccessType::kElements; -+ -+ // Output related typedefinitions -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ static FillMode const kFillModeC = FillModeC_; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma1::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldc; -+ typename LayoutC::Stride::Index ldd; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1), -+ ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutB::Stride::Index ldb, -+ typename LayoutC::Stride::Index ldc, -+ typename LayoutC::Stride::Index ldd -+ ): -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_count), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { -+ -+ } -+ -+ /// Returns arguments for a the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ // Mma1 Iterator A and B params -+ typename Mma1::IteratorA::Params params_A; -+ typename Mma1::IteratorB::Params params_BT; -+ -+ // Mma2 Iterator A and B params -+ typename Mma2::IteratorA::Params params_B; -+ typename Mma2::IteratorB::Params params_AT; -+ -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ int *semaphore; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ params_A(0), -+ params_BT(0), -+ params_B(0), -+ params_AT(0), -+ params_C(0), -+ params_D(0), -+ batch_count(0), -+ gemm_k_size(0), -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ batch_stride_A(0), -+ batch_stride_B(0), -+ batch_stride_C(0), -+ batch_stride_D(0), -+ semaphore(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ int gemm_k_size, -+ void *workspace = nullptr -+ ): -+ problem_size(args.problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(args.lda), -+ params_BT(args.ldb), -+ params_B(args.ldb), -+ params_AT(args.lda), -+ params_C(args.ldc), -+ params_D(args.ldd), -+ output_op(args.epilogue), -+ mode(args.mode), -+ batch_count(args.batch_count), -+ gemm_k_size(gemm_k_size), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(const_cast(args.ptr_D)), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_D(args.batch_stride_D), -+ semaphore(static_cast(workspace)) { -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr) { -+ -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ output_op = args.epilogue; -+ -+ semaphore = static_cast(workspace); -+ } -+ -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma1::SharedStorage mma1_main_loop; -+ typename Mma2::SharedStorage mma2_main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ Rank2KUniversal() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ static int const kAlignmentA = Mma1::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma1::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ return; -+ } -+ -+ // Early exit if Fill Mode is Lower and -+ // if the entire tile is above the main diagonal (bottom-left corner is at or above the diagonal) -+ if (kFillModeC == cutlass::FillMode::kLower && -+ (threadblock_tile_offset.m() + 1) * Mma1::Shape::kM <= threadblock_tile_offset.n() * Mma1::Shape::kN) { -+ return; -+ } -+ -+ // Early exit if Fill Mode is Upper and -+ // if the entire tile is below the main diagonal (top-right corner is at or below the diagonal) -+ if (kFillModeC == cutlass::FillMode::kUpper && -+ threadblock_tile_offset.m() * Mma1::Shape::kM >= (threadblock_tile_offset.n() + 1) * Mma1::Shape::kN) { -+ return; -+ } -+ -+ bool tile_on_diagonal = false; -+ // Mark tiles that are being crossed by the main diagonal -+ // (top-right and bottom-left corners are on either side of the diagonal) -+ if ((threadblock_tile_offset.m() + 1) * Mma1::Shape::kM > threadblock_tile_offset.n() * Mma1::Shape::kN -+ && threadblock_tile_offset.m() * Mma1::Shape::kM < (threadblock_tile_offset.n() + 1) * Mma1::Shape::kN) { -+ tile_on_diagonal = true; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_MxK{ -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_KxN{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ }; -+ -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands for Mma1 -+ typename Mma1::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_MxK); -+ -+ typename Mma1::IteratorB iterator_BT( -+ params.params_BT, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_KxN); -+ -+ // Construct iterators to A and B operands for Mma2 -+ typename Mma2::IteratorA iterator_B( -+ params.params_B, -+ ptr_B, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_MxK); -+ -+ typename Mma2::IteratorB iterator_AT( -+ params.params_AT, -+ ptr_A, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_KxN); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply for Mma1 (A x BT) -+ Mma1 mma1(shared_storage.mma1_main_loop, thread_idx, warp_idx, lane_idx); -+ -+ // Construct thread-scoped matrix multiply for Mma2 (B x AT) -+ Mma2 mma2(shared_storage.mma2_main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma1::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add (A x BT) -+ mma1( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_BT, -+ accumulators); -+ -+ // HER2K kernel needs Alpha to be complex and is conj(Alpha) is applied to the second HERK. -+ if (kBlasMode == BlasMode::kHermitian) { -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ -+ // If CTA not on diagonal, FillMode doesn't apply. -+ FillMode kFillModeCTA = tile_on_diagonal ? kFillModeC : FillMode::kNone; -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ kFillModeCTA -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ kFillModeCTA -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ -+ __syncthreads(); -+ -+ accumulators.clear(); -+ } -+ -+ // Compute threadblock-scoped matrix multiply-add (B x AT) -+ mma2( -+ gemm_k_iterations, -+ accumulators, -+ iterator_B, -+ iterator_AT, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ /* Needed for HER2K where the second HERK is multiplied by conj(alpha) */ -+ typename EpilogueOutputOp::Params second_her2k_params(conj(params.output_op.alpha), 1); -+ EpilogueOutputOp output_op_her2k(second_her2k_params); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ -+ // HER2K kernel needs Alpha to be complex and is conj(Alpha) is applied to the second HERK. -+ if (kBlasMode == BlasMode::kHermitian) { -+ ptr_C = static_cast(params.ptr_D); -+ } -+ -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ if (kBlasMode == BlasMode::kSymmetric) { -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } else { -+ output_op_her2k.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ -+ // If CTA not on diagonal, FillMode doesn't apply. -+ FillMode kFillModeCTA = tile_on_diagonal ? kFillModeC : FillMode::kNone; -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ kFillModeCTA -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ kFillModeCTA -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ if (kBlasMode == BlasMode::kSymmetric) { -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ } else { -+ epilogue( -+ output_op_her2k, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ } -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_k_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_k_universal.h -new file mode 100644 -index 0000000..b7d1ad1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_k_universal.h -@@ -0,0 +1,565 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ FillMode FillModeC_ ///! Fill Mode for C (kLower or kUpper) -+> -+struct RankKUniversal { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ static FillMode const kFillModeC = FillModeC_; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = 128 / sizeof_bits::value; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldc; -+ typename LayoutC::Stride::Index ldd; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1), -+ ptr_A(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutC::Stride::Index ldc, -+ typename LayoutC::Stride::Index ldd -+ ): -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_count), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { -+ -+ } -+ -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ int *semaphore; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ params_A(0), -+ params_B(0), -+ params_C(0), -+ params_D(0), -+ batch_count(0), -+ gemm_k_size(0), -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ batch_stride_A(0), -+ batch_stride_B(0), -+ batch_stride_C(0), -+ batch_stride_D(0), -+ semaphore(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ int gemm_k_size, -+ void *workspace = nullptr -+ ): -+ problem_size(args.problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(args.lda), -+ params_B(args.lda), -+ params_C(args.ldc), -+ params_D(args.ldd), -+ output_op(args.epilogue), -+ mode(args.mode), -+ batch_count(args.batch_count), -+ gemm_k_size(gemm_k_size), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_A)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(const_cast(args.ptr_D)), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_A), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_D(args.batch_stride_D), -+ semaphore(static_cast(workspace)) { -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr) { -+ -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_A); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ output_op = args.epilogue; -+ -+ semaphore = static_cast(workspace); -+ } -+ -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ RankKUniversal() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ return; -+ } -+ -+ // Early exit if Fill Mode is Lower and -+ // if the entire tile is above the main diagonal (bottom-left corner is at or above the diagonal) -+ if (kFillModeC == cutlass::FillMode::kLower && -+ (threadblock_tile_offset.m() + 1) * Mma::Shape::kM <= threadblock_tile_offset.n() * Mma::Shape::kN) { -+ return; -+ } -+ -+ // Early exit if Fill Mode is Upper and -+ // if the entire tile is below the main diagonal (top-right corner is at or below the diagonal) -+ if (kFillModeC == cutlass::FillMode::kUpper && -+ threadblock_tile_offset.m() * Mma::Shape::kM >= (threadblock_tile_offset.n() + 1) * Mma::Shape::kN) { -+ return; -+ } -+ -+ bool tile_on_diagonal = false; -+ // Mark tiles that are being crossed by the main diagonal -+ // (top-right and bottom-left corners are on either side of the diagonal) -+ if ((threadblock_tile_offset.m() + 1) * Mma::Shape::kM > threadblock_tile_offset.n() * Mma::Shape::kN -+ && threadblock_tile_offset.m() * Mma::Shape::kM < (threadblock_tile_offset.n() + 1) * Mma::Shape::kN) { -+ tile_on_diagonal = true; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ -+ // If CTA not on diagonal, FillMode doesn't apply. -+ FillMode kFillModeCTA = tile_on_diagonal ? kFillModeC : FillMode::kNone; -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ kFillModeCTA -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ kFillModeCTA -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp -new file mode 100644 -index 0000000..efe51e2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp -@@ -0,0 +1,252 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/kernel_hardware_info.hpp" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+#include "cute/tensor.hpp" -+ -+namespace cutlass::gemm::kernel { -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class ProblemShape_, -+ class CollectiveMainloop_, -+ class CollectiveEpilogue_, -+ class GridSwizzle_ -+> -+class GemmUniversal< -+ ProblemShape_, -+ CollectiveMainloop_, -+ CollectiveEpilogue_, -+ GridSwizzle_, -+ std::enable_if_t>> -+{ -+public: -+ // -+ // Type Aliases -+ // -+ using ProblemShape = ProblemShape_; -+ using GridSwizzle = GridSwizzle_; -+ static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, -+ "ProblemShape{} should be or "); -+ -+ // Mainloop derived types -+ using CollectiveMainloop = CollectiveMainloop_; -+ using TileShape = typename CollectiveMainloop::TileShape; -+ using TiledMma = typename CollectiveMainloop::TiledMma; -+ using ArchTag = typename CollectiveMainloop::ArchTag; -+ using ElementA = typename CollectiveMainloop::ElementA; -+ using StrideA = typename CollectiveMainloop::StrideA; -+ using ElementB = typename CollectiveMainloop::ElementB; -+ using StrideB = typename CollectiveMainloop::StrideB; -+ using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; -+ using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; -+ using MainloopParams = typename CollectiveMainloop::Params; -+ -+ // Epilogue derived types -+ using CollectiveEpilogue = CollectiveEpilogue_; -+ using ElementC = typename CollectiveEpilogue::ElementC; -+ using StrideC = typename CollectiveEpilogue::StrideC; -+ using ElementD = typename CollectiveEpilogue::ElementD; -+ using StrideD = typename CollectiveEpilogue::StrideD; -+ using EpilogueParams = typename CollectiveEpilogue::Params; -+ static_assert(std::is_same_v, -+ "Mainloop and epilogue do not agree on accumulator value type."); -+ -+ static constexpr int SharedStorageSize = cute::max( -+ sizeof(typename CollectiveMainloop::SharedStorage), -+ sizeof(typename CollectiveEpilogue::SharedStorage)); -+ -+ static constexpr uint32_t MaxThreadsPerBlock = cute::size(TiledMma{}); -+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1; -+ -+ // Device side arguments -+ struct Arguments { -+ GemmUniversalMode mode{}; -+ ProblemShape problem_shape{}; -+ ElementA const* ptr_A = nullptr; -+ StrideA dA{}; -+ ElementB const* ptr_B = nullptr; -+ StrideB dB{}; -+ EpilogueParams epilogue_params{}; -+ KernelHardwareInfo hw_info; -+ }; -+ -+ // Kernel entry point API -+ struct Params { -+ GemmUniversalMode mode; -+ ProblemShape problem_shape; -+ MainloopParams mainloop; -+ EpilogueParams epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ // Convert to underlying arguments. In this case, a simple copy for the aliased type. -+ static -+ Params -+ to_underlying_arguments(Arguments const& args, void* workspace) { -+ (void) workspace; -+ return { -+ args.mode, -+ args.problem_shape, -+ CollectiveMainloop::to_underlying_arguments(args, workspace), -+ CollectiveEpilogue::to_underlying_arguments(args, workspace) -+ }; -+ } -+ -+ static -+ bool -+ can_implement(Arguments const& args) { -+ return args.mode == GemmUniversalMode::kGemm or -+ (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); -+ } -+ -+ static -+ int -+ get_workspace_size(Arguments const& args) { -+ return 0; -+ } -+ -+ static constexpr -+ dim3 -+ get_grid_shape(Params const& params) { -+ int batch_count = 1; -+ if constexpr (rank(ProblemShape{}) == 4) { -+ batch_count = cute::size<3>(params.problem_shape); -+ } -+ -+ return dim3( -+ cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), -+ cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), -+ batch_count -+ ); -+ } -+ -+ static constexpr -+ dim3 -+ get_block_shape() { -+ return dim3(MaxThreadsPerBlock, 1, 1); -+ } -+ -+ CUTLASS_DEVICE -+ void -+ operator()(Params const& params, char* smem_buf) { -+ using namespace cute; -+ using X = Underscore; -+ -+ // Preconditions -+ CUTE_STATIC_ASSERT(is_static::value); -+ -+ // Separate out problem shape for convenience -+ // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ auto M = get<0>(problem_shape_MNKL); -+ auto N = get<1>(problem_shape_MNKL); -+ auto K = get<2>(problem_shape_MNKL); -+ auto L = get<3>(problem_shape_MNKL); -+ -+ // Preconditions -+ static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ -+ // Get the appropriate blocks for this thread block -- potential for thread block locality -+ int thread_idx = int(threadIdx.x); -+ auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) -+ auto [m_coord, n_coord, l_coord] = blockIdx; -+ auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); // (m,n,k,l) -+ -+ // Represent the full tensors -+ Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l) -+ Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); //(n,k,l) -+ -+ // Get batch slice -+ Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k) -+ Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k) -+ -+ // Slice to get the tiles this thread block is responsible for -+ Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) -+ Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) -+ -+ // Compute tile residues for predication -+ auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord -+ auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord -+ auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max -+ auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); -+ -+ // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape -+ TiledMma tiled_mma; -+ Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) -+ clear(accumulators); -+ -+ auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); -+ int k_tile_count = size<2>(gA); -+ -+ // Perform the collective scoped MMA -+ CollectiveMainloop collective_mma; -+ collective_mma( -+ accumulators, -+ gA, -+ gB, -+ accumulators, -+ k_tile_iter, k_tile_count, -+ residue_mnk, -+ thread_idx, -+ smem_buf -+ ); -+ -+ // Epilogue and write to gD -+ CollectiveEpilogue epilogue{params.epilogue}; -+ epilogue( -+ problem_shape_MNKL, -+ blk_shape, -+ blk_coord_mnkl, -+ accumulators, -+ tiled_mma, -+ residue_mnk, -+ thread_idx, -+ smem_buf -+ ); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::kernel -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp -new file mode 100644 -index 0000000..bd82ed1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp -@@ -0,0 +1,301 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/kernel_hardware_info.hpp" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/mma_sm90.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -+ -+#include "cute/tensor.hpp" -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::kernel { -+ -+namespace detail { -+ -+// IF_SWAP_AB::value will be true only if: -+// class T has member SwapAB and T::SwapAB is true -+template -+struct IF_SWAP_AB { static constexpr bool value = false; }; -+ -+template -+struct IF_SWAP_AB > -+{ static constexpr bool value = T::SwapAB; }; -+ -+} // namespace -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class ProblemShape_, -+ class CollectiveMainloop_, -+ class CollectiveEpilogue_, -+ class GridSwizzle_ -+> -+class GemmUniversal< -+ ProblemShape_, -+ CollectiveMainloop_, -+ CollectiveEpilogue_, -+ GridSwizzle_, -+ std::enable_if_t>> -+{ -+public: -+ // -+ // Type Aliases -+ // -+ using ProblemShape = ProblemShape_; -+ using GridSwizzle = GridSwizzle_; -+ static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, -+ "ProblemShape{} should be or "); -+ -+ // Mainloop derived types -+ using CollectiveMainloop = CollectiveMainloop_; -+ using TileShape = typename CollectiveMainloop::TileShape; -+ using TiledMma = typename CollectiveMainloop::TiledMma; -+ using ArchTag = typename CollectiveMainloop::ArchTag; -+ using ElementA = typename CollectiveMainloop::ElementA; -+ using StrideA = typename CollectiveMainloop::StrideA; -+ using ElementB = typename CollectiveMainloop::ElementB; -+ using StrideB = typename CollectiveMainloop::StrideB; -+ using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; -+ using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; -+ using ClusterShape = typename DispatchPolicy::ClusterShape; -+ using MainloopParams = typename CollectiveMainloop::Params; -+ static_assert(ArchTag::kMinComputeCapability >= 90); -+ -+ // Epilogue derived types -+ using CollectiveEpilogue = CollectiveEpilogue_; -+ using ElementC = typename CollectiveEpilogue::ElementC; -+ using StrideC = typename CollectiveEpilogue::StrideC; -+ using ElementD = typename CollectiveEpilogue::ElementD; -+ using StrideD = typename CollectiveEpilogue::StrideD; -+ using EpilogueParams = typename CollectiveEpilogue::Params; -+ static_assert(std::is_same_v, -+ "Mainloop and epilogue do not agree on accumulator value type."); -+ -+ static constexpr int SharedStorageSize = cute::max( -+ sizeof(typename CollectiveMainloop::SharedStorage), -+ sizeof(typename CollectiveEpilogue::SharedStorage)); -+ -+ static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); -+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1; -+ -+ // Device side arguments -+ struct Arguments { -+ GemmUniversalMode mode{}; -+ ProblemShape problem_shape{}; -+ ElementA const* ptr_A = nullptr; -+ StrideA dA{}; -+ ElementB const* ptr_B = nullptr; -+ StrideB dB{}; -+ EpilogueParams epilogue_params{}; -+ KernelHardwareInfo hw_info; -+ }; -+ -+ // Kernel entry point API -+ struct Params { -+ GemmUniversalMode mode; -+ ProblemShape problem_shape; -+ MainloopParams mainloop; -+ EpilogueParams epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ // Convert to underlying arguments. In this case, a simple copy for the aliased type. -+ static -+ Params -+ to_underlying_arguments(Arguments const& args, void* workspace) { -+ (void) workspace; -+ auto problem_shape = args.problem_shape; -+ if constexpr (detail::IF_SWAP_AB::value) { -+ // swap M/N -+ get<0>(problem_shape) = get<1>(args.problem_shape); -+ get<1>(problem_shape) = get<0>(args.problem_shape); -+ } -+ return { -+ args.mode, -+ problem_shape, -+ CollectiveMainloop::to_underlying_arguments(args, workspace), -+ CollectiveEpilogue::to_underlying_arguments(args, workspace) -+ }; -+ } -+ -+ CUTLASS_HOST_DEVICE static -+ bool -+ can_implement(Arguments const& args) { -+ return args.mode == GemmUniversalMode::kGemm or -+ (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); -+ } -+ -+ static -+ int -+ get_workspace_size(Arguments const& args) { -+ return 0; -+ } -+ -+ // Computes the kernel launch grid shape based on runtime parameters -+ static constexpr -+ dim3 -+ get_grid_shape(Params const& params) { -+ auto cluster_shape = ClusterShape{}; -+ auto tile_shape = TileShape{}; -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ return detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl( -+ problem_shape_MNKL, tile_shape, cluster_shape); -+ } -+ -+ static constexpr -+ dim3 -+ get_block_shape() { -+ return dim3(MaxThreadsPerBlock, 1, 1); -+ } -+ -+ CUTLASS_DEVICE -+ void -+ operator()(Params const& params, char* smem_buf) { -+ using namespace cute; -+ using X = Underscore; -+ -+ // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -+ #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) -+ if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { -+ printf("ERROR : Arch conditional MMA instruction used without targetting sm90a compute capability. Aborting.\n"); -+ return; -+ } -+ #endif -+ -+ // Preconditions -+ static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ -+ int thread_idx = int(threadIdx.x); -+ int warp_idx = canonical_warp_idx(); -+ int lane_predicate = cute::elect_one_sync(); -+ -+ // Issue Tma Descriptor Prefetch from a single thread -+ if ((warp_idx == 0) && lane_predicate) { -+ CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); -+ } -+ -+ // Separate out problem shape for convenience -+ // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ auto M = get<0>(problem_shape_MNKL); -+ auto N = get<1>(problem_shape_MNKL); -+ auto K = get<2>(problem_shape_MNKL); -+ auto L = get<3>(problem_shape_MNKL); -+ -+ // TMA requires special handling of strides to deal with coord codomain mapping -+ // Represent the full tensors -- get these from TMA -+ Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) -+ Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) -+ -+ // Get the appropriate blocks for this thread block -- potential for thread block locality -+ auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) -+ auto blk_coord = make_coord(_,_,_); // (m,n,k) -- defer the slice -+ -+ // Make tiled views -+ Tensor gA_mkl = local_tile(mA_mkl, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) -+ Tensor gB_nkl = local_tile(mB_nkl, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) -+ -+ // Compute m_coord, n_coord, and l_coord with their post-tiled shapes -+ auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); -+ auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); -+ auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); -+ auto output_tile_coord = make_coord(m_coord, n_coord, _, l_coord); -+ -+ // Slice with m_coord and n_coord -+ Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) -+ Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) -+ -+ // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape -+ TiledMma tiled_mma; -+ Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) -+ -+ clear(accumulators); -+ -+ auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); -+ auto k_tile_count = size<2>(gA); -+ -+ // Perform the collective scoped MMA -+ CollectiveMainloop collective_mma; -+ collective_mma( -+ gA, params.mainloop.tma_load_a, -+ gB, params.mainloop.tma_load_b, -+ accumulators, -+ k_tile_iter, k_tile_count, -+ thread_idx, -+ smem_buf, -+ params.mainloop -+ ); -+ -+ constexpr int BLK_M_RANK = rank<0>(blk_shape); -+ bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl); -+ auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { -+ return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); -+ })); -+ -+ constexpr int BLK_N_RANK = rank<1>(blk_shape); -+ bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl); -+ auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { -+ return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); -+ })); -+ auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); -+ -+ // Epilogue and write to gD -+ CollectiveEpilogue epilogue{params.epilogue}; -+ epilogue( -+ problem_shape_MNKL, -+ blk_shape, -+ output_tile_coord, -+ accumulators, -+ tiled_mma, -+ residue_mnk, -+ thread_idx, -+ smem_buf -+ ); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::kernel -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp -new file mode 100644 -index 0000000..9fc719e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp -@@ -0,0 +1,351 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/kernel_hardware_info.hpp" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/reg_reconfig.h" -+#include "cutlass/arch/mma_sm90.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -+#include "cutlass/pipeline.hpp" -+#include "cute/tensor.hpp" -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::kernel { -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class ProblemShape_, -+ class CollectiveMainloop_, -+ class CollectiveEpilogue_, -+ class GridSwizzle_ -+> -+class GemmUniversal< -+ ProblemShape_, -+ CollectiveMainloop_, -+ CollectiveEpilogue_, -+ GridSwizzle_, -+ std::enable_if_t>> -+{ -+public: -+ // -+ // Type Aliases -+ // -+ using ProblemShape = ProblemShape_; -+ using GridSwizzle = GridSwizzle_; -+ static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, -+ "ProblemShape{} should be or "); -+ -+ // Mainloop derived types -+ using CollectiveMainloop = CollectiveMainloop_; -+ using TileShape = typename CollectiveMainloop::TileShape; -+ using TiledMma = typename CollectiveMainloop::TiledMma; -+ using ArchTag = typename CollectiveMainloop::ArchTag; -+ using ElementA = typename CollectiveMainloop::ElementA; -+ using StrideA = typename CollectiveMainloop::StrideA; -+ using ElementB = typename CollectiveMainloop::ElementB; -+ using StrideB = typename CollectiveMainloop::StrideB; -+ using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; -+ using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; -+ using ClusterShape = typename DispatchPolicy::ClusterShape; -+ using MainloopParams = typename CollectiveMainloop::Params; -+ static_assert(ArchTag::kMinComputeCapability >= 90); -+ -+ // Epilogue derived types -+ using CollectiveEpilogue = CollectiveEpilogue_; -+ using ElementC = typename CollectiveEpilogue::ElementC; -+ using StrideC = typename CollectiveEpilogue::StrideC; -+ using ElementD = typename CollectiveEpilogue::ElementD; -+ using StrideD = typename CollectiveEpilogue::StrideD; -+ using EpilogueParams = typename CollectiveEpilogue::Params; -+ static_assert(std::is_same_v, -+ "Mainloop and epilogue do not agree on accumulator value type."); -+ -+ static constexpr int SharedStorageSize = cute::max( -+ sizeof(typename CollectiveMainloop::SharedStorage), -+ sizeof(typename CollectiveEpilogue::SharedStorage)); -+ -+ static constexpr uint32_t NumDmaWarpGroups = 1; -+ static constexpr uint32_t NumMmaWarpGroups = 1; -+ static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumDmaWarpGroups * NumThreadsPerWarpGroup); -+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1; -+ -+ // Device side arguments -+ struct Arguments { -+ GemmUniversalMode mode{}; -+ ProblemShape problem_shape{}; -+ ElementA const* ptr_A = nullptr; -+ StrideA dA{}; -+ ElementB const* ptr_B = nullptr; -+ StrideB dB{}; -+ EpilogueParams epilogue_params{}; -+ KernelHardwareInfo hw_info; -+ }; -+ -+ // Kernel entry point API -+ struct Params { -+ GemmUniversalMode mode; -+ ProblemShape problem_shape; -+ MainloopParams mainloop; -+ EpilogueParams epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ // Convert to underlying arguments. In this case, a simple copy for the aliased type. -+ static -+ Params -+ to_underlying_arguments(Arguments const& args, void* workspace) { -+ (void) workspace; -+ auto problem_shape = args.problem_shape; -+ if constexpr (detail::IF_SWAP_AB::value) { -+ // swap M/N -+ get<0>(problem_shape) = get<1>(args.problem_shape); -+ get<1>(problem_shape) = get<0>(args.problem_shape); -+ } -+ return { -+ args.mode, -+ problem_shape, -+ CollectiveMainloop::to_underlying_arguments(args, workspace), -+ CollectiveEpilogue::to_underlying_arguments(args, workspace) -+ }; -+ } -+ -+ CUTLASS_HOST_DEVICE static -+ bool -+ can_implement(Arguments const& args) { -+ return args.mode == GemmUniversalMode::kGemm or -+ (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); -+ } -+ -+ static -+ int -+ get_workspace_size(Arguments const& args) { -+ return 0; -+ } -+ -+ // Computes the kernel launch grid shape based on runtime parameters -+ static constexpr -+ dim3 -+ get_grid_shape(Params const& params) { -+ auto cluster_shape = ClusterShape{}; -+ auto tile_shape = TileShape{}; -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ return detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl( -+ problem_shape_MNKL, tile_shape, cluster_shape); -+ } -+ -+ static constexpr -+ dim3 -+ get_block_shape() { -+ return dim3(MaxThreadsPerBlock, 1, 1); -+ } -+ -+ CUTLASS_DEVICE -+ void -+ operator()(Params const& params, char* smem_buf) { -+ using namespace cute; -+ using X = Underscore; -+ -+ // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -+ #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) -+ if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { -+ printf("ERROR : Arch conditional MMA instruction used without targetting sm90a compute capability. Aborting.\n"); -+ return; -+ } -+ #endif -+ -+ enum class WarpGroupRole { -+ Producer = 0, -+ Consumer = 1, -+ }; -+ -+ int thread_idx = int(threadIdx.x); -+ int warp_idx = canonical_warp_idx(); -+ int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; -+ auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); -+ int lane_predicate = cute::elect_one_sync(); -+ -+ // Issue Tma Descriptor Prefetch from a single thread -+ if ((warp_idx == 0) && lane_predicate) { -+ CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); -+ } -+ -+ using Pipeline = typename CollectiveMainloop::MainloopPipeline; -+ -+ using PipelineParams = typename CollectiveMainloop::PipelineParams; -+ PipelineParams params_pipeline; -+ params_pipeline.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; -+ if (warp_group_role == WarpGroupRole::Producer) { -+ params_pipeline.role = Pipeline::ThreadCategory::Producer; -+ } -+ else { -+ params_pipeline.role = Pipeline::ThreadCategory::Consumer; -+ } -+ params_pipeline.is_leader = warp_group_thread_idx == 0; -+ params_pipeline.num_consumers = NumThreadsPerWarpGroup; -+ -+ // Initialize pipeline and setup starting pipeline state for the collectives -+ Pipeline pipeline = CollectiveMainloop::make_pipeline(smem_buf, params_pipeline); -+ -+ auto cluster_wait_fn = [&] () { -+ // We need this to guarantee that the Pipeline init is visible -+ // To all producers and consumer thread blocks in the Cluster -+ if constexpr (size(ClusterShape{}) > 1) { -+ cute::cluster_arrive_relaxed(); -+ return [] () { cute::cluster_wait(); }; -+ } -+ else { -+ __syncthreads(); -+ return [] () {}; // do nothing -+ } -+ } (); -+ -+ // Preconditions -+ static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ -+ // Separate out problem shape for convenience -+ // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ auto M = get<0>(problem_shape_MNKL); -+ auto N = get<1>(problem_shape_MNKL); -+ auto K = get<2>(problem_shape_MNKL); -+ auto L = get<3>(problem_shape_MNKL); -+ -+ // TMA requires special handling of strides to deal with coord codomain mapping -+ // Represent the full tensors -- get these from TMA -+ Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) -+ Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) -+ -+ // Get the appropriate blocks for this thread block -- potential for thread block locality -+ auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) -+ auto blk_coord = make_coord(_,_,_); // (m,n,k) -- defer the slice -+ -+ // Make tiled views -+ Tensor gA_mkl = local_tile(mA_mkl, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) -+ Tensor gB_nkl = local_tile(mB_nkl, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) -+ -+ // Compute m_coord, n_coord, and l_coord with their post-tiled shapes -+ auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); -+ auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); -+ auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); -+ auto output_tile_coord = make_coord(m_coord, n_coord, _, l_coord); -+ -+ // Slice with m_coord and n_coord -+ Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) -+ Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) -+ -+ auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); -+ auto k_tile_count = size<2>(gA); -+ -+ // Wait for all thread blocks in the Cluster -+ cluster_wait_fn(); -+ -+ // In a warp specialized kernel, CollectiveMainloop exposes data movement and compute operations separately -+ CollectiveMainloop collective_mainloop; -+ -+ if (warp_group_role == WarpGroupRole::Producer) { -+ // For the DMA (prologue) - we start with an opposite phase - since we skip all waits -+ // i.e., we know that the buffer is indeed empty -+ typename CollectiveMainloop::PipelineState smem_pipe_write = cutlass::make_producer_start_state(); -+ collective_mainloop.dma( -+ pipeline, -+ smem_pipe_write, -+ gA, params.mainloop.tma_load_a, -+ gB, params.mainloop.tma_load_b, -+ k_tile_iter, k_tile_count, -+ thread_idx, -+ smem_buf -+ ); -+ // Update starting pipeline state for the next tile -+ smem_pipe_write.advance(k_tile_count); -+ // Make sure all Consumer Warp Groups have been waited upon -+ collective_mainloop.dma_epilogue(pipeline, smem_pipe_write); -+ } -+ else if (warp_group_role == WarpGroupRole::Consumer) { -+ typename CollectiveMainloop::PipelineState smem_pipe_read; -+ TiledMma tiled_mma; -+ Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) -+ clear(accumulators); -+ -+ collective_mainloop.mma( -+ pipeline, -+ smem_pipe_read, -+ accumulators, -+ k_tile_count, -+ thread_idx, -+ smem_buf, -+ params.mainloop -+ ); -+ -+ constexpr int BLK_M_RANK = rank<0>(blk_shape); -+ bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl); -+ auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { -+ return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); -+ })); -+ -+ constexpr int BLK_N_RANK = rank<1>(blk_shape); -+ bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl); -+ auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { -+ return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); -+ })); -+ auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); -+ -+ // Epilogue and write to gD -+ CollectiveEpilogue epilogue{params.epilogue}; -+ epilogue( -+ problem_shape_MNKL, -+ blk_shape, -+ output_tile_coord, -+ accumulators, -+ tiled_mma, -+ residue_mnk, -+ warp_group_thread_idx, -+ smem_buf -+ ); -+ } -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::kernel -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp -new file mode 100644 -index 0000000..498bfad ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp -@@ -0,0 +1,487 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/kernel_hardware_info.hpp" -+#include "cutlass/fast_math.h" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/reg_reconfig.h" -+#include "cutlass/arch/mma_sm90.h" -+#include "cutlass/pipeline.hpp" -+#include "cutlass/trace.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -+ -+#include "cute/tensor.hpp" -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::kernel { -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class ProblemShape_, -+ class CollectiveMainloop_, -+ class CollectiveEpilogue_, -+ class GridSwizzle_ -+> -+class GemmUniversal< -+ ProblemShape_, -+ CollectiveMainloop_, -+ CollectiveEpilogue_, -+ GridSwizzle_, -+ std::enable_if_t>> -+{ -+public: -+ // -+ // Type Aliases -+ // -+ using ProblemShape = ProblemShape_; -+ using GridSwizzle = GridSwizzle_; -+ static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, -+ "ProblemShape{} should be or "); -+ -+ // Mainloop derived types -+ using CollectiveMainloop = CollectiveMainloop_; -+ using TileShape = typename CollectiveMainloop::TileShape; -+ using TiledMma = typename CollectiveMainloop::TiledMma; -+ using ArchTag = typename CollectiveMainloop::ArchTag; -+ using ElementA = typename CollectiveMainloop::ElementA; -+ using StrideA = typename CollectiveMainloop::StrideA; -+ using ElementB = typename CollectiveMainloop::ElementB; -+ using StrideB = typename CollectiveMainloop::StrideB; -+ using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; -+ using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; -+ using ClusterShape = typename DispatchPolicy::ClusterShape; -+ using MainloopParams = typename CollectiveMainloop::Params; -+ static_assert(ArchTag::kMinComputeCapability >= 90); -+ -+ // Epilogue derived types -+ using CollectiveEpilogue = CollectiveEpilogue_; -+ using ElementC = typename CollectiveEpilogue::ElementC; -+ using StrideC = typename CollectiveEpilogue::StrideC; -+ using ElementD = typename CollectiveEpilogue::ElementD; -+ using StrideD = typename CollectiveEpilogue::StrideD; -+ using EpilogueParams = typename CollectiveEpilogue::Params; -+ static_assert(std::is_same_v, -+ "Mainloop and epilogue do not agree on accumulator value type."); -+ -+ static constexpr uint32_t NumDmaWarpGroups = 1; -+ static constexpr uint32_t NumMmaWarpGroups = 2; -+ static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); -+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1; -+ -+ /// Register requirement for DMA and MATH WGs -+ static constexpr uint32_t DmaRegisterRequirement = 40; -+ static constexpr uint32_t MmaRegisterRequirement = 232; -+ -+ /* Order Sequence barrier with two stages: one for Mainloop and one for Epilogue */ -+ static constexpr uint32_t StagesPerMathWarpGroup = 2; -+ using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier< -+ StagesPerMathWarpGroup, NumMmaWarpGroups>; -+ -+ // Kernel level shared memory storage -+ struct SharedStorage { -+ using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; -+ using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; -+ using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; -+ -+ MainloopSharedStorage mainloop; -+ EpilogueSharedStorage epilogue; -+ alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order_barrier_storage; -+ }; -+ -+ static constexpr int SharedStorageSize = sizeof(SharedStorage); -+ -+ // Device side arguments -+ struct Arguments { -+ GemmUniversalMode mode{}; -+ ProblemShape problem_shape{}; -+ ElementA const* ptr_A = nullptr; -+ StrideA dA{}; -+ ElementB const* ptr_B = nullptr; -+ StrideB dB{}; -+ EpilogueParams epilogue_params{}; -+ KernelHardwareInfo hw_info; -+ }; -+ -+ // Kernel entry point API -+ struct Params { -+ GemmUniversalMode mode; -+ ProblemShape problem_shape; -+ MainloopParams mainloop; -+ EpilogueParams epilogue; -+ KernelHardwareInfo hw_info; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ // Convert to underlying arguments. In this case, a simple copy for the aliased type. -+ static -+ Params -+ to_underlying_arguments(Arguments const& args, void* workspace) { -+ CUTLASS_TRACE_HOST("to_underlying_arguments():"); -+ -+ (void) workspace; -+ auto problem_shape = args.problem_shape; -+ if constexpr (detail::IF_SWAP_AB::value) { -+ // swap M/N -+ get<0>(problem_shape) = get<1>(args.problem_shape); -+ get<1>(problem_shape) = get<0>(args.problem_shape); -+ } -+ -+ // Get SM count if needed, otherwise use user supplied SM count -+ int sm_count = args.hw_info.sm_count; -+ if (sm_count <= 0) { -+ CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" -+ " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); -+ sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); -+ } -+ -+ CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); -+ return { -+ args.mode, -+ problem_shape, -+ CollectiveMainloop::to_underlying_arguments(args, workspace), -+ CollectiveEpilogue::to_underlying_arguments(args, workspace), -+ {args.hw_info.device_id, sm_count} -+ }; -+ } -+ -+ CUTLASS_HOST_DEVICE static -+ bool -+ can_implement(Arguments const& args) { -+ bool implementable = args.mode == GemmUniversalMode::kGemm or -+ (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); -+ -+ // Number of blocks per problem (without batch) must not exceed 2^31 for the persistent scheduler to calculate using FastDivmod -+ auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); -+ auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = -+ detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl(problem_shape_MNKL, TileShape{}, ClusterShape{}); -+ uint64_t problem_blocks = problem_blocks_m * problem_blocks_n * problem_blocks_l; -+ implementable = implementable && (problem_blocks < (uint64_t(1) << 31)); -+ -+ return implementable; -+ } -+ -+ static -+ int -+ get_workspace_size(Arguments const& args) { -+ return 0; -+ } -+ -+ // Computes the kernel launch grid shape based on runtime parameters -+ static constexpr -+ dim3 -+ get_grid_shape(Params const& params) { -+ int sm_count = params.hw_info.sm_count; -+ CUTLASS_TRACE_HOST("get_grid_shape(): Persistent schedule grid plan using SM count = " << sm_count); -+ -+ // Compute the total number of output tiles our problem has -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = -+ detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl(problem_shape_MNKL, TileShape{}, ClusterShape{}); -+ int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks_l; -+ -+ // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently -+ dim3 launch_grid(1, cute::size<1>(ClusterShape{}), 1); -+ -+ // The else path is generic, however, we can avoid some divs if we know Cluster size is 1 -+ if constexpr (size(ClusterShape{}) == 1) { -+ launch_grid.x = std::min(sm_count, problem_blocks_total); -+ } -+ else { -+ /* -+ * Optimal grid size calculation is based on -+ * GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU -+ * Hence, maximum SMs per GPC = 18 -+ */ -+ constexpr int max_sm_per_gpc = 18; -+ // Provided SM count could possibly be less than the assumed maximum SMs per GPC -+ int min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; -+ int max_blk_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % size(ClusterShape{})); -+ int blk_per_device = min_num_gpc * max_blk_occupancy_per_gpc; -+ -+ launch_grid.x = std::min( -+ blk_per_device / size<1>(ClusterShape{}), -+ problem_blocks_total / size<1>(ClusterShape{})); -+ } -+ -+ return launch_grid; -+ } -+ -+ static constexpr -+ dim3 -+ get_block_shape() { -+ return dim3(MaxThreadsPerBlock, 1, 1); -+ } -+ -+ CUTLASS_DEVICE -+ void -+ operator()(Params const& params, char* smem_buf) { -+ using namespace cute; -+ using X = Underscore; -+ -+ // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -+ #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) -+ if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { -+ printf("ERROR : Arch conditional MMA instruction used without targetting sm90a compute capability. Aborting.\n"); -+ return; -+ } -+ #endif -+ -+ // Preconditions -+ static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ -+ enum class WarpGroupRole { -+ Producer = 0, -+ Consumer0 = 1, -+ Consumer1 = 2 -+ }; -+ -+ // Kernel level shared memory storage -+ SharedStorage& shared_storage = *reinterpret_cast(smem_buf); -+ -+ int thread_idx = int(threadIdx.x); -+ int warp_idx = canonical_warp_idx(); -+ int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; -+ auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); -+ int lane_predicate = cute::elect_one_sync(); -+ -+ // Issue Tma Descriptor Prefetch from a single thread -+ if ((warp_idx == 0) && lane_predicate) { -+ CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); -+ } -+ -+ using Pipeline = typename CollectiveMainloop::MainloopPipeline; -+ using PipelineParams = typename CollectiveMainloop::PipelineParams; -+ PipelineParams params_pipeline; -+ params_pipeline.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; -+ if (warp_group_role == WarpGroupRole::Producer) { -+ params_pipeline.role = Pipeline::ThreadCategory::Producer; -+ } -+ else { -+ params_pipeline.role = Pipeline::ThreadCategory::Consumer; -+ } -+ params_pipeline.is_leader = warp_group_thread_idx == 0; -+ params_pipeline.num_consumers = NumThreadsPerWarpGroup; -+ -+ // Initialize pipeline and setup starting pipeline state for the collectives -+ Pipeline pipeline = CollectiveMainloop::make_pipeline(smem_buf, params_pipeline); -+ typename CollectiveMainloop::PipelineState collective_start_state_pipe; -+ -+ typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; -+ // DMA WG will not participate in these Ordered Barrier syncs -+ params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); -+ params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group -+ MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.math_wg_order_barrier_storage, params_math_wg_order_barrier); -+ -+ auto cluster_wait_fn = [&] () { -+ // We need this to guarantee that the Pipeline init is visible -+ // To all producers and consumer thread blocks in the Cluster -+ if constexpr (size(ClusterShape{}) > 1) { -+ cute::cluster_arrive_relaxed(); -+ return [] () { cute::cluster_wait(); }; -+ } -+ else { -+ __syncthreads(); -+ return [] () {}; // do nothing -+ } -+ } (); -+ -+ // Separate out problem shape for convenience -+ // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ auto M = get<0>(problem_shape_MNKL); -+ auto N = get<1>(problem_shape_MNKL); -+ auto K = get<2>(problem_shape_MNKL); -+ auto L = get<3>(problem_shape_MNKL); -+ -+ // TMA requires special handling of strides to deal with coord codomain mapping -+ // Represent the full tensors -- get these from TMA -+ Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) -+ Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) -+ -+ // Get the appropriate blocks for this thread block -- potential for thread block locality -+ auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) -+ auto blk_coord = make_coord(_,_,_); // (m,n,k) -- defer the slice -+ -+ // Slice to get the tiles this thread block is responsible for -+ Tensor gA_mkl = local_tile(mA_mkl, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) -+ Tensor gB_nkl = local_tile(mB_nkl, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) -+ -+ // Get iterations along k-dimension -+ auto k_tile_count = size<3>(gA_mkl); -+ -+ detail::PersistentTileSchedulerSm90 scheduler(problem_shape_MNKL, blk_shape, ClusterShape{}); -+ -+ if (warp_group_role == WarpGroupRole::Consumer1) { -+ /* Advance 2nd Math WG to the next work tile for the startup */ -+ scheduler.advance_to_next_work(); -+ /* Advance 2nd Math WG pipeline state to the end of 1st Math WG */ -+ collective_start_state_pipe.advance(k_tile_count); -+ } -+ auto work_tile_info = scheduler.get_current_work(); -+ -+ // Perform the collective scoped MMA -+ CollectiveMainloop collective_mainloop; -+ -+ // Wait for all thread blocks in the Cluster -+ cluster_wait_fn(); -+ -+ if (warp_group_role == WarpGroupRole::Producer) { -+ cutlass::arch::warpgroup_reg_dealloc(); -+ -+ // For the DMA (prologue) - we start with an opposite phase - since we skip all waits -+ // i.e., we know that the buffer is indeed empty -+ typename CollectiveMainloop::PipelineState smem_pipe_write = cutlass::make_producer_start_state(); -+ while (work_tile_info.is_valid_tile) { -+ // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape -+ auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); -+ auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); -+ auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); -+ auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); -+ -+ // Slice with our work tile coordinates to construct mainloop tensor views -+ Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) -+ Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) -+ -+ auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); -+ -+ collective_mainloop.dma( -+ pipeline, -+ smem_pipe_write, -+ gA, params.mainloop.tma_load_a, -+ gB, params.mainloop.tma_load_b, -+ k_tile_iter, k_tile_count, -+ thread_idx, -+ reinterpret_cast(&shared_storage.mainloop) -+ ); -+ // Update starting pipeline state for the next tile -+ smem_pipe_write.advance(k_tile_count); -+ scheduler.advance_to_next_work(); -+ work_tile_info = scheduler.get_current_work(); -+ } // Scheduler work fetch loop -+ -+ // Make sure all Consumer Warp Groups have been waited upon -+ collective_mainloop.dma_epilogue(pipeline, smem_pipe_write); -+ } // Producer Warp Group End -+ -+ else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { -+ // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape -+ cutlass::arch::warpgroup_reg_alloc(); -+ -+ while (work_tile_info.is_valid_tile) { -+ // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape -+ auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); -+ auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); -+ auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); -+ auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); -+ -+ // Slice with our work tile coordinates to construct mainloop tensor views -+ Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) -+ Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) -+ -+ auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); -+ -+ TiledMma tiled_mma; -+ Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) -+ clear(accumulators); -+ -+ /* Order two Math WG's MMA one after the other, helps hide Epilogue */ -+ math_wg_order_barrier.wait(); -+ -+ collective_mainloop.mma( -+ pipeline, -+ collective_start_state_pipe, -+ accumulators, -+ k_tile_count, -+ thread_idx, -+ reinterpret_cast(&shared_storage.mainloop), -+ params.mainloop -+ ); -+ -+ /* Cue for next Math WG's MMA to start */ -+ math_wg_order_barrier.arrive(); -+ -+ /* Order two Math WG's Epilogue one after the other */ -+ math_wg_order_barrier.wait(); -+ -+ constexpr int BLK_M_RANK = rank<0>(blk_shape); -+ bool m_oob = int(work_tile_info.M_idx) >= size<2>(gA_mkl); -+ auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { -+ return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); -+ })); -+ -+ constexpr int BLK_N_RANK = rank<1>(blk_shape); -+ bool n_oob = int(work_tile_info.N_idx) >= size<2>(gB_nkl); -+ auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { -+ return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); -+ })); -+ auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); -+ -+ // Epilogue and write to gD -+ CollectiveEpilogue epilogue{params.epilogue}; -+ epilogue( -+ problem_shape_MNKL, -+ blk_shape, -+ blk_coord, -+ accumulators, -+ tiled_mma, -+ residue_mnk, -+ warp_group_thread_idx, -+ reinterpret_cast(&shared_storage.epilogue) -+ ); -+ -+ /* Cue for next Math WG's Epilogue to start */ -+ math_wg_order_barrier.arrive(); -+ -+ // Update starting pipeline state for the next tile -+ collective_start_state_pipe.advance(k_tile_count * NumMmaWarpGroups); -+ -+ scheduler.advance_to_next_work(NumMmaWarpGroups); -+ work_tile_info = scheduler.get_current_work(); -+ } // Scheduler work fetch loop -+ } // Consumer Warp Groups End -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::kernel -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp -new file mode 100644 -index 0000000..496d5e0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp -@@ -0,0 +1,133 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/fast_math.h" -+#include "cute/layout.hpp" -+ -+namespace cutlass::gemm::kernel::detail { -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Persistent Thread Block (TB) scheduler -+class PersistentTileSchedulerSm90 { -+ // -+ // Data members -+ // -+ -+private: -+ uint32_t blocks_per_problem_; -+ uint32_t current_work_linear_idx_; -+ uint32_t grid_blocks_total_; -+ -+ FastDivmod divmod_batch_; -+ FastDivmod divmod_grid_y_; -+ FastDivmod divmod_blk_m_; -+ -+ struct WorkTileInfo { -+ int32_t M_idx = 0; -+ int32_t N_idx = 0; -+ int32_t L_idx = 0; -+ uint32_t is_valid_tile = false; -+ }; -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ template -+ CUTLASS_DEVICE -+ PersistentTileSchedulerSm90(ProblemShapeMNKL problem_shape_mnkl, TileShape tile_shape, ClusterShape cluster_shape) { -+ // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic -+ static_assert(is_static::value); -+ static_assert(is_static::value); -+ -+ // Round up to nearest multiple of cluster dim along each mode -+ auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = get_tiled_blk_shape_mnl( -+ problem_shape_mnkl, tile_shape, cluster_shape); -+ -+ blocks_per_problem_ = problem_blocks_m * problem_blocks_n * problem_blocks_l; -+ current_work_linear_idx_ = (int(blockIdx.x) * int(gridDim.y)) + int(blockIdx.y); -+ grid_blocks_total_ = int(gridDim.x) * int(gridDim.y); -+ -+ // Pre-compute our fast div/mods for rasterization so we don't have to pay for DIVs -+ divmod_batch_ = FastDivmod(problem_blocks_m * problem_blocks_n); -+ divmod_grid_y_ = FastDivmod(size<1>(cluster_shape)); -+ divmod_blk_m_ = FastDivmod(problem_blocks_m); -+ } -+ -+ CUTLASS_DEVICE -+ WorkTileInfo -+ get_current_work() const { -+ // Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices -+ int work_idx_l, remainder; -+ divmod_batch_(work_idx_l, remainder, current_work_linear_idx_); -+ -+ int blk_per_grid_dim, dontcare; -+ divmod_grid_y_(blk_per_grid_dim, dontcare, remainder); -+ -+ int block_idx_m, block_idx_n; -+ divmod_blk_m_(block_idx_n, block_idx_m, blk_per_grid_dim); -+ int work_idx_m = block_idx_m; -+ int work_idx_n = (block_idx_n * gridDim.y) + blockIdx.y; -+ -+ return {work_idx_m, work_idx_n, work_idx_l, current_work_linear_idx_ < blocks_per_problem_}; -+ } -+ -+ CUTLASS_DEVICE -+ void -+ advance_to_next_work(uint32_t advance_count = 1) { -+ current_work_linear_idx_ += grid_blocks_total_ * advance_count; -+ } -+ -+ // Given the inputs, computes the total number of output blocks this problem will compute over -+ // Note that this is only the logical size of our grid, not the physical grid we will actually launch. -+ template -+ CUTLASS_HOST_DEVICE constexpr static -+ dim3 -+ get_tiled_blk_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape blk_shape, ClusterShape cluster_shape) { -+ // Across M and N is our Cluster tile, so we must round up the blocks to the nearest whole number of Cluster tiles -+ auto blk_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(blk_shape))); -+ auto blk_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(blk_shape))); -+ -+ // Round up to nearest multiple of cluster dim along each mode -+ int problem_blocks_m = round_up(blk_m, cute::size<0>(cluster_shape)); -+ int problem_blocks_n = round_up(blk_n, cute::size<1>(cluster_shape)); -+ -+ // Cluster tile does not span the batch mode, so no extra rounding up required for it -+ int problem_blocks_l = int(cute::size<3>(problem_shape_mnkl)); -+ return {uint32_t(problem_blocks_m), uint32_t(problem_blocks_n), uint32_t(problem_blocks_l)}; -+ } -+}; -+ -+} // namespace cutlass::gemm::kernel::detail -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h -new file mode 100644 -index 0000000..eba95aa ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h -@@ -0,0 +1,400 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/semaphore.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. -+> -+struct SparseGemm { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ static int const kSparse = Mma::kSparse; -+ static int const kMetaSizeInBits = Mma::kMetaSizeInBits; -+ static int const kMaxID2 = Mma::kMaxID2; -+ static int const kElementsPerElementE = Mma::kElementsPerElementE; -+ -+ using ElementE = typename Mma::ElementE; -+ using LayoutE = typename Mma::LayoutE; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::TensorRef ref_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::TensorRef ref_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D; -+ typename Mma::IteratorE::Params params_E; -+ typename Mma::IteratorE::TensorRef ref_E; -+ typename OutputOp::Params output_op; -+ int *semaphore; -+ int gemm_k_iterations; -+ int gemm_k_size; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D, -+ typename Mma::IteratorE::TensorRef ref_E, -+ typename OutputOp::Params output_op = typename OutputOp::Params(), -+ int *workspace = nullptr -+ ): -+ problem_size(problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(ref_A.layout()), -+ ref_A(ref_A), -+ params_B(ref_B.layout()), -+ ref_B(ref_B), -+ params_C(ref_C.layout()), -+ ref_C(ref_C), -+ params_D(ref_D.layout()), -+ ref_D(ref_D), -+ params_E(ref_E.layout()), -+ ref_E(ref_E), -+ output_op(output_op) { -+ -+ int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ -+ gemm_k_size = gemm_k_iterations * Mma::Shape::kK; -+ -+ semaphore = workspace; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ SparseGemm() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D, -+ typename Mma::IteratorE::TensorRef ref_E) { -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ static int const kAlignmentE = Mma::IteratorE::AccessType::kElements; -+ -+ if (!TensorRef_aligned(ref_A, kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_E, kAlignmentE)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if ((problem_size.m() % kAlignmentA) || ((problem_size.k() / kSparse) % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC) || -+ (problem_size.m() % kAlignmentE) || ((problem_size.k() / kSparse) % kAlignmentE)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ // The k dimension has to be the multiple of the Threadblock k because out -+ // of bound meta data would be initialized to 0 by acync.zfill but 0 is not -+ // a valid meta data. -+ if (problem_size.k() % Mma::Shape::kK) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) -+ // because of the row reordering of operand E -+ static int const kAlignmentM = (sizeof(ElementE) == 2) ? 32 : 16; -+ -+ if (problem_size.m() % kAlignmentM) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size / kSparse, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ cutlass::MatrixCoord tb_offset_E{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size / kSparse, -+ }; -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k = min( -+ params.problem_size.k(), -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A, B, and E operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ params.ref_A.data(), -+ {params.problem_size.m(), problem_size_k / kSparse}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ params.ref_B.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ typename Mma::IteratorE iterator_E( -+ params.params_E, params.ref_E.data(), -+ {params.problem_size.m(), -+ problem_size_k / kSparse / kElementsPerElementE}, -+ thread_idx, tb_offset_E); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ if (!kSplitKSerial || gemm_k_iterations > 0) { -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators); -+ } -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ params.ref_C.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ref_D.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ __threadfence(); -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/symm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/symm_universal.h -new file mode 100755 -index 0000000..47e7035 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/symm_universal.h -@@ -0,0 +1,698 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma1_, ///! Threadblock-scoped triangular matrix multiply-accumulate (A*B or B*A) -+ typename Mma2_, ///! Threadblock-scoped triangular matrix multiply-accumulate (AT*B or B*AT) -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ SideMode SideMode_, ///! Side Mode for the kernel (kLeft or kRight) -+ FillMode FillMode_ ///! Fill Mode for triangular matrix (kLower or kUpper) -+> -+struct SymmUniversal { -+public: -+ -+ using Mma1 = Mma1_; -+ using Mma2 = Mma2_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma1::IteratorA::Element; -+ using ElementB = typename Mma1::IteratorB::Element; -+ -+ // Mma1 (TRMM - with diagonal: C_tmp = alpha * A * B) -+ using LayoutA = typename Mma1::IteratorA::Layout; -+ using LayoutBT = typename Mma1::IteratorB::Layout; -+ static ComplexTransform const kMma1TransformA = Mma1::kTransformA; -+ static ComplexTransform const kMma1TransformB = Mma1::kTransformB; -+ -+ // Mma2 (TRMM - withOUT diagonal: alpha * AT * B) -+ using LayoutB = typename Mma2::IteratorA::Layout; -+ using LayoutAT = typename Mma2::IteratorB::Layout; -+ static ComplexTransform const kMma2TransformA = Mma2::kTransformA; -+ static ComplexTransform const kMma2TransformB = Mma2::kTransformB; -+ -+ // Common type definitions for Mma1 and Mma2 -+ using Operator = typename Mma1::Operator; -+ using OperatorClass = typename Mma1::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma1::Shape; -+ using WarpShape = typename Mma1::Operator::Shape; -+ using InstructionShape = typename Mma1::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma1::ArchTag; -+ -+ static int const kStages = Mma1::kStages; -+ static int const kAlignmentA = Mma1::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma1::IteratorB::AccessType::kElements; -+ -+ // Output related typedefinitions -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ static SideMode const kSideModeA = SideMode_; -+ static FillMode const kFillModeA = FillMode_; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma1::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldc; -+ typename LayoutC::Stride::Index ldd; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1), -+ ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutB::Stride::Index ldb, -+ typename LayoutC::Stride::Index ldc, -+ typename LayoutC::Stride::Index ldd -+ ): -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_count), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { -+ -+ } -+ -+ /// Returns arguments for the transposed problem sizes -+ Arguments transposed_problem_size() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ -+ return args; -+ } -+ -+ /// Returns arguments for the transposed matrices -+ Arguments swapped_matrices() const { -+ Arguments args(*this); -+ -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ // Mma1 Iterator A and B params -+ typename Mma1::IteratorA::Params params_A_mma1; -+ typename Mma1::IteratorB::Params params_B_mma1; -+ -+ // Mma2 Iterator A and B params -+ typename Mma2::IteratorA::Params params_A_mma2; -+ typename Mma2::IteratorB::Params params_B_mma2; -+ -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ int *semaphore; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ params_A_mma1(0), -+ params_B_mma1(0), -+ params_A_mma2(0), -+ params_B_mma2(0), -+ params_C(0), -+ params_D(0), -+ batch_count(0), -+ gemm_k_size(0), -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ batch_stride_A(0), -+ batch_stride_B(0), -+ batch_stride_C(0), -+ batch_stride_D(0), -+ semaphore(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ int gemm_k_size, -+ void *workspace = nullptr -+ ): -+ problem_size(args.problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A_mma1(args.lda), -+ params_B_mma1(args.ldb), -+ params_A_mma2(args.lda), -+ params_B_mma2(args.ldb), -+ params_C(args.ldc), -+ params_D(args.ldd), -+ output_op(args.epilogue), -+ mode(args.mode), -+ batch_count(args.batch_count), -+ gemm_k_size(gemm_k_size), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(const_cast(args.ptr_D)), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_D(args.batch_stride_D), -+ semaphore(static_cast(workspace)) { -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr) { -+ -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ output_op = args.epilogue; -+ -+ semaphore = static_cast(workspace); -+ } -+ -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma1::SharedStorage mma1_main_loop; -+ typename Mma2::SharedStorage mma2_main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ SymmUniversal() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ static int const kAlignmentA = Mma1::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma1::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes two GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_MxK_mma1{ -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_KxN_mma1{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ }; -+ -+ cutlass::MatrixCoord tb_offset_MxK_mma2{ -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_KxN_mma2{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply for Mma1 -+ Mma1 mma1(shared_storage.mma1_main_loop, thread_idx, warp_idx, lane_idx); -+ -+ // Construct thread-scoped matrix multiply for Mma2 -+ Mma2 mma2(shared_storage.mma2_main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma1::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ int gemm_k_iterations_mma1 = gemm_k_iterations; -+ int gemm_k_iterations_mma2 = gemm_k_iterations; -+ -+ -+ /****************************************************************************************************** -+ * SYMM (Side Mode, Fill Mode) is made of two TRMMs: -+ First TRMM (Mma1: Side Mode, Fill Mode, Non-Unit Diag): (A * B) or (B * A) -+ Second TRMM (Mma2: Side Mode, Inverted Fill Mode, Unit Diag): (AT * B) or (B * AT) -+ -+ * For the first TRMM (Mma1) of SYMM, the following method is used to calculate the k-iterations: -+ First two cases: (Left Side, Lower Fill) and (Right Side, Upper Fill) are transpose of each other -+ - (Left Side, Lower Fill): calculate bottom of the CTA tile, then find the k-iterations -+ needed to process all elements till that coordinate. -+ - (Right Side, Upper Fill): calculate right end of the CTA tile, then find the k-iterations -+ needed to process all elements till that coordinate. -+ -+ Last two cases: (Left Side, Upper Fill) and (Right Side, Lower Fill) are transpose of each other -+ - (Left Side, Upper Fill): calculate the top of the CTA tile, then find k-iterations -+ that can be skipped for all elements of this tile. -+ - (Right Side, Lower Fill): calculate the left start of the CTA tile, then find k-iterations -+ that can be skipped for all elements of this tile. -+ -+ * For the second TRMM (Mma2) of SYMM, the k-iterations and threadblock offsets are calculated -+ the same way as the first TRMM (Mma1) of same side mode but with inverted fill mode. -+ For example, if the first TRMM is left sided with lower fill, the second TRMM would be -+ left sided with upper fill. -+ ********************************************************************************************************/ -+ -+ if (kSideModeA == SideMode::kLeft && kFillModeA == FillMode::kLower) { -+ -+ int k_iterations_till_diagonal_mma1 = ((threadblock_tile_offset.m() + 1) * Mma1::Shape::kM + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma1 < gemm_k_iterations) { -+ gemm_k_iterations_mma1 = k_iterations_till_diagonal_mma1; -+ } -+ -+ int k_iterations_till_diagonal_mma2 = ((threadblock_tile_offset.m()) * Mma1::Shape::kM) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma2 != 0) { -+ tb_offset_MxK_mma2 += cutlass::MatrixCoord({0, k_iterations_till_diagonal_mma2 * Mma1::Shape::kK}); -+ tb_offset_KxN_mma2 += cutlass::MatrixCoord({k_iterations_till_diagonal_mma2 * Mma1::Shape::kK, 0}); -+ gemm_k_iterations_mma2 -= k_iterations_till_diagonal_mma2; -+ } -+ -+ } else if (kSideModeA == SideMode::kRight && kFillModeA == FillMode::kUpper) { -+ -+ int k_iterations_till_diagonal_mma1 = ((threadblock_tile_offset.n() + 1) * Mma1::Shape::kN + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma1 < gemm_k_iterations) { -+ gemm_k_iterations_mma1 = k_iterations_till_diagonal_mma1; -+ } -+ -+ int k_iterations_till_diagonal_mma2 = ((threadblock_tile_offset.n()) * Mma1::Shape::kN) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma2 != 0) { -+ tb_offset_MxK_mma2 += cutlass::MatrixCoord({0, k_iterations_till_diagonal_mma2 * Mma1::Shape::kK}); -+ tb_offset_KxN_mma2 += cutlass::MatrixCoord({k_iterations_till_diagonal_mma2 * Mma1::Shape::kK, 0}); -+ gemm_k_iterations_mma2 -= k_iterations_till_diagonal_mma2; -+ } -+ -+ } else if (kSideModeA == SideMode::kLeft && kFillModeA == FillMode::kUpper) { -+ -+ int k_iterations_till_diagonal_mma1 = ((threadblock_tile_offset.m()) * Mma1::Shape::kM) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma1 != 0) { -+ tb_offset_MxK_mma1 += cutlass::MatrixCoord({0, k_iterations_till_diagonal_mma1 * Mma1::Shape::kK}); -+ tb_offset_KxN_mma1 += cutlass::MatrixCoord({k_iterations_till_diagonal_mma1 * Mma1::Shape::kK, 0}); -+ gemm_k_iterations_mma1 -= k_iterations_till_diagonal_mma1; -+ } -+ -+ int k_iterations_till_diagonal_mma2 = ((threadblock_tile_offset.m() + 1) * Mma1::Shape::kM + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma2 < gemm_k_iterations) { -+ gemm_k_iterations_mma2 = k_iterations_till_diagonal_mma2; -+ } -+ -+ } else if (kSideModeA == SideMode::kRight && kFillModeA == FillMode::kLower) { -+ -+ int k_iterations_till_diagonal_mma1 = ((threadblock_tile_offset.n()) * Mma1::Shape::kN) / Mma1::Shape::kK; -+ -+ if (k_iterations_till_diagonal_mma1 != 0) { -+ tb_offset_MxK_mma1 += cutlass::MatrixCoord({0, k_iterations_till_diagonal_mma1 * Mma1::Shape::kK}); -+ tb_offset_KxN_mma1 += cutlass::MatrixCoord({k_iterations_till_diagonal_mma1 * Mma1::Shape::kK, 0}); -+ gemm_k_iterations_mma1 -= k_iterations_till_diagonal_mma1; -+ } -+ -+ int k_iterations_till_diagonal_mma2 = ((threadblock_tile_offset.n() + 1) * Mma1::Shape::kN + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma2 < gemm_k_iterations) { -+ gemm_k_iterations_mma2 = k_iterations_till_diagonal_mma2; -+ } -+ -+ } -+ -+ // Construct iterators to A and B operands for Mma1 -+ typename Mma1::IteratorA iterator_A_mma1( -+ params.params_A_mma1, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_MxK_mma1); -+ -+ typename Mma1::IteratorB iterator_B_mma1( -+ params.params_B_mma1, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_KxN_mma1); -+ -+ // Construct iterators to A and B operands for Mma2 -+ typename Mma2::IteratorA iterator_A_mma2( -+ params.params_A_mma2, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_MxK_mma2); -+ -+ typename Mma2::IteratorB iterator_B_mma2( -+ params.params_B_mma2, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_KxN_mma2); -+ -+ // Compute threadblock-scoped matrix multiply-add (A x B) or (B x A) -+ mma1( -+ gemm_k_iterations_mma1, -+ accumulators, -+ iterator_A_mma1, -+ iterator_B_mma1, -+ accumulators); -+ -+ // Compute threadblock-scoped matrix multiply-add (AT x B) or (B x AT) -+ mma2( -+ gemm_k_iterations_mma2, -+ accumulators, -+ iterator_A_mma2, -+ iterator_B_mma2, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/trmm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/trmm_universal.h -new file mode 100644 -index 0000000..7ba223b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/trmm_universal.h -@@ -0,0 +1,599 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/core_io.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ SideMode SideMode_, ///! Side Mode for the kernel (kLeft or kRight) -+ FillMode FillMode_, ///! Fill Mode for triangular matrix (kLower or kUpper) -+ DiagType DiagType_ ///! Diag Type for triangular matrix (kNonUnit or kUnit) -+> -+struct TrmmUniversal { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ static SideMode const kSideMode = SideMode_; -+ static FillMode const kFillMode = FillMode_; -+ static DiagType const kDiagType = DiagType_; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_D; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldd; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1), -+ ptr_A(nullptr), ptr_B(nullptr), ptr_D(nullptr) { } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutB::Stride::Index ldb, -+ typename LayoutC::Stride::Index ldd -+ ): -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_count), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_D(batch_stride_D), -+ lda(lda), ldb(ldb), ldd(ldd) { -+ } -+ -+ /// Returns arguments for the transposed problem sizes -+ Arguments transposed_problem_size() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ -+ return args; -+ } -+ -+ /// Returns arguments for the transposed matrices -+ Arguments swapped_matrices() const { -+ Arguments args(*this); -+ -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_D; -+ -+ int *semaphore; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ params_A(0), -+ params_B(0), -+ params_D(0), -+ batch_count(0), -+ gemm_k_size(0), -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_D(nullptr), -+ batch_stride_A(0), -+ batch_stride_B(0), -+ batch_stride_D(0), -+ semaphore(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ int gemm_k_size, -+ void *workspace = nullptr -+ ): -+ problem_size(args.problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(args.lda), -+ params_B(args.ldb), -+ params_D(args.ldd), -+ output_op(args.epilogue), -+ mode(args.mode), -+ batch_count(args.batch_count), -+ gemm_k_size(gemm_k_size), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_D(args.ptr_D), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_D(args.batch_stride_D), -+ semaphore(static_cast(workspace)) { -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr) { -+ -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_D = args.ptr_D; -+ -+ batch_stride_A = args.batch_stride_A; -+ batch_stride_B = args.batch_stride_B; -+ batch_stride_D = args.batch_stride_D; -+ -+ output_op = args.epilogue; -+ -+ semaphore = static_cast(workspace); -+ } -+ -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ TrmmUniversal() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ /****************************************************************************************************** -+ First two cases: (Left Side, Lower Fill) and (Right Side, Upper Fill) are transpose of each other -+ - (Left Side, Lower Fill): calculate bottom of the CTA tile, then find the k-iterations -+ needed to process all elements till that coordinate. -+ - (Right Side, Upper Fill): calculate right end of the CTA tile, then find the k-iterations -+ needed to process all elements till that coordinate. -+ -+ Last two cases: (Left Side, Upper Fill) and (Right Side, Lower Fill) are transpose of each other -+ - (Left Side, Upper Fill): calculate the top of the CTA tile, then find k-iterations -+ that can be skipped for all elements of this tile. -+ - (Right Side, Lower Fill): calculate the left start of the CTA tile, then find k-iterations -+ that can be skipped for all elements of this tile. -+ ********************************************************************************************************/ -+ -+ if (kSideMode == SideMode::kLeft && kFillMode == FillMode::kLower) { -+ -+ int k_iterations_till_diagonal = ((threadblock_tile_offset.m() + 1) * Mma::Shape::kM + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ if (k_iterations_till_diagonal < gemm_k_iterations) { -+ gemm_k_iterations = k_iterations_till_diagonal; -+ } -+ -+ } else if (kSideMode == SideMode::kRight && kFillMode == FillMode::kUpper) { -+ -+ int k_iterations_till_diagonal = ((threadblock_tile_offset.n() + 1) * Mma::Shape::kN + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ if (k_iterations_till_diagonal < gemm_k_iterations) { -+ gemm_k_iterations = k_iterations_till_diagonal; -+ } -+ -+ } else if (kSideMode == SideMode::kLeft && kFillMode == FillMode::kUpper) { -+ -+ int k_iterations_till_diagonal = ((threadblock_tile_offset.m()) * Mma::Shape::kM) / Mma::Shape::kK; -+ -+ if (k_iterations_till_diagonal != 0) { -+ tb_offset_A += cutlass::MatrixCoord({0, k_iterations_till_diagonal * Mma::Shape::kK}); -+ tb_offset_B += cutlass::MatrixCoord({k_iterations_till_diagonal * Mma::Shape::kK, 0}); -+ gemm_k_iterations -= k_iterations_till_diagonal; -+ } -+ -+ } else if (kSideMode == SideMode::kRight && kFillMode == FillMode::kLower) { -+ -+ int k_iterations_till_diagonal = ((threadblock_tile_offset.n()) * Mma::Shape::kN) / Mma::Shape::kK; -+ -+ if (k_iterations_till_diagonal != 0) { -+ tb_offset_A += cutlass::MatrixCoord({0, k_iterations_till_diagonal * Mma::Shape::kK}); -+ tb_offset_B += cutlass::MatrixCoord({k_iterations_till_diagonal * Mma::Shape::kK, 0}); -+ gemm_k_iterations -= k_iterations_till_diagonal; -+ } -+ -+ } -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ -+ // Tile iterator loading from source tensor (although irrelevant to this kernel as beta is zero). -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/thread/mma.h b/3rdparty/cutlass/include/cutlass/gemm/thread/mma.h -new file mode 100644 -index 0000000..d1f9b69 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/thread/mma.h -@@ -0,0 +1,90 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for warp-level multiply-add operations -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/arch/mma.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Concept: arch::OpMultiplyAdd or arch::Mma<> -+ typename Operator = arch::OpMultiplyAdd, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+struct Mma; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Overloads specialized for existing architectures -+// -+ -+#include "cutlass/gemm/thread/mma_sm50.h" -+#include "cutlass/gemm/thread/mma_sm60.h" -+#include "cutlass/gemm/thread/mma_sm61.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm50.h b/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm50.h -new file mode 100644 -index 0000000..1573e64 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm50.h -@@ -0,0 +1,539 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for multiply-add operations -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/thread/mma.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gemplate that handles all packed matrix layouts -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: layout::MapFunc) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: layout::MapFunc) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: layout::MapFunc) -+ typename LayoutC_, -+ /// Operator used to compute GEMM -+ typename Operator_ -+> -+struct MmaGeneric { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = ElementA_; -+ -+ /// Layout of A matrix (concept: layout::MapFunc) -+ using LayoutA = LayoutA_; -+ -+ /// Data type of operand B -+ using ElementB = ElementB_; -+ -+ /// Layout of B matrix (concept: layout::MapFunc) -+ using LayoutB = LayoutB_; -+ -+ /// Element type of operand C -+ using ElementC = ElementC_; -+ -+ /// Layout of C matrix (concept: layout::MapFunc) -+ using LayoutC = LayoutC_; -+ -+ /// Underlying mathematical operator -+ using Operator = Operator_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Instruction -+ using MmaOp = arch::Mma< -+ gemm::GemmShape<1,1,1>, -+ 1, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ Operator>; -+ -+ static bool const kMultipleOf2 = ((Shape::kM % 2 == 0) && (Shape::kN % 2 == 0)); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ TensorRef a_ref( -+ reinterpret_cast(&A), LayoutA::packed({Shape::kM, Shape::kK})); -+ -+ TensorRef b_ref( -+ reinterpret_cast(&B), LayoutB::packed({Shape::kK, Shape::kN})); -+ -+ TensorRef d_ref( -+ reinterpret_cast(&D), LayoutC::packed(make_Coord(Shape::kM, Shape::kN))); -+ -+ MmaOp mma_op; -+ -+ // Copy accumulators -+ D = C; -+ -+ // Compute matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Shape::kK; ++k) { -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 860) -+ if (kMultipleOf2 && -+ platform::is_same::value && -+ platform::is_same::value && -+ platform::is_same::value) { -+ -+ //2x2 zigzag - m and n loops to increment by 2. Inner loop to process 4 multiply-adds in a 2x2 tile. -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN; n+=2) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM; m+=2) { -+ -+ int m_serpentine = (n % 4) ? (Shape::kM - 2 - m) : m; -+ -+ //top-left element in 2x2 tile -+ { -+ MatrixCoord mn(m_serpentine, n); -+ MatrixCoord mk(m_serpentine, k); -+ MatrixCoord kn(k, n); -+ Array d; -+ Array a; -+ Array b; -+ d[0] = d_ref.at(mn); -+ a[0] = a_ref.at(mk); -+ b[0] = b_ref.at(kn); -+ mma_op(d, a, b, d); -+ d_ref.at(mn) = d[0]; -+ } -+ -+ //bottom-left element in 2x2 tile -+ { -+ MatrixCoord mn(m_serpentine+1, n); -+ MatrixCoord mk(m_serpentine+1, k); -+ MatrixCoord kn(k, n); -+ Array d; -+ Array a; -+ Array b; -+ d[0] = d_ref.at(mn); -+ a[0] = a_ref.at(mk); -+ b[0] = b_ref.at(kn); -+ mma_op(d, a, b, d); -+ d_ref.at(mn) = d[0]; -+ } -+ -+ //bottom-right element in 2x2 tile -+ { -+ MatrixCoord mn(m_serpentine+1, n+1); -+ MatrixCoord mk(m_serpentine+1, k); -+ MatrixCoord kn(k, n+1); -+ Array d; -+ Array a; -+ Array b; -+ d[0] = d_ref.at(mn); -+ a[0] = a_ref.at(mk); -+ b[0] = b_ref.at(kn); -+ mma_op(d, a, b, d); -+ d_ref.at(mn) = d[0]; -+ } -+ -+ //top-right element in 2x2 tile -+ { -+ MatrixCoord mn(m_serpentine, n+1); -+ MatrixCoord mk(m_serpentine, k); -+ MatrixCoord kn(k, n+1); -+ Array d; -+ Array a; -+ Array b; -+ d[0] = d_ref.at(mn); -+ a[0] = a_ref.at(mk); -+ b[0] = b_ref.at(kn); -+ mma_op(d, a, b, d); -+ d_ref.at(mn) = d[0]; -+ } -+ } -+ } -+ } else -+ #endif -+ { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN; ++n) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM; ++m) { -+ -+ int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m; -+ -+ MatrixCoord mn(m_serpentine, n); -+ MatrixCoord mk(m_serpentine, k); -+ MatrixCoord kn(k, n); -+ -+ Array d; -+ Array a; -+ Array b; -+ -+ d[0] = d_ref.at(mn); -+ a[0] = a_ref.at(mk); -+ b[0] = b_ref.at(kn); -+ -+ mma_op(d, a, b, d); -+ -+ d_ref.at(mn) = d[0]; -+ } -+ } -+ } -+ } -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Matrix multiply-add operation - assumes operand B is not changing -+struct MmaComplexF32_Column { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array, 1> const &a, -+ Array, 1> const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0].real() * b[0].real() + c[0].real(); -+ d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); -+ d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); -+ d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); -+ } -+}; -+ -+/// Matrix multiply-add operation - assumes operand A is not changing -+struct MmaComplexF32_Corner { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array, 1> const &a, -+ Array, 1> const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); -+ d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); -+ d[0].real() = a[0].real() * b[0].real() + c[0].real(); -+ d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); -+ } -+}; -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gemplate that handles all packed matrix layouts -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of A matrix (concept: layout::MapFunc) -+ typename LayoutA_, -+ /// Layout of B matrix (concept: layout::MapFunc) -+ typename LayoutB_, -+ /// Layout of C matrix (concept: layout::MapFunc) -+ typename LayoutC_ -+> -+struct MmaGeneric< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ arch::OpMultiplyAdd> { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = complex; -+ -+ /// Layout of A matrix (concept: layout::MapFunc) -+ using LayoutA = LayoutA_; -+ -+ /// Data type of operand B -+ using ElementB = complex; -+ -+ /// Layout of B matrix (concept: layout::MapFunc) -+ using LayoutB = LayoutB_; -+ -+ /// Element type of operand C -+ using ElementC = complex; -+ -+ /// Layout of C matrix (concept: layout::MapFunc) -+ using LayoutC = LayoutC_; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Instruction -+ using MmaOp = arch::Mma< -+ gemm::GemmShape<1,1,1>, -+ 1, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ Operator>; -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ TensorRef a_ref( -+ reinterpret_cast(&A), LayoutA::packed({Shape::kM, Shape::kK})); -+ -+ TensorRef b_ref( -+ reinterpret_cast(&B), LayoutB::packed({Shape::kK, Shape::kN})); -+ -+ TensorRef d_ref( -+ reinterpret_cast(&D), LayoutC::packed(make_Coord(Shape::kM, Shape::kN))); -+ -+ detail::MmaComplexF32_Column mma_column; -+ detail::MmaComplexF32_Corner mma_corner; -+ -+ // Copy accumulators -+ D = C; -+ -+ // Compute matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Shape::kK; ++k) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN; ++n) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM; ++m) { -+ -+ int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m; -+ -+ MatrixCoord mn(m_serpentine, n); -+ MatrixCoord mk(m_serpentine, k); -+ MatrixCoord kn(k, n); -+ -+ Array d; -+ Array a; -+ Array b; -+ -+ d[0] = d_ref.at(mn); -+ a[0] = a_ref.at(mk); -+ b[0] = b_ref.at(kn); -+ -+ if ((m == 0 && n) || m == Shape::kM - 1) { -+ mma_corner(d, a, b, d); -+ } -+ else { -+ mma_column(d, a, b, d); -+ } -+ -+ d_ref.at(mn) = d[0]; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gemplate that handles conventional layouts for FFMA and DFMA GEMM -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: layout::MapFunc) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: layout::MapFunc) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: layout::MapFunc) -+ typename LayoutC_ -+> -+struct Mma< -+ Shape_, -+ ElementA_, -+ LayoutA_, -+ ElementB_, -+ LayoutB_, -+ ElementC_, -+ LayoutC_, -+ arch::OpMultiplyAdd, -+ bool> { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = ElementA_; -+ -+ /// Layout of A matrix (concept: layout::MapFunc) -+ using LayoutA = LayoutA_; -+ -+ /// Data type of operand B -+ using ElementB = ElementB_; -+ -+ /// Layout of B matrix (concept: layout::MapFunc) -+ using LayoutB = LayoutB_; -+ -+ /// Element type of operand C -+ using ElementC = ElementC_; -+ -+ /// Layout of C matrix (concept: layout::MapFunc) -+ using LayoutC = LayoutC_; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename MmaGeneric< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator>::MmaOp; -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ MmaGeneric< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator> mma; -+ -+ mma(D, A, B, C); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm60.h b/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm60.h -new file mode 100644 -index 0000000..e4bcb70 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm60.h -@@ -0,0 +1,1178 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for multiply-add operations -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/thread/mma.h" -+#include "cutlass/functional.h" -+#include "cutlass/reduction/thread/reduce.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Structure to compute the matrix product for HFMA -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape, -+ -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ -+ /// Type of GEMM inner vs outer product -+ bool -+> -+struct Mma_HFMA2; -+ -+ -+///////////////////////////// -+// Specialization for NNN // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2 < -+ Shape_, -+ layout::ColumnMajor, -+ layout::ColumnMajor, -+ layout::ColumnMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kM % 2), -+ "Mma_HFMA2 requires the M dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x1x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<2,1,1>, -+ 1, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::ColumnMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m]; -+ -+ mma( -+ tmp, -+ ptr_A[k*Shape::kM/2 + m], -+ ptr_B[n*Shape::kK + k], -+ tmp); -+ -+ ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////// -+// Specialization for NNT // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2< -+ Shape_, -+ layout::ColumnMajor, -+ layout::ColumnMajor, -+ layout::RowMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kN % 2), -+ "Mma_HFMA2 requires the N dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x2x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<1,2,1>, -+ 1, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; -+ -+ Array tmp_B; -+ tmp_B[0] = ptr_B->at(2*n*Shape::kK + k); -+ tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k); -+ -+ mma( -+ tmp, -+ ptr_A[k*Shape::kM + m], -+ tmp_B, -+ tmp); -+ -+ ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+ -+///////////////////////////// -+// Specialization for NTN // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2 < -+ Shape_, -+ layout::ColumnMajor, -+ layout::RowMajor, -+ layout::ColumnMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kM % 2), -+ "Mma_HFMA2 requires the GEMM M dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ using Mma = arch::Mma< -+ gemm::GemmShape<2,1,1>, -+ 1, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM / Mma::Shape::kM; ++m) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN / Mma::Shape::kN; ++n) { -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ -+ ptr_tmp[0] = ptr_D[m + n * Shape::kM/2]; -+ -+ mma( -+ tmp, -+ ptr_A[m + k * Shape::kM/2], -+ ptr_B[k * Shape::kN + n], -+ tmp); -+ -+ ptr_D[m + n * Shape::kM/2] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////// -+// Specialization for NTT // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2< -+ Shape_, -+ layout::ColumnMajor, -+ layout::RowMajor, -+ layout::RowMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kN % 2), -+ "Mma_HFMA2 requires the N dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x2x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<1,2,1>, -+ 1, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; -+ -+ mma( -+ tmp, -+ ptr_A[k*Shape::kM + m], -+ ptr_B[k*Shape::kN/2 + n], -+ tmp); -+ -+ ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+ -+///////////////////////////// -+// Specialization for TNN // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2 < -+ Shape_, -+ layout::RowMajor, -+ layout::ColumnMajor, -+ layout::ColumnMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kM % 2), -+ "Mma_HFMA2 requires the M dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x1x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<2,1,1>, -+ 1, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::ColumnMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m]; -+ -+ Array tmp_A; -+ tmp_A[0] = ptr_A->at(2*m*Shape::kK + k); -+ tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k); -+ -+ mma( -+ tmp, -+ tmp_A, -+ ptr_B[n*Shape::kK + k], -+ tmp); -+ -+ ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////// -+// Specialization for TNT // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2 < -+ Shape_, -+ layout::RowMajor, -+ layout::ColumnMajor, -+ layout::RowMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kN % 2), -+ "Mma_HFMA2 requires the N dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x2x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<1,2,1>, -+ 1, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; -+ -+ Array tmp_B; -+ tmp_B[0] = ptr_B->at(2*n*Shape::kK + k); -+ tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k); -+ -+ mma( -+ tmp, -+ ptr_A[m*Shape::kK + k], -+ tmp_B, -+ tmp); -+ -+ ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////// -+// Specialization for TTN // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2 < -+ Shape_, -+ layout::RowMajor, -+ layout::RowMajor, -+ layout::ColumnMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kM % 2), -+ "Mma_HFMA2 requires the M dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x2x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<2,1,1>, -+ 1, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m]; -+ -+ Array tmp_A; -+ tmp_A[0] = ptr_A->at(2*m*Shape::kK + k); -+ tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k); -+ -+ mma( -+ tmp, -+ tmp_A, -+ ptr_B[k*Shape::kN + n], -+ tmp); -+ -+ ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+ -+///////////////////////////// -+// Specialization for TTT // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2< -+ Shape_, -+ layout::RowMajor, -+ layout::RowMajor, -+ layout::RowMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kN % 2), -+ "Mma_HFMA2 requires the N dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x2x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<1,2,1>, -+ 1, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; -+ -+ mma( -+ tmp, -+ ptr_A[m*Shape::kK + k], -+ ptr_B[k*Shape::kN/2 + n], -+ tmp); -+ -+ ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////// -+// Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T // -+///////////////////////////////////////////////////////////////////// -+ -+template -+struct Mma_HFMA2< -+ Shape_, -+ LayoutA, -+ LayoutB, -+ layout::RowMajor, -+ false -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kK % 2), -+ "Mma_HFMA2 requires the K dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x1x2 HFMA2 sequence for bulk of computation -+ using GemmShape = gemm::GemmShape<1,1,2>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ // Inner product is calculated using MACs, followed by final reduction -+ multiply_add> mac; -+ cutlass::reduction::thread::Reduce< plus, Array > reduce; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / GemmShape::kN; n++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / GemmShape::kM; m++){ -+ -+ Array tmp_C; -+ tmp_C.clear(); -+ Array *ptr_tmp_C = reinterpret_cast *>(&tmp_C); -+ ptr_tmp_C[0] = ptr_D[n*Shape::kM + m]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / GemmShape::kK; k++){ -+ tmp_C = mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C); -+ } -+ -+ Array res; -+ Array *ptr_res = &res; -+ res = reduce(tmp_C); -+ -+ ptr_D[m*Shape::kN + n] = ptr_res[0]; -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////// -+// Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N // -+///////////////////////////////////////////////////////////////////// -+ -+template -+struct Mma_HFMA2< -+ Shape_, -+ LayoutA, -+ LayoutB, -+ layout::ColumnMajor, -+ false -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kK % 2), -+ "Mma_HFMA2 requires the K dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x1x2 HFMA2 sequence for bulk of computation -+ using GemmShape= gemm::GemmShape<1,1,2>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ // Inner product is calculated using MACs, followed by final reduction -+ multiply_add> mac; -+ cutlass::reduction::thread::Reduce< plus, Array > reduce; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / GemmShape::kN; n++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / GemmShape::kM; m++){ -+ -+ Array tmp_C; -+ tmp_C.clear(); -+ Array *ptr_tmp_C = reinterpret_cast *>(&tmp_C); -+ ptr_tmp_C[0] = ptr_D[n*Shape::kM + m]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / GemmShape::kK; k++){ -+ -+ tmp_C = mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C); -+ -+ } -+ -+ Array res; -+ Array *ptr_res = &res; -+ res = reduce(tmp_C); -+ -+ ptr_D[n*Shape::kM + m] = ptr_res[0]; -+ } -+ } -+ } -+}; -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, typename LayoutA, typename LayoutB, typename LayoutC -+> -+struct Mma< -+ Shape_, -+ half_t, -+ LayoutA, -+ half_t, -+ LayoutB, -+ half_t, -+ LayoutC, -+ arch::OpMultiplyAdd -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = half_t; -+ -+ /// Data type of operand B -+ using ElementB = half_t; -+ -+ /// Element type of operand C -+ using ElementC = half_t; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ static bool const a_row_major = platform::is_same< LayoutA, layout::RowMajor>::value; -+ static bool const b_column_major = platform::is_same< LayoutB, layout::ColumnMajor>::value; -+ static bool const c_row_major = platform::is_same< LayoutC, layout::RowMajor>::value; -+ static bool const c_column_major = platform::is_same< LayoutC, layout::ColumnMajor>::value; -+ -+ static bool const m_mod2 = !(Shape::kM % 2); -+ static bool const n_mod2 = !(Shape::kN % 2); -+ static bool const k_mod2 = !(Shape::kK % 2); -+ -+ // HFMA based MMA optimizations are of 2 types : -+ // 1. Inner product -+ // 2. Outer product -+ // It is chosen based on LayoutC (for outer product gemm) or -+ // Using LayoutA and LayoutB or shape=1x1x2K (for inner product gemms) -+ // If all fails, we choose the generic MMA -+ static bool const use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2); -+ static bool const use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2); -+ static bool const use_optimized = (use_outer_prod || use_inner_prod); -+ -+ using ArchMmaOperator = typename platform::conditional< use_optimized, -+ detail::Mma_HFMA2, -+ MmaGeneric -+ >::type; -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ ArchMmaOperator mma; -+ -+ mma(D, A, B, C); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+ /// Determines whether to enable thread::Gemm<> specializations compatible with SM50 -+ template < -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB> -+ struct EnableMma_Crow_SM60 { -+ -+ static bool const kIsConventionalLayout = -+ (platform::is_same::value || -+ platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value); -+ -+ static bool const value = kIsConventionalLayout; -+ }; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes matrix product when C is row-major -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ typename LayoutA_, -+ typename LayoutB_ -+> -+struct Mma< -+ Shape_, -+ half_t, -+ LayoutA_, -+ half_t, -+ LayoutB_, -+ half_t, -+ layout::RowMajor, -+ arch::OpMultiplyAdd, -+ typename platform::enable_if::value>::type>{ -+ -+ using Shape = Shape_; -+ using ElementA = half_t; -+ using LayoutA = LayoutA_; -+ using ElementB = half_t; -+ using LayoutB = LayoutB_; -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using Operator = arch::OpMultiplyAdd; -+ -+ using TransposeMma = Mma< -+ GemmShapeTranspose, -+ half_t, -+ typename layout::LayoutTranspose::type, -+ half_t, -+ typename layout::LayoutTranspose::type, -+ half_t, -+ layout::ColumnMajor, -+ arch::OpMultiplyAdd, -+ bool>; -+ -+ using FragmentA = Array; -+ using FragmentB = Array; -+ using FragmentC = Array; -+ -+ using ArchMmaOperator = typename TransposeMma::ArchMmaOperator; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ TransposeMma mma; -+ -+ mma(D, B, A, C); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm61.h b/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm61.h -new file mode 100644 -index 0000000..7ef1efb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm61.h -@@ -0,0 +1,284 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for multiply-add operations -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/thread/mma.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gemplate that handles conventional layouts for IDP4A -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_ -+> -+struct Mma< -+ Shape_, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int32_t, -+ LayoutC_, -+ arch::OpMultiplyAdd, -+ bool> { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = int8_t; -+ -+ /// Layout of A matrix (concept: layout::MapFunc) -+ using LayoutA = layout::RowMajor; -+ -+ /// Data type of operand B -+ using ElementB = int8_t; -+ -+ /// Layout of B matrix (concept: layout::MapFunc) -+ using LayoutB = layout::ColumnMajor; -+ -+ /// Element type of operand C -+ using ElementC = int32_t; -+ -+ /// Layout of C matrix (concept: layout::MapFunc) -+ using LayoutC = LayoutC_; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ // Use 1x1x4 IDP4A sequence for bulk of computation -+ using ArchMmaOperator = arch::Mma< -+ gemm::GemmShape<1,1,4>, -+ 1, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ arch::OpMultiplyAdd>; -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ TensorRef d( -+ reinterpret_cast(&D), LayoutC::packed({ Shape::kM, Shape::kN })); -+ -+ // Copy accumulators -+ D = C; -+ -+ /// Use 1x1x4 IDP4A sequence for bulk of computation -+ ArchMmaOperator mma; -+ -+ // Compute matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Shape::kK / ArchMmaOperator::Shape::kK; ++k) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN; ++n) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM; ++m) { -+ MatrixCoord mn(m, n); -+ -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Array tmp = reinterpret_cast &>(d.at(mn)); -+ -+ mma( -+ tmp, -+ ptr_A[m * Shape::kK / ArchMmaOperator::Shape::kK + k], -+ ptr_B[n * Shape::kK / ArchMmaOperator::Shape::kK + k], -+ tmp); -+ -+ d.at(mn) = reinterpret_cast(tmp); -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Gemplate that handles conventional layouts for IDP4A -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_ -+> -+struct Mma< -+ Shape_, -+ int8_t, -+ layout::ColumnMajor, -+ int8_t, -+ layout::RowMajor, -+ int32_t, -+ LayoutC_, -+ arch::OpMultiplyAdd, -+ int8_t> { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = int8_t; -+ -+ /// Layout of A matrix (concept: layout::MapFunc) -+ using LayoutA = layout::ColumnMajor; -+ -+ /// Data type of operand B -+ using ElementB = int8_t; -+ -+ /// Layout of B matrix (concept: layout::MapFunc) -+ using LayoutB = layout::RowMajor; -+ -+ /// Element type of operand C -+ using ElementC = int32_t; -+ -+ /// Layout of C matrix (concept: layout::MapFunc) -+ using LayoutC = LayoutC_; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ /// Use 1x1x4 IDP4A sequence for bulk of computation -+ using ArchMmaOperator = arch::Mma< -+ gemm::GemmShape<1,1,4>, -+ 1, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ arch::OpMultiplyAdd>; -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ TensorRef d( -+ reinterpret_cast(&D), LayoutC::packed({ Shape::kM, Shape::kN })); -+ -+ // Copy accumulators -+ D = C; -+ -+ /// Underlying matrix multiply operator -+ ArchMmaOperator mma; -+ -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ // Compute matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Shape::kK / ArchMmaOperator::Shape::kK; ++k) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN; ++n) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM; ++m) { -+ MatrixCoord mn(m, n); -+ -+ Array tmp = reinterpret_cast &>(d.at(mn)); -+ -+ mma( -+ tmp, -+ ptr_A[m + k * Shape::kM], -+ ptr_B[n + k * Shape::kN], -+ tmp); -+ -+ d.at(mn) = reinterpret_cast(tmp); -+ } -+ } -+ } -+ } -+}; -+ -+} // namespace thread -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h -new file mode 100644 -index 0000000..7e4d765 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h -@@ -0,0 +1,734 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default template for a Blocked-Ell MMA. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+#include "cutlass/gemm/threadblock/ell_mma_pipelined.h" -+#include "cutlass/gemm/threadblock/ell_mma_multistage.h" -+#include "cutlass/transform/threadblock/ell_predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false -+ > -+struct DefaultEllMma; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass Simt) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultEllMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassSimt, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator -+ > -+struct DefaultEllMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator -+ > -+struct DefaultEllMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, float, LayoutA, float, -+ LayoutB, float, layout::RowMajor, arch::OpClassTensorOp, 2, -+ arch::OpMultiplyAddFastF16>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ float, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ float, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, float, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultEllMma, OperatorClass, -+ ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, -+ Operator, true> { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, -+ true>; -+ -+ static_assert(kAlignmentA == 128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ static_assert(kAlignmentB ==128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, ElementA, -+ LayoutA, 1, typename MmaCore::IteratorThreadMapA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, ElementB, -+ LayoutB, 0, typename MmaCore::IteratorThreadMapB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, -+ typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultEllMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultEllMma { -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultEllMma, OperatorClass, -+ ArchTag, ThreadblockShape, WarpShape, InstructionShape, -+ Stages, Operator, true> { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, Stages, -+ Operator, true>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for SIMT IDP4A Kernels -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape> -+struct DefaultEllMma, 2, -+ Operator, false> { -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using OperatorClass = arch::OpClassSimt; -+ -+ static const bool transposeA = cutlass::platform::is_same< LayoutA, layout::ColumnMajor >::value; -+ static const bool transposeB = cutlass::platform::is_same< LayoutB, layout::RowMajor >::value; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ OperatorClass, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+/// Specialization for Wmma TensorOp operator with 2 staged pipeline -+template < -+ ///< Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultEllMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, -+ arch::OpClassWmmaTensorOp, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ LayoutC, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for Wmma TensorOp operator with 1 staged pipeline -+template < -+ ///< Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultEllMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, -+ arch::OpClassWmmaTensorOp, 1, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped singlestage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaSingleStage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ LayoutC, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_gemv_core.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_gemv_core.h -new file mode 100755 -index 0000000..afb74e7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_gemv_core.h -@@ -0,0 +1,151 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level batched GEMV assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting SIMT instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+ -+#include "cutlass/gemm/threadblock/gemv.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+/// Template defininng default vector-matrix multiply operators inferred from threadblock tile size, -+/// global memory data layout. -+template < -+ typename Shape_, /// Shape of the threadblock vector-matrix multiply operator -+ typename ThreadShape_, /// Shape of per-thread vector-matrix multiply operator -+ typename ElementA_, /// Element data type of A operand -+ typename LayoutA_, /// Layout of operand A -+ typename ElementB_, /// Element data type of B operand -+ typename LayoutB_, /// Layout of operand B -+ typename ElementC_, /// Data type of accumulator -+ typename LayoutC_ /// Layout of accumulator -+> -+struct DefaultGemvCore { -+ -+ using Shape = Shape_; -+ using ThreadShape = ThreadShape_; -+ -+ using LayoutA = LayoutA_; -+ using LayoutB = LayoutB_; -+ using LayoutC = LayoutC_; -+ -+ using ElementA = ElementA_; -+ using ElementB = ElementB_; -+ using ElementC = ElementC_; -+ -+ static int const kThreadsPerN = Shape::kN / ThreadShape::kN; -+ -+ using IteratorPolicyA = typename platform::conditional< -+ platform::is_same::value, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< -+ layout::PitchLinearShape, 1, ThreadShape::kK>, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< -+ layout::PitchLinearShape, 1, ThreadShape::kM>>::type; -+ -+ using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementA, LayoutA, 1, IteratorPolicyA>; -+ -+ using IteratorPolicyB = typename platform::conditional< -+ platform::is_same::value, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< -+ layout::PitchLinearShape, kThreadsPerN, ThreadShape::kN>, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< -+ layout::PitchLinearShape, kThreadsPerN, ThreadShape::kK>>::type; -+ -+ using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementB, LayoutB, 0, IteratorPolicyB>; -+ -+ using IteratorPolicyC = typename platform::conditional< -+ platform::is_same::value, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< -+ layout::PitchLinearShape, kThreadsPerN, ThreadShape::kN>, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< -+ layout::PitchLinearShape, kThreadsPerN, ThreadShape::kM>>::type; -+ -+ using IteratorC = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementC, LayoutC, 0, IteratorPolicyC>; -+ -+ using MmaSimtOp = typename cutlass::gemm::thread::Mma< -+ cutlass::gemm::GemmShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC>; -+ -+ using Operator = MmaSimtOp; -+ -+ // Assertions for correctness -+ static_assert((Shape::kM == 1), "M=1 is required for GEMV"); -+ -+ static_assert((ThreadShape::kM == 1), "M=1 is required for GEMV"); -+ -+ static_assert(Shape::kK % ThreadShape::kK == 0, "Shape::K must be a multiple of ThreadShape::K"); -+ -+ static_assert(((ThreadShape::kK == 1) || -+ (ThreadShape::kK == 2) || -+ (ThreadShape::kK == 4) || -+ (ThreadShape::kK == 8) || -+ (ThreadShape::kK == 16) || -+ (ThreadShape::kK == 32) -+ ), -+ "ThreadShape::K must be a 1, 2, 4, 8, 16 or 32"); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma.h -new file mode 100644 -index 0000000..7e0b206 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma.h -@@ -0,0 +1,791 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Gather operand A by using an index array -+ bool GatherA = false, -+ /// Gather operand B by using an index array -+ bool GatherB = false -+ > -+struct DefaultMma; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass Simt) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB -+ > -+struct DefaultMma { -+ -+ static_assert(platform::is_same::value -+ || platform::is_same>::value, -+ "simt epilogue must be row major"); -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, -+ arch::OpClassSimt, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ LayoutC, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB -+ > -+struct DefaultMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, -+ GatherA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, -+ GatherB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB -+ > -+struct DefaultMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, float, LayoutA, float, -+ LayoutB, float, layout::RowMajor, arch::OpClassTensorOp, 2, -+ arch::OpMultiplyAddFastF16>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ float, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ float, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, float, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultMma, OperatorClass, -+ ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, -+ Operator, true, SharedMemoryClearOption::kNone, false, false> { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, -+ true>; -+ -+ static_assert(kAlignmentA == 128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ static_assert(kAlignmentB ==128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementA, -+ LayoutA, 1, typename MmaCore::IteratorThreadMapA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementB, -+ LayoutB, 0, typename MmaCore::IteratorThreadMapB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, -+ typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB -+ > -+struct DefaultMma { -+ -+ static_assert(platform::is_same::value -+ || platform::is_same>::value, -+ "simt epilogue must be row major"); -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassSimt, -+ Stages, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, LayoutC, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB -+ > -+struct DefaultMma { -+ -+ static_assert(platform::is_same::value -+ || platform::is_same>::value, -+ "simt epilogue must be row major"); -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, LayoutC, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClear>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultMma, OperatorClass, -+ ArchTag, ThreadblockShape, WarpShape, InstructionShape, -+ Stages, Operator, true, SharedMemoryClearOption::kNone, false, false> { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, Stages, -+ Operator, true>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for SIMT IDP4A Kernels -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape> -+struct DefaultMma, 2, -+ Operator, false, SharedMemoryClearOption::kNone, false, false> { -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using OperatorClass = arch::OpClassSimt; -+ -+ static const bool transposeA = platform::is_same< LayoutA, layout::ColumnMajor >::value; -+ static const bool transposeB = platform::is_same< LayoutB, layout::RowMajor >::value; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ OperatorClass, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+/// Specialization for Wmma TensorOp operator with 2 staged pipeline -+template < -+ ///< Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, -+ arch::OpClassWmmaTensorOp, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ LayoutC, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for Wmma TensorOp operator with 1 staged pipeline -+template < -+ ///< Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, -+ arch::OpClassWmmaTensorOp, 1, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped singlestage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaSingleStage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ LayoutC, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h -new file mode 100644 -index 0000000..3d7ffe9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h -@@ -0,0 +1,116 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/warp/mma.h" -+#include "cutlass/gemm/threadblock/mma_pipelined.h" -+#include "cutlass/gemm/threadblock/mma_singlestage.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/arch/mma.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template defininng default matrix multiply operators inferred from threadblock tile size, -+/// global memory data layout, and target math instruction. -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Number of stages -+ int Stages = 2, -+ /// Operation performed by MMA -+ typename Operator = typename platform::conditional< -+ (platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::arch::OpMultiplyAdd>::type, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ bool IsComplex = false // (is_complex::value || is_complex::value) -+> -+struct DefaultMmaCore; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_simt.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_simt.h -new file mode 100644 -index 0000000..a6d8ec0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_simt.h -@@ -0,0 +1,1723 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting simt instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h" -+ -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+namespace detail { -+ -+// convert a WarpShape which is the whole tile of elements into warp num threads. -+// The goal is for each thread's tile of elements to be as square as possible -+// for performance (4x4 will be faster than 2x8). -+template -+constexpr int simt_get_warp_threads_m() { -+ return (WarpShape::kM > WarpShape::kN) ? 8 : 4; -+} -+ -+/// Computes padding in shared memory to perform efficient transpose without bank conflicts. -+constexpr int simt_transpose_padding(int threads, int crosswise, int size_in_bits) { -+ return (size_in_bits >= 32 ? -+ threads / crosswise / (size_in_bits / 32) : -+ threads / crosswise * (32 / size_in_bits) -+ ); -+} -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::ColumnMajor, ElementB_, layout::RowMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::RowMajor, ElementB_, layout::ColumnMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ SmemThreadMapA // was IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ SmemThreadMapB // was IteratorThreadMapA -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, // skew for A matrix to avoid SMEM bank conflicts -+ MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::RowMajor, ElementB_, layout::RowMajor, ElementC_, -+ LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ SmemThreadMapA -+ >; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ static_assert(!(kPaddingM % LaneM), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, // skew for A matrix to avoid SMEM bank conflicts -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::ColumnMajor, ElementB_, layout::ColumnMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ SmemThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ static_assert(!(kPaddingN % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::AffineRank2ColumnMajor, ElementB_, layout::AffineRank2RowMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::AffineRank2RowMajor, ElementB_, layout::AffineRank2ColumnMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::AffineRank2RowMajor, ElementB_, layout::AffineRank2RowMajor, ElementC_, -+ LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::AffineRank2ColumnMajor, ElementB_, layout::AffineRank2ColumnMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: simt class, for dp4a -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, int8_t, -+ layout::ColumnMajor, int8_t, layout::RowMajor, ElementC_, -+ LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = int8_t; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorInterleaved<4>; -+ using SmemLayoutB = layout::RowMajorInterleaved<4>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(4, ThreadTileM); -+ static const int LaneN = cutlass::const_min(4, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 4>; -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::ColumnMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ PartitionsK /// Number of partitions along K dimension -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization: -+// -+/// -+/// A: Row-major -+/// B: Column-major -+/// Operator: simt class, for dp4a -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, int8_t, -+ layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, -+ LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorInterleaved<4>; -+ using SmemLayoutB = layout::RowMajorInterleaved<4>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMap2DThreadTile; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ SmemThreadMapA -+ >; -+ -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMap2DThreadTile; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ SmemThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(4, ThreadTileM); -+ static const int LaneN = cutlass::const_min(4, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 4>; -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::ColumnMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ PartitionsK /// Number of partitions along K dimension -+ >; -+ -+ static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, -+ MatrixShape<0, kPaddingN>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization: -+// -+/// -+/// A: Row-major -+/// B: Row-major -+/// Operator: simt class, for dp4a -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, int8_t, -+ layout::RowMajor, int8_t, layout::RowMajor, ElementC_, -+ LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorInterleaved<4>; -+ using SmemLayoutB = layout::RowMajorInterleaved<4>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMap2DThreadTile; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ SmemThreadMapA -+ >; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(4, ThreadTileM); -+ static const int LaneN = cutlass::const_min(4, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 4>; -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::ColumnMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ PartitionsK /// Number of partitions along K dimension -+ >; -+ -+ static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization: -+// -+/// -+/// A: Column-major -+/// B: Column-major -+/// Operator: simt class, for dp4a -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, int8_t, -+ layout::ColumnMajor, int8_t, layout::ColumnMajor, ElementC_, -+ LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorInterleaved<4>; -+ using SmemLayoutB = layout::RowMajorInterleaved<4>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMap2DThreadTile; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ SmemThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(4, ThreadTileM); -+ static const int LaneN = cutlass::const_min(4, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 4>; -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::ColumnMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ PartitionsK /// Number of partitions along K dimension -+ >; -+ -+ static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, kPaddingN>, -+ WarpCount::kK -+ >; -+}; -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm70.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm70.h -new file mode 100644 -index 0000000..fc83965 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm70.h -@@ -0,0 +1,682 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm70.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::ColumnMajor, ElementB_, layout::RowMajor, -+ ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<8, 8, 4>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = -+ layout::ColumnMajorVoltaTensorOpMultiplicandCongruous< -+ sizeof_bits::value>; -+ -+ // Shared memory layout -+ using SmemLayoutB = -+ layout::RowMajorVoltaTensorOpMultiplicandBCongruous< -+ sizeof_bits::value>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::RowMajor, ElementB_, layout::ColumnMajor, -+ ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<8, 8, 4>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 8>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 0, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 8>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 1, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::RowMajor, ElementB_, layout::RowMajor, ElementC_, -+ LayoutC_, arch::OpClassTensorOp, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<8, 8, 4>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorVoltaTensorOpMultiplicandBCongruous< -+ sizeof_bits::value>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 8>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 0, -+ IteratorThreadMapA -+ >; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::ColumnMajor, ElementB_, layout::ColumnMajor, -+ ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<8, 8, 4>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorVoltaTensorOpMultiplicandCongruous< -+ sizeof_bits::value>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 8>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 1, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm75.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm75.h -new file mode 100644 -index 0000000..697c45f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm75.h -@@ -0,0 +1,1279 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h" -+ -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = -+ layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementA))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementB))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by MMA -+ typename Operator_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 0, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 1, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by MMA -+ typename Operator_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementB))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 0, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by MMA -+ typename Operator_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementA))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Below is for arch::OpMultiplyAddFastF16 -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = float; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = float; -+ using LayoutB = layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 256; -+ -+ /// Default Operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(half_t))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = -+ layout::RowMajorTensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(half_t))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ half_t, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ half_t, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = float; -+ using LayoutA = layout::RowMajor; -+ using ElementB = float; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 256; -+ -+ /// Default Operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = -+ layout::RowMajorTensorOpMultiplicandCrosswise::value, -+ Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ half_t, -+ SmemLayoutA, -+ 0, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ half_t, -+ SmemLayoutB, -+ 1, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = float; -+ using LayoutA = layout::RowMajor; -+ using ElementB = float; -+ using LayoutB = layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 256; -+ -+ /// Default Operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(half_t))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ half_t, -+ SmemLayoutA, -+ 0, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ half_t, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = float; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = float; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 256; -+ -+ /// Default Operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(half_t))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, half_t, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, half_t, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, MatrixShape<0, 0>, -+ WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major-interleave -+/// B: row-major-interleave -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+/// -+/// Column/RowMajorInterleved(m, n) is mapped to Column/RowMajor(m -+/// x InterleavedK, n / InterleavedK) so that Column/RowMajor global iterators -+/// can be reused. The shared store iterator is the same as the crosswise shared -+/// store iterator. So, the only thing we need to do is to swap the coordinates -+/// (contiguous <=> strided) used by the global iterator and the shared store -+/// iterator. -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor, -+ /// Number of interleaved k -+ int InterleavedK> -+struct DefaultMmaCore, ElementB_, -+ layout::RowMajorInterleaved, ElementC_, -+ LayoutC_, arch::OpClassTensorOp, 2, Operator_, -+ AccumulatorsInRowMajor> { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajorInterleaved; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ static int const kInterleavedK = InterleavedK; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kElementsPerAccess = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ static int const kWarpThreadArrangementContiguous = -+ kInterleavedK / kElementsPerAccess; -+ -+ static int const kWarpThreadArrangementStrided = -+ kWarpSize / kWarpThreadArrangementContiguous; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kInterleavedK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kInterleavedK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMap< -+ IteratorThreadMapA, -+ layout::PitchLinearShape>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ SmemThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMap< -+ IteratorThreadMapB, -+ layout::PitchLinearShape>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ SmemThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK, AccumulatorsInRowMajor>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm80.h -new file mode 100644 -index 0000000..ad232fc ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm80.h -@@ -0,0 +1,2916 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming -+ expectations about data layout of the global memory fragments, data types, -+ and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp -+ instructions. -+ -+ SM80 Multi stage kernel expects stage number to be larger or equal to 3 -+ to use asyncronous copy. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -+#include "cutlass/gemm/threadblock/mma_multistage.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for double-precision -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 64; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+/// Partial specialization for double-precision -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 64; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for double-precision -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 64; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Partial specialization for double-precision -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 64; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for double-precision -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::AffineRank2ColumnMajor; -+ using ElementB = double; -+ using LayoutB = layout::AffineRank2ColumnMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+/// Partial specialization for double-precision -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::AffineRank2ColumnMajor; -+ using ElementB = double; -+ using LayoutB = layout::AffineRank2RowMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for double-precision -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::AffineRank2RowMajor; -+ using ElementB = double; -+ using LayoutB = layout::AffineRank2ColumnMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Partial specialization for double-precision -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::AffineRank2RowMajor; -+ using ElementB = double; -+ using LayoutB = layout::AffineRank2RowMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for float-precision -+/// -+/// ElementA: complex -+/// ElementB: complex -+/// ElementC: complex -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout for A operand -+ typename LayoutA_, -+ /// Layout for B operand -+ typename LayoutB_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA_, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB_ -+ > -+struct DefaultMmaCore< -+ Shape_, WarpShape_, GemmShape<16, 8, 8>, -+ complex, LayoutA_, -+ complex, LayoutB_, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ Operator_, -+ false, -+ CacheOpA, -+ CacheOpB, -+ TransformA_, TransformB_, true> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<16, 8, 8>; -+ using ElementA = complex; -+ using LayoutA = LayoutA_; -+ using ElementB = complex; -+ using LayoutB = LayoutB_; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ static const ComplexTransform TransformA = TransformA_; -+ static const ComplexTransform TransformB = TransformB_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ static_assert( -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value, -+ "The operator tag must indicate complex multiplication."); -+ -+ // -+ // Underlying template -+ // -+ -+ using MmaComplexCore = DefaultMultistageMmaComplexCore< -+ Shape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ arch::OpClassTensorOp, -+ kStages, -+ TransformA, -+ TransformB, -+ Operator, -+ kCacheOpA, -+ kCacheOpB -+ >; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename MmaComplexCore::SmemLayoutA; -+ -+ // Shared memory layout -+ using SmemLayoutB = typename MmaComplexCore::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename MmaComplexCore::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename MmaComplexCore::SmemIteratorA; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = typename MmaComplexCore::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename MmaComplexCore::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename MmaComplexCore::MmaTensorOp; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename MmaComplexCore::MmaPolicy; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for double-precision -+/// -+/// ElementA: complex -+/// ElementB: complex -+/// ElementC: complex -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout for A operand -+ typename LayoutA_, -+ /// Layout for B operand -+ typename LayoutB_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA_, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB_ -+ > -+struct DefaultMmaCore< -+ Shape_, WarpShape_, InstructionShape_, -+ complex, LayoutA_, -+ complex, LayoutB_, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ Operator_, -+ false, -+ CacheOpA, -+ CacheOpB, -+ TransformA_, TransformB_, true> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = complex; -+ using LayoutA = LayoutA_; -+ using ElementB = complex; -+ using LayoutB = LayoutB_; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ static const ComplexTransform TransformA = TransformA_; -+ static const ComplexTransform TransformB = TransformB_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 64; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ static_assert( -+ platform::is_same::value || -+ platform::is_same::value, -+ "The operator tag must indicate complex multiplication."); -+ -+ // -+ // Underlying template -+ // -+ -+ using MmaComplexCore = DefaultMultistageMmaComplexCore< -+ Shape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ arch::OpClassTensorOp, -+ kStages, -+ TransformA, -+ TransformB, -+ Operator, -+ kCacheOpA, -+ kCacheOpB -+ >; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename MmaComplexCore::SmemLayoutA; -+ -+ // Shared memory layout -+ using SmemLayoutB = typename MmaComplexCore::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename MmaComplexCore::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename MmaComplexCore::SmemIteratorA; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = typename MmaComplexCore::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename MmaComplexCore::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename MmaComplexCore::MmaTensorOp; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename MmaComplexCore::MmaPolicy; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementA))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementB))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementA))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementB))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major-interleaved -+/// B: row-major-interleaved -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+/// -+/// Column/RowMajorInterleved(m, n) is mapped to Column/RowMajor(m -+/// x InterleavedK, n / InterleavedK) so that Column/RowMajor global iterators -+/// can be reused. The shared store iterator is the same as the crosswise shared -+/// store iterator. So, the only thing we need to do is to swap the coordinates -+/// (contiguous <=> strided) used by the global iterator and the shared store -+/// iterator. -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Number of interleaved K -+ int InterleavedK> -+struct DefaultMmaCore, ElementB_, -+ layout::RowMajorInterleaved, ElementC_, -+ LayoutC_, arch::OpClassTensorOp, Stages, Operator_, -+ AccumulatorsInRowMajor, CacheOpA, CacheOpB> { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajorInterleaved; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ static int const kInterleavedK = InterleavedK; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kElementsPerAccess = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ static int const kWarpThreadArrangementContiguous = -+ kInterleavedK / kElementsPerAccess; -+ -+ static int const kWarpThreadArrangementStrided = -+ kWarpSize / kWarpThreadArrangementContiguous; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kInterleavedK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kInterleavedK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMap< -+ IteratorThreadMapA, -+ layout::PitchLinearShape>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ SmemThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMap< -+ IteratorThreadMapB, -+ layout::PitchLinearShape>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ SmemThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK, AccumulatorsInRowMajor>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator B -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ SmemThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static_assert(!((Shape::kK / 32) % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, Shape::kK / 32>, -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ SmemThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator B -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ SmemThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static_assert(!((Shape::kK / 32) % LaneM) && !((Shape::kK / 32) % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, -+ MatrixShape<0, Shape::kK / 32>, -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ SmemThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static_assert(!((Shape::kK / 32) % LaneM), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, -+ MatrixShape<0, 0>, -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+ -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h -new file mode 100644 -index 0000000..870845f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h -@@ -0,0 +1,834 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming -+ expectations about data layout of the global memory fragments, data types, -+ and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting sparse -+ TensorOp instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/default_mma_sparse_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -+#include "cutlass/gemm/threadblock/mma_sparse_multistage.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template defininng default matrix multiply operators inferred from threadblock tile size, -+/// global memory data layout, and target math instruction. -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator = typename platform::conditional< -+ (platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::arch::OpMultiplyAdd>::type, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false -+ /// Cache operation of operand A -+ , cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global -+> -+struct DefaultSparseMmaCore; -+ -+//////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultSparseMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ static int const kSparse = 2; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementA))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementB))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Cache operation of operand E -+ static cutlass::arch::CacheOperation::Kind const kCacheOpE = -+ cutlass::arch::CacheOperation::Global; -+ -+ static int const kInterleavedE = MmaTensorOp::kInterleaved; -+ static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; -+ static int const kMaxID2 = MmaTensorOp::kMaxID2; -+ static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; -+ -+ using ElementE = typename MmaTensorOp::ElementE; -+ using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; -+ -+ // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. -+ using SmemLayoutE = typename MmaTensorOp::LayoutE; -+ -+ /// ThreadMap of iterator E -+ static int const kElementsPerAccessE = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ /// E is tiny. Not all warps are needed. -+ static int const kThreadsE = -+ (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value) > -+ kThreads) -+ ? kThreads -+ : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value)); -+ -+ using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreadsE, kElementsPerAccessE>; -+ -+ /// Shared memory iterator to E operand -+ using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, -+ ElementE, SmemLayoutE, 0, IteratorThreadMapE>; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = -+ SparseMmaPolicy, MatrixShape<0, 0>, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultSparseMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ static int const kSparse = 2; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / kSparse / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ // crosswise cannot be larger than 1024 bit. -+ static int const kCrosswiseB = -+ (Shape::kK > (1024 / sizeof_bits::value)) -+ ? (1024 / sizeof_bits::value) -+ : Shape::kK; -+ -+ static int const kWarpThreadArrangementContiguousB = -+ kCrosswiseB / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK / kSparse>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kCrosswiseB>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Cache operation of operand E -+ static cutlass::arch::CacheOperation::Kind const kCacheOpE = -+ cutlass::arch::CacheOperation::Global; -+ -+ static int const kInterleavedE = MmaTensorOp::kInterleaved; -+ static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; -+ static int const kMaxID2 = MmaTensorOp::kMaxID2; -+ static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; -+ -+ using ElementE = typename MmaTensorOp::ElementE; -+ using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; -+ -+ // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. -+ using SmemLayoutE = typename MmaTensorOp::LayoutE; -+ -+ /// ThreadMap of iterator E -+ static int const kElementsPerAccessE = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ /// E is tiny. Not all warps are needed. -+ static int const kThreadsE = -+ (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value) > -+ kThreads) -+ ? kThreads -+ : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value)); -+ -+ using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreadsE, kElementsPerAccessE>; -+ -+ -+ /// Shared memory iterator to E operand -+ using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, -+ ElementE, SmemLayoutE, 0, IteratorThreadMapE>; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = -+ SparseMmaPolicy, MatrixShape<0, 0>, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultSparseMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ static int const kSparse = 2; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ // crosswise cannot be larger than 1024 bit. -+ static int const kCrosswiseB = -+ (Shape::kK > (1024 / sizeof_bits::value)) -+ ? (1024 / sizeof_bits::value) -+ : Shape::kK; -+ -+ static int const kWarpThreadArrangementContiguousB = -+ kCrosswiseB / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementA))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kCrosswiseB>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Cache operation of operand E -+ static cutlass::arch::CacheOperation::Kind const kCacheOpE = -+ cutlass::arch::CacheOperation::Global; -+ -+ static int const kInterleavedE = MmaTensorOp::kInterleaved; -+ static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; -+ static int const kMaxID2 = MmaTensorOp::kMaxID2; -+ static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; -+ -+ using ElementE = typename MmaTensorOp::ElementE; -+ using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; -+ -+ // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. -+ using SmemLayoutE = typename MmaTensorOp::LayoutE; -+ -+ /// ThreadMap of iterator E -+ static int const kElementsPerAccessE = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ /// E is tiny. Not all warps are needed. -+ static int const kThreadsE = -+ (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value) > -+ kThreads) -+ ? kThreads -+ : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value)); -+ -+ using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreadsE, kElementsPerAccessE>; -+ -+ /// Shared memory iterator to E operand -+ using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, -+ ElementE, SmemLayoutE, 0, IteratorThreadMapE>; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = -+ SparseMmaPolicy, MatrixShape<0, 0>, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultSparseMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ static int const kSparse = 2; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / kSparse / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK / kSparse>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementB))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Cache operation of operand E -+ static cutlass::arch::CacheOperation::Kind const kCacheOpE = -+ cutlass::arch::CacheOperation::Global; -+ -+ static int const kInterleavedE = MmaTensorOp::kInterleaved; -+ static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; -+ static int const kMaxID2 = MmaTensorOp::kMaxID2; -+ static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; -+ -+ using ElementE = typename MmaTensorOp::ElementE; -+ using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; -+ -+ // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. -+ using SmemLayoutE = typename MmaTensorOp::LayoutE; -+ -+ /// ThreadMap of iterator E -+ static int const kElementsPerAccessE = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ /// E is tiny. Not all warps are needed. -+ static int const kThreadsE = -+ (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value) > -+ kThreads) -+ ? kThreads -+ : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value)); -+ -+ using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreadsE, kElementsPerAccessE>; -+ -+ /// Shared memory iterator to E operand -+ using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, -+ ElementE, SmemLayoutE, 0, IteratorThreadMapE>; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = -+ SparseMmaPolicy, MatrixShape<0, 0>, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h -new file mode 100644 -index 0000000..0345084 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h -@@ -0,0 +1,328 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting simt instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/warp/mma.h" -+#include "cutlass/gemm/threadblock/mma_pipelined.h" -+#include "cutlass/gemm/threadblock/mma_singlestage.h" -+#include "cutlass/arch/cache_operation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Size of a threadblock-scoped access -+ int kAccessSizeInBits = -1, // -1 denoting the default -+ /// Number of stages -+ int Stages = 2, -+ /// Operation performed by MMA -+ typename Operator = typename platform::conditional< -+ (platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::arch::OpMultiplyAdd>::type, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ bool IsComplex = false // (is_complex::value || is_complex::value) -+> -+struct DefaultMmaCoreWithAccessSize; -+ -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB, -+ bool IsComplex -+> -+struct DefaultMmaCoreWithAccessSize< -+ Shape, WarpShape, InstructionShape, -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ OperatorClass, -1, Stages, Operator, AccumulatorsInRowMajor, -+ CacheOpA, CacheOpB, TransformA, TransformB, IsComplex -+> : DefaultMmaCore< -+ Shape, WarpShape, InstructionShape, -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ OperatorClass, Stages, Operator, AccumulatorsInRowMajor, -+ CacheOpA, CacheOpB, TransformA, TransformB, IsComplex -+> {}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Size of a threadblock-scoped access (a value of -1 indicates the default) -+ int kAccessSizeInBits_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCoreWithAccessSize>::type, ElementA_, -+ layout::ColumnMajor, ElementB_, layout::RowMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, kAccessSizeInBits_, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ static int const kElementsPerAccessDefault = 1; -+ static_assert(kAccessSizeInBits_ == -1 || -+ sizeof_bits::value == sizeof_bits::value || -+ kAccessSizeInBits_ / sizeof_bits::value == kElementsPerAccessDefault, -+ "Non-default value for kAccessSizeInBits_ is only allowed if size(elementA) == sizeof(elementB)"); -+ static int const kElementsPerAccess = (kAccessSizeInBits_ != -1) ? kAccessSizeInBits_ / sizeof_bits::value : kElementsPerAccessDefault; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h -new file mode 100644 -index 0000000..d150791 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h -@@ -0,0 +1,167 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming -+ expectations about data layout of the global memory fragments, data types, -+ and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp -+ instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -+#include "cutlass/gemm/threadblock/mma_with_reduction_multistage.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template defininng default matrix multiply operators inferred from threadblock tile size, -+/// global memory data layout, and target math instruction. -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Reduce operand A or B along K dimension -+ bool ReduceKForA_, -+ /// Number of stages -+ int Stages = 2, -+ /// Operation performed by MMA -+ typename Operator = typename platform::conditional< -+ (platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::arch::OpMultiplyAdd>::type, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ bool IsComplex = false// (is_complex::value || is_complex::value) -+> -+struct DefaultMmaWithReductionCore { -+ using Base = DefaultMmaCore; -+ using Shape = Shape_; -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ using WarpCount = typename Base::WarpCount; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaWithReductionTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, ReduceKForA_, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_wmma.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_wmma.h -new file mode 100644 -index 0000000..f4d0a23 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_wmma.h -@@ -0,0 +1,712 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_wmma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: wmma tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ ///< Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Number of stages -+ int Stages> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassWmmaTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ // NOTE: shared memory layout for wmma is same as the operands' layout in the global memory -+ using SmemLayoutA = LayoutA; -+ using SmemLayoutB = LayoutB; -+ -+ // Pad shared memory to avoid bank conflicts -+ static int const kPaddingA = 128 / sizeof_bits::value; -+ static int const kPaddingB = 128 / sizeof_bits::value; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape, -+ MatrixShape<0, kPaddingB>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: wmma tensorop class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ ///< Shape of threadblock-scoped matrix multiply operator -+ ///< (concept:GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) [allowed -+ /// wmma instruction shapes, e.g., 16x16x16, 32x8x16, 8x32x16,...] -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Number of stages -+ int Stages> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassWmmaTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads per threadblock -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ // shared memory layout for wmma is same as the operands' layout in global memory -+ using SmemLayoutA = LayoutA; -+ using SmemLayoutB = LayoutB; -+ -+ // Pad shared memory to avoid bank conflicts -+ static int const kPaddingA = 128 / sizeof_bits::value; -+ static int const kPaddingB = 128 / sizeof_bits::value; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB // SmemThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, kPaddingA>, -+ MatrixShape, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Number of stages -+ int Stages> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassWmmaTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ // shared memory layout for wmma is same as the operands' layout in global memory -+ using SmemLayoutA = LayoutA; -+ using SmemLayoutB = LayoutB; -+ -+ // Pad shared memory to avoid bank conflicts -+ static int const kPaddingA = 128 / sizeof_bits::value; -+ static int const kPaddingB = 128 / sizeof_bits::value; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, kPaddingA>, -+ MatrixShape<0, kPaddingB>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Number of stages -+ int Stages> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassWmmaTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = -+ GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ // shared memory layout for wmma is same as the operands' layout in global memory -+ using SmemLayoutA = LayoutA; -+ using SmemLayoutB = LayoutB; -+ -+ // Pad shared memory to avoid bank conflicts -+ static int const kPaddingA = 128 / sizeof_bits::value; -+ static int const kPaddingB = 128 / sizeof_bits::value; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape, -+ MatrixShape, -+ WarpCount::kK -+ >; -+}; -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h -new file mode 100644 -index 0000000..b05c634 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h -@@ -0,0 +1,178 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+#include "cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h" -+#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" -+#include "cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" -+#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for Scale/Bias vectors -+ typename ElementScaleBias, -+ /// Layout type for Scale/Bias vectors -+ typename LayoutScaleBias, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Use zfill or predicate for SM80 out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone -+ > -+struct DefaultMmaLayernormMainloopFusion { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpGammaBeta = CacheOpA; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorVarMean = -+ cutlass::transform::threadblock::PredicatedScaleBiasVectorIterator< -+ cutlass::MatrixShape<1, WarpShape::kN>, -+ ElementScaleBias, -+ LayoutScaleBias>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorGammaBeta = -+ cutlass::transform::threadblock::PredicatedScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ using SmemIteratorGammaBeta = -+ cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ static int const kThreadCount = 32; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using WarpIteratorGammaBeta = cutlass::gemm::warp::ScaleBiasTileIterator< -+ MatrixShape, ElementScaleBias, -+ LayoutScaleBias, MatrixShape, -+ typename MmaCore::MmaTensorOp::IteratorA::Base::Policy, kThreadCount, -+ MmaCore::WarpCount::kK>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaLayernormMainloopFusionMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, IteratorVarMean, IteratorGammaBeta, SmemIteratorGammaBeta, -+ CacheOpGammaBeta, -+ ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, WarpIteratorGammaBeta, Stages, SharedMemoryClear>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h -new file mode 100644 -index 0000000..6915b20 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a multistage GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/arch.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/mma_planar_complex_multistage.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Math operator tag (e.g. arch::OpMultiplyAdd) -+ typename Operator = arch::OpMultiplyAdd -+> -+struct DefaultMmaPlanarComplexMultistage { -+ -+ // Construct a planar complex variant from the real-valued variant -+ using RealMmaMultistage = typename DefaultMma< -+ ElementA_, -+ LayoutA_, -+ kAlignmentA, -+ ElementB_, -+ LayoutB_, -+ kAlignmentB, -+ ElementAccumulator_, -+ LayoutC_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ Stages, -+ Operator -+ >::ThreadblockMma; -+ -+ using ThreadblockMma = MmaPlanarComplexMultistage< -+ ThreadblockShape_, -+ typename RealMmaMultistage::IteratorA, -+ typename RealMmaMultistage::SmemIteratorA, -+ cutlass::arch::CacheOperation::Global, -+ typename RealMmaMultistage::IteratorB, -+ typename RealMmaMultistage::SmemIteratorB, -+ cutlass::arch::CacheOperation::Global, -+ ElementAccumulator_, -+ LayoutC_, -+ typename RealMmaMultistage::Policy, -+ Stages, -+ TransformA, -+ TransformB -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h -new file mode 100644 -index 0000000..a7ae5a4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#include "cutlass/gemm/warp/mma_planar_complex.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/mma_planar_complex_pipelined.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Math operator tag (e.g. arch::OpMultiplyAdd) -+ typename Operator = arch::OpMultiplyAdd -+> -+struct DefaultMmaPlanarComplexPipelined { -+ -+ // Construct a planar complex variant from the real-valued variant -+ using RealMma = typename DefaultMma< -+ ElementA_, -+ LayoutA_, -+ kAlignmentA, -+ ElementB_, -+ LayoutB_, -+ kAlignmentB, -+ ElementAccumulator_, -+ LayoutC_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ Stages, -+ Operator -+ >::ThreadblockMma; -+ -+ using ThreadblockMma = MmaPlanarComplexPipelined< -+ ThreadblockShape_, -+ typename RealMma::IteratorA, -+ typename RealMma::SmemIteratorA, -+ typename RealMma::IteratorB, -+ typename RealMma::SmemIteratorB, -+ ElementAccumulator_, -+ LayoutC_, -+ typename RealMma::Policy, -+ Stages, -+ TransformA, -+ TransformB -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h -new file mode 100644 -index 0000000..e8db4d8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h -@@ -0,0 +1,160 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined softmax-GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+#include "cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h" -+#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" -+#include "cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" -+#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for Scale/Bias vectors -+ typename ElementScaleBias, -+ /// Layout type for Scale/Bias vectors -+ typename LayoutScaleBias, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Whether problem has been transformed. This determines to which operand -+ /// the softmax is applied. -+ bool InternalTranspose, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Use zfill or predicate for SM80 out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone -+ > -+struct DefaultMmaSoftmaxMainloopFusion { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpGammaBeta = CacheOpA; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorNormSum = -+ cutlass::transform::threadblock::PredicatedScaleBiasVectorIterator< -+ cutlass::MatrixShape<1, WarpShape::kN>, -+ ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaSoftmaxMainloopFusionMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, IteratorNormSum, -+ ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, InternalTranspose, SharedMemoryClear>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h -new file mode 100644 -index 0000000..bc6957a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h -@@ -0,0 +1,141 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/gemm/threadblock/default_mma_core_with_reduction.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Operator class tag -+ typename OperatorClass, -+ /// -+ bool ReduceKForA_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Use zfill or predicate for SM80 out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone -+ > -+struct DefaultMmaWithReduction { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaWithReductionCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ ReduceKForA_, Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaWithReductionMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClear>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h -new file mode 100644 -index 0000000..4bd3530 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h -@@ -0,0 +1,159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a multistage GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/arch.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator = arch::OpMultiplyAddComplex, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false> -+struct DefaultMultistageMmaComplex; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageMmaComplex { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h -new file mode 100644 -index 0000000..79b4ec3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h -@@ -0,0 +1,119 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming -+ expectations about data layout of the global memory fragments, data types, -+ and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp -+ instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/complex.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template defininng default matrix multiply operators inferred from -+/// threadblock tile size, global memory data layout, and target math -+/// instruction. -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator = arch::OpMultiplyAddComplex, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global> -+struct DefaultMultistageMmaComplexCore; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h -new file mode 100644 -index 0000000..1a7065b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h -@@ -0,0 +1,1808 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming -+ expectations about data layout of the global memory fragments, data types, -+ and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp -+ instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -+#include "cutlass/gemm/threadblock/mma_multistage.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex double-precision -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, InstructionShape_, -+ complex, layout::ColumnMajor, -+ complex, layout::RowMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = complex; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = complex; -+ using LayoutB = layout::RowMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped 128 -+ static int const kAccessSizeInBits = 128; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+ -+/// Partial specialization for complex double-precision -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, InstructionShape_, -+ complex, layout::ColumnMajor, -+ complex, layout::ColumnMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = complex; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = complex; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ using Operator = Operator_; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped 128 -+ static int const kAccessSizeInBits = 128; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex double-precision -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, InstructionShape_, -+ complex, layout::RowMajor, -+ complex, layout::ColumnMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = complex; -+ using LayoutA = layout::RowMajor; -+ using ElementB = complex; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped 128 -+ static int const kAccessSizeInBits = 128; -+ -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+ -+/// Partial specialization for complex double-precision -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, InstructionShape_, -+ complex, layout::RowMajor, -+ complex, layout::RowMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = complex; -+ using LayoutA = layout::RowMajor; -+ using ElementB = complex; -+ using LayoutB = layout::RowMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped 128 -+ static int const kAccessSizeInBits = 128; -+ -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex floating-point -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: arch::OpMultiplyAddComplex -+/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<16, 8, 8>, -+ complex, layout::ColumnMajor, -+ complex, layout::ColumnMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<16, 8, 8>; -+ using ElementA = complex; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = complex; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped -+ static int const kAccessSizeInBits = 64; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+ -+/// Partial specialization for complex floating-point -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex -+/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<16, 8, 8>, -+ complex, layout::ColumnMajor, -+ complex, layout::RowMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<16, 8, 8>; -+ using ElementA = complex; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = complex; -+ using LayoutB = layout::RowMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped -+ static int const kAccessSizeInBits = 64; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex floating-point -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: arch::OpMultiplyAddComplex -+/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<16, 8, 8>, -+ complex, layout::RowMajor, -+ complex, layout::ColumnMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<16, 8, 8>; -+ using ElementA = complex; -+ using LayoutA = layout::RowMajor; -+ using ElementB = complex; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped -+ static int const kAccessSizeInBits = 64; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex floating-point -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex -+/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<16, 8, 8>, -+ complex, layout::RowMajor, -+ complex, layout::RowMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<16, 8, 8>; -+ using ElementA = complex; -+ using LayoutA = layout::RowMajor; -+ using ElementB = complex; -+ using LayoutB = layout::RowMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped -+ static int const kAccessSizeInBits = 64; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex SIMT operation -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ typename RealA, -+ typename RealB, -+ typename RealC, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<1, 1, 1>, -+ complex, layout::ColumnMajor, -+ complex, layout::ColumnMajor, -+ complex, LayoutC_, -+ arch::OpClassSimt, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = complex; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = complex; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of access -+ static int const kAccessSizeInBits = sizeof_bits::value; -+ -+ /// No vectorized accesses -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator B -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ SmemThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ 1, /// 1 partition along K dimension -+ kTransformA, /// Transform for A -+ kTransformB /// Transform for B -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, Shape::kK / 32>, -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for complex SIMT operation -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ typename RealA, -+ typename RealB, -+ typename RealC, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<1, 1, 1>, -+ complex, layout::ColumnMajor, -+ complex, layout::RowMajor, -+ complex, LayoutC_, -+ arch::OpClassSimt, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = complex; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = complex; -+ using LayoutB = layout::RowMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of access -+ static int const kAccessSizeInBits = sizeof_bits::value; -+ -+ /// No vectorized accesses -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ 1, /// 1 partition along K dimension -+ kTransformA, /// Transform for A -+ kTransformB /// Transform for B -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, // or Shape::kK / 32 -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for complex SIMT operation -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ typename RealA, -+ typename RealB, -+ typename RealC, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<1, 1, 1>, -+ complex, layout::RowMajor, -+ complex, layout::ColumnMajor, -+ complex, LayoutC_, -+ arch::OpClassSimt, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = complex; -+ using LayoutA = layout::RowMajor; -+ using ElementB = complex; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of access -+ static int const kAccessSizeInBits = sizeof_bits::value; -+ -+ /// No vectorized accesses -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ SmemThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator B -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ SmemThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ 1, /// 1 partition along K dimension -+ kTransformA, /// Transform for A -+ kTransformB /// Transform for B -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, -+ MatrixShape<0, Shape::kK / 32>, -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for complex SIMT operation -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ typename RealA, -+ typename RealB, -+ typename RealC, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<1, 1, 1>, -+ complex, layout::RowMajor, -+ complex, layout::RowMajor, -+ complex, LayoutC_, -+ arch::OpClassSimt, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = complex; -+ using LayoutA = layout::RowMajor; -+ using ElementB = complex; -+ using LayoutB = layout::RowMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of access -+ static int const kAccessSizeInBits = sizeof_bits::value; -+ -+ /// No vectorized accesses -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ SmemThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ 1, /// 1 partition along K dimension -+ kTransformA, /// Transform for A -+ kTransformB /// Transform for B -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, -+ MatrixShape<0, 0>, // or Shape::kK / 32 -+ WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h -new file mode 100644 -index 0000000..367869e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h -@@ -0,0 +1,556 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a multistage GEMM kernel. Does not compute batching or support split-K. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h" -+#include "cutlass/gemm/threadblock/mma_blas3_multistage.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator = arch::OpMultiplyAddComplex, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kTriangular, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false> -+struct DefaultMultistageTrmmComplex; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageTrmmComplex { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, -+ kSideMode, kFillMode, kDiagType, -+ AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, -+ kSideMode, FillMode::kFull, DiagType::kInvalid, -+ AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output and right-side mode -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageTrmmComplex { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, -+ SideMode::kRight, FillMode::kFull, DiagType::kInvalid, -+ AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, -+ SideMode::kRight, kFillMode, kDiagType, -+ AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output with unit diagonal -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageTrmmComplex { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, -+ kSideMode, kFillMode, DiagType::kUnit, -+ AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, -+ kSideMode, FillMode::kFull, DiagType::kInvalid, -+ AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output and right-side mode, unit diagonal -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageTrmmComplex { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, -+ SideMode::kRight, FillMode::kFull, DiagType::kInvalid, -+ AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, -+ SideMode::kRight, kFillMode, DiagType::kUnit, -+ AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (for TRMM where diagonal imag part is ignored - used by HEMM) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageTrmmComplex { -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ // PredicatedTileAccessIteratorTriangularMatrix only tracks diagonal elements, -+ // when DiagType is kUnit -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, -+ kSideMode, kFillMode, DiagType::kUnit, -+ AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, -+ kSideMode, FillMode::kFull, DiagType::kInvalid, -+ AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill, -+ BlasMode::kHermitian>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output and right-side mode (for TRMM where diagonal imag part is ignored - used by HEMM) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageTrmmComplex { -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, -+ SideMode::kRight, FillMode::kFull, DiagType::kInvalid, -+ AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ // PredicatedTileAccessIteratorTriangularMatrix only tracks diagonal elements, -+ // when DiagType is kUnit -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, -+ SideMode::kRight, kFillMode, DiagType::kUnit, -+ AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill, -+ BlasMode::kHermitian>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h -new file mode 100644 -index 0000000..5faa76b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h -@@ -0,0 +1,196 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false -+ > -+struct DefaultSparseMma; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultSparseMma { -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ static int const kSparse = MmaCore::kSparse; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define iterators over tiles from the E operand -+ using ElementE = typename MmaCore::ElementE; -+ using LayoutE = typename MmaCore::GmemLayoutE; -+ using ThreadMapE = typename MmaCore::IteratorThreadMapE; -+ using AccessTypeE = -+ cutlass::Array::value>; -+ using IteratorE = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementE, LayoutE, 1, ThreadMapE, AccessTypeE>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::SparseMmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ IteratorE, typename MmaCore::SmemIteratorE, MmaCore::kCacheOpE, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_trmm.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_trmm.h -new file mode 100644 -index 0000000..8c13d17 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_trmm.h -@@ -0,0 +1,445 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+// -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h" -+#include "cutlass/gemm/threadblock/mma_blas3_multistage.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false -+ > -+struct DefaultTrmm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultTrmm { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, kSideMode, kFillMode, kDiagType, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, kSideMode, FillMode::kFull, DiagType::kInvalid, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output, right side mode (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultTrmm { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, SideMode::kRight, FillMode::kFull, DiagType::kInvalid, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, SideMode::kRight, kFillMode, kDiagType, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output with unit diagonal (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultTrmm { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, kSideMode, kFillMode, DiagType::kUnit, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, kSideMode, FillMode::kFull, DiagType::kInvalid, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output, right side mode, unit diagonal (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultTrmm { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, SideMode::kRight, FillMode::kFull, DiagType::kInvalid, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, SideMode::kRight, kFillMode, DiagType::kUnit, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/ell_mma_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/ell_mma_multistage.h -new file mode 100644 -index 0000000..3f73b9e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/ell_mma_multistage.h -@@ -0,0 +1,642 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped Blocked-Ell MMA. -+*/ -+ -+#pragma once -+ -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class EllMmaMultistage : -+ public MmaBase { -+public: -+ ///< Base class -+ using Base = MmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ EllMmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ template -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, EllIterator &ell_iter, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ bool is_valid = iterator_A.valid(); -+ -+ if (!is_A_sparse){ -+ if (is_offset_constant){ -+ auto ell_offset = ell_iter.get_offset_fast(); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += ell_offset * sizeof(IteratorA::Element) / kSrcBytes; -+ } else { -+ int k_offset = iterator_A.get_k(); -+ auto ell_offset = ell_iter.get_offset(k_offset); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += (ell_offset * sizeof(IteratorA::Element)) / kSrcBytes; -+ } -+ } -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, is_valid); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ bool is_valid = iterator_B.valid(); -+ -+ if (is_A_sparse){ -+ if (is_offset_constant){ -+ auto ell_offset = ell_iter.get_offset_fast(); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += ell_offset * sizeof(IteratorB::Element) / kSrcBytes; -+ } else { -+ int k_offset = iterator_B.get_k(); -+ auto ell_offset = ell_iter.get_offset(k_offset); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += ( ell_offset * sizeof(IteratorB::Element)) / kSrcBytes; -+ } -+ } -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, is_valid); -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ template -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< initial value of accumulator -+ FragmentC const &src_accum, -+ EllIterator &ell_iterator -+ ) { -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ auto gmem_ptr = iterator_A.get(); -+ bool is_valid = iterator_A.valid(); -+ -+ if (!is_A_sparse){ -+ if (is_offset_constant){ -+ auto ell_offset = ell_iterator.get_offset_fast(); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += ell_offset * sizeof(IteratorA::Element) / kSrcBytes; -+ } else { -+ int k_offset = iterator_A.get_k(); -+ auto ell_offset = ell_iterator.get_offset(k_offset); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += (ell_offset * sizeof(IteratorA::Element)) / kSrcBytes; -+ } -+ } -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, is_valid); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ auto gmem_ptr = iterator_B.get(); -+ bool is_valid = iterator_B.valid(); -+ -+ if (is_A_sparse){ -+ if (is_offset_constant){ -+ auto ell_offset = ell_iterator.get_offset_fast(); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += ell_offset * sizeof(IteratorB::Element) / kSrcBytes; -+ } else { -+ int k_offset = iterator_B.get_k(); -+ auto ell_offset = ell_iterator.get_offset(k_offset); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += ( ell_offset * sizeof(IteratorB::Element)) / kSrcBytes; -+ } -+ } -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, is_valid); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ ++ell_iterator; -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ if (is_A_sparse){ -+ iterator_A.ell_add_mask(ell_iterator.get_blocksize()); -+ } -+ else { -+ iterator_B.ell_add_mask(ell_iterator.get_blocksize()); -+ } -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary -+ // accumulator and this temporary accumulator is added to the final -+ // accumulator once in every mainloop iteration. -+ plus plus_accum; -+ -+ FragmentC tmp_accum; -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ -+ tmp_accum.clear(); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ -+ warp_mma( -+ tmp_accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ tmp_accum -+ ); -+ -+ if (warp_mma_k == 0) { -+ accum = plus_accum(accum, tmp_accum); -+ tmp_accum.clear(); -+ } -+ } else { -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ } -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance( -+ iterator_A, iterator_B, ell_iterator, group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance( -+ iterator_A, iterator_B, ell_iterator, group_start_iteration_A, -+ group_start_iteration_B); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ ++ell_iterator; -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ accum = plus_accum(accum, tmp_accum); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/ell_mma_pipelined.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/ell_mma_pipelined.h -new file mode 100644 -index 0000000..10ff6df ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/ell_mma_pipelined.h -@@ -0,0 +1,376 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped Blocked-Ell MMA. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to A operand -+ typename TransformA_ = NumericArrayConverter< -+ typename SmemIteratorA_::Element, -+ typename IteratorA_::Element, -+ IteratorA_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B operand -+ typename TransformB_ = NumericArrayConverter< -+ typename SmemIteratorB_::Element, -+ typename IteratorB_::Element, -+ IteratorB_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class EllMmaPipelined : public MmaBase { -+public: -+ -+ ///< Base class -+ using Base = MmaBase; -+ -+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for EllMmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "EllMmaPipelined requires kStages set to value 2"); -+ -+private: -+ -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; -+ -+public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ EllMmaPipelined( -+ typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ template -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC &accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const &src_accum, ///< source accumulator tile -+ EllIterator &ell_iterator, -+ TransformA transform_A = TransformA(), ///< transformation applied to A fragment -+ TransformB transform_B = TransformB()) { ///< transformation applied to B fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentA tb_frag_A; -+ FragmentB tb_frag_B; -+ -+ tb_frag_A.clear(); -+ tb_frag_B.clear(); -+ -+ // load sparse matrix -+ if (is_A_sparse){ -+ iterator_A.load(tb_frag_A); -+ } else { -+ iterator_B.load(tb_frag_B); -+ } -+ -+ // load dense matrix -+ if (is_offset_constant){ -+ if (is_A_sparse){ -+ iterator_B.load_with_ell_index_fast(tb_frag_B, ell_iterator); -+ } else { -+ iterator_A.load_with_ell_index_fast(tb_frag_A, ell_iterator); -+ } -+ } else { -+ if (is_A_sparse){ -+ iterator_B.load_with_ell_index(tb_frag_B, ell_iterator); -+ } else { -+ iterator_A.load_with_ell_index(tb_frag_A, ell_iterator); -+ } -+ } -+ -+ ++iterator_A; -+ ++iterator_B; -+ ++ell_iterator; -+ -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA warp_frag_A[2]; -+ WarpFragmentB warp_frag_B[2]; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Avoid reading out of bounds -+ iterator_A.clear_mask(gemm_k_iterations <= 1); -+ iterator_B.clear_mask(gemm_k_iterations <= 1); -+ -+ if (is_A_sparse){ -+ iterator_A.ell_add_mask(ell_iterator.get_blocksize()); -+ } -+ else { -+ iterator_B.ell_add_mask(ell_iterator.get_blocksize()); -+ } -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ // load sparse matrix -+ if (is_A_sparse){ -+ iterator_A.load(tb_frag_A); -+ } else { -+ iterator_B.load(tb_frag_B); -+ } -+ -+ // load dense matrix -+ if (is_offset_constant){ -+ if (is_A_sparse){ -+ iterator_B.load_with_ell_index_fast(tb_frag_B, ell_iterator); -+ } else { -+ iterator_A.load_with_ell_index_fast(tb_frag_A, ell_iterator); -+ } -+ } else { -+ if (is_A_sparse){ -+ iterator_B.load_with_ell_index(tb_frag_B, ell_iterator); -+ } else { -+ iterator_A.load_with_ell_index(tb_frag_A, ell_iterator); -+ } -+ } -+ -+ ++iterator_A; -+ ++iterator_B; -+ ++ell_iterator; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A.clear_mask(gemm_k_iterations <= 2); -+ iterator_B.clear_mask(gemm_k_iterations <= 2); -+ } -+ -+ warp_mma(accum, warp_frag_A[warp_mma_k % 2], -+ warp_frag_B[warp_mma_k % 2], accum); -+ } -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/gemv.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/gemv.h -new file mode 100755 -index 0000000..f0a4b1d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/gemv.h -@@ -0,0 +1,147 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a threadblock-scoped GEMV kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix-vector product using SIMT math instructions. -+template < -+ class Core_ //< GemvCore -+> -+class Gemv { -+public: -+ using Shape = typename Core_::Shape; -+ -+ /// The MMA operator that computes GEMV -+ using Operator = typename Core_::Operator; -+ -+ /// Iterates over A in global memory -+ using IteratorA = typename Core_::IteratorA; -+ -+ /// Iterates over B in global memory -+ using IteratorB = typename Core_::IteratorB; -+ -+ /// Fragment of operand C loaded from global memory -+ using IteratorC = typename Core_::IteratorC; -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of operand accumulator loaded/stored to global memory -+ using FragmentC = typename Operator::FragmentC; -+ -+ /// Shape of the per-thread GEMV operation -+ using ThreadShape = typename Core_::ThreadShape; -+ -+public: -+ CUTLASS_DEVICE -+ Gemv() { } -+ -+ CUTLASS_DEVICE -+ void operator()( -+ GemmCoord const &problem_size, ///< problem size of batched GEMV -+ FragmentC &accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const &src_accum) { ///< source accumualtor tile -+ -+ // -+ // Prologue -+ // -+ -+ FragmentA frag_A; -+ FragmentB frag_B; -+ frag_A.clear(); -+ frag_B.clear(); -+ -+ iterator_A.load(frag_A); -+ iterator_B.load(frag_B); -+ ++iterator_A; -+ ++iterator_B; -+ -+ // -+ // Mainloop -+ // -+ Operator thread_mma; -+ int gemm_k = problem_size.k(); -+ -+ if (gemm_k < Shape::kK) -+ { -+ iterator_A.clear_mask(); -+ iterator_B.clear_mask(); -+ } -+ -+ // iterate over K to accumulate result -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k > 0; gemm_k -= Shape::kK) { -+ thread_mma(accum, frag_A, frag_B, accum); -+ -+ iterator_A.load(frag_A); -+ iterator_B.load(frag_B); -+ ++iterator_A; -+ ++iterator_B; -+ -+ if (gemm_k < Shape::kK) -+ { -+ iterator_A.clear_mask(); -+ iterator_B.clear_mask(); -+ } -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/index_remat.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/index_remat.h -new file mode 100644 -index 0000000..1e24568 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/index_remat.h -@@ -0,0 +1,107 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Helpers for rematerializing indices/dimensions in the thread hierarchy from special registers -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to rematerialize block Idx. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeThreadIdxX() { -+ return threadIdx.x; -+} -+ -+/// Helper to rematerialize block Idx. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeThreadIdxY() { -+ return threadIdx.y; -+} -+ -+/// Helper to rematerialize block Idx. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeThreadIdxZ() { -+ return threadIdx.z; -+} -+ -+/// Helper to rematerialize block Idx. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeBlockIdxX() { -+ return blockIdx.x; -+} -+ -+/// Helper to rematerialize block Idx. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeBlockIdxY() { -+ return blockIdx.y; -+} -+ -+/// Helper to rematerialize block Idx. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeBlockIdxZ() { -+ return blockIdx.z; -+} -+ -+/// Helper to rematerialize block Dim. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeBlockDimX() { -+ return blockDim.x; -+} -+ -+/// Helper to rematerialize block Dim. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeBlockDimY() { -+ return blockDim.y; -+} -+ -+/// Helper to rematerialize block Dim. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeBlockDimZ() { -+ return blockDim.z; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_base.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_base.h -new file mode 100644 -index 0000000..524fdf9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_base.h -@@ -0,0 +1,236 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/tensor_ref.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy object describing MmaTensorOp -+template < -+ /// Warp-level GEMM operator (concept: gemm::warp::Mma) -+ typename Operator_, -+ /// Padding used for A operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingA_, -+ /// Padding used for B operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingB_, -+ /// Number of partitions of K dimension of GEMM -+ int PartitionsK = 1> -+struct MmaPolicy { -+ /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) -+ using Operator = Operator_; -+ -+ /// Padding used for A operand in shared memory -+ using SmemPaddingA = SmemPaddingA_; -+ -+ /// Padding used for B operand in shared memory -+ using SmemPaddingB = SmemPaddingB_; -+ -+ /// Number of partitions of K dimension -+ static int const kPartitionsK = PartitionsK; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ static_assert(kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ static_assert((kWarpGemmIterations % 2) == 0, -+ "Inner loop iteration must be an even number."); -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_blas3_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_blas3_multistage.h -new file mode 100644 -index 0000000..fa05aac ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_blas3_multistage.h -@@ -0,0 +1,702 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+ Used by BLAS3 kernels that need to treat diagonal elements of a input iterator as a special case. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kZfill, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kTriangular, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaBlas3Multistage : -+ public MmaBase { -+public: -+ ///< Base class -+ using Base = MmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ ///< Blas Mode -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaBlas3Multistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ bool isvalid = iterator_A.valid(); -+ -+ if (isvalid && iterator_A.getOnDiag()) { -+ // Elements that are on diagonal -+ if (kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { -+ /* Copy real part from gmem, write zero for imag part in smem */ -+ /* The following logic to determine kSizeRealBytes is so that compiler doesn't complain when -+ * compiling for not complex datatype and using half the size for cp_async_zfill */ -+ int const kSizeRealBytes = (platform::is_same>::value) ? 8 : 4; -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, true); -+ cutlass::arch::cp_async_diag( -+ reinterpret_cast (dst_ptr + v) + kSizeRealBytes); -+ } else { -+ /* Write one (1) directly to smem*/ -+ cutlass::arch::cp_async_diag(dst_ptr + v); -+ } -+ } else { -+ // Elements that are not of diagonal -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, isvalid); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ bool isvalid = iterator_B.valid(); -+ -+ if (isvalid && iterator_B.getOnDiag()) { -+ // Elements that are on diagonal -+ if (kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { -+ /* Copy real part from gmem, write zero for imag part in smem */ -+ int const kSizeRealBytes = (platform::is_same>::value) ? 8 : 4; -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, true); -+ cutlass::arch::cp_async_diag( -+ reinterpret_cast (dst_ptr + v) + kSizeRealBytes); -+ } else { -+ /* Write one (1) directly to smem*/ -+ cutlass::arch::cp_async_diag(dst_ptr + v); -+ } -+ } else { -+ // Elements that are not of diagonal -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, isvalid); -+ } -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ auto gmem_ptr = iterator_A.get(); -+ bool isvalid = iterator_A.valid(); -+ -+ if (isvalid && iterator_A.getOnDiag()) { -+ // Elements that are on diagonal -+ if (kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { -+ /* Copy real part from gmem, write zero for imag part in smem */ -+ int const kSizeRealBytes = (platform::is_same>::value) ? 8 : 4; -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, true); -+ cutlass::arch::cp_async_diag( -+ reinterpret_cast (dst_ptr + v) + kSizeRealBytes); -+ } else { -+ /* Write one (1) directly to smem*/ -+ cutlass::arch::cp_async_diag(dst_ptr + v); -+ } -+ } else { -+ // Elements that are not of diagonal -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, isvalid); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ auto gmem_ptr = iterator_B.get(); -+ bool isvalid = iterator_B.valid(); -+ -+ if (isvalid && iterator_B.getOnDiag()) { -+ // Elements that are on diagonal -+ if (kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { -+ /* Copy real part from gmem, write zero for imag part in smem */ -+ int const kSizeRealBytes = (platform::is_same>::value) ? 8 : 4; -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, true); -+ cutlass::arch::cp_async_diag( -+ reinterpret_cast (dst_ptr + v) + kSizeRealBytes); -+ } else { -+ /* Write one (1) directly to smem*/ -+ cutlass::arch::cp_async_diag(dst_ptr + v); -+ } -+ } else { -+ // Elements that are not of diagonal -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, isvalid); -+ } -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // -+ // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels -+ // so that all accumulator elements outside the GEMM footprint are zero. -+ // -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); -+ -+ typename IteratorA::AccessType zero_A; -+ zero_A.clear(); -+ -+ last_smem_iterator_A.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_A.get()); -+ -+ *dst_ptr = zero_A; -+ -+ ++last_smem_iterator_A; -+ } -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); -+ typename IteratorB::AccessType zero_B; -+ -+ zero_B.clear(); -+ last_smem_iterator_B.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_B.get()); -+ -+ *dst_ptr = zero_B; -+ -+ ++last_smem_iterator_B; -+ } -+ } -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary -+ // accumulator and this temporary accumulator is added to the final -+ // accumulator once in every mainloop iteration. -+ plus plus_accum; -+ -+ FragmentC tmp_accum; -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ -+ tmp_accum.clear(); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ -+ warp_mma( -+ tmp_accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ tmp_accum -+ ); -+ -+ if (warp_mma_k == 0) { -+ accum = plus_accum(accum, tmp_accum); -+ tmp_accum.clear(); -+ } -+ } else { -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ } -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, -+ group_start_iteration_B); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ accum = plus_accum(accum, tmp_accum); -+ } -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h -new file mode 100644 -index 0000000..03055ee ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h -@@ -0,0 +1,865 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+ -+ It loads two loop invariant vectors, mean and var, in the prologue and -+ stores them in the register file. In the mainloop, it loads two loop -+ variant vectors, gamma and beta, by using cp.async. We will call -+ elementwise operation to apply var, mean, gamma, beta between ldmatrix and -+ warp mma. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/gemm/warp/layernorm_scale_bias_transform.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Element type of scale and bias vectors -+ typename ElementScaleBias_, -+ /// Layout of scale and bias vectors -+ typename LayoutScaleBias_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// WarpIterator to load Scale or Bias vector from the shared memory -+ typename WarpIteratorGammaBeta_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaMainloopFusionBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Element type of scale and bias vectors -+ using ElementScaleBias = ElementScaleBias_; -+ -+ /// Layout of scale and bias vectors -+ using LayoutScaleBias = LayoutScaleBias_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ ///< WarpIterator to load Scale or Bias vector from the shared memory -+ using WarpIteratorGammaBeta = WarpIteratorGammaBeta_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = cutlass::gemm::GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the scale and bias vectors -+ using TensorRefGammaBeta = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the A scale and bias vectors in shared memory -+ using ShapeGammaBeta = -+ MatrixShape<1 + Policy::SmemPaddingA::kRow, -+ 2 * Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ /// Buffer for A operand Scale and Bias -+ AlignedBuffer operand_A_gamma_beta; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a layout object for the A scale and bias vectors -+ CUTLASS_DEVICE -+ static LayoutScaleBias LayoutScaleBias() { -+ return LayoutScaleBias::packed( -+ {ShapeGammaBeta::kRow, ShapeGammaBeta::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ -+ /// Returns a TensorRef to the A operand Scale vector -+ CUTLASS_HOST_DEVICE -+ TensorRefGammaBeta operand_A_gamma_beta_ref() { -+ return TensorRefGammaBeta{operand_A_gamma_beta.data(), LayoutScaleBias()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of A operand scale and bias vector -+ /// from shared memory -+ WarpIteratorGammaBeta warp_tile_iterator_A_gamma_beta_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaMainloopFusionBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_A_gamma_beta_( -+ shared_storage.operand_A_gamma_beta_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -+}; -+ -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Iterates over vectors of var and mean vector in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorVarMean_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorGammaBeta_, -+ /// Iterates over vectors of scale and bias vector in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorGammaBeta_, -+ /// Cache operation for scale/bias operand -+ cutlass::arch::CacheOperation::Kind CacheOpGammaBeta, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// WarpIterator to load Scale or Bias vector from the shared memory -+ typename WarpIteratorGammaBeta_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaLayernormMainloopFusionMultistage : -+ public MmaMainloopFusionBase { -+public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Iterates over tiles of the var and mean vectors in global memory -+ using IteratorVarMean = IteratorVarMean_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorGammaBeta = IteratorGammaBeta_; -+ ///< WarpIterator to load Scale or Bias vector from the shared memory -+ using WarpIteratorGammaBeta = WarpIteratorGammaBeta_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ ///< Base class -+ using Base = MmaMainloopFusionBase; -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ using SmemIteratorGammaBeta = SmemIteratorGammaBeta_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpGammaBeta = -+ CacheOpGammaBeta; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ using WarpLoadedFragmentVarMean = typename IteratorVarMean::Fragment; -+ using WarpLoadedFragmentGammaBeta = -+ typename WarpIteratorGammaBeta::Fragment; -+ -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of A operand scale vector to shared memory -+ SmemIteratorGammaBeta smem_iterator_A_gamma_beta_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ int warp_idx_m_; -+ -+ int warp_idx_n_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaLayernormMainloopFusionMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_A_gamma_beta_(shared_storage.operand_A_gamma_beta_ref(), -+ thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; -+ warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_A_gamma_beta_.add_tile_offset( -+ {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, -+ IteratorGammaBeta &iterator_A_gamma_beta, -+ IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ // Async Copy for operand A scale and bias vector. Scale and bias vectors -+ // are small. One iteration is enough. -+ if (group_start_A == 0) { -+ typename IteratorGammaBeta::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_gamma_beta_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorGammaBeta::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async( -+ dst_ptr, iterator_A_gamma_beta.get(), iterator_A_gamma_beta.valid()); -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< iterator over B operand in global memory -+ IteratorVarMean iterator_var_mean, -+ ///< iterator over scale and bias vectors in global memory -+ IteratorGammaBeta iterator_A_gamma_beta, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // -+ // Prologue -+ // -+ // Issue several complete stages -+ -+ WarpLoadedFragmentVarMean warp_loaded_frag_var_mean; -+ iterator_var_mean.add_tile_offset({0, warp_idx_m_}); -+ iterator_var_mean.load(warp_loaded_frag_var_mean); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ // Async Copy for operand A scale and bias vectors. Scale and bias -+ // vectors are small. One iteration is enough. -+ { -+ typename IteratorGammaBeta::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_gamma_beta_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorGammaBeta::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async( -+ dst_ptr, iterator_A_gamma_beta.get(), iterator_A_gamma_beta.valid()); -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_A_gamma_beta.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_A_gamma_beta_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpLoadedFragmentGammaBeta warp_loaded_frag_A_gamma_beta[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ cutlass::gemm::warp::LayernormScaleBiasTransform -+ elementwise_transform; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_A_gamma_beta_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_A_gamma_beta_.load( -+ warp_loaded_frag_A_gamma_beta[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_A_gamma_beta_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ elementwise_transform(warp_transformed_frag_A[0], -+ warp_loaded_frag_var_mean, -+ warp_loaded_frag_A_gamma_beta[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_A_gamma_beta_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_A_gamma_beta_.load( -+ warp_loaded_frag_A_gamma_beta[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_A_gamma_beta_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) { -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_var_mean, -+ warp_loaded_frag_A_gamma_beta[warp_mma_k % 2]); -+ } -+ -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_A_gamma_beta, iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_A_gamma_beta, iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_A_gamma_beta.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_A_gamma_beta_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_A_gamma_beta_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_A_gamma_beta_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ elementwise_transform( -+ warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_var_mean, -+ warp_loaded_frag_A_gamma_beta[(warp_mma_k + 1) % 2]); -+ } -+ } -+ -+ } -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_multistage.h -new file mode 100644 -index 0000000..5f6f852 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_multistage.h -@@ -0,0 +1,746 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaMultistage : -+ public MmaBase { -+public: -+ ///< Base class -+ using Base = MmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical -+ // accuracy, where each mainloop iteration first accumulates into a temporary -+ // set of freshly-cleared accumulators, which are subsequently added to the -+ // final accumulator set. -+ static bool const kStagedAccumulation = -+ platform::is_same::value || -+ platform::is_same::value; -+ -+ }; -+ -+ private: -+ -+ -+ // Structure encapsulating pipeline state live from one iteration to the next -+ struct PipeState { -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ /// Temporary accumulator to facilitate staged-accumulation -+ FragmentC tmp_accum_; -+ -+ /// Pair of A fragments used to overlap shared memory loads and math instructions -+ WarpLoadedFragmentA warp_loaded_frag_A_[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A_[2]; -+ -+ /// Pair of B fragments used to overlap shared memory loads and math instructions -+ WarpLoadedFragmentB warp_loaded_frag_B_[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B_[2]; -+ }; -+ -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Warp-level MMA operator -+ Operator warp_mma_; -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ /// Shared memory write stage index -+ int smem_write_stage_idx_; -+ -+ /// Shared memory read stage index -+ int smem_read_stage_idx_; -+ -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), -+ smem_write_stage_idx_(0), -+ smem_read_stage_idx_(0) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ /// Advance shared memory read-iterators to the next stage -+ CUTLASS_DEVICE -+ void advance_smem_read_stage() -+ { -+ ++smem_read_stage_idx_; -+ -+ if (smem_read_stage_idx_ == Base::kStages) { -+ // Wrap back around to the 'start' of the circular buffer in shared memory -+ this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); -+ smem_read_stage_idx_ = 0; -+ } -+ } -+ -+ /// Advance global memory read-iterators and shared memory write-iterators to the stage -+ CUTLASS_DEVICE -+ void advance_smem_write_stage( -+ IteratorA &iterator_A, -+ IteratorB &iterator_B) -+ { -+ // Advance global iterators -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ // Advance shared iterators -+ smem_iterator_A_.add_tile_offset({0, 1}); -+ smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Increment shared memory write stage index -+ ++smem_write_stage_idx_; -+ -+ if (smem_write_stage_idx_ == Base::kStages) { -+ // Wrap back around to the 'start' of the circular buffer in shared memory -+ smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx_ = 0; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching -+ /// the global fragments needed by the first kStages-1 threadblock mainloop iterations -+ CUTLASS_DEVICE -+ void prologue( -+ IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory -+ IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory -+ int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining -+ { -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { -+ -+ // Disable global fetching if done with global fetch iterations -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next write stage -+ advance_smem_write_stage(iterator_A, iterator_B); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Optionally clear the remaining stages of SMEM. This is a functional requirement for -+ // some kernels so that all accumulator elements outside the GEMM footprint are zero. -+ if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); -+ typename IteratorA::AccessType zero_A; -+ -+ zero_A.clear(); -+ last_smem_iterator_A.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_A.get()); -+ -+ *dst_ptr = zero_A; -+ -+ ++last_smem_iterator_A; -+ } -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); -+ typename IteratorB::AccessType zero_B; -+ -+ zero_B.clear(); -+ last_smem_iterator_B.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_B.get()); -+ -+ *dst_ptr = zero_B; -+ -+ ++last_smem_iterator_B; -+ } -+ } -+ } -+ -+ -+ /// Wait until we have at least one completed global fetch stage -+ CUTLASS_DEVICE -+ void gmem_wait() -+ { -+ // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ } -+ -+ -+ /// Perform a threadblock mainloop iteration of matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void mac_loop_iter( -+ PipeState &pipe_state, ///< [in|out] loop-carried pipeline state -+ FragmentC &accum, ///< [in|out] destination accumulator tile -+ IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory -+ IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory -+ int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining -+ { -+ // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load the next warp-tile's A fragment from shared memory -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); -+ ++this->warp_tile_iterator_A_; -+ -+ // Load the next warp-tile's B fragment from shared memory -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); -+ ++this->warp_tile_iterator_B_; -+ -+ // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary -+ if (warp_mma_k > 0) { -+ warp_mma_.transform( -+ pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], -+ pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], -+ pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], -+ pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); -+ } -+ -+ // Execute the current warp-tile of MMA operations -+ if (Detail::kStagedAccumulation) { -+ warp_mma_( -+ pipe_state.tmp_accum_, -+ pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], -+ pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], -+ pipe_state.tmp_accum_ -+ ); -+ -+ if (warp_mma_k == 0) { -+ plus plus_accum; -+ accum = plus_accum(accum, pipe_state.tmp_accum_); -+ pipe_state.tmp_accum_.clear(); -+ } -+ } else { -+ warp_mma_( -+ accum, -+ pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], -+ pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], -+ accum -+ ); -+ } -+ -+ // Except for the last warp-tile, all warp-tiles issue their share of -+ // global->shared fragment copies -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance( -+ iterator_A, -+ iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ // The second-to-last warp-tile also: -+ // - performs the last warp-tile's share of global->shared fragment copies -+ // - moves to the next global fetch stage -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ -+ // Performs the last warp-tile's share of global->shared fragment copies -+ int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance( -+ iterator_A, -+ iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Wait until we have at least one completed global fetch stage -+ gmem_wait(); -+ -+ // Move to the next global fetch stage -+ advance_smem_write_stage(iterator_A, iterator_B); -+ advance_smem_read_stage(); -+ -+ // Disable global fetching when done with global fetch iterations -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // The last warp-tile also converts the shared memory fragments used by -+ // the first warp-tile of the next iteration, if necessary (so we can -+ // immediately start issuing MMA instructions at the top of the loop ) -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ -+ warp_mma_.transform( -+ pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], -+ pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], -+ pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], -+ pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ } -+ -+ -+ /// Perform the specified number of threadblock mainloop iterations of matrix -+ /// multiply-accumulate. Assumes prologue has been initiated. -+ CUTLASS_DEVICE -+ void gemm_iters( -+ int gemm_k_iterations, ///< number of threadblock mainloop iterations -+ FragmentC &accum, ///< [in|out] accumulator tile -+ IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory -+ IteratorB &iterator_B) ///< [in|out] iterator over B operand in global memory -+ { -+ PipeState pipe_state; -+ -+ // Disable global fetching if done with global fetch iterations -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ // Load first warp-tile's A fragment from shared memory -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); -+ ++this->warp_tile_iterator_A_; -+ -+ // Load first warp-tile's B fragment from shared memory -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); -+ ++this->warp_tile_iterator_B_; -+ -+ // Transform, if necessary, the first warp-tile's shared memory fragments -+ warp_mma_.transform( -+ pipe_state.warp_transformed_frag_A_[0], -+ pipe_state.warp_transformed_frag_B_[0], -+ pipe_state.warp_loaded_frag_A_[0], -+ pipe_state.warp_loaded_frag_B_[0]); -+ -+ if (Detail::kStagedAccumulation) { -+ pipe_state.tmp_accum_.clear(); -+ } -+ -+ // Mainloop -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ mac_loop_iter( -+ pipe_state, -+ accum, -+ iterator_A, -+ iterator_B, -+ gemm_k_iterations); -+ } -+ -+ if (Detail::kStagedAccumulation) { -+ plus plus_accum; -+ accum = plus_accum(accum, pipe_state.tmp_accum_); -+ } -+ -+ // Optionally commit and drain all pending and predicated cp.async pnz from the GEMM mainloop -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ -+ } -+ -+ -+ /// Prepares the class for another prologue. -+ CUTLASS_DEVICE -+ void wind_down() -+ { -+ // Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue) -+ -+ // First, increment remaining warp tiles to get to the next full stage. (Ideally we would -+ // just decrement one tile, but not all iterators implement --() decrement.) -+ #pragma unroll -+ for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) -+ { -+ this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); -+ this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ } -+ smem_read_stage_idx_++; -+ -+ // Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators) -+ static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations; -+ if (smem_read_stage_idx_ > 1) -+ { -+ this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)}); -+ this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); -+ } -+ else -+ { -+ this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)}); -+ this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); -+ } -+ smem_read_stage_idx_ = smem_write_stage_idx_; -+ } -+ -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // Prologue (start fetching iterations of global fragments into shared memory) -+ prologue(iterator_A, iterator_B, gemm_k_iterations); -+ -+ // Wait until we have at least one completed global fetch stage -+ gmem_wait(); -+ -+ // Initialize destination accumulators with source accumulators -+ accum = src_accum; -+ -+ // Perform the MAC-iterations -+ gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h -new file mode 100644 -index 0000000..8ada21c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h -@@ -0,0 +1,439 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to A operand -+ typename TransformA_ = NumericArrayConverter< -+ typename SmemIteratorA_::Element, -+ typename IteratorA_::Element, -+ IteratorA_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B operand -+ typename TransformB_ = NumericArrayConverter< -+ typename SmemIteratorB_::Element, -+ typename IteratorB_::Element, -+ IteratorB_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaPipelined : public MmaBase { -+public: -+ -+ ///< Base class -+ using Base = MmaBase; -+ -+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Warp-level MMA operator -+ Operator warp_mma; -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ ///< transformation applied to A fragment -+ TransformA transform_A_; -+ -+ ///< transformation applied to B fragment -+ TransformB transform_B_; -+ -+ /// Shared memory write stage index -+ int smem_write_stage_idx; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaPipelined( -+ typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx, ///< ID of each thread within a warp -+ TransformA transform_A = TransformA(), ///< transformation applied to A fragment -+ TransformB transform_B = TransformB() ///< transformation applied to B fragment -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), -+ transform_A_(transform_A), -+ transform_B_(transform_B), -+ smem_write_stage_idx(0) -+ { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ -+ /// Advance shared memory write-iterators to the next stage -+ CUTLASS_DEVICE -+ void advance_smem_write_stage() -+ { -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ /// Advance shared memory read- and write-iterators to the next stage -+ CUTLASS_DEVICE -+ void advance_smem_stages() -+ { -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ // wrap write stage -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else -+ { -+ // wrap read stage -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ -+ /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching -+ /// the global fragments needed by the first kStages-1 threadblock mainloop iterations -+ CUTLASS_DEVICE -+ void prologue( -+ IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory -+ IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory -+ int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining -+ { -+ // The last kblock is loaded in the prolog -+ -+ // Load A fragment from global A -+ FragmentA tb_frag_A; -+ tb_frag_A.clear(); -+ iterator_A.load(tb_frag_A); -+ ++iterator_A; -+ -+ // Load B fragment from global B -+ FragmentB tb_frag_B; -+ tb_frag_B.clear(); -+ iterator_B.load(tb_frag_B); -+ ++iterator_B; -+ -+ // Store A and B fragments to shared -+ this->smem_iterator_A_.store(transform_A_(tb_frag_A)); -+ this->smem_iterator_B_.store(transform_B_(tb_frag_B)); -+ -+ // Advance write stage -+ advance_smem_write_stage(); -+ } -+ -+ /// Wait until we have at least one completed global fetch stage -+ CUTLASS_DEVICE -+ void gmem_wait() -+ { -+ __syncthreads(); -+ } -+ -+ -+ /// Perform the specified number of threadblock mainloop iterations of matrix -+ /// multiply-accumulate. Assumes prologue has been initiated. -+ CUTLASS_DEVICE -+ void gemm_iters( -+ int gemm_k_iterations, ///< number of threadblock mainloop iterations -+ FragmentC &accum, ///< [in|out] accumulator tile -+ IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory -+ IteratorB &iterator_B) ///< [in|out] iterator over B operand in global memory -+ { -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA warp_frag_A[2]; -+ WarpFragmentB warp_frag_B[2]; -+ -+ // Load A fragment from shared A -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_A_.load(warp_frag_A[0]); -+ ++this->warp_tile_iterator_A_; -+ -+ // Load B fragment from shared B -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.load(warp_frag_B[0]); -+ ++this->warp_tile_iterator_B_; -+ -+ // Pair of fragments used to overlap global memory loads and math instructions; -+ FragmentA tb_frag_A; -+ FragmentB tb_frag_B; -+ -+ // Avoid reading out of bounds -+ iterator_A.clear_mask(gemm_k_iterations <= 1); -+ iterator_B.clear_mask(gemm_k_iterations <= 1); -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A_(tb_frag_A)); -+ -+ this->smem_iterator_B_.store(transform_B_(tb_frag_B)); -+ -+ // Wait until we have at least one completed global fetch stage -+ gmem_wait(); -+ -+ // Advance smem read and write stages -+ advance_smem_stages(); -+ } -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ -+ // Load fragment from global A -+ tb_frag_A.clear(); -+ iterator_A.load(tb_frag_A); -+ ++iterator_A; -+ -+ // Load fragment from global B -+ tb_frag_B.clear(); -+ iterator_B.load(tb_frag_B); -+ ++iterator_B; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A.clear_mask(gemm_k_iterations <= 2); -+ iterator_B.clear_mask(gemm_k_iterations <= 2); -+ } -+ -+ warp_mma( -+ accum, -+ warp_frag_A[warp_mma_k % 2], -+ warp_frag_B[warp_mma_k % 2], -+ accum); -+ } -+ } -+ -+ } -+ -+ -+ /// Prepares the class for another prologue. -+ CUTLASS_DEVICE -+ void wind_down() -+ { -+ // First, increment remaining warp tiles to catch it up with the write stage. -+ #pragma unroll -+ for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) -+ { -+ this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); -+ this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ } -+ -+ // If we bumped the read iterators to the end of the circular buffer, wrap them around to -+ // align them with the write iterators -+ if (smem_write_stage_idx == 0) -+ { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC &accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const &src_accum) ///< source accumulator tile -+ { -+ // Prologue -+ prologue(iterator_A, iterator_B, gemm_k_iterations); -+ -+ // Wait until we have at least one completed global fetch stage -+ gmem_wait(); -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Perform the MAC-iterations -+ gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h -new file mode 100644 -index 0000000..d21600e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h -@@ -0,0 +1,208 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaPlanarComplexBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Stride to the imaginary part of the A operand -+ static int const kImaginaryStrideA = ShapeA::kCount; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ /// Stride to the imaginary part of the A operand -+ static int const kImaginaryStrideB = ShapeB::kCount; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaPlanarComplexBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h -new file mode 100644 -index 0000000..b7edd51 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h -@@ -0,0 +1,640 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/array_planar_complex.h" -+#include "cutlass/functional.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_planar_complex_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Transformation applied to A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Transformation applied to B -+ ComplexTransform TransformB = ComplexTransform::kNone -+> -+class MmaPlanarComplexMultistage : -+ public MmaPlanarComplexBase { -+public: -+ ///< Base class -+ using Base = MmaPlanarComplexBase; -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ ///< Archtecture tag -+ using ArchTag = arch::Sm80; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ /// Transformation applied to A -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Transformation applied to B -+ static ComplexTransform const kTransformB = TransformB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = ArrayPlanarComplex< -+ typename Policy::Operator::FragmentC::Element, -+ Policy::Operator::FragmentC::kElements -+ >; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const TBLoadIterationsA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const TBLoadIterationsB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ static int const kAccessesPerGroupA = -+ (TBLoadIterationsA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ static int const kAccessesPerGroupB = -+ (TBLoadIterationsB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaPlanarComplexMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance( -+ IteratorA &iterator_A_real, -+ IteratorA &iterator_A_imag, -+ -+ IteratorB &iterator_B_real, -+ IteratorB &iterator_B_imag, -+ -+ int group_start_A = 0, -+ int group_start_B = 0) { -+ -+ iterator_A_real.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); -+ iterator_A_imag.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Load for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ -+ auto gmem_ptr_real = iterator_A_real.get(); -+ auto gmem_ptr_imag = iterator_A_imag.get(); -+ -+ bool pred_guard = iterator_A_real.valid(); -+ cutlass::arch::cp_async( -+ dst_ptr + v, -+ gmem_ptr_real, -+ pred_guard); -+ cutlass::arch::cp_async( -+ dst_ptr + v + (Base::SharedStorage::kImaginaryStrideA / IteratorA::ThreadMap::kElementsPerAccess), -+ reinterpret_cast(gmem_ptr_imag), -+ pred_guard); -+ -+ ++iterator_A_real; -+ ++iterator_A_imag; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B_real.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); -+ iterator_B_imag.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr_real = iterator_B_real.get(); -+ auto gmem_ptr_imag = iterator_B_imag.get(); -+ -+ bool pred_guard = iterator_B_real.valid(); -+ cutlass::arch::cp_async( -+ dst_ptr + v, -+ gmem_ptr_real, -+ pred_guard); -+ cutlass::arch::cp_async( -+ dst_ptr + v + (Base::SharedStorage::kImaginaryStrideB / IteratorB::ThreadMap::kElementsPerAccess), -+ reinterpret_cast(gmem_ptr_imag), -+ pred_guard); -+ -+ ++iterator_B_real; -+ ++iterator_B_imag; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void warp_mma_planar_complex( -+ Operator & warp_mma, -+ FragmentC &accum, -+ WarpFragmentA const & real_A, -+ WarpFragmentA const & imag_A, -+ WarpFragmentB const & real_B, -+ WarpFragmentB const & imag_B) { -+ -+ cutlass::negate> neg_op_B; -+ -+ WarpFragmentB neg_real_B = neg_op_B(real_B); -+ WarpFragmentB neg_imag_B = neg_op_B(imag_B); -+ -+ warp_mma(accum.real, real_A, real_B, accum.real); -+ -+ if (kTransformB == ComplexTransform::kNone) { -+ warp_mma(accum.imag, real_A, imag_B, accum.imag); -+ } -+ else { -+ warp_mma(accum.imag, real_A, neg_imag_B, accum.imag); -+ } -+ -+ if (kTransformA == ComplexTransform::kNone) { -+ warp_mma(accum.imag, imag_A, real_B, accum.imag); -+ } -+ else { -+ warp_mma(accum.imag, imag_A, neg_real_B, accum.imag); -+ } -+ -+ if (kTransformA == ComplexTransform::kNone ^ kTransformB == ComplexTransform::kNone) { -+ warp_mma(accum.real, imag_A, imag_B, accum.real); -+ } -+ else { -+ warp_mma(accum.real, imag_A, neg_imag_B, accum.real); -+ } -+ } -+ -+public: -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A_real, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A_imag, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B_real, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B_imag, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A_real.clear_mask(gemm_k_iterations == 0); -+ iterator_A_imag.clear_mask(gemm_k_iterations == 0); -+ iterator_B_real.clear_mask(gemm_k_iterations == 0); -+ iterator_B_imag.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A_real.set_iteration_index(0); -+ iterator_A_imag.set_iteration_index(0); -+ -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Load for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsA; ++j) { -+ -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; -+ -+ bool pred_guard = iterator_A_real.valid(); -+ -+ auto src_ptr_real = iterator_A_real.get(); -+ auto src_ptr_imag = iterator_A_imag.get(); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, src_ptr_real, pred_guard); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v + -+ Base::SharedStorage::kImaginaryStrideA / -+ IteratorA::ThreadMap::kElementsPerAccess, -+ reinterpret_cast(src_ptr_imag), -+ pred_guard); -+ -+ ++iterator_A_real; -+ ++iterator_A_imag; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B_real.set_iteration_index(0); -+ iterator_B_imag.set_iteration_index(0); -+ -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB; ++j) { -+ -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; -+ -+ bool pred_guard = iterator_B_real.valid(); -+ -+ auto src_ptr_real = iterator_B_real.get(); -+ auto src_ptr_imag = iterator_B_imag.get(); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, src_ptr_real, pred_guard); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v + -+ Base::SharedStorage::kImaginaryStrideB / -+ IteratorB::ThreadMap::kElementsPerAccess, -+ reinterpret_cast(src_ptr_imag), -+ pred_guard); -+ -+ ++iterator_B_real; -+ ++iterator_B_imag; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A_real.add_tile_offset({0, 1}); -+ iterator_A_imag.add_tile_offset({0, 1}); -+ -+ iterator_B_real.add_tile_offset({1, 0}); -+ iterator_B_imag.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Inserts a memory fence between stages of cp.async instructions -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Blocks until all but kStages-2 cp.async stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ -+ WarpFragmentA warp_frag_real_A[2]; -+ WarpFragmentA warp_frag_imag_A[2]; -+ -+ WarpFragmentB warp_frag_real_B[2]; -+ WarpFragmentB warp_frag_imag_B[2]; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_real_A[0]); -+ this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[0], Base::SharedStorage::kImaginaryStrideA); -+ -+ this->warp_tile_iterator_B_.load(warp_frag_real_B[0]); -+ this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[0], Base::SharedStorage::kImaginaryStrideB); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A_real.clear_mask(gemm_k_iterations == 0); -+ iterator_A_imag.clear_mask(gemm_k_iterations == 0); -+ iterator_B_real.clear_mask(gemm_k_iterations == 0); -+ iterator_B_imag.clear_mask(gemm_k_iterations == 0); -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance(iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag); -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_real_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideA); -+ -+ this->warp_tile_iterator_B_.load(warp_frag_real_B[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideB); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ group_start_iteration_A = 0; -+ group_start_iteration_B = 0; -+ } -+ else { -+ group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ } -+ -+ copy_tiles_and_advance( -+ iterator_A_real, -+ iterator_A_imag, -+ iterator_B_real, -+ iterator_B_imag, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ // Inserts a memory fence between stages of cp.async instructions -+ cutlass::arch::cp_async_fence(); -+ -+ // Blocks until all but kStages-2 cp.async stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A_real.add_tile_offset({0, 1}); -+ iterator_A_imag.add_tile_offset({0, 1}); -+ -+ iterator_B_real.add_tile_offset({1, 0}); -+ iterator_B_imag.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A_real.clear_mask(gemm_k_iterations == 0); -+ iterator_A_imag.clear_mask(gemm_k_iterations == 0); -+ iterator_B_real.clear_mask(gemm_k_iterations == 0); -+ iterator_B_imag.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ warp_mma_planar_complex( -+ warp_mma, -+ accum, -+ warp_frag_real_A[warp_mma_k % 2], -+ warp_frag_imag_A[warp_mma_k % 2], -+ warp_frag_real_B[warp_mma_k % 2], -+ warp_frag_imag_B[warp_mma_k % 2]); -+ } -+ -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h -new file mode 100644 -index 0000000..160c548 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h -@@ -0,0 +1,424 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_planar_complex_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Transformation applied to A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Transformation applied to B -+ ComplexTransform TransformB = ComplexTransform::kNone -+> -+class MmaPlanarComplexPipelined : -+ public MmaPlanarComplexBase { -+public: -+ ///< Base class -+ using Base = MmaPlanarComplexBase; -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ /// Transformation applied to A -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Transformation applied to B -+ static ComplexTransform const kTransformB = TransformB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = ArrayPlanarComplex< -+ typename Policy::Operator::FragmentC::Element, -+ Policy::Operator::FragmentC::kElements -+ >; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ private: -+ -+ using FragmentA = typename IteratorA::Fragment; -+ using FragmentB = typename IteratorB::Fragment; -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaPlanarComplexPipelined( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ void warp_mma_planar_complex( -+ Operator & warp_mma, -+ FragmentC &accum, -+ WarpFragmentA const & real_A, -+ WarpFragmentA const & imag_A, -+ WarpFragmentB const & real_B, -+ WarpFragmentB const & imag_B) { -+ -+ cutlass::negate> neg_op_B; -+ -+ WarpFragmentB neg_real_B = neg_op_B(real_B); -+ WarpFragmentB neg_imag_B = neg_op_B(imag_B); -+ -+ warp_mma(accum.real, real_A, real_B, accum.real); -+ -+ if (kTransformB == ComplexTransform::kNone) { -+ warp_mma(accum.imag, real_A, imag_B, accum.imag); -+ } -+ else { -+ warp_mma(accum.imag, real_A, neg_imag_B, accum.imag); -+ } -+ -+ if (kTransformA == ComplexTransform::kNone) { -+ warp_mma(accum.imag, imag_A, real_B, accum.imag); -+ } -+ else { -+ warp_mma(accum.imag, imag_A, neg_real_B, accum.imag); -+ } -+ -+ if (kTransformA == ComplexTransform::kNone ^ kTransformB == ComplexTransform::kNone) { -+ warp_mma(accum.real, imag_A, imag_B, accum.real); -+ } -+ else { -+ warp_mma(accum.real, imag_A, neg_imag_B, accum.real); -+ } -+ } -+ -+public: -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A_real, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A_imag, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B_real, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B_imag, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentA tb_frag_A_real; -+ FragmentA tb_frag_A_imag; -+ -+ FragmentB tb_frag_B_real; -+ FragmentB tb_frag_B_imag; -+ -+ tb_frag_A_real.clear(); -+ tb_frag_A_imag.clear(); -+ -+ tb_frag_B_real.clear(); -+ tb_frag_B_imag.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A_real.load(tb_frag_A_real); -+ iterator_A_imag.load(tb_frag_A_imag); -+ -+ iterator_B_real.load(tb_frag_B_real); -+ iterator_B_imag.load(tb_frag_B_imag); -+ -+ ++iterator_A_real; -+ ++iterator_A_imag; -+ -+ ++iterator_B_real; -+ ++iterator_B_imag; -+ -+ this->smem_iterator_A_.store(tb_frag_A_real); -+ this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA); -+ -+ this->smem_iterator_B_.store(tb_frag_B_real); -+ this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA warp_frag_real_A[2]; -+ WarpFragmentA warp_frag_imag_A[2]; -+ -+ WarpFragmentB warp_frag_real_B[2]; -+ WarpFragmentB warp_frag_imag_B[2]; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_real_A[0]); -+ this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[0], Base::SharedStorage::kImaginaryStrideA); -+ -+ this->warp_tile_iterator_B_.load(warp_frag_real_B[0]); -+ this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[0], Base::SharedStorage::kImaginaryStrideB); -+ -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Avoid reading out of bounds -+ iterator_A_real.clear_mask(gemm_k_iterations <= 1); -+ iterator_A_imag.clear_mask(gemm_k_iterations <= 1); -+ -+ iterator_B_real.clear_mask(gemm_k_iterations <= 1); -+ iterator_B_imag.clear_mask(gemm_k_iterations <= 1); -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(tb_frag_A_real); -+ this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA); -+ -+ this->smem_iterator_B_.store(tb_frag_B_real); -+ this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_B_; -+ ++this->smem_iterator_A_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_real_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideA); -+ -+ this->warp_tile_iterator_B_.load(warp_frag_real_B[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideB); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A_real.load(tb_frag_A_real); -+ iterator_A_imag.load(tb_frag_A_imag); -+ -+ iterator_B_real.load(tb_frag_B_real); -+ iterator_B_imag.load(tb_frag_B_imag); -+ -+ ++iterator_A_real; -+ ++iterator_A_imag; -+ ++iterator_B_real; -+ ++iterator_B_imag; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A_real.clear_mask(gemm_k_iterations <= 2); -+ iterator_A_imag.clear_mask(gemm_k_iterations <= 2); -+ iterator_B_real.clear_mask(gemm_k_iterations <= 2); -+ iterator_B_imag.clear_mask(gemm_k_iterations <= 2); -+ } -+ -+ warp_mma_planar_complex( -+ warp_mma, -+ accum, -+ warp_frag_real_A[warp_mma_k % 2], -+ warp_frag_imag_A[warp_mma_k % 2], -+ warp_frag_real_B[warp_mma_k % 2], -+ warp_frag_imag_B[warp_mma_k % 2]); -+ } -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h -new file mode 100644 -index 0000000..3ce8ac8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h -@@ -0,0 +1,265 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaSingleStage : public MmaBase { -+public: -+ -+ ///< Base class -+ using Base = MmaBase; -+ -+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ using ArchTag = arch::Sm70; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for MmaSingleStage is 1 (single stage mma pipeline) -+ static_assert((Base::kStages==1), "MmaSingleStage requires kStages set to value 1"); -+private: -+ -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaSingleStage( -+ typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC &accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const &src_accum) { ///< source accumualtor tile -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentA tb_frag_A; -+ FragmentB tb_frag_B; -+ -+ tb_frag_A.clear(); -+ tb_frag_B.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA warp_frag_A; -+ WarpFragmentB warp_frag_B; -+ -+ Operator warp_mma; -+ -+ // Avoid reading out of bounds -+ iterator_A.clear_mask(gemm_k_iterations <= 1); -+ iterator_B.clear_mask(gemm_k_iterations <= 1); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ this->smem_iterator_A_.store(tb_frag_A); -+ this->smem_iterator_B_.store(tb_frag_B); -+ -+ __syncthreads(); -+ -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A); -+ this->warp_tile_iterator_B_.load(warp_frag_B); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ warp_mma(accum, warp_frag_A, warp_frag_B, accum); -+ } -+ -+ // Add negative offsets to return smem load iterators to the 'start' of the shared memory -+ this->warp_tile_iterator_A_.add_tile_offset({0, -Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); -+ -+ __syncthreads(); -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A.clear_mask(gemm_k_iterations <= 2); -+ iterator_B.clear_mask(gemm_k_iterations <= 2); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h -new file mode 100644 -index 0000000..905283e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h -@@ -0,0 +1,751 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+ -+ It loads two loop invariant vectors, norm and sum, in the prologue and -+ stores them in the register file. We will call elementwise operation to -+ apply norm and sum between ldmatrix and warp mma. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/gemm/warp/softmax_scale_bias_transform.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaMainloopFusionBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = cutlass::gemm::GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaMainloopFusionBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -+}; -+ -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Iterates over vectors of var and mean vector in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorNormSum_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Whether problem has been transformed. This determines to which operand -+ /// the softmax is applied. -+ bool InternalTranspose, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaSoftmaxMainloopFusionMultistage : -+ public MmaMainloopFusionBase { -+public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Iterates over tiles of the var and mean vectors in global memory -+ using IteratorNormSum = IteratorNormSum_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ ///< Base class -+ using Base = MmaMainloopFusionBase; -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ using WarpLoadedFragmentNormSum = typename IteratorNormSum::Fragment; -+ -+ static bool const kInternalTranspose = InternalTranspose; -+ -+ using SoftmaxFragment = typename platform::conditional::type; -+ -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ int warp_idx_m_; -+ -+ int warp_idx_n_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaSoftmaxMainloopFusionMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; -+ warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, -+ IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< iterator over B operand in global memory -+ IteratorNormSum iterator_norm_sum, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // -+ // Prologue -+ // -+ // Issue several complete stages -+ -+ WarpLoadedFragmentNormSum warp_loaded_frag_norm_sum; -+ iterator_norm_sum.add_tile_offset({0, warp_idx_m_}); -+ iterator_norm_sum.load(warp_loaded_frag_norm_sum); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ cutlass::gemm::warp::SoftmaxScaleBiasTransform< -+ SoftmaxFragment, WarpLoadedFragmentNormSum> elementwise_transform; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance(iterator_A, iterator_B); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ if (kInternalTranspose) { -+ elementwise_transform(warp_transformed_frag_B[0], -+ warp_loaded_frag_norm_sum); -+ } else { -+ elementwise_transform(warp_transformed_frag_A[0], -+ warp_loaded_frag_norm_sum); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) { -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ if (kInternalTranspose) { -+ elementwise_transform(warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_norm_sum); -+ } else { -+ elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_norm_sum); -+ } -+ } -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ group_start_iteration_A = 0; -+ group_start_iteration_B = 0; -+ } else { -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ } -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ if (kInternalTranspose) { -+ elementwise_transform(warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_norm_sum); -+ } else { -+ elementwise_transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_norm_sum); -+ } -+ } -+ } -+ -+ } -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h -new file mode 100644 -index 0000000..9f82a7f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h -@@ -0,0 +1,273 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy object describing MmaTensorOp -+template < -+ /// Warp-level GEMM operator (concept: gemm::warp::Mma) -+ typename Operator_, -+ /// Padding used for A operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingA_, -+ /// Padding used for B operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingB_, -+ /// Padding used for E operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingE_, -+ /// Number of partitions of K dimension of GEMM -+ int PartitionsK = 1> -+struct SparseMmaPolicy { -+ /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) -+ using Operator = Operator_; -+ -+ /// Padding used for A operand in shared memory -+ using SmemPaddingA = SmemPaddingA_; -+ -+ /// Padding used for B operand in shared memory -+ using SmemPaddingB = SmemPaddingB_; -+ -+ /// Padding used for B operand in shared memory -+ using SmemPaddingE = SmemPaddingE_; -+ -+ /// Number of partitions of K dimension -+ static int const kPartitionsK = PartitionsK; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class SparseMmaBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ static_assert(kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ static_assert((kWarpGemmIterations % 2) == 0, -+ "Inner loop iteration must be an even number."); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ static int const kSparse = Operator::kSparse; -+ -+ static int const kElementsPerElementE = Operator::kElementsPerElementE; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ /// Tensor reference to the E operand -+ using TensorRefE = TensorRef; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ /// Shape of the E matrix operand in shared memory -+ using ShapeE = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ /// Buffer for E operand -+ AlignedBuffer operand_E; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a layout object for the E matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutE LayoutE() { -+ return Operator::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ -+ /// Returns a TensorRef to the E operand -+ CUTLASS_HOST_DEVICE -+ TensorRefE operand_E_ref() { -+ return TensorRefE{operand_E.data(), LayoutE()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+ /// Iterator to load a warp-scoped tile of E operand from shared memory -+ typename Operator::IteratorE warp_tile_iterator_E_; -+ -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ SparseMmaBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), -+ warp_tile_iterator_E_(shared_storage.operand_E_ref(), lane_idx) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h -new file mode 100644 -index 0000000..beb58c8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h -@@ -0,0 +1,662 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_sparse_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Iterates over tiles of E operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorE_, -+ /// Iterates over tiles of E operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorE_, -+ /// Cache operation for operand E -+ cutlass::arch::CacheOperation::Kind CacheOpE, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class SparseMmaMultistage : -+ public SparseMmaBase { -+public: -+ ///< Base class -+ using Base = SparseMmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Iterates over tiles of E operand in global memory -+ using IteratorE = IteratorE_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ using SmemIteratorE = SmemIteratorE_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpE = CacheOpE; -+ -+ static int const kSparse = Policy::Operator::kSparse; -+ static int const kMetaSizeInBits = Policy::Operator::kMetaSizeInBits; -+ static int const kMaxID2 = Policy::Operator::kMaxID2; -+ static int const kElementsPerElementE = -+ Policy::Operator::kElementsPerElementE; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// ElementE -+ using ElementE = typename IteratorE::Element; -+ -+ /// LayoutE -+ using LayoutE = typename IteratorE::Layout; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of async copies to load one stage of operand A -+ static int const TBLoadIterationsA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of async copies to load one stage of operand B -+ static int const TBLoadIterationsB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of async copies to load one stage of operand E -+ static int const TBLoadIterationsE = -+ IteratorE::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of async copies to load one group of operand A -+ static int const kAccessesPerGroupA = -+ (TBLoadIterationsA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of async copies to load one group of operand B -+ static int const kAccessesPerGroupB = -+ (TBLoadIterationsB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of async copies to load one group of operand E -+ static int const kAccessesPerGroupE = -+ (TBLoadIterationsE + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// E operand is tiny. For the most of time, not all the warps are needed -+ /// to load it from the global memory. -+ static int const kValidWarps = IteratorE::ThreadMap::kThreads / 32; -+ -+ /// B operand is twice as big as A which brings very high register pressure. -+ /// We have to sacrifice the double buffer when the warp tile size is big. -+ static int const kBBufferSize = -+ ((sizeof(typename Operator::ElementC) == 4) && -+ ((platform::is_same::value && -+ platform::is_same::value)) && -+ (Operator::Shape::kM >= 64 && Operator::Shape::kN >= 64)) -+ ? 1 -+ : 2; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ using WarpFragmentE = typename Operator::FragmentE; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ /// Iterator to write threadblock-scoped tile of E operand to shared memory -+ SmemIteratorE smem_iterator_E_; -+ -+ /// Warp id -+ bool is_warp_valid_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ SparseMmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), -+ smem_iterator_E_(shared_storage.operand_E_ref(), thread_idx) -+ { -+ is_warp_valid_ = warp_idx < Detail::kValidWarps; -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ this->warp_tile_iterator_E_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, -+ IteratorE &iterator_E, int group_start_A = 0, -+ int group_start_B = 0, int group_start_E = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // async copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::TBLoadIterationsA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // async copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::TBLoadIterationsB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ -+ iterator_E.set_iteration_index(group_start_E); -+ this->smem_iterator_E_.set_iteration_index(group_start_E); -+ -+ // async copy for operand E -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupE; ++j) { -+ if (group_start_E + j < Detail::TBLoadIterationsE) { -+ typename IteratorE::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_E_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorE::ThreadMap::kElementsPerAccess / 8; -+ -+ auto gmem_ptr = iterator_E.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr, gmem_ptr, iterator_E.valid() && is_warp_valid_); -+ -+ ++iterator_E; -+ ++this->smem_iterator_E_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< iterator over E operand in global memory -+ IteratorE iterator_E, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ iterator_E.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // async copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // async copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ iterator_E.set_iteration_index(0); -+ this->smem_iterator_E_.set_iteration_index(0); -+ -+ // async copy for operand E -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsE; ++j) { -+ typename IteratorE::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_E_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorE::ThreadMap::kElementsPerAccess / 8; -+ if (is_warp_valid_) -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_E.get(), iterator_E.valid()); -+ -+ ++iterator_E; -+ -+ ++this->smem_iterator_E_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ iterator_E.add_tile_offset({0, 1}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ this->smem_iterator_E_.add_tile_offset({0, 1}); -+ -+ // cp.async.commit_group - completes a stage -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[Detail::kBBufferSize]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[Detail::kBBufferSize]; -+ WarpFragmentE warp_frag_E[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ this->warp_tile_iterator_E_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ this->warp_tile_iterator_E_.load(warp_frag_E[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ ++this->warp_tile_iterator_E_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ iterator_E.clear_mask(gemm_k_iterations == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_E_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_E_.load(warp_frag_E[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_E_; -+ -+ if (Detail::kBBufferSize == 2) { -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.load( -+ warp_loaded_frag_B[(warp_mma_k + 1) % Detail::kBBufferSize]); -+ ++this->warp_tile_iterator_B_; -+ } -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % Detail::kBBufferSize], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % Detail::kBBufferSize]); -+ -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % Detail::kBBufferSize], accum, -+ warp_frag_E[warp_mma_k % 2] -+ ); -+ -+ if (Detail::kBBufferSize == 1) { -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ ++this->warp_tile_iterator_B_; -+ -+ } -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B, group_start_iteration_E; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ group_start_iteration_E = warp_mma_k * Detail::kAccessesPerGroupE; -+ -+ copy_tiles_and_advance( -+ iterator_A, iterator_B, iterator_E, group_start_iteration_A, -+ group_start_iteration_B, group_start_iteration_E); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ int group_start_iteration_A, group_start_iteration_B, group_start_iteration_E; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ group_start_iteration_E = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupE; -+ -+ copy_tiles_and_advance( -+ iterator_A, iterator_B, iterator_E, group_start_iteration_A, -+ group_start_iteration_B, group_start_iteration_E); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ iterator_E.add_tile_offset({0, 1}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ this->smem_iterator_E_.add_tile_offset({0, 1}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ this->smem_iterator_E_.add_tile_offset({0, -Base::kStages}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ this->warp_tile_iterator_E_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ iterator_E.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h -new file mode 100644 -index 0000000..fb0e92e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h -@@ -0,0 +1,547 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaWithReductionMultistage : -+ public MmaBase { -+public: -+ ///< Base class -+ using Base = MmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ using FragmentReduction = typename Operator::FragmentReduction; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ static int const kReduceKForA = Operator::kReduceKForA; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaWithReductionMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< initial value of accumulator -+ FragmentC const &src_accum, -+ FragmentReduction &gemm_k_reduction_accum) { -+ -+ // -+ // Prologue -+ // -+ // Issue several complete stages -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum, -+ gemm_k_reduction_accum -+ ); -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, -+ group_start_iteration_B); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h -new file mode 100644 -index 0000000..48c1737 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h -@@ -0,0 +1,459 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implements several possible threadblock-swizzling functions mapping blockIdx to -+ GEMM problems. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/platform/platform.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/gemm/threadblock/index_remat.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle_streamk.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for GEMMs -+template -+struct GemmIdentityThreadblockSwizzle { -+ -+ CUTLASS_HOST_DEVICE -+ GemmIdentityThreadblockSwizzle() { } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// *Gemm* problem size: gemm(M, N, K) -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ GemmCoord problem_size, -+ GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ return GemmCoord( -+ (problem_size.m() + tile_size.m() - 1) / tile_size.m(), -+ (problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ split_k_slices); -+ } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// *ImplicitGemm* Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ gemm::GemmCoord implicit_gemm_problem_size = -+ cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); -+ -+ return get_tiled_shape( -+ implicit_gemm_problem_size, tile_size, split_k_slices); -+ } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// *ImplicitGemm* Conv3d problem size: conv_operator(NZPQK, NDHWC, KTRSC) -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv3dProblemSize const &problem_size, -+ GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ gemm::GemmCoord implicit_gemm_problem_size = -+ cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); -+ -+ return get_tiled_shape( -+ implicit_gemm_problem_size, tile_size, split_k_slices); -+ } -+ -+ /// Computes CUDA grid dimensions given a size in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ dim3 get_grid_shape(GemmCoord tiled_shape) const { -+ int tile = 1 << get_log_tile(tiled_shape); -+ return dim3(tiled_shape.m() * tile, (tiled_shape.n() + tile - 1) / tile, tiled_shape.k()); -+ } -+ -+ /// Calculates optimal swizzle width -+ CUTLASS_HOST_DEVICE -+ int get_log_tile(GemmCoord tiled_shape) const { -+ auto n = tiled_shape.n(); -+ // Thresholds picked so that it doesn't cause too many no-op CTAs -+ if (N >= 8 && n >= 6) -+ return 3; -+ else if (N >= 4 && n >= 3) -+ return 2; -+ else if (N >= 2 && n >= 2) -+ return 1; -+ else -+ return 0; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(int log_tile) const { -+ int block_idx_x = RematerializeBlockIdxX(); -+ int block_idx_y = RematerializeBlockIdxY(); -+ int block_idx_z = RematerializeBlockIdxZ(); -+ -+ return GemmCoord{(block_idx_x >> log_tile), // -+ (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), -+ block_idx_z}; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(GemmCoord tiled_shape) const { -+ -+ int const kTile = N; -+ int block_idx_x = RematerializeBlockIdxX(); -+ int block_idx_y = RematerializeBlockIdxY(); -+ -+ if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile)) -+ return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()}; -+ -+ return GemmCoord{ -+ (block_idx_x / kTile), -+ (block_idx_y * kTile) + (block_idx_x % kTile), -+ RematerializeBlockIdxZ() -+ }; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for GEMMs -+struct GemmHorizontalThreadblockSwizzle { -+ -+ CUTLASS_HOST_DEVICE -+ GemmHorizontalThreadblockSwizzle() { } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ GemmCoord problem_size, -+ GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ return GemmCoord( -+ (problem_size.m() + tile_size.m() - 1) / tile_size.m(), -+ (problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ split_k_slices); -+ } -+ -+ /// Computes CUDA grid dimensions given a size in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ dim3 get_grid_shape(GemmCoord tiled_shape) const { -+ return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k()); -+ } -+ -+ /// Calculates optimal swizzle width -+ CUTLASS_HOST_DEVICE -+ int get_log_tile(GemmCoord tiled_shape) const { -+ return 0; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(GemmCoord tiled_shape) const { -+ return GemmCoord{ -+ RematerializeBlockIdxY(), -+ RematerializeBlockIdxX(), -+ RematerializeBlockIdxZ() -+ }; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for batched GEMMs -+struct GemmBatchedIdentityThreadblockSwizzle { -+ -+ /// Returns the shape of the problem in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ GemmCoord problem_size, -+ GemmCoord tile_size, -+ int batch_count) const { -+ -+ return GemmCoord( -+ (problem_size.m() + tile_size.m() - 1) / tile_size.m(), -+ (problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ batch_count % (1 << 16)); -+ } -+ -+ /// Computes CUDA grid dimensions given a size in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ dim3 get_grid_shape(GemmCoord tiled_shape) const { -+ return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); -+ } -+ -+ /// Calculates optimal swizzle width -+ CUTLASS_HOST_DEVICE -+ int get_log_tile(GemmCoord tiled_shape) const { -+ return 0; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(GemmCoord tiled_shape) const { -+ return GemmCoord{ -+ RematerializeBlockIdxX(), -+ RematerializeBlockIdxY(), -+ RematerializeBlockIdxZ() -+ }; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(int log_tile) const { -+ int block_idx_x = RematerializeBlockIdxX(); -+ int block_idx_y = RematerializeBlockIdxY(); -+ int block_idx_z = RematerializeBlockIdxZ(); -+ -+ return GemmCoord{(block_idx_x >> log_tile), // -+ (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), -+ block_idx_z}; -+ } -+ -+ /// Gets the batch index -+ CUTLASS_DEVICE -+ int get_batch_idx() const { -+ return RematerializeBlockIdxZ(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for split-K GEMMs -+template -+struct GemmSplitKIdentityThreadblockSwizzle { -+ -+ int const kTile = N; -+ -+ /// Returns the shape of the problem in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ GemmCoord problem_size, -+ GemmCoord tile_size, -+ int partitions) const { -+ -+ return GemmCoord( -+ (problem_size.m() + tile_size.m() - 1) / tile_size.m(), -+ (problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ partitions); -+ } -+ -+ /// Calculates optimal swizzle width -+ CUTLASS_HOST_DEVICE -+ int get_log_tile(GemmCoord tiled_shape) const { -+ auto n = tiled_shape.n(); -+ // Thresholds picked so that it doesn't cause too many no-op CTAs -+ if (N >= 8 && n >= 6) -+ return 3; -+ else if (N >= 4 && n >= 3) -+ return 2; -+ else if (N >= 2 && n >= 2) -+ return 1; -+ else -+ return 0; -+ } -+ -+ /// Computes CUDA grid dimensions given a size in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ dim3 get_grid_shape(GemmCoord tiled_shape) const { -+ int tile = 1 << get_log_tile(tiled_shape); -+ return dim3(tiled_shape.m() * tile, (tiled_shape.n() + tile - 1) / tile, tiled_shape.k()); -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(int log_tile) const { -+ int block_idx_x = RematerializeBlockIdxX(); -+ int block_idx_y = RematerializeBlockIdxY(); -+ int block_idx_z = RematerializeBlockIdxZ(); -+ -+ return GemmCoord{(block_idx_x >> log_tile), // -+ (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), -+ block_idx_z}; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(GemmCoord tiled_shape) const { -+ -+ int const kTile = N; -+ int block_idx_x = RematerializeBlockIdxX(); -+ int block_idx_y = RematerializeBlockIdxY(); -+ -+ if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile)) -+ return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()}; -+ -+ return GemmCoord{ -+ (block_idx_x / kTile), -+ (block_idx_y * kTile) + (block_idx_x % kTile), -+ RematerializeBlockIdxZ() -+ }; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for split-K GEMMs -+struct GemmSplitKHorizontalThreadblockSwizzle { -+ -+ /// Returns the shape of the problem in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ GemmCoord problem_size, -+ GemmCoord tile_size, -+ int partitions) const { -+ -+ return GemmCoord( -+ (problem_size.m() + tile_size.m() - 1) / tile_size.m(), -+ (problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ partitions); -+ } -+ -+ /// Computes CUDA grid dimensions given a size in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ dim3 get_grid_shape(GemmCoord tiled_shape) const { -+ return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k()); -+ } -+ -+ /// Calculates optimal swizzle width -+ CUTLASS_HOST_DEVICE -+ int get_log_tile(GemmCoord tiled_shape) const { -+ return 0; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(int log_tile) const { -+ return GemmCoord{ -+ RematerializeBlockIdxY(), -+ RematerializeBlockIdxX(), -+ RematerializeBlockIdxZ() -+ }; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(GemmCoord tiled_shape) const { -+ return GemmCoord{ -+ RematerializeBlockIdxY(), -+ RematerializeBlockIdxX(), -+ RematerializeBlockIdxZ() -+ }; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for batched GEMVs -+struct GemvBatchedStridedThreadblockDefaultSwizzle { -+ -+ /// Returns the shape of the problem in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord get_tiled_shape( -+ BatchedGemmCoord problem_size, -+ BatchedGemmCoord tile_size) const { -+ -+ return BatchedGemmCoord( -+ 1, // M is always 1 -+ (problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ (problem_size.k() + tile_size.k() - 1) / tile_size.k(), -+ (problem_size.batch() + tile_size.batch() - 1) / tile_size.batch()); -+ } -+ -+ /// Computes CUDA grid dimensions given a size in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ dim3 get_grid_shape(BatchedGemmCoord tiled_shape) const { -+ return dim3(tiled_shape.n(), tiled_shape.batch(), tiled_shape.k()); -+ } -+ -+ /// Calculates optimal swizzle width -+ CUTLASS_HOST_DEVICE -+ int get_log_tile(GemmCoord tiled_shape) const { -+ return 0; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ BatchedGemmCoord get_tile_offset(int log_tile) const { -+ return BatchedGemmCoord{ -+ 0, // M is always 1 -+ RematerializeBlockIdxX(), -+ RematerializeBlockIdxZ(), -+ RematerializeBlockIdxY(), -+ }; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ BatchedGemmCoord get_tile_offset() const { -+ return BatchedGemmCoord{ -+ 0, // M is always 1 -+ RematerializeBlockIdxX(), -+ RematerializeBlockIdxZ(), -+ RematerializeBlockIdxY(), -+ }; -+ } -+ -+ /// Gets the batch tile index -+ CUTLASS_DEVICE -+ int get_batch_tile_idx() const { -+ return RematerializeBlockIdxY(); -+ } -+ -+ /// Gets the absolute batch index -+ CUTLASS_DEVICE -+ int get_batch_idx() const { -+ return RematerializeBlockDimY()*RematerializeBlockIdxY() + RematerializeThreadIdxY(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h -new file mode 100644 -index 0000000..b91046e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h -@@ -0,0 +1,813 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implements streamk threadblock mapping blockIdx to GEMM problems. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/platform/platform.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/gemm/threadblock/index_remat.h" -+ -+#include -+#include "cutlass/core_io.h" -+#include "cutlass/trace.h" -+ -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock mapping control for GEMMs -+struct ThreadblockSwizzleStreamK { -+ -+ /// Advertise StreamkFeature -+ using StreamkFeature = void; -+ -+ -+ /// Kernel traits -+ template -+ struct KernelTraits {}; -+ -+ -+ /// Reduction strategy -+ enum ReductionStrategy -+ { -+ kNone, // Data-parallel strategy (no seams, fixup, etc.) -+ -+ kAtomic, // Non-deterministic reduction of SK-block partials using atomic aggregation in L2 -+ -+ kMixed, // Deterministic reduction of SK-block partials employing either: -+ // (a) A separate wave of reduction thread blocks" (for scenarios with lots of -+ // SK-blocks per SK-tile) -+ // (b) Turnstile-ordered atomic aggregation in L2 (for scenarios with few -+ // SK-blocks per SK-tile) -+ }; -+ -+ static ReductionStrategy const kReductionStrategy = kMixed; -+ -+ -+ // -+ // Heuristics -+ // -+ -+ /// Data-parallel wave-quantization efficiency threshold (above which we go data-parallel) -+ static float constexpr kDpEfficiencyThreshold = 0.92f; -+ -+ /// Minimum number of MAC-iterations per streamk block -+ static int const kMinItersPerSkBlock = 2; -+ -+ /// Height in CTAs of a grid rasterization cohort -+ static int const kCohortCtasM = 8; -+ -+ /// Width in CTAs of a grid rasterization cohort -+ static int const kCohortCtasN = 4; -+ -+ /// Number of CTAs per cohort -+ static int const kCtasPerCohort = kCohortCtasN * kCohortCtasM; -+ -+ /// Cost-equivalent number of SM-iterations for fixup I/O -+ static int const kFixupStartupIterEquiv = 10; -+ static int const kFixupPeerIterEquiv = 3; -+ -+ -+ // -+ // Member state -+ // -+ -+ -+ /// The 3D value-extents of the GEMM computation volume (m,n,k) -+ GemmCoord problem_size; -+ -+ /// Div/mod accelerators -+ FastDivmod div_mod_tiled_shape_m; -+ FastDivmod div_mod_tiled_shape_n; -+ FastDivmod div_mod_tiled_cohort_shape_n; -+ FastDivmod div_mod_iters_per_tile; -+ -+ /// Whether to perform cohort CTA rasterization -+ bool cohort_raster; -+ -+ // Whether to pad and remap block indices -+ bool remap_block_indices; -+ -+ /// CTA occupancy per SM -+ int sm_occupancy; -+ -+ /// Number of SMs for dispatch heuristics to load-balance using Stream-K CTAs (wave size) -+ int avail_sms; -+ -+ int dp_blocks; /// Number of data-parallel thread blocks in the grid -+ int dp_first_wave_tiles; /// Number of output tiles each CTA in the first DP wave will produce -+ -+ /// Number of reduction blocks in the grid -+ int reduction_blocks; -+ -+ int sk_waves; -+ int sk_tiles; -+ int sk_big_blocks_per_region; -+ int sk_iters_per_region; -+ -+ /// Div/mod accelerators -+ FastDivmod div_mod_sk_iters_per_normal_block; -+ FastDivmod div_mod_sk_iters_per_big_block; -+ FastDivmod div_mod_sk_iters_per_region; -+ FastDivmod div_mod_sk_regions; //!! used in block map -+ FastDivmod div_mod_sk_blocks_per_region; //!! used in block map -+ -+ /// The batch count -+ int batch_count; -+ -+ -+ // -+ // Host+device interface -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ ThreadblockSwizzleStreamK() {} -+ -+ /// Returns the GEMM volume in thread block tiles -+ CUTLASS_HOST_DEVICE -+ GemmCoord tiled_shape() const -+ { -+ return GemmCoord( -+ static_cast(div_mod_tiled_shape_m), -+ static_cast(div_mod_tiled_shape_n), -+ batch_count); -+ } -+ -+ /// Number of iterations per output tile -+ CUTLASS_HOST_DEVICE -+ int iters_per_tile() const -+ { -+ return static_cast(div_mod_iters_per_tile); -+ } -+ -+ /// Number of iterations for normal SK-blocks -+ CUTLASS_HOST_DEVICE -+ int sk_iters_per_normal_block() const -+ { -+ return static_cast(div_mod_sk_iters_per_normal_block); -+ } -+ -+ /// Number of SK regions -+ CUTLASS_HOST_DEVICE -+ int sk_regions() const -+ { -+ return static_cast(div_mod_sk_regions); -+ } -+ -+ /// Number of SK blocks per region (splitting factor) -+ CUTLASS_HOST_DEVICE -+ int sk_blocks_per_region() const -+ { -+ return static_cast(div_mod_sk_blocks_per_region); -+ } -+ -+ -+ // -+ // Host-side interface -+ // -+ -+ /// Debug print -+ void Print() -+ { -+#ifndef __CUDA_ARCH__ -+ auto tiles = tiled_shape().mn().product(); -+ std::cout << -+ "problem_size: (" << problem_size.m() << "," << problem_size.n() << ")" << -+ ", tiled_shape: (" << tiled_shape().m() << "," << tiled_shape().n() << ")" << -+ ", tiles: " << tiles << -+ ", dp_tiles: " << tiles - sk_tiles << -+ ", sk_tiles: " << sk_tiles << -+ ", iters_per_tile: " << iters_per_tile() << -+ ", reduction_blocks: " << reduction_blocks << -+ ", dp_blocks: " << dp_blocks << -+ ", dp_waves: " << dp_blocks / avail_sms << -+ ", dp_first_wave_tiles: " << dp_first_wave_tiles << -+ ", sk_blocks_per_region: " << sk_blocks_per_region() << -+ ", sk_regions: " << sk_regions() << -+ ", sk_waves: " << sk_waves << -+ ", sk_iters_per_normal_block: " << sk_iters_per_normal_block() << -+ ", sk_big_blocks_per_region: " << sk_big_blocks_per_region << -+ ", remap_block_indices: " << remap_block_indices << -+ ", cohort_raster: " << cohort_raster << -+ ", sm_occupancy: " << sm_occupancy << -+ ", avail_sms: " << avail_sms << -+ ", num_blocks: " << get_num_blocks() << -+ "\n\n"; -+#endif -+ } -+ -+ -+ // Compute sk_blocks to dispatch for a given number of sk_tiles -+ static void get_sk_blocks( -+ int &sk_blocks, /// [out] -+ int &savings_iters, /// [out] -+ int sk_tiles, -+ int iters_per_tile, -+ int avail_sms, -+ int max_sk_occupancy, -+ bool allow_partial_wave) -+ { -+ savings_iters = INT_MIN; -+ sk_blocks = 0; -+ -+ if (sk_tiles == 0) { -+ return; -+ } -+ -+ int sk_iters = sk_tiles * iters_per_tile; -+ -+ int dp_equiv_waves = (sk_tiles + avail_sms - 1) / avail_sms; -+ int dp_equiv_iters = iters_per_tile * dp_equiv_waves; -+ -+ int min_sk_blocks = (allow_partial_wave) ? fast_min(avail_sms, sk_tiles + 1) : avail_sms; -+ int max_sk_blocks = fast_min(avail_sms * max_sk_occupancy, sk_iters / kMinItersPerSkBlock); -+ -+ for (int trial_sk_blocks = min_sk_blocks; trial_sk_blocks <= max_sk_blocks; ++trial_sk_blocks) -+ { -+ int sk_waves = (trial_sk_blocks + avail_sms - 1) / avail_sms; -+ int max_sk_iters_per_block = (sk_iters + trial_sk_blocks - 1) / trial_sk_blocks; -+ int sk_iter_equiv = max_sk_iters_per_block * sk_waves; -+ -+ int num_peers = ((trial_sk_blocks + sk_tiles - 1) / sk_tiles) + 1; // add one for alignment skew -+ -+ float iter_cost = 0.02f * float(num_peers) * float(sk_iter_equiv); -+ -+ if (trial_sk_blocks % sk_tiles == 0) -+ { -+ // aligned -+ num_peers = (trial_sk_blocks / sk_tiles); -+ -+ iter_cost = 0.0f; -+ } -+ -+ float peer_cost = 2.0f * float(num_peers); -+ -+ float base_cost = 2.0f * float(sk_waves); -+ -+ int fixup_iter_equiv = int(base_cost + iter_cost + peer_cost); -+ -+ int trial_savings_iters = dp_equiv_iters - sk_iter_equiv - fixup_iter_equiv; -+ -+ if (trial_savings_iters >= savings_iters) { -+ savings_iters = trial_savings_iters; -+ sk_blocks = trial_sk_blocks; -+ } -+ } -+ } -+ -+ -+ /// Determine the populations of DP and SK blocks to invoke for the given number of output tiles -+ static void get_blocks( -+ int &dp_tiles, /// [out] -+ int &sk_blocks, /// [out] -+ int output_tiles, -+ int iters_per_tile, -+ int avail_sms, -+ int sm_occupancy) -+ { -+ int full_waves = output_tiles / avail_sms; -+ int full_wave_tiles = full_waves * avail_sms; -+ int partial_wave_tiles = output_tiles - full_wave_tiles; -+ -+ int score = -1; -+ dp_tiles = output_tiles; -+ sk_blocks = 0; -+ -+ if (partial_wave_tiles == 0) -+ { -+ // Perfect quantization -+ return; -+ } -+ -+ if (full_waves < sm_occupancy) -+ { -+ // We're less than full GPU occupancy -+ -+ // Form the SK wave from the partial wave to get us up to full GPU occupancy -+ int max_sk_occupancy = sm_occupancy - full_waves; -+ -+ dp_tiles = full_wave_tiles; -+ -+ get_sk_blocks( -+ sk_blocks, -+ score, -+ partial_wave_tiles, -+ iters_per_tile, -+ avail_sms, -+ max_sk_occupancy, -+ true); // we can run with less than a full wave of SK-blocks -+ -+ if (score < 0) { -+ // not profitable -+ sk_blocks = 0; -+ dp_tiles = output_tiles; -+ } -+ -+ return; -+ } -+ -+ // We're at (or greater) than GPU occupancy -+ -+ if ((sm_occupancy > 1 ) && (full_waves % sm_occupancy == sm_occupancy - 1)) -+ { -+ // If occupancy is more than one CTA per SM, form the SK wave from the partial -+ // wave to get us to full GPU occupancy -+ int max_sk_occupancy = 1; -+ -+ dp_tiles = full_wave_tiles; -+ -+ get_sk_blocks( -+ sk_blocks, -+ score, -+ partial_wave_tiles, -+ iters_per_tile, -+ avail_sms, -+ max_sk_occupancy, -+ true); // we can run with less than a full wave of SK-blocks -+ -+ if (score >= 0) { -+ return; -+ } -+ } -+ -+ // Form the SK wave by combining the last full wave and the partial wave -+ // We're less than full GPU occupancy -+ dp_tiles = full_wave_tiles - avail_sms; -+ -+ int max_sk_occupancy = sm_occupancy - ((full_waves - 1) % sm_occupancy); -+ -+ get_sk_blocks( -+ sk_blocks, -+ score, -+ partial_wave_tiles + avail_sms, -+ iters_per_tile, -+ avail_sms, -+ max_sk_occupancy, -+ false); // we cannot run with less than a full wave of SK-blocks -+ -+ if (score < 0) { -+ // not profitable -+ sk_blocks = 0; -+ dp_tiles = output_tiles; -+ } -+ -+ } -+ -+ /// Constructor: *Gemm* problem size (m, n, k) -+ template -+ ThreadblockSwizzleStreamK( -+ KernelTraits const kernel_traits_, -+ GemmUniversalMode const mode_, -+ GemmCoord const problem_size_, -+ GemmCoord const tile_size_, -+ int const batch_split_, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) -+ int const sm_occupancy_, -+ int const device_sms_, -+ int const avail_sms_) /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) -+ : -+ problem_size(problem_size_), -+ batch_count((mode_ == GemmUniversalMode::kBatched) ? batch_split_ : 1), -+ reduction_blocks(0), -+ dp_blocks(0), -+ dp_first_wave_tiles(1), // Default: one tile per DP-block in the first wave of DP blocks -+ sk_tiles(0), -+ sk_big_blocks_per_region(0), -+ sk_iters_per_region(0), -+ sk_waves(0), -+ sm_occupancy(sm_occupancy_), -+ remap_block_indices(false), -+ avail_sms(fast_max(1, avail_sms_)), -+ cohort_raster(false) -+ { -+ int gpu_occupancy = device_sms_ * sm_occupancy; -+ int iters_per_tile = (problem_size.k() + tile_size_.k() - 1) / tile_size_.k(); -+ int sk_iters_per_normal_block = 0; -+ -+ int sk_regions = 1; // Default: a single region of iteration space (across all SK tiles) -+ int sk_blocks_per_region = 0; -+ -+ GemmCoord tiled_shape( -+ (problem_size.m() + tile_size_.m() - 1) / tile_size_.m(), -+ (problem_size.n() + tile_size_.n() - 1) / tile_size_.n(), -+ batch_count); -+ -+ size_t problem_bytes = -+ (sizeof(typename GemmKernel::ElementC) * problem_size.m() * problem_size.n()) + -+ (sizeof(typename GemmKernel::ElementA) * problem_size.m() * problem_size.k()) + -+ (sizeof(typename GemmKernel::ElementB) * problem_size.k() * problem_size.n()); -+ -+ size_t problem_flops = size_t(problem_size.m()) * size_t(problem_size.n()) * size_t(problem_size.k()) * 2; -+ -+ float flops_per_byte = float(problem_flops) / float(problem_bytes); -+ -+ int output_tiles = tiled_shape.m() * tiled_shape.n(); -+ int waves = (output_tiles + avail_sms - 1) / avail_sms; -+ float dp_efficiency = float(output_tiles) / float(waves * avail_sms); -+ -+ // -+ // Determine dispatch composition of DP-tiles and SK-blocks -+ // -+ -+ // Start with a DP-only configuration -+ int dp_tiles = output_tiles; // Number of data-parallel tiles -+ int sk_blocks = 0; // Number of thread blocks to produce the remaining SK tiles -+ -+ // Only kGemm mode allows for SK load balancing -+ if (mode_ == GemmUniversalMode::kGemm) -+ { -+ int split_factor = batch_split_; -+ if (split_factor > 1) -+ { -+ // Split-K override -+ dp_tiles = 0; -+ sk_blocks = output_tiles * split_factor; -+ } -+ else if ((kReductionStrategy != kNone) && // Load-balancing strategy statically enabled -+ (avail_sms > 1)) // Plurality of SMs to load balance across -+ { -+ // Use heuristics -+ get_blocks( -+ dp_tiles, /// [out] -+ sk_blocks, /// [out] -+ output_tiles, -+ iters_per_tile, -+ avail_sms, -+ sm_occupancy); -+ } -+ } -+ -+ sk_tiles = output_tiles - dp_tiles; -+ -+ -+ // Compute SK block iteration details -+ if (sk_blocks > 0) -+ { -+ sk_waves = (sk_blocks + avail_sms - 1) / avail_sms; -+ -+ int sk_iters = sk_tiles * iters_per_tile; -+ sk_blocks = fast_min(sk_blocks, sk_iters); -+ -+ sk_iters_per_normal_block = sk_iters / sk_blocks; -+ int extra_sk_iters = sk_iters - (sk_iters_per_normal_block * sk_blocks); -+ int sk_big_blocks = extra_sk_iters; -+ -+ if ((sk_blocks > sk_tiles) && (sk_blocks % sk_tiles == 0)) -+ { -+ // Split-K decomposition -+ sk_regions = sk_tiles; -+ } -+ -+ sk_blocks_per_region = sk_blocks / sk_regions; -+ sk_big_blocks_per_region = sk_big_blocks / sk_regions; -+ sk_iters_per_region = sk_iters / sk_regions; -+ -+ // Use a separate reduction wave when all of: -+ // - Non-atomic reduction stratgy -+ // - The number of SK waves won't fully occupy the GPU (Otherwise we don't have -+ // a strong-scaling case for more parallel reduction) -+ // - More than three peers working on an SK tile. (This occurs when the ratio of -+ // SK-blocks to SK-tiles > 2, as a single tile may be covered by four SK-blocks, -+ // e.g.:[partial-block | block | block | partial-block] ). With three or -+ // less peers, the two non-finishing SK-blocks are not expexted to contend. -+ if ((kReductionStrategy == kMixed) && -+ (sk_waves < sm_occupancy) && -+ (sk_blocks > 2 * sk_tiles)) -+ { -+ // Launch a reduction block for every accumulator fragment in each SK-tile -+ static const int kAccumulatorFragments = GemmKernel::Epilogue::kAccumulatorFragments; -+ reduction_blocks = sk_tiles * kAccumulatorFragments; -+ -+ } -+ -+ // When we have a multi-occupancy kernel and at least two waves of active blocks (where -+ // at least one wave is SK blocks), we need to (1) dispatch at least four waves, and (2) -+ // remap the block indices so that we can reliably spread the SK blocks evenly across the -+ // device's first SM occupancy valence. Also see get_num_blocks() and get_block_idx(). -+ remap_block_indices = ( -+ (sm_occupancy > 1) && -+ (device_sms_ == avail_sms) && -+ (get_num_active_blocks() > avail_sms * 2)); -+ -+ // Initialize fast div/mod members related to SK -+ div_mod_sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block); -+ div_mod_sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1); -+ div_mod_sk_iters_per_region = FastDivmod(sk_iters_per_region); -+ div_mod_sk_regions = FastDivmod(sk_regions); -+ div_mod_sk_blocks_per_region = FastDivmod(sk_blocks_per_region); -+ } -+ -+ // -+ // Compute DP blocks -+ // -+ -+ dp_blocks = dp_tiles; -+ -+ cutlass::gemm::GemmCoord tiled_cohort_shape( -+ (tiled_shape.m() + kCohortCtasM - 1) / kCohortCtasM, -+ (tiled_shape.n() + kCohortCtasN - 1) / kCohortCtasN, -+ tiled_shape.k()); -+ int cohort_blocks = (tiled_cohort_shape.m() * tiled_cohort_shape.n()) * kCtasPerCohort; -+ float cohort_efficiency = float(dp_blocks) / float(cohort_blocks); -+ -+ // Check if the SK tiles would be in cohorts that are in-bounds -+ bool sk_in_range = true; -+ if (sk_tiles > 0) -+ { -+ int last_sk_tile = sk_tiles - 1; -+ int cohort_tile_idx = last_sk_tile / kCtasPerCohort; -+ int cohort_grid_m = cohort_tile_idx / tiled_cohort_shape.n(); -+ int cohort_grid_n = (cohort_grid_m > 0) ? -+ tiled_cohort_shape.n() - 1 : -+ cohort_tile_idx % tiled_cohort_shape.n(); -+ -+ if ((((cohort_grid_m + 1) * kCohortCtasM) >= tiled_shape.m()) || -+ (((cohort_grid_n + 1) * kCohortCtasN) >= tiled_shape.n())) -+ { -+ sk_in_range = false; -+ } -+ -+ } -+ -+ // Decide if we're going to be doing cohort raster -+ if (sk_in_range && -+ (dp_blocks >= gpu_occupancy * 2) && -+ (cohort_efficiency > 0.85f)) -+ { -+ cohort_raster = true; -+ dp_blocks = cohort_blocks; -+ } -+ else if (sk_waves > 0) -+ { -+ // Update semi-persistence of first DP wave to ensure full grid wavesets -+ // (Only applies when there's an SK component and we're not doing blocked cohort rasterization) -+ int dp_tile_waves = (dp_tiles + avail_sms - 1) / avail_sms; -+ int full_dp_tile_waves = dp_tiles / avail_sms; -+ int waveset_excess = (sk_waves + dp_tile_waves) % sm_occupancy; -+ -+ if (dp_first_wave_tiles + waveset_excess <= full_dp_tile_waves) -+ { -+ dp_first_wave_tiles += waveset_excess; -+ dp_blocks -= (waveset_excess * avail_sms); -+ } -+ } -+ -+ // Setup fast-div/mod for device-side usage -+ div_mod_tiled_shape_m = FastDivmod(tiled_shape.m()); -+ div_mod_tiled_shape_n = FastDivmod(tiled_shape.n()); -+ div_mod_tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n()); -+ div_mod_iters_per_tile = FastDivmod(iters_per_tile); -+ -+ } -+ -+ /// Number of blocks performing useful work -+ int get_num_active_blocks() const -+ { -+ return (sk_waves * avail_sms) + dp_blocks + reduction_blocks; -+ } -+ -+ /// Obtains number of threadblocks per GEMM -+ int get_num_blocks() const -+ { -+ int active_blocks = get_num_active_blocks(); -+ if (remap_block_indices) -+ { -+ // Add padding blocks if we are performing remapping in order to dispatch a grid of at least four waves -+ return fast_max(active_blocks, avail_sms * 4); -+ } -+ -+ return active_blocks; -+ } -+ -+ -+ /// Obtains grid extents in CTAs -+ dim3 get_grid_dims() const -+ { -+ return dim3(get_num_blocks(), 1, batch_count); -+ } -+ -+ -+// Guards needed for PyCUTLASS library generation -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ -+ // -+ // Device-side interface -+ // -+ -+ /// Proves to the compiler that val is warp-uniform -+ CUTLASS_DEVICE -+ int uniform(int val) const -+ { -+ return __shfl_sync(0xffffffff, val, 0); -+ } -+ -+ /// Obtains number of threadblocks per GEMM -+ CUTLASS_DEVICE -+ int device_num_blocks() const -+ { -+ return gridDim.x; -+ } -+ -+ /// Obtains tile index for the given sk iteration -+ CUTLASS_DEVICE -+ int get_sk_tile_idx(int iter) const -+ { -+ int tile_idx = div_mod_iters_per_tile.div(iter); -+ return uniform(tile_idx); -+ } -+ -+ /// Obtains the batch index -+ CUTLASS_DEVICE -+ int get_batch_idx() const -+ { -+ return RematerializeBlockIdxZ(); -+ } -+ -+ /// Obtains the calling threadblock's tiled coordinates for the given tile index -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(int tile_idx) const -+ { -+ int m, n; -+ -+ // row-major raster -+ div_mod_tiled_shape_n(m, n, tile_idx); -+ -+ if (tiled_shape().m() < tiled_shape().n()) -+ { -+ // column-major raster -+ div_mod_tiled_shape_m(n, m, tile_idx); -+ } -+ -+ if (cohort_raster) -+ { -+ // tiled cohort raster -+ int cohort_tile_idx = tile_idx / kCtasPerCohort; -+ int cohort_grid_m, cohort_grid_n; -+ div_mod_tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx); -+ -+ int block_idx_cohort = tile_idx % kCtasPerCohort; -+ int block_cohort_m = block_idx_cohort / kCohortCtasN; -+ int block_cohort_n = block_idx_cohort % kCohortCtasN; -+ -+ m = (cohort_grid_m * kCohortCtasM) + block_cohort_m; -+ n = (cohort_grid_n * kCohortCtasN) + block_cohort_n; -+ } -+ -+ return GemmCoord(m, n, get_batch_idx()); -+ } -+ -+ /// Obtains the calling threadblock's tiled coordinates for the given tile index (row-major rastorization) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset_row_major(int tile_idx) const -+ { -+ // row-major raster -+ int m, n; -+ div_mod_tiled_shape_n(m, n, tile_idx); -+ return GemmCoord(m, n, get_batch_idx()); -+ } -+ -+ /// Obtains calling threadblock's linear threadblock index -+ CUTLASS_DEVICE -+ int get_block_idx() const -+ { -+ int block_idx = RematerializeBlockIdxX(); -+ -+ // Remap the block indices for the first two waves of thread blocks if -+ // we have multi-occupancy and the grid constitutes four or more waves -+ if (remap_block_indices && (block_idx < avail_sms * 2)) -+ { -+ int dest_sm = block_idx / 2; -+ int dest_wave = block_idx % 2; -+ int remapped_block_idx = dest_sm + (dest_wave * avail_sms); -+ block_idx = remapped_block_idx; -+ } -+ -+ // Remap block indices to interleave SK regions to limit intra-region waiting -+ if (block_idx < sk_regions() * sk_blocks_per_region()) -+ { -+ int block_in_region; -+ int region; -+ div_mod_sk_regions(block_in_region, region, block_idx); -+ block_idx = (region * sk_blocks_per_region()) + block_in_region; -+ } -+ -+ return uniform(block_idx); -+ } -+ -+ -+ /// Obtains calling linear threadblock index of the first block to work on the given tile -+ CUTLASS_DEVICE -+ int get_sk_block_idx(int iter) const -+ { -+ int region_idx; -+ int iter_in_region; -+ div_mod_sk_iters_per_region(region_idx, iter_in_region, iter); -+ -+ int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block()) + sk_big_blocks_per_region; // number of iterations in the region's big blocks -+ int normal_block_iters = iter_in_region - big_block_iters; // number of iterations in the region's normal bocks -+ -+ int big_block_idx_in_region = div_mod_sk_iters_per_big_block.div(iter_in_region); -+ int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod_sk_iters_per_normal_block.div(normal_block_iters); -+ -+ int block_idx_in_region = (big_block_idx_in_region < sk_big_blocks_per_region) ? -+ big_block_idx_in_region : -+ normal_block_idx_in_region; -+ -+ int owning_block_idx = (sk_blocks_per_region() * region_idx) + block_idx_in_region; -+ -+ return owning_block_idx; -+ } -+ -+ /// Obtains iteration extends for the given SK block index -+ CUTLASS_DEVICE -+ void get_iter_extents( -+ int sk_block_idx, -+ int &block_iter_begin, -+ int &block_iter_end) const -+ { -+ int region_idx; -+ int block_idx_in_region; -+ div_mod_sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx); -+ -+ block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block()); -+ -+ // Adjust extents for the first "num_big_blocks" blocks that get one extra iteration -+ int block_iters = sk_iters_per_normal_block(); -+ if (block_idx_in_region < sk_big_blocks_per_region) { -+ // This is a +1 iteration block -+ block_iter_begin += block_idx_in_region; -+ block_iters++; -+ } else { -+ // This is a regular block -+ block_iter_begin += sk_big_blocks_per_region; -+ } -+ block_iter_end = block_iter_begin + block_iters; -+ } -+ -+ -+ /// Obtains calling linear threadblock index of the first block to work on the given tile -+ CUTLASS_DEVICE -+ int get_first_block_idx(int tile_idx, int block_idx) const -+ { -+ if (tile_idx >= sk_tiles) { -+ // DP tile -+ return block_idx; -+ } -+ -+ int iter = tile_idx * iters_per_tile(); -+ return get_sk_block_idx(iter); -+ } -+ -+#endif // defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h -new file mode 100644 -index 0000000..1c794b1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h -@@ -0,0 +1,612 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/warp/mma_complex_tensor_op.h" -+#include "cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h" -+#include "cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transform on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_ = arch::OpMultiplyAddComplex> -+struct DefaultMmaComplexTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex case -+// 4 real-valued mma operations -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Real-valued underlying type of complex-valued A operand -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Real-valued underlying type of complex-valued B operand -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Real-valued underlying type of complex-valued C operand -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddComplex> { -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ RealElementA, -+ cutlass::layout::RowMajor, -+ RealElementB, -+ cutlass::layout::ColumnMajor, -+ RealElementC, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex case using GaussianComplex operation -+// 3 real-valued mma operations -+// A = (ar + j ai), B = (br +j bi), D = AB -+// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) -+// D = dr + j di = (P1 - P3) + j (P1 + P2) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Real-valued underlying type of complex-valued A operand -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Real-valued underlying type of complex-valued B operand -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Real-valued underlying type of complex-valued C operand -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddGaussianComplex> { -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ RealElementA, -+ cutlass::layout::RowMajor, -+ RealElementB, -+ cutlass::layout::ColumnMajor, -+ RealElementC, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaGaussianComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB>; -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization - input and output types are complex*complex -+// Use TF32 tensor operation internally -+// 4 real-valued mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 operations on TF32 -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddComplex> { -+ -+ // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 mma instruction -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ tfloat32_t, -+ cutlass::layout::RowMajor, -+ tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization - input and output types are complex*complex -+// Use BF16 tensor operation internally -+// 4 real-valued mma.sync.aligned.m16n8k8.f32.bf16.bf16.f32 operations on BF16 -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddFastBF16> { -+ -+ // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.bf16.bf16.f32 mma instruction -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ bfloat16_t, -+ cutlass::layout::RowMajor, -+ bfloat16_t, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization - input and output types are complex*complex -+// Use F16 tensor operation internally -+// 4 real-valued mma.sync.aligned.m16n8k8.f32.f16.f16.f32 operations on F16 -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddFastF16> { -+ -+ // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.f16.f16.f32 mma instruction -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ half_t, -+ cutlass::layout::RowMajor, -+ half_t, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// 3xTF32 or 4xTF32 (fast and accurate complex operation) -+/// Partial specialization - input and output types are complex * complex -+// Use 3xTF32 or 4xTF32 tensor operation internally -+// 4 real-valued mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 operations on TF32 -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = 3x[(ar*br - ai*bi) + j (ar*bi + ai*br)] -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddComplexFastF32> { -+ -+ // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 mma instruction -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ tfloat32_t, -+ cutlass::layout::RowMajor, -+ tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaComplexTensorOpFastF32< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex case -+// 4 real-valued mma.sync.aligned.m16n8k4.f64.f64.f64.f64 operations -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Real-valued underlying type of complex-valued A operand -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Real-valued underlying type of complex-valued B operand -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Real-valued underlying type of complex-valued C operand -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ GemmShape<16, 8, 4>, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddComplex> { -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ GemmShape<16, 8, 4>, -+ 32, -+ RealElementA, -+ cutlass::layout::RowMajor, -+ RealElementB, -+ cutlass::layout::ColumnMajor, -+ RealElementC, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB, -+ true>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization for complex*complex case using GaussianComplex operation -+// 3 real-valued mma.sync.aligned.m16n8k4.f64.f64.f64.f64 operations -+// A = (ar + j ai), B = (br +j bi), D = AB -+// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) -+// D = dr + j di = (P1 - P3) + j (P1 + P2) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Real-valued underlying type of complex-valued A operand -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Real-valued underlying type of complex-valued B operand -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Real-valued underlying type of complex-valued C operand -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ GemmShape<16, 8, 4>, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddGaussianComplex> { -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ GemmShape<16, 8, 4>, -+ 32, -+ RealElementA, -+ cutlass::layout::RowMajor, -+ RealElementB, -+ cutlass::layout::ColumnMajor, -+ RealElementC, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaGaussianComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB, -+ true>; -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h -new file mode 100644 -index 0000000..89f8f1c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h -@@ -0,0 +1,165 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/warp/mma_sparse_tensor_op.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Operator describing the tensor operation -+ typename Operator_ = arch::OpMultiplyAdd, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false -+> -+struct DefaultSparseMmaTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial Specialization - inputs and output types are float - uses TF32 internally -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of target matrix multiply instruction (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultSparseMmaTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ float, LayoutA, -+ float, LayoutB, -+ float, LayoutC, -+ arch::OpMultiplyAdd, PartitionsK, AccumulatorsInRowMajor> { -+ -+ // Uses TF32 internally -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::SparseMma< -+ InstructionShape_, -+ 32, -+ tfloat32_t, cutlass::layout::RowMajor, -+ tfloat32_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::SparseMmaTensorOp< -+ WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for m-by-n-by-kgroup -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Operator describing the tensor operation -+ typename Operator_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultSparseMmaTensorOp { -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::SparseMma, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::SparseMmaTensorOp< -+ WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h -new file mode 100644 -index 0000000..3421de9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h -@@ -0,0 +1,123 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Operator describing the tensor operation -+ typename Operator_ = arch::OpMultiplyAdd, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false> -+struct DefaultMmaTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for m-by-n-by-kgroup -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Operator describing the tensor operation -+ typename Operator_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultMmaTensorOp { -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaTensorOp< -+ WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "default_mma_tensor_op_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h -new file mode 100644 -index 0000000..d4d8026 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h -@@ -0,0 +1,238 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial Specialization - inputs and output types are float - uses BF16 internally -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultMmaTensorOp< -+ WarpShape_, -+ GemmShape<16, 8, 8>, -+ float, LayoutA, -+ float, LayoutB, -+ float, LayoutC, -+ arch::OpMultiplyAddFastBF16, -+ PartitionsK, AccumulatorsInRowMajor> { -+ -+ // Uses BF16 internally -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ GemmShape<16, 8, 8>, -+ 32, -+ bfloat16_t, cutlass::layout::RowMajor, -+ bfloat16_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaTensorOp< -+ WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial Specialization - inputs and output types are float - uses F16 internally -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultMmaTensorOp< -+ WarpShape_, -+ GemmShape<16, 8, 8>, -+ float, LayoutA, -+ float, LayoutB, -+ float, LayoutC, -+ arch::OpMultiplyAddFastF16, -+ PartitionsK, AccumulatorsInRowMajor> { -+ -+ // Uses F16 internally -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ GemmShape<16, 8, 8>, -+ 32, -+ half_t, cutlass::layout::RowMajor, -+ half_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaTensorOp< -+ WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial Specialization - inputs and output types are float - uses TF32 internally -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of target matrix multiply instruction (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultMmaTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ float, LayoutA, -+ float, LayoutB, -+ float, LayoutC, -+ arch::OpMultiplyAdd, PartitionsK, AccumulatorsInRowMajor> { -+ -+ // Uses TF32 internally -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ tfloat32_t, cutlass::layout::RowMajor, -+ tfloat32_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaTensorOp< -+ WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial Specialization - inputs and output types are float - uses TF32 for Fast Accurate FP32 -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of target matrix multiply instruction (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultMmaTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ float, LayoutA, -+ float, LayoutB, -+ float, LayoutC, -+ arch::OpMultiplyAddFastF32, PartitionsK, AccumulatorsInRowMajor> { -+ -+ // Uses TF32 internally -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ cutlass::tfloat32_t, cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaTensorOpFastF32< -+ WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h -new file mode 100644 -index 0000000..63effe8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h -@@ -0,0 +1,92 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/warp/mma_with_reduction_tensor_op.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Operator describing the tensor operation -+ typename Operator_, -+ /// Reduce operand A or B along K dimension -+ bool ReduceKForA_, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false> -+struct DefaultMmaWithReductionTensorOp { -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaWithReductionTensorOp< -+ WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ Policy, ReduceKForA_, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h -new file mode 100644 -index 0000000..4f951d4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/warp/mma_tensor_op_wmma.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ ///< Size of the Gemm problem (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Operator describing the tensor operation -+ typename Operator_ = arch::OpMultiplyAdd, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1 -+> -+struct DefaultMmaTensorOpWmma; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for m-by-n-by-kgroup -+template < -+ ///< Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Operator describing the tensor operation -+ typename Operator_, -+ /// Number of partitions along K dimension -+ int PartitionsK> -+struct DefaultMmaTensorOpWmma { -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape_, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator_>, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaTensorOpWmma< -+ WarpShape_, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Policy, -+ PartitionsK>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+#endif -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h b/3rdparty/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h -new file mode 100644 -index 0000000..c604ef3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h -@@ -0,0 +1,140 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level per channel scale+bias+relu before -+ matrix multiply-accumulate operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct LayernormScaleBiasTransform { -+ -+ using T = typename FragmentActivations::Element; -+ -+ static int const NumActivations = FragmentActivations::kElements; -+ static int const NumVarMean = FragmentVarMean::kElements; -+ static int const NumGammaBeta = FragmentGammaBeta::kElements; -+ static int const MmaElements = 2; -+ // One element has one scale and one bias -+ static int const MmaScaleBiasPair = 2; -+ // 16816 has 2 columns and 2 rows -+ static int const MmaCols = 2; -+ static int const MmaRows = 2; -+ -+ using MmaOperand = Array; -+ using VarMeanOperand = Array<__half2, MmaScaleBiasPair>; -+ using GammaBetaOperand = Array; -+ -+ CUTLASS_DEVICE -+ void transform(MmaOperand &activations, -+ VarMeanOperand const &var_mean, -+ GammaBetaOperand const &gamma_beta) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+ uint32_t *ptr_activations = reinterpret_cast(&activations); -+ uint32_t const *ptr_var_mean = reinterpret_cast(&var_mean); -+ uint32_t const *ptr_gamma_beta = reinterpret_cast(&gamma_beta); -+ -+ // Apply per channel scale+bias+relu if the data is not a special NaN -+ // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. -+ -+ // We assumes the pair of FP16 are either both inbound or both out-of-bound. -+ // It requires C to be an even number. -+ asm volatile( -+ "{\n\t" -+ " fma.rn.f16x2 %0, %1, %2, %3;\n" -+ " fma.rn.f16x2 %0, %4, %0, %5;\n" -+ "}\n" -+ : "=r"(ptr_activations[0]) -+ : "r"(ptr_var_mean[0]), "r"(ptr_activations[0]), -+ "r"(ptr_var_mean[1]), -+ "r"(ptr_gamma_beta[0]), "r"(ptr_gamma_beta[1])); -+#else -+ // TODO: write emulation code -+ assert(0); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ void operator()(FragmentActivations &activations, -+ FragmentVarMean const &var_mean, -+ FragmentGammaBeta const &gamma_beta) { -+ MmaOperand *ptr_activations = reinterpret_cast(&activations); -+ VarMeanOperand const *ptr_var_mean = -+ reinterpret_cast(&var_mean); -+ GammaBetaOperand const *ptr_gamma_beta = -+ reinterpret_cast(&gamma_beta); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < (NumActivations / MmaElements); ++i) { -+ transform(ptr_activations[i], -+ ptr_var_mean[i / (MmaCols * MmaRows) * MmaRows + i % MmaRows], -+ ptr_gamma_beta[(i / MmaScaleBiasPair) % MmaCols]); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma.h -new file mode 100644 -index 0000000..1f3ca94 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma.h -@@ -0,0 +1,60 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for warp-level multiply-add operations -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Query the number of threads per warp -+template -+struct WarpSize { -+ static int const value = 32; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h -new file mode 100644 -index 0000000..7bcf7fe ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h -@@ -0,0 +1,1167 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+#include "cutlass/arch/mma_sm90.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ /// Data type of real & imag members of complex numbers in the SourceFragment -+ typename RealElement, -+ /// Destination fragment required by the mma operation -+ typename DestinationFragment, -+ /// Source fragment holding complex elements -+ typename SourceFragment, -+ /// Number of mma operations performed -+ typename MmaIterations, -+ /// Shape of operand elements -+ typename MmaOperandShape, -+ /// Complex transform on A operand -+ ComplexTransform Transform_, -+ /// Operand A or Operand B -+ Operand Operand_, -+ /// Floating-point rounding style -+ FloatRoundStyle Round_> -+struct UnpackComplexConvertAndPackForMma; -+ -+// Partial specialization for OperandA and Congruous smem layout -+template < -+ typename RealElement, -+ typename DestinationFragment, -+ typename SourceFragment, -+ typename MmaIterations, -+ typename MmaOperandShape, -+ ComplexTransform Transform_, -+ FloatRoundStyle Round_> -+struct UnpackComplexConvertAndPackForMma < -+ RealElement, -+ DestinationFragment, -+ SourceFragment, -+ MmaIterations, -+ MmaOperandShape, -+ Transform_, -+ Operand::kA, -+ Round_> { -+ -+ // -+ // Type definitions -+ // -+ static Operand const kOperand = Operand::kA; -+ static ComplexTransform const kTransform = Transform_; -+ static FloatRoundStyle const kRound = Round_; -+ -+ // Data type of elements in the destination fragment -+ using MmaElement = typename DestinationFragment::Element; -+ -+ // Numeric convertor MmaElement <= RealElement -+ using Converter = NumericConverter; -+ -+ // Operand layout parameters -+ using SourceFragmentLayout = layout::ColumnMajor; -+ static int const kLdm = MmaIterations::kRow * MmaOperandShape::kRow; -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ UnpackComplexConvertAndPackForMma() {} -+ -+ CUTLASS_DEVICE -+ void operator()(DestinationFragment *dest, SourceFragment const &source) { -+ -+ Converter convert_op; -+ SourceFragmentLayout layout(kLdm); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int i=0; i and apply rounding on real and imag parts -+ MmaElement a = convert_op(source[layout(MatrixCoord{row,col})].real()); -+ MmaElement b = convert_op(source[layout(MatrixCoord{row,col})].imag()); -+ -+ // Unpack rounded complex and pack into DestinationFragment for mma operation -+ dest[i][pos] = a; -+ dest[i+MmaIterations::kRow][pos++] = (kTransform == ComplexTransform::kConjugate ? -b : b); -+ -+ } -+ } -+ } -+ } -+}; -+ -+// Partial specialization for OperandB and Congruous smem layout -+template < -+ typename RealElement, -+ typename DestinationFragment, -+ typename SourceFragment, -+ typename MmaIterations, -+ typename MmaOperandShape, -+ ComplexTransform Transform_, -+ FloatRoundStyle Round_> -+struct UnpackComplexConvertAndPackForMma < -+ RealElement, -+ DestinationFragment, -+ SourceFragment, -+ MmaIterations, -+ MmaOperandShape, -+ Transform_, -+ Operand::kB, -+ Round_> { -+ -+ // -+ // Type definitions -+ // -+ static Operand const kOperand = Operand::kB; -+ static ComplexTransform const kTransform = Transform_; -+ static FloatRoundStyle const kRound = Round_; -+ -+ // Data type of elements in the destination fragment -+ using MmaElement = typename DestinationFragment::Element; -+ -+ // Numeric convertor MmaElement <= RealElement -+ using Converter = NumericConverter; -+ -+ // Operand layout parameters -+ using SourceFragmentLayout = layout::RowMajor; -+ static int const kLdm = MmaIterations::kColumn * MmaOperandShape::kColumn; -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ UnpackComplexConvertAndPackForMma() {} -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(DestinationFragment *dest, SourceFragment const &source) { -+ -+ Converter convert_op; -+ SourceFragmentLayout layout(kLdm); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int i=0; i apply rounding on real and imag parts -+ MmaElement a = convert_op(source[layout(MatrixCoord{row,col})].real()); -+ MmaElement b = convert_op(source[layout(MatrixCoord{row,col})].imag()); -+ -+ // Unpack rounded complex and pack into DestinationFragment for mma operation -+ dest[i][pos] = a; -+ dest[i+MmaIterations::kColumn][pos++] = (kTransform == ComplexTransform::kConjugate ? -b : b); -+ } -+ } -+ } -+ } -+}; -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transform on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Do source operands need more than one elements -+ bool GeneralizedOperatorElements = false, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaComplexTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB -+> -+class MmaComplexTensorOp< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ Policy_, -+ TransformA, -+ TransformB> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = complex; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = complex; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = complex; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicyTensorOp) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Architecture tag from underlying instruction -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Indicates math operator -+ using MathOperator = arch::OpMultiplyAddComplex; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = FragmentA; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kColumn, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = FragmentB; -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ Shape::kM / ArchMmaOperator::Shape::kM, -+ Shape::kN / ArchMmaOperator::Shape::kN -+ >; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta>; -+ -+ /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this -+ /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued -+ /// parts are stored consecutively followed by all imaginary parts. This matches the structure -+ /// of Tensor Cores which are always real-valued matrix multiplies. -+ using FragmentC = typename IteratorC::Fragment; -+ -+ static_assert( -+ FragmentC::kElements == 2 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, -+ "Unexpected planar complex fragment length."); -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying real-valued matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaComplexTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ // Alias types for underlying real-valued matrix multiply operator -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ static_assert(MmaOperandA::kElements == 1, -+ "This implementation only supports math instructions in which exactly one element is needed for the A operand." -+ "We can geneneralize later."); -+ -+ static_assert(MmaOperandB::kElements == 1, -+ "This implementation only supports math instructions in which exactly one element is needed for the B operand." -+ "We can geneneralize later."); -+ -+ D = C; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ // mma(accum.real(), a.real(), b.real(), accum.real()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ operand_A[0] = A[m].real(); -+ operand_B[0] = B[n].real(); -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ -+ // mma(accum.imag(), a.real(), b.imag(), accum.imag()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ operand_A[0] = A[m].real(); -+ operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag()); -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ -+ // mma(accum.real(), -a.imag(), b.imag(), accum.real()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ // A imaginary part is intentionally negated -+ operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? A[m].imag() : -A[m].imag()); -+ operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag()); -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ -+ // mma(accum.imag(), a.imag(), b.real(), accum.imag()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? -A[m].imag() : A[m].imag()); -+ operand_B[0] = B[n].real(); -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ } -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ //TODO: Implement this -+ dst_A = A; -+ dst_B = B; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex+complex => complex: -+// Operands data type: complex -+// Rounding: float -> tfloat32_t (round half_ulp_truncate nearest) -+// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+// Output data type: complex -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB -+> -+class MmaComplexTensorOp< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ Policy_, -+ TransformA, -+ TransformB> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of members of complex multiplicand A -+ using RealElementA = float; -+ -+ /// Data type of multiplicand A -+ using ElementA = complex; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of members of complex multiplicand B -+ using RealElementB = float; -+ -+ /// Data type of multiplicand B -+ using ElementB = complex; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of members of complex accumulator matrix C -+ using RealElementC = float; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = complex; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Underlying arch tag -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Indicates math operator -+ using MathOperator = typename arch::OpMultiplyAddComplex; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = -+ Array; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kColumn, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = -+ Array; -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of complex products operations performed (one complex product needs four mma instructions) -+ using MmaIterations = MatrixShape< -+ Shape::kM / ArchMmaOperator::Shape::kM, -+ Shape::kN / ArchMmaOperator::Shape::kN -+ >; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta>; -+ -+ /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this -+ /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued -+ /// parts are stored consecutively followed by all imaginary parts. This matches the structure -+ /// of Tensor Cores which are always real-valued matrix multiplies. -+ using FragmentC = typename IteratorC::Fragment; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying real-valued matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaComplexTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ TransformedFragmentA const &A, -+ TransformedFragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ // Alias types for underlying real-valued matrix multiply operator -+ using InstMmaOperandA = typename ArchMmaOperator::FragmentA; -+ using InstMmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ static_assert(platform::is_same, typename ArchMmaOperator::Shape>::value, -+ "This implementation only supports mma.m16n8k8 math instructions."); -+ -+ static_assert(InstMmaOperandA::kElements == 4, -+ "This implementation only supports math instructions in which exactly four element is needed for the A operand." -+ "We can geneneralize later."); -+ -+ static_assert(InstMmaOperandB::kElements == 2, -+ "This implementation only supports math instructions in which exactly two element is needed for the B operand." -+ "We can geneneralize later."); -+ -+ // Instruction Operands A & B holding real part followed by imaginary part for mma operations -+ InstMmaOperandA const *operand_A = reinterpret_cast(&A); -+ InstMmaOperandB const *operand_B = reinterpret_cast(&B); -+ -+ // -+ // Accumulate in place -+ // -+ D = C; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ // mma(accum.real(), a.real(), b.real(), accum.real()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A[m], operand_B[n], *accum); -+ } -+ -+ // mma(accum.imag(), a.real(), b.imag(), accum.imag()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A[m], operand_B[n+MmaIterations::kColumn], *accum); -+ } -+ -+ // mma(accum.real(), a.imag(), -b.imag(), accum.real()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // negate OperandB to accumulate -(a.imag()*b.imag()) -+ // negating OperandB emits less instrucitons than negating OperandA as OperandB has less elements -+ negate negate_op; -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A[m+MmaIterations::kRow], negate_op(operand_B[n+MmaIterations::kColumn]), *accum); -+ } -+ -+ // mma(accum.imag(), a.imag(), b.real(), accum.imag()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A[m+MmaIterations::kRow], operand_B[n], *accum); -+ } -+ } -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ // Alias types for underlying real-valued matrix multiply operator -+ using InstMmaOperandA = typename ArchMmaOperator::FragmentA; -+ using InstMmaOperandB = typename ArchMmaOperator::FragmentB; -+ -+ // -+ // Define conversions from source type to instruction operands' type -+ // -+ -+ FloatRoundStyle const kRoundA = FloatRoundStyle::round_half_ulp_trunc_dntz; -+ FloatRoundStyle const kRoundB = FloatRoundStyle::round_half_ulp_trunc_dntz; -+ -+ detail::UnpackComplexConvertAndPackForMma < -+ RealElementA, -+ InstMmaOperandA, -+ FragmentA, -+ MmaIterations, -+ MatrixShape<2, 2>, -+ kTransformA, -+ Operand::kA, -+ kRoundA> convert_A; -+ -+ detail::UnpackComplexConvertAndPackForMma < -+ RealElementB, -+ InstMmaOperandB, -+ FragmentB, -+ MmaIterations, -+ MatrixShape<2, 1>, -+ kTransformB, -+ Operand::kB, -+ kRoundB> convert_B; -+ -+ // Convert Fragment[A|B] holding complex to InstMmaOperand[A|B] holding InstMmaOperand[A|B]::Element -+ convert_A(reinterpret_cast(&dst_A), A); -+ convert_B(reinterpret_cast(&dst_B), B); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization for complex*complex+complex => complex: -+// Operands data type: complex -+// Math instruction: mma.sync.aligned.m16n8k4.f64.f64.f64.f64 -+// Output data type: complex -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB -+> -+class MmaComplexTensorOp< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ Policy_, -+ TransformA, -+ TransformB, -+ true> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of members of complex multiplicand A -+ using RealElementA = double; -+ -+ /// Data type of multiplicand A -+ using ElementA = complex; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of members of complex multiplicand B -+ using RealElementB = double; -+ -+ /// Data type of multiplicand B -+ using ElementB = complex; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of members of complex accumulator matrix C -+ using RealElementC = double; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = complex; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicyTensorOp) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Underlying arch tag -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Indicates math operator -+ using MathOperator = typename arch::OpMultiplyAddComplex; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = FragmentA; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kColumn, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = FragmentB; -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ Shape::kM / ArchMmaOperator::Shape::kM, -+ Shape::kN / ArchMmaOperator::Shape::kN -+ >; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta>; -+ -+ /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this -+ /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued -+ /// parts are stored consecutively followed by all imaginary parts. This matches the structure -+ /// of Tensor Cores which are always real-valued matrix multiplies. -+ using FragmentC = typename IteratorC::Fragment; -+ -+ static_assert( -+ FragmentC::kElements == 2 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, -+ "Unexpected planar complex fragment length."); -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying real-valued matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaComplexTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ // Alias types for underlying real-valued matrix multiply operator -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ D = C; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ // mma(accum.real(), a.real(), b.real(), accum.real()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_A[mk] = A[m*MmaOperandA::kElements + mk].real(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_B[nk] = B[n*MmaOperandB::kElements + nk].real(); -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ -+ // mma(accum.imag(), a.real(), b.imag(), accum.imag()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_A[mk] = A[m*MmaOperandA::kElements + mk].real(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_B[nk] = (kTransformB == ComplexTransform::kConjugate ? -+ -B[n*MmaOperandB::kElements + nk].imag() : B[n*MmaOperandB::kElements + nk].imag()); -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ -+ // mma(accum.real(), -a.imag(), b.imag(), accum.real()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ // A imaginary part is intentionally negated -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_A[mk] = (kTransformA == ComplexTransform::kConjugate ? -+ A[m*MmaOperandA::kElements + mk].imag() : -A[m*MmaOperandA::kElements + mk].imag()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_B[nk] = (kTransformB == ComplexTransform::kConjugate ? -+ -B[n*MmaOperandB::kElements + nk].imag() : B[n*MmaOperandB::kElements + nk].imag()); -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ -+ // mma(accum.imag(), a.imag(), b.real(), accum.imag()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_A[mk] = (kTransformA == ComplexTransform::kConjugate ? -+ -A[m*MmaOperandA::kElements + mk].imag() : A[m*MmaOperandA::kElements + mk].imag()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_B[nk] = B[n*MmaOperandB::kElements + nk].real(); -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ } -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ dst_A = A; -+ dst_B = B; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// TODO - partial specializations of real*complex and complex*real -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h -new file mode 100644 -index 0000000..4db983d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h -@@ -0,0 +1,663 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+namespace detail { -+ -+template < -+ /// Data type of real & imag members of complex numbers in the SourceFragment -+ typename RealElement, -+ /// Destination fragment required by the mma operation -+ typename DestinationFragment, -+ /// Source fragment holding complex elements -+ typename SourceFragment, -+ /// Number of mma operations performed -+ typename MmaIterations, -+ /// Shape of operand elements -+ typename MmaOperandShape, -+ /// Complex transform on A operand -+ ComplexTransform Transform_, -+ /// Operand A or Operand B -+ Operand Operand_, -+ /// Floating-point rounding style for big part -+ FloatRoundStyle RoundBig_, -+ /// Floating-point rounding style for small part -+ FloatRoundStyle RoundSmall_> -+struct UnpackComplexConvertAndPackForMmaFastF32; -+ -+// Partial specialization for OperandA and Congruous smem layout -+template < -+ typename RealElement, -+ typename DestinationFragment, -+ typename SourceFragment, -+ typename MmaIterations, -+ typename MmaOperandShape, -+ ComplexTransform Transform_, -+ FloatRoundStyle RoundBig_, -+ FloatRoundStyle RoundSmall_> -+struct UnpackComplexConvertAndPackForMmaFastF32 < -+ RealElement, -+ DestinationFragment, -+ SourceFragment, -+ MmaIterations, -+ MmaOperandShape, -+ Transform_, -+ Operand::kA, -+ RoundBig_, -+ RoundSmall_> { -+ -+ // -+ // Type definitions -+ // -+ static Operand const kOperand = Operand::kA; -+ static ComplexTransform const kTransform = Transform_; -+ static FloatRoundStyle const kRoundBig = RoundBig_; -+ static FloatRoundStyle const kRoundSmall = RoundSmall_; -+ -+ // Data type of elements in the destination fragment -+ using MmaElement = typename DestinationFragment::Element; -+ -+ // Numeric convertor MmaElementBig, MmaElementSmall <= RealElement -+ using Converter = NumericConverterFastF32; -+ -+ // Operand layout parameters -+ using SourceFragmentLayout = layout::ColumnMajor; -+ static int const kLdm = MmaIterations::kRow * MmaOperandShape::kRow; -+ -+ // BigSmall Fragment holding two TF32 elements (big, small) for every float -+ using BigSmallFragment = Array; -+ -+ /// Index in fargments for the big and small part -+ static int const kBigIndex = 0; -+ static int const kSmallIndex = 1; -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ UnpackComplexConvertAndPackForMmaFastF32() {} -+ -+ CUTLASS_DEVICE -+ void operator()(DestinationFragment *dest, SourceFragment const &source) { -+ -+ Converter convert_op; -+ SourceFragmentLayout layout(kLdm); -+ -+ DestinationFragment *dest_big_ = reinterpret_cast(dest); -+ DestinationFragment *dest_small_ = reinterpret_cast(&dest[MmaIterations::kRow * 2]); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int i=0; i and apply rounding on real and imag parts -+ BigSmallFragment a = convert_op(source[layout(MatrixCoord{row,col})].real()); -+ BigSmallFragment b = convert_op(source[layout(MatrixCoord{row,col})].imag()); -+ -+ // Unpack rounded complex and pack into DestinationFragment for mma operation -+ dest_big_[i][pos] = a[kBigIndex]; -+ dest_big_[i+MmaIterations::kRow][pos] = (kTransform == ComplexTransform::kConjugate ? -b[kBigIndex] : b[kBigIndex]); -+ -+ // Unpack rounded complex and pack into DestinationFragment for mma operation -+ dest_small_[i][pos] = a[kSmallIndex]; -+ dest_small_[i+MmaIterations::kRow][pos] = (kTransform == ComplexTransform::kConjugate ? -b[kSmallIndex] : b[kSmallIndex]); -+ -+ // Next position -+ pos++; -+ } -+ } -+ } -+ } -+}; -+ -+// Partial specialization for OperandB and Congruous smem layout -+template < -+ typename RealElement, -+ typename DestinationFragment, -+ typename SourceFragment, -+ typename MmaIterations, -+ typename MmaOperandShape, -+ ComplexTransform Transform_, -+ FloatRoundStyle RoundBig_, -+ FloatRoundStyle RoundSmall_> -+struct UnpackComplexConvertAndPackForMmaFastF32 < -+ RealElement, -+ DestinationFragment, -+ SourceFragment, -+ MmaIterations, -+ MmaOperandShape, -+ Transform_, -+ Operand::kB, -+ RoundBig_, -+ RoundSmall_> { -+ -+ // -+ // Type definitions -+ // -+ static Operand const kOperand = Operand::kB; -+ static ComplexTransform const kTransform = Transform_; -+ static FloatRoundStyle const kRoundBig = RoundBig_; -+ static FloatRoundStyle const kRoundSmall = RoundSmall_; -+ -+ // Data type of elements in the destination fragment -+ using MmaElement = typename DestinationFragment::Element; -+ -+ // Numeric convertor MmaElementBig, MmaElementSmall <= RealElement -+ using Converter = NumericConverterFastF32; -+ -+ // Operand layout parameters -+ using SourceFragmentLayout = layout::RowMajor; -+ static int const kLdm = MmaIterations::kColumn * MmaOperandShape::kColumn; -+ -+ // BigSmall Fragment holding two TF32 elements (big, small) for every float -+ using BigSmallFragment = Array; -+ -+ /// Index in fargments for the big and small part -+ static int const kBigIndex = 0; -+ static int const kSmallIndex = 1; -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ UnpackComplexConvertAndPackForMmaFastF32() {} -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(DestinationFragment *dest, SourceFragment const &source) { -+ -+ Converter convert_op; -+ SourceFragmentLayout layout(kLdm); -+ -+ DestinationFragment *dest_big_ = reinterpret_cast(dest); -+ DestinationFragment *dest_small_ = reinterpret_cast(&dest[MmaIterations::kColumn * 2]); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int i=0; i apply rounding on real and imag parts -+ BigSmallFragment a = convert_op(source[layout(MatrixCoord{row,col})].real()); -+ BigSmallFragment b = convert_op(source[layout(MatrixCoord{row,col})].imag()); -+ -+ // Unpack rounded complex and pack into DestinationFragment for mma operation -+ dest_big_[i][pos] = a[kBigIndex]; -+ dest_big_[i+MmaIterations::kColumn][pos] = (kTransform == ComplexTransform::kConjugate ? -b[kBigIndex] : b[kBigIndex]); -+ -+ // Unpack rounded complex and pack into DestinationFragment for mma operation -+ dest_small_[i][pos] = a[kSmallIndex]; -+ dest_small_[i+MmaIterations::kColumn][pos] = (kTransform == ComplexTransform::kConjugate ? -b[kSmallIndex] : b[kSmallIndex]); -+ -+ // next position -+ pos++; -+ } -+ } -+ } -+ } -+}; -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transform on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaComplexTensorOpFastF32; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex+complex => complex: -+// Operands data type: complex -+// Rounding: float -> tfloat32_t (round half_ulp_truncate nearest) -+// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+// Output data type: complex -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB, -+ /// Used for partial specialization -+ typename Enable -+> -+class MmaComplexTensorOpFastF32< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ Policy_, -+ TransformA, -+ TransformB, -+ Enable> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of members of complex multiplicand A -+ using RealElementA = float; -+ -+ /// Data type of multiplicand A -+ using ElementA = complex; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of members of complex multiplicand B -+ using RealElementB = float; -+ -+ /// Data type of multiplicand B -+ using ElementB = complex; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of members of complex accumulator matrix C -+ using RealElementC = float; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = complex; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Underlying arch tag -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Indicates math operator -+ using MathOperator = arch::OpMultiplyAddComplexFastF32; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ -+ /// Tune F32 to TF32 big small conversion for complex operation -+ /// Different combination of big small conversin can cause different tradeoff -+ /// between speed and accuracy. Generally, use round_half_ulp_truncate can -+ /// improve the performance but hur the accuracy. -+ using ComplexFastF32 = FastF32 < -+ FloatRoundStyle::round_toward_zero, // kRoundBigA -+ FloatRoundStyle::round_half_ulp_truncate, // kRoundSmallA -+ FloatRoundStyle::round_toward_zero, // kRoundBigB -+ FloatRoundStyle::round_half_ulp_truncate, // kRoundSmallB -+ TensorFloat32Op::k3xTF32 // Number of TF32 operations -+ >; -+ -+ /// Index in fargments for the big and small part -+ static int const kBigIndex = 0; -+ static int const kSmallIndex = 1; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ // (4 times the original FragmentA::kElements) -+ // (real_big), (imag_big), (real_small), (imag_small) -+ using TransformedFragmentA = Array; -+ -+ // Fragment bisecting big and small sections -+ // (real_big, imag_big), (real_small, imag_small) -+ using AccessTypeFragmentA = Array; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kColumn, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ // (4 times the original FragmentB::kElements) -+ // (real_big), (imag_big), (real_small), (imag_small) -+ using TransformedFragmentB = Array; -+ -+ // Fragment bisecting big and small sections -+ // (real_big, imag_big), (real_small, imag_small) -+ using AccessTypeFragmentB = Array; -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of complex products operations performed (one complex product needs four mma instructions) -+ using MmaIterations = MatrixShape< -+ Shape::kM / ArchMmaOperator::Shape::kM, -+ Shape::kN / ArchMmaOperator::Shape::kN -+ >; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta>; -+ -+ /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this -+ /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued -+ /// parts are stored consecutively followed by all imaginary parts. This matches the structure -+ /// of Tensor Cores which are always real-valued matrix multiplies. -+ using FragmentC = typename IteratorC::Fragment; -+ -+ // -+ // Alias types for underlying real-valued matrix multiply operator -+ // -+ using InstMmaOperandA = typename ArchMmaOperator::FragmentA; -+ using InstMmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ static_assert(platform::is_same, typename ArchMmaOperator::Shape>::value, -+ "This implementation only supports mma.m16n8k8 math instructions."); -+ -+ static_assert(InstMmaOperandA::kElements == 4, -+ "This implementation only supports math instructions in which exactly four element is needed for the A operand." -+ "We can geneneralize later."); -+ -+ static_assert(InstMmaOperandB::kElements == 2, -+ "This implementation only supports math instructions in which exactly two element is needed for the B operand." -+ "We can geneneralize later."); -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying real-valued matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaComplexTensorOpFastF32() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ TransformedFragmentA const &A, -+ TransformedFragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ AccessTypeFragmentA const *complex_A = reinterpret_cast(&A); -+ AccessTypeFragmentB const *complex_B = reinterpret_cast(&B); -+ -+ // -+ // Accumulate in place -+ // -+ D = C; -+ -+ -+ complex_mma_operator(D, complex_A[kSmallIndex], complex_B[kBigIndex], D); -+ -+ complex_mma_operator(D, complex_A[kBigIndex], complex_B[kSmallIndex], D); -+ -+ complex_mma_operator(D, complex_A[kBigIndex], complex_B[kBigIndex], D); -+ -+ if (ComplexFastF32::kPrecision == TensorFloat32Op::k4xTF32) -+ complex_mma_operator(D, complex_A[kSmallIndex], complex_B[kSmallIndex], D); -+ } -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void complex_mma_operator( -+ FragmentC &D, -+ AccessTypeFragmentA const &complex_A, -+ AccessTypeFragmentB const &complex_B, -+ FragmentC const &C -+ ) const { -+ -+ // Instruction Operands A & B holding real part followed by imaginary part for mma operations -+ InstMmaOperandA const *operand_A = reinterpret_cast(&complex_A); -+ InstMmaOperandB const *operand_B = reinterpret_cast(&complex_B); -+ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ // mma(accum.real(), a.real(), b.real(), accum.real()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A[m], operand_B[n], *accum); -+ } -+ -+ // mma(accum.imag(), a.real(), b.imag(), accum.imag()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A[m], operand_B[n+MmaIterations::kColumn], *accum); -+ } -+ -+ // mma(accum.real(), a.imag(), -b.imag(), accum.real()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // negate OperandB to accumulate -(a.imag()*b.imag()) -+ // negating OperandB emits less instrucitons than negating OperandA as OperandB has less elements -+ negate negate_op; -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A[m+MmaIterations::kRow], negate_op(operand_B[n+MmaIterations::kColumn]), *accum); -+ } -+ -+ // mma(accum.imag(), a.imag(), b.real(), accum.imag()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A[m+MmaIterations::kRow], operand_B[n], *accum); -+ } -+ } -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ -+ detail::UnpackComplexConvertAndPackForMmaFastF32 < -+ RealElementA, -+ InstMmaOperandA, -+ FragmentA, -+ MmaIterations, -+ MatrixShape<2, 2>, -+ kTransformA, -+ Operand::kA, -+ ComplexFastF32::kRoundBigA, -+ ComplexFastF32::kRoundSmallA> convert_A; -+ -+ detail::UnpackComplexConvertAndPackForMmaFastF32 < -+ RealElementB, -+ InstMmaOperandB, -+ FragmentB, -+ MmaIterations, -+ MatrixShape<2, 1>, -+ kTransformB, -+ Operand::kB, -+ ComplexFastF32::kRoundBigB, -+ ComplexFastF32::kRoundSmallB> convert_B; -+ -+ // Convert Fragment[A|B] holding complex to InstMmaOperand[A|B] holding InstMmaOperand[A|B]::Element -+ convert_A(reinterpret_cast(&dst_A), A); -+ convert_B(reinterpret_cast(&dst_B), B); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h -new file mode 100644 -index 0000000..d872012 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h -@@ -0,0 +1,2493 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for loading 128b vectors of 128b elements. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicandCongruous128b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ static_assert(!(Shape::kContiguous % 8) && !(Shape::kStrided % 4), "Divisibility."); -+ -+ static_assert(sizeof_bits::value == 128, "This is specialized for 128b accesses."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCongruous128b; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Load two elements per access -+ static int const kElementsPerAccess = 1; -+ -+ /// Policy defining internal details of tile iterator -+ struct Policy { -+ -+ /// Shape of one access -+ using Delta = layout::PitchLinearShape<8, 4>; -+ -+ /// Number of iterations to load -+ using Iterations = layout::PitchLinearShape< -+ Shape::kContiguous / Delta::kContiguous, -+ InstructionShape::kStrided / Delta::kStrided -+ >; -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0) { -+ -+ int quad_pair = lane_id / 8; -+ int quad = lane_id / 4; -+ int lane = lane_id % 4; -+ -+ int row = (quad & 1) * 4 + (lane ^ quad_pair); -+ -+ byte_offset_ = (row + quad_pair * stride_) * sizeof(AccessType); -+ -+ pointer_= reinterpret_cast(ref.data()); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ pointer_ += offset; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int offset = -+ (tile_offset.contiguous() * Shape::kContiguous) + -+ (tile_offset.strided() * InstructionShape::kStrided * stride_); -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ pointer_ += stride_ * InstructionShape::kStrided; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::Iterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::Iterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::Delta::kContiguous * c + -+ Policy::Delta::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ AccessType const *source = reinterpret_cast(source_byte_ptr); -+ -+ fetch_ptr[access_idx] = *source; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCongruous128b, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(-tile_offset.column(), -tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.strided(), tile_offset.contiguous()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCongruous128b, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(-tile_offset.row(), -tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// -+/// Partial specialization for complex -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of underlying field of reals. -+ typename RealElement, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpAccumulatorTileIterator< -+ Shape_, complex, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = complex; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape; -+ }; -+ -+private: -+ -+ // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire -+ // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements -+ // of that row. The accumulators within one row are assumed to be consecutive. -+ static int const kElementsPerAccess = InstructionShape::kN / 4; -+ static int const kRowsPerTile = 8; -+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile. It is assumed that the accumulators -+ /// are stored in a planar complex arrangement with the real parts as entirely contiguous -+ /// followed by the imaginary parts. -+ using Fragment = Array; -+ -+ static int const kRealIndex = 0; -+ static int const kImaginaryIndex = Shape::kCount / kThreads; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ -+ Element z = offset_ref.at({accum_m, accum_n}); -+ -+ frag[mma_accum_start + row * kElementsPerAccess + col + kRealIndex] = z.real(); -+ frag[mma_accum_start + row * kElementsPerAccess + col + kImaginaryIndex] = z.imag(); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ -+ Element z(frag[kRealIndex + idx], frag[kImaginaryIndex + idx]); -+ -+ offset_ref.at({accum_m, accum_n}) = z; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for loading 128b vectors of 128b elements. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicandCrosswise128x4, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 8), "Divisibility."); -+ -+ static_assert(sizeof_bits::value == 128, "This is specialized for 128b accesses."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCrosswise128x4; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Load two elements per access -+ static int const kElementsPerAccess = 1; -+ -+ /// Policy defining internal details of tile iterator -+ struct Policy { -+ -+ /// Shape of one access -+ using Delta = layout::PitchLinearShape<4, 8>; -+ -+ /// Number of iterations to load -+ using Iterations = layout::PitchLinearShape< -+ InstructionShape::kContiguous / Delta::kContiguous, -+ Shape::kStrided / Delta::kStrided -+ >; -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0) { -+ -+ int quad = lane_id / 4; -+ int liq = lane_id % 4; -+ -+ int c = liq + (quad & 1) * 4; -+ int s = (quad / 2); -+ -+ byte_offset_ = (c + s * stride_) * sizeof(AccessType); -+ -+ pointer_= reinterpret_cast(ref.data()); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ pointer_ += offset; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ // Compute the offset in units of elements. Note, the external coordinate system is -+ // approximately transposed with respect to the tiled internal structure -+ int offset = -+ (tile_offset.contiguous() * InstructionShape::kContiguous) * stride_ + -+ (tile_offset.strided() * Shape::kStrided); -+ -+ add_pointer_offset(offset); -+ -+ byte_offset_ ^= (tile_offset.contiguous() & 1) * 4 * sizeof(AccessType); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ pointer_ += stride_ * InstructionShape::kContiguous; -+ -+ byte_offset_ ^= 4 * sizeof(AccessType); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::Iterations::kStrided; ++s) { -+ -+ int access_idx = s + c * Policy::Iterations::kStrided; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::Delta::kContiguous * c * stride_ + -+ Policy::Delta::kStrided * s; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ AccessType const *source = reinterpret_cast(source_byte_ptr); -+ -+ fetch_ptr[access_idx] = *source; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * InstructionShape::kContiguous * stride_ + -+ tile_offset.strided() * Shape::kStrided; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCrosswise128x4, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(-tile_offset.column(), -tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.strided(), tile_offset.contiguous()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCrosswise128x4, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(-tile_offset.row(), -tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Congruous shared memory layout -+// Warp-level iterators for complex*complex + complex => complex -+// The underlying iterators are similar to that for MMA f64*f64 + f64 = f64 -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for loading 128b vectors of 64b elements. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, cutlass::complex, -+ cutlass::layout::TensorOpMultiplicandCongruous64b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ static_assert(!(Shape::kContiguous % 16) && !(Shape::kStrided % 8), "Divisibility."); -+ -+ /// Element type -+ using Element = cutlass::complex; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCongruous64b; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Load two elements per access -+ static int const kElementsPerAccess = 2; -+ -+ /// Policy defining internal details of tile iterator -+ struct Policy { -+ -+ /// Shape of one access -+ using Delta = layout::PitchLinearShape<8, 4>; -+ -+ /// Number of iterations to load -+ using Iterations = layout::PitchLinearShape< -+ Shape::kContiguous / kElementsPerAccess / Delta::kContiguous, -+ InstructionShape::kStrided / Delta::kStrided -+ >; -+ -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+ /// Internal counter used to jump to next K partition -+ int k_group_idx_; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), -+ k_group_idx_(0) { -+ -+ int access_strided = lane_id / Policy::Delta::kContiguous; -+ int access_contiguous = (lane_id % Policy::Delta::kContiguous) ^ access_strided; -+ -+ pointer_= reinterpret_cast(ref.data()) + -+ access_contiguous + access_strided * stride_; -+ -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int offset = -+ (tile_offset.strided() * InstructionShape::kStrided) * stride_ * kElementsPerAccess + -+ tile_offset.contiguous() * Shape::kContiguous; -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ add_tile_offset({0, 1}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ add_tile_offset({0, -1}); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::Iterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::Iterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::Delta::kContiguous * c + -+ Policy::Delta::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ AccessType const *source = reinterpret_cast(source_byte_ptr); -+ -+ fetch_ptr[access_idx] = *source; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Crosswise shared memory layout -+// Warp-level iterators for complex*complex + complex => complex -+// The underlying iterators are similar to that for f64*f64 + f64 = f64 -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for loading 128b vectors of 64b elements. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, complex, -+ cutlass::layout::TensorOpMultiplicand64bCrosswise, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 16), "Divisibility."); -+ -+ static_assert(sizeof_bits>::value == 64, "This is specialized for 64b accesses."); -+ -+ /// Element type -+ using Element = complex; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicand64bCrosswise; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Load two elements per access -+ static int const kElementsPerAccess = 2; -+ -+ /// Policy defining internal details of tile iterator -+ struct Policy { -+ -+ /// Shape of one access -+ using Delta = layout::PitchLinearShape<4, 16>; -+ -+ /// Number of iterations to load -+ using Iterations = layout::PitchLinearShape< -+ InstructionShape::kContiguous / Delta::kContiguous, -+ Shape::kStrided / Delta::kStrided -+ >; -+ -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ /// Internal counter for tracking K-group -+ Index k_group_idx_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), -+ k_group_idx_(0) { -+ -+ int access_strided = lane_id / 8; -+ int access_contiguous = (lane_id % 8); -+ -+ byte_offset_ = (access_contiguous + access_strided * stride_) * sizeof(AccessType); -+ -+ pointer_= reinterpret_cast(ref.data()); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ pointer_ += offset / kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ int offset = (tile_offset.contiguous() * InstructionShape::kContiguous) * -+ stride_ * kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided; -+ -+ add_pointer_offset(offset); -+ -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { -+ -+ add_tile_offset(tile_offset); -+ -+ if (k_group_idx_ & 1) -+ byte_offset_ ^= 0x40; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ pointer_ += stride_ * InstructionShape::kContiguous; -+ -+ // xor ptr -+ byte_offset_ ^= 0x40; -+ -+ ++k_group_idx_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::Iterations::kStrided; ++s) { -+ -+ int access_idx = c * Policy::Iterations::kStrided + s; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::Delta::kContiguous * c * stride_ + -+ Policy::Delta::kStrided * s / kElementsPerAccess; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ AccessType const *source = reinterpret_cast(source_byte_ptr); -+ -+ fetch_ptr[access_idx] = *source; -+ } -+ } -+ -+ Element *exchange_ptr = reinterpret_cast(&frag); -+ -+ // exchange on 64b granularity only for fragments held in k=8/2 to k=8 -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = Fragment::kElements/2; i < Fragment::kElements; i += 2) { -+ Element tmp = exchange_ptr[i]; -+ exchange_ptr[i] = exchange_ptr[i + 1]; -+ exchange_ptr[i + 1] = tmp; -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = tile_offset.contiguous() * -+ InstructionShape::kContiguous / -+ Layout::kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ k_group_idx_ = k_group; -+ } -+}; -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h -new file mode 100644 -index 0000000..00760a6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h -@@ -0,0 +1,643 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transform on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Do source operands need more than one elements -+ bool GeneralizedOperatorElements = false, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaGaussianComplexTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB -+> -+class MmaGaussianComplexTensorOp< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ Policy_, -+ TransformA, -+ TransformB> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = complex; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = complex; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = complex; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Underlying arch tag -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Indicates math operator -+ using MathOperator = arch::OpMultiplyAddGaussianComplex; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = FragmentA; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kColumn, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = FragmentB; -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ Shape::kM / ArchMmaOperator::Shape::kM, -+ Shape::kN / ArchMmaOperator::Shape::kN -+ >; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpGaussianComplexAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta>; -+ -+ /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this -+ /// storage arrangement is to be considered 'gaussian complex' in the sense that the accumulation is -+ /// done in three parts namely part1, part2, and part3. The parts 1, 2, and 3 are stored consecutively -+ /// in InteratorC::Frament. This matches the structure of Tensor Cores which are always real-valued matrix multiplies. -+ using FragmentC = typename IteratorC::Fragment; -+ -+ static_assert( -+ FragmentC::kElements == 3 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, -+ "Unexpected gaussian complex fragment length."); -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying real-valued matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaGaussianComplexTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ // Alias types for underlying real-valued matrix multiply operator -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ static_assert(MmaOperandA::kElements == 1, -+ "This implementation only supports math instructions in which exactly one element is needed for the A operand." -+ "We can geneneralize later."); -+ -+ static_assert(MmaOperandB::kElements == 1, -+ "This implementation only supports math instructions in which exactly one element is needed for the B operand." -+ "We can geneneralize later."); -+ -+ D = C; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ // mma(accum.part1(), (a.real() + a.imag()), b.real(), accum.part1()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_Asum; -+ MmaOperandB operand_Br; -+ -+ operand_Asum[0] = A[m].real() + ((kTransformA == ComplexTransform::kConjugate) ? -A[m].imag() : +A[m].imag()); -+ operand_Br[0] = B[n].real(); -+ -+ // accumulator part1 -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_Asum, operand_Br, *accum); -+ } -+ -+ // mma(accum.part2(), -a.real(), (b.real() - b.imag()), accum.part2()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_Ar; -+ MmaOperandB operand_Bdiff; -+ -+ operand_Ar[0] = -A[m].real(); -+ operand_Bdiff[0] = B[n].real() - ((kTransformB == ComplexTransform::kConjugate) ? -B[n].imag() : +B[n].imag()); -+ -+ // accumulator part2 -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_Ar, operand_Bdiff, *accum); -+ } -+ -+ // mma(accum.part3(), a.imag(), (b.real() + b.imag()), accum.part3()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_Ai; -+ MmaOperandB operand_Bsum; -+ -+ operand_Ai[0] = (kTransformA == ComplexTransform::kConjugate) ? -A[m].imag() : +A[m].imag(); -+ operand_Bsum[0] = B[n].real() + ((kTransformB == ComplexTransform::kConjugate) ? -B[n].imag() : +B[n].imag()); -+ -+ // accumulator part3 -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + 2 * MmaIterations::kCount; -+ -+ mma(*accum, operand_Ai, operand_Bsum, *accum); -+ } -+ } -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ //TODO: Implement this -+ dst_A = A; -+ dst_B = B; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB -+> -+class MmaGaussianComplexTensorOp< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ Policy_, -+ TransformA, -+ TransformB, -+ true> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = complex; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = complex; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = complex; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Underlying arch tag -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Indicates math operator -+ using MathOperator = arch::OpMultiplyAddGaussianComplex; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = FragmentA; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kColumn, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = FragmentB; -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ Shape::kM / ArchMmaOperator::Shape::kM, -+ Shape::kN / ArchMmaOperator::Shape::kN -+ >; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpGaussianComplexAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta>; -+ -+ /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this -+ /// storage arrangement is to be considered 'gaussian complex' in the sense that the accumulation is -+ /// done in three parts namely part1, part2, and part3. The parts 1, 2, and 3 are stored consecutively -+ /// in InteratorC::Frament. This matches the structure of Tensor Cores which are always real-valued matrix multiplies. -+ using FragmentC = typename IteratorC::Fragment; -+ -+ static_assert( -+ FragmentC::kElements == 3 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, -+ "Unexpected gaussian complex fragment length."); -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying real-valued matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaGaussianComplexTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ // Alias types for underlying real-valued matrix multiply operator -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ D = C; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ // mma(accum.part1(), (a.real() + a.imag()), b.real(), accum.part1()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_Asum; -+ MmaOperandB operand_Br; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_Asum[mk] = A[m*MmaOperandA::kElements + mk].real() + ((kTransformA == ComplexTransform::kConjugate) ? -+ -A[m*MmaOperandA::kElements + mk].imag() : +A[m*MmaOperandA::kElements + mk].imag()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_Br[nk] = B[n*MmaOperandB::kElements + nk].real(); -+ -+ // accumulator part1 -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_Asum, operand_Br, *accum); -+ } -+ -+ // mma(accum.part2(), -a.real(), (b.real() - b.imag()), accum.part2()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_Ar; -+ MmaOperandB operand_Bdiff; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_Ar[mk] = -A[m*MmaOperandA::kElements + mk].real(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_Bdiff[nk] = B[n*MmaOperandB::kElements + nk].real() - ((kTransformB == ComplexTransform::kConjugate) ? -+ -B[n*MmaOperandB::kElements + nk].imag() : +B[n*MmaOperandB::kElements + nk].imag()); -+ -+ // accumulator part2 -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_Ar, operand_Bdiff, *accum); -+ } -+ -+ // mma(accum.part3(), a.imag(), (b.real() + b.imag()), accum.part3()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_Ai; -+ MmaOperandB operand_Bsum; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_Ai[mk] = (kTransformA == ComplexTransform::kConjugate) ? -+ -A[m*MmaOperandA::kElements + mk].imag() : +A[m*MmaOperandA::kElements + mk].imag(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_Bsum[nk] = B[n*MmaOperandB::kElements + nk].real() + ((kTransformB == ComplexTransform::kConjugate) ? -+ -B[n*MmaOperandB::kElements + nk].imag() : +B[n*MmaOperandB::kElements + nk].imag()); -+ -+ // accumulator part3 -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + 2 * MmaIterations::kCount; -+ -+ mma(*accum, operand_Ai, operand_Bsum, *accum); -+ } -+ } -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ dst_A = A; -+ dst_B = B; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h -new file mode 100644 -index 0000000..1903622 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h -@@ -0,0 +1,390 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpGaussianComplexAccumulatorTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// -+/// Partial specialization for complex -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of underlying field of reals. -+ typename RealElement, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpGaussianComplexAccumulatorTileIterator< -+ Shape_, complex, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = complex; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape; -+ }; -+ -+private: -+ -+ // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire -+ // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements -+ // of that row. The accumulators within one row are assumed to be consecutive. -+ static int const kElementsPerAccess = InstructionShape::kN / 4; -+ static int const kRowsPerTile = 8; -+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile. It is assumed that the accumulators -+ /// are stored in a gaussian complex arrangement with parts 1, 2, and 3 as entirely contiguous -+ /// arranged as [part1, part2, part3] -+ using Fragment = Array; -+ -+ static int const kPart1Index = (Shape::kCount / kThreads) * 0; -+ static int const kPart2Index = (Shape::kCount / kThreads) * 1; -+ static int const kPart3Index = (Shape::kCount / kThreads) * 2; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ -+ Element z = offset_ref.at({accum_m, accum_n}); -+ -+ frag[mma_accum_start + row * kElementsPerAccess + col + kPart1Index] = z.real() + z.imag(); -+ frag[mma_accum_start + row * kElementsPerAccess + col + kPart2Index] = -z.real(); -+ frag[mma_accum_start + row * kElementsPerAccess + col + kPart3Index] = z.imag(); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ -+ Element z(frag[kPart1Index + idx] - frag[kPart3Index + idx], -+ frag[kPart1Index + idx] + frag[kPart2Index + idx]); -+ -+ offset_ref.at({accum_m, accum_n}) = z; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h -new file mode 100644 -index 0000000..894efd7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h -@@ -0,0 +1,182 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/array_planar_complex.h" -+#include "cutlass/gemm/warp/tile_iterator_planar_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Underlying real-valued warp-level matrix multiply -+ typename Operator_, -+ /// Transformation applied to A operand (typically folded into math instruction) -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Transformation applied to B operand (typically folded into math instruction) -+ ComplexTransform TransformB = ComplexTransform::kNone -+> -+class MmaPlanarComplex { -+public: -+ -+ /// Underlying real-valued warp-level matrix multiply -+ using Operator = Operator_; -+ -+ /// Shape of warp-level matrix multipy -+ using Shape = typename Operator::Shape; -+ -+ /// Transformation applied to A operand (typically folded into math instruction) -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Transformation applied to B operand (typically folded into math instruction) -+ static ComplexTransform const kTransformB = TransformB; -+ -+ /// Fragment of elements -+ using FragmentA = ArrayPlanarComplex; -+ -+ /// Iterator into planar complex -+ using IteratorA = TileIteratorPlanarComplex; -+ -+ /// Layout in memory of the A operand -+ using LayoutA = typename Operator::LayoutA; -+ -+ using FragmentB = ArrayPlanarComplex; -+ -+ /// Iterator into planar complex -+ using IteratorB = TileIteratorPlanarComplex; -+ -+ /// Layout in memory of the B operand -+ using LayoutB = typename Operator::LayoutB; -+ -+ /// Tile iterator for accumulator -+ using IteratorC = TileIteratorPlanarComplex; -+ -+ /// Accumulator fragment -+ using FragmentC = ArrayPlanarComplex; -+ -+ /// Layout of accumulator fragment in memory -+ using LayoutC = typename Operator::LayoutC; -+ -+private: -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ Operator::Shape::kM / Operator::Policy::Operator::Shape::kM, -+ Operator::Shape::kN / Operator::Policy::Operator::Shape::kN -+ >; -+ -+public: -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaPlanarComplex() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A_in, -+ FragmentB const &B_in, -+ FragmentC const &C) const { -+ -+ D.real = C.real; -+ D.imag = C.imag; -+ -+ // -+ // Transform fragments based on conjugate operations. -+ // -+ -+ negate neg_A; -+ -+ FragmentA frag_A; -+ frag_A.real = A_in.real; -+ -+ if (kTransformA == ComplexTransform::kConjugate) { -+ frag_A.imag = neg_A(frag_A.imag); -+ } -+ else { -+ frag_A.imag = frag_A.imag; -+ } -+ -+ FragmentB frag_B; -+ frag_B.real = B_in.real; -+ -+ if (kTransformB == ComplexTransform::kConjugate) { -+ negate neg; -+ frag_B.imag = neg(frag_B.imag); -+ } -+ else { -+ frag_B.imag = frag_B.imag; -+ } -+ -+ // -+ // Accumulated real-valued matrix multiplies -+ // -+ -+ Operator real_mma; -+ -+ // D.i += A.i * B.r -+ real_mma(D.imag, frag_A.imag, frag_B.real, D.imag); -+ -+ // D.r += A.r * B.r -+ real_mma(D.real, frag_A.real, frag_B.real, D.real); -+ -+ // D.i += A.r * B.i -+ real_mma(D.imag, frag_A.real, frag_B.imag, D.imag); -+ -+ // D.r += -A.i * B.i -+ frag_A.imag = neg_A(frag_A.imag); -+ real_mma(D.real, frag_A.imag, frag_B.imag, D.real); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt.h -new file mode 100644 -index 0000000..9790792 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt.h -@@ -0,0 +1,264 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaSimt { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Hard-coded for now -+ using ArchTag = arch::Sm50; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ /// Layout of threads -+ using ThreadLayoutA = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA >::value, -+ layout::ColumnMajor, -+ typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value, -+ layout::RowMajor, -+ LayoutA>::type -+ >::type; -+ -+ using ThreadLayoutB = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutB >::value, -+ layout::ColumnMajor, -+ typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutB >::value, -+ layout::RowMajor, -+ LayoutB>::type -+ >::type; -+ -+ static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value || -+ platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) && -+ platform::is_same< ElementA, int8_t >::value && -+ platform::is_same< ElementB, int8_t >::value; -+ -+ using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type; -+ -+ /// Thread-level matrix multiply accumulate operator -+ using ThreadMma = thread::Mma< -+ GemmShape< -+ Shape::kM / Policy::WarpShape::kRow, -+ Shape::kN / Policy::WarpShape::kColumn, -+ Policy::LaneMmaShape::kK>, -+ ElementA, -+ ThreadLayoutA, -+ ElementB, -+ ThreadLayoutB, -+ ElementC, -+ LayoutC, -+ arch::OpMultiplyAdd, -+ dp4a_type -+ >; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename ThreadMma::ArchMmaOperator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Shape of the underlying instruction -+ using InstructionShape = GemmShape<1,1,use_dp4a ? 4 : 1>; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaSimtTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ Policy, -+ PartitionsK, -+ Shape::kK -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = FragmentA; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaSimtTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ Policy, -+ PartitionsK, -+ Shape::kK -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentB = FragmentB; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaSimtTileIterator< -+ MatrixShape, -+ Operand::kC, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Storage for C tile -+ using FragmentC = typename ThreadMma::FragmentC; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaSimt() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA a, -+ FragmentB b, -+ FragmentC const &c, int group_idx = 0) const { -+ -+ ThreadMma mma; -+ -+ if (kTransformA == ComplexTransform::kConjugate) { -+ conjugate conj_a; -+ a = conj_a(a); -+ } -+ -+ if (kTransformB == ComplexTransform::kConjugate) { -+ conjugate conj_b; -+ b = conj_b(b); -+ } -+ -+ mma(d, a, b, c); -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ //TODO: Implement this -+ dst_A = A; -+ dst_B = B; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h -new file mode 100644 -index 0000000..a0b0a75 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h -@@ -0,0 +1,69 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT -+ instructions -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Describes the arrangement and configuration of per-lane operations in warp-level matrix multiply -+template < -+ typename WarpShape_, ///< shape of the warp in lanes (concept: MatrixShape) -+ typename LaneLayout_, ///< layout function of lanes -+ typename LaneMmaShape_ ///< size of each lane's thread-level matrix product (concept: GemmShape) -+> -+struct MmaSimtPolicy { -+ using WarpShape = WarpShape_; -+ using LaneLayout = LaneLayout_; -+ using LaneMmaShape = LaneMmaShape_; -+ using MmaShape = LaneMmaShape; -+ -+ /// Returns a layout functor mapping lane position in the warp to thread ID -+ CUTLASS_HOST_DEVICE -+ static LaneLayout get_lane_layout() { -+ return LaneLayout::packed({WarpShape::kRow, WarpShape::kColumn}); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h -new file mode 100644 -index 0000000..53c1c36 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h -@@ -0,0 +1,1890 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT -+ instructions -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Iterates over operands to warp-level matrix multiply operations targeting SIMT instructions -+/// -+/// concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK = 1, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize = 1 -+> -+class MmaSimtTileIterator; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for A operands of column-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize -+> -+class MmaSimtTileIterator { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::ColumnMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kRow % Policy::WarpShape::kRow), -+ "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow / Policy::WarpShape::kRow, -+ Shape::kColumn -+ >; -+ -+ static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow / Policy::LaneMmaShape::kM, -+ ThreadShape::kColumn -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Internal reference -+ cutlass::TensorRef, layout::ColumnMajor> ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, 0); -+ -+ ref.add_coord_offset(lane_offset); -+ -+ ref_.reset( -+ reinterpret_cast *>(ref.data()), -+ ref.stride(0) / Policy::LaneMmaShape::kM); -+ } -+ -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ ref_.add_coord_offset({ -+ coord.row() * Shape::kRow / Policy::LaneMmaShape::kM, -+ coord.column() * Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ ref_.add_coord_offset({0, Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({0, -Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kColumn; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ -+ // This logic has been replaced with calls to inline PTX to guarantee vectorization. -+ #if 0 -+ dst_ptr[m + k * Iterations::kRow] = -+ *(ref_.data() + ref_.offset({m * Policy::WarpShape::kRow, k}) + pointer_offset / Policy::LaneMmaShape::kM); -+ #endif -+ -+ auto ptr = ref_.data() + ref_.offset({m * Policy::WarpShape::kRow, k}) + pointer_offset / Policy::LaneMmaShape::kM; -+ arch::shared_load(dst_ptr[m + k * Iterations::kRow], ptr); -+ } -+ } -+ } -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ Array const *src_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kN; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kM; ++m) { -+ *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) = -+ src_ptr[m + k * Iterations::kM]; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for A operands of row-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize -+> -+class MmaSimtTileIterator { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::RowMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kRow % Policy::WarpShape::kRow), -+ "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow / Policy::WarpShape::kRow, -+ Shape::kColumn -+ >; -+ -+ static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads (scalar loads) -+ using Iterations = MatrixShape< -+ ThreadShape::kRow / Policy::LaneMmaShape::kM, -+ ThreadShape::kColumn -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Internal reference -+ cutlass::TensorRef ref_; -+ -+ /// Extent of tensor -+ MatrixCoord extent_; -+ -+ /// Origin -+ MatrixCoord origin_; -+ -+ /// Used to conditionally enable extents checking -+ bool divisible_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() : divisible_(true) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) : extent_(Shape::kRow, Shape::kColumn), divisible_ (true) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, 0); -+ -+ origin_ = lane_offset; -+ -+ ref.add_coord_offset(lane_offset); -+ -+ ref_.reset(ref.data(), ref.stride(0)); -+ -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ TensorCoord extent, -+ int lane_id -+ ) : extent_(extent), divisible_ (false) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, 0); -+ -+ origin_ = lane_offset; -+ -+ ref.add_coord_offset(lane_offset); -+ -+ ref_.reset(ref.data(), ref.stride(0)); -+ -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ TensorCoord coord_offset( -+ coord.row() * Shape::kRow, -+ coord.column() * Shape::kColumn); -+ -+ origin_ += coord_offset; -+ -+ ref_.add_coord_offset(coord_offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ ref_.add_coord_offset({0, Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({0, -Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (scalar loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kColumn; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Policy::LaneMmaShape::kM; i++) { -+ -+ MatrixCoord offset(m * Policy::WarpShape::kRow * Policy::LaneMmaShape::kM + i, k); -+ -+ MatrixCoord access_coord = origin_ + offset; -+ -+ int frag_idx = m * Policy::LaneMmaShape::kM + i + k * Iterations::kRow; -+ -+ if (divisible_ || -+ (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { -+ -+ frag[frag_idx] = *(ref_.data() + ref_.offset(offset) + pointer_offset); -+ } -+ else { -+ frag[frag_idx] = Element(); -+ } -+ } -+ } -+ } -+ } -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kColumn; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Policy::LaneMmaShape::kM; i++) { -+ -+ *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM * Policy::LaneMmaShape::kM + i, k) + pointer_offset) = -+ frag[m * Policy::LaneMmaShape::kM + i + k * Iterations::kM]; -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for B operands of row-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize -+> -+class MmaSimtTileIterator { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kB; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::RowMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn), -+ "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); -+ static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow, -+ Shape::kColumn / Policy::WarpShape::kColumn -+ >; -+ -+ static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kN -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+protected: -+ -+ /// Internal reference -+ cutlass::TensorRef, layout::RowMajor> ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(0, Policy::LaneMmaShape::kN); -+ -+ ref.add_coord_offset(lane_offset); -+ -+ ref_.reset( -+ reinterpret_cast *>(ref.data()), -+ ref.stride(0) / Policy::LaneMmaShape::kN); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ ref_.add_coord_offset({ -+ coord.row() * Shape::kRow, -+ coord.column() * Shape::kColumn / Policy::LaneMmaShape::kN}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ ref_.add_coord_offset({Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({-Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kRow; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ -+ #if 0 -+ dst_ptr[n + k * Iterations::kColumn] = -+ *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kColumn}) + pointer_offset / Policy::LaneMmaShape::kN); -+ #endif -+ -+ void const *ptr = ref_.data() + ref_.offset({k, n * Policy::WarpShape::kColumn}) + pointer_offset / Policy::LaneMmaShape::kN; -+ arch::shared_load(dst_ptr[n + k * Iterations::kColumn], ptr); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ Array const *src_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kM; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kN; ++n) { -+ *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) = -+ src_ptr[n + k * Iterations::kN]; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, Index pointer_offset) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for B operands of column-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize -+> -+class MmaSimtTileIterator { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kB; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::ColumnMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn), -+ "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); -+ static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow, -+ Shape::kColumn / Policy::WarpShape::kColumn -+ >; -+ -+ static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kN -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Internal reference -+ cutlass::TensorRef ref_; -+ -+ /// Extent of tensor -+ MatrixCoord extent_; -+ -+ /// Origin -+ MatrixCoord origin_; -+ -+ /// Used to conditionally enable extents checking -+ bool divisible_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator(): divisible_(true) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ): extent_(Shape::kRow, Shape::kColumn), divisible_(true) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(0, Policy::LaneMmaShape::kN); -+ -+ origin_ = lane_offset; -+ -+ ref.add_coord_offset(lane_offset); -+ -+ ref_.reset(ref.data(), ref.stride(0)); -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ TensorCoord extent, -+ int lane_id -+ ): extent_(extent), divisible_(false) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(0, Policy::LaneMmaShape::kN); -+ -+ origin_ = lane_offset; -+ -+ ref.add_coord_offset(lane_offset); -+ -+ ref_.reset(ref.data(), ref.stride(0)); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ TensorCoord coord_offset( -+ coord.row() * Shape::kRow, -+ coord.column() * Shape::kColumn); -+ -+ origin_ += coord_offset; -+ -+ ref_.add_coord_offset(coord_offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ ref_.add_coord_offset({Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({-Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (scalar loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kRow; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Policy::LaneMmaShape::kN; ++i) { -+ -+ MatrixCoord offset(k, n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + i); -+ -+ MatrixCoord access_coord = origin_ + offset; -+ -+ int frag_idx = n * Policy::LaneMmaShape::kN + i + k * Iterations::kColumn; -+ -+ if (divisible_ || -+ (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { -+ -+ frag[frag_idx] = *(ref_.data() + ref_.offset(offset) + pointer_offset); -+ } -+ else { -+ frag[frag_idx] = Element(); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ Array const *src_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kM; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kN; ++n) { -+ *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) = -+ src_ptr[n + k * Iterations::kN]; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, Index pointer_offset) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for C operands of column-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_ -+> -+class MmaSimtTileIterator { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of accumulators in memory -+ using Layout = layout::ColumnMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert( -+ (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)), -+ "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); -+ -+ /// Thraed-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow / Policy::WarpShape::kRow, -+ Shape::kColumn / Policy::WarpShape::kColumn -+ >; -+ -+ static_assert( -+ (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)), -+ "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow / Policy::LaneMmaShape::kM, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kN -+ >; -+ -+ using Delta = MatrixShape< -+ Policy::WarpShape::kRow * Policy::LaneMmaShape::kM, -+ Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ ref_.add_coord_offset({ -+ coord.row() * Shape::kRow, -+ coord.column() * Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ ref_.add_coord_offset({Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({-Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to be loaded from memory -+ Index pointer_offset) const { ///< linear offset (in units of Element) when loading -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Iterations::kN; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { -+ -+ Array const *src_ptr = -+ reinterpret_cast const *>( -+ ref_.data() + pointer_offset + ref_.offset({0, mma_n * Delta::kN + n})); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Iterations::kM; ++mma_m) { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag) + -+ mma_m + Iterations::kM * (n + mma_n * Policy::LaneMmaShape::kN); -+ -+ *dst_ptr = src_ptr[mma_m * Policy::WarpShape::kM]; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { -+ -+ Array *dst_ptr= -+ reinterpret_cast *>( -+ ref_.data() + pointer_offset + ref_.offset({0, mma_n * Delta::kColumn + n})); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { -+ -+ Array const *src_ptr = -+ reinterpret_cast const *>(&frag) + -+ mma_m + Iterations::kRow * (n + mma_n * Policy::LaneMmaShape::kN); -+ -+ dst_ptr[mma_m * Policy::WarpShape::kRow] = *src_ptr; -+ } -+ } -+ } -+ } -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for C operands of row-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_ -+> -+class MmaSimtTileIterator { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of accumulators in memory -+ using Layout = layout::RowMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert( -+ (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)), -+ "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); -+ -+ /// Thraed-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow / Policy::WarpShape::kRow, -+ Shape::kColumn / Policy::WarpShape::kColumn -+ >; -+ -+ static_assert( -+ (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)), -+ "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow / Policy::LaneMmaShape::kM, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kN -+ >; -+ -+ using Delta = MatrixShape< -+ Policy::WarpShape::kRow * Policy::LaneMmaShape::kM, -+ Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ ref_.add_coord_offset({ -+ coord.row() * Shape::kRow, -+ coord.column() * Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ ref_.add_coord_offset({Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({-Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to be loaded from memory -+ Index pointer_offset) const { ///< linear offset (in units of Element) when loading -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { -+ -+ Array const *src_ptr = -+ reinterpret_cast const *>( -+ ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0})); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag) + -+ mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM); -+ -+ *dst_ptr = src_ptr[mma_n * Policy::WarpShape::kColumn]; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>( -+ ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0})); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { -+ -+ Array const *src_ptr = -+ reinterpret_cast const *>(&frag) + -+ mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM); -+ -+ dst_ptr[mma_n * Policy::WarpShape::kColumn] = *src_ptr; -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for A operands of column-major-K interleaved layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Number of KGroups per kPartition -+ int PartitionGroupSize -+> -+class MmaSimtTileIterator, Policy_, PartitionsK, PartitionGroupSize> { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::ColumnMajorInterleaved<4> ; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Iterleave factor -+ static const int kInterleave = 4; -+ -+ /// Number of partitions along K dimension -+ static const int kPartitionsK = PartitionsK; -+ -+ /// Number of KGroups per kPartition -+ static const int kGroupPerTile = PartitionGroupSize / Shape::kColumn; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kRow % Policy::WarpShape::kRow), -+ "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow / Policy::WarpShape::kRow, -+ Shape::kColumn -+ >; -+ -+ static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM) && !(ThreadShape::kColumn % Policy::LaneMmaShape::kK), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow / Policy::LaneMmaShape::kM, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kK -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Internal reference -+ cutlass::TensorRef, layout::ColumnMajorInterleaved<4>> ref_; -+ -+ /// group index within tile -+ int k_group_idx_; -+ -+public: -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, 0); -+ -+ ref.add_coord_offset(lane_offset); -+ -+ k_group_idx_ = 0; -+ ref_.reset(reinterpret_cast *>(ref.data()), ref.stride(0)/Policy::LaneMmaShape::kMK); -+ } -+ -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ ref_.add_coord_offset({ -+ coord.row() * Shape::kRow / Policy::LaneMmaShape::kMK, -+ coord.column() * Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ add_tile_offset({0, 1}); -+ -+ if (kPartitionsK > 1) { -+ ++k_group_idx_; -+ // Jump to next stage -+ if (k_group_idx_ == kGroupPerTile) { -+ k_group_idx_ = 0; -+ add_tile_offset({0, kGroupPerTile * (kPartitionsK-1)}); -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({0, -Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kColumn; ++k) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ -+ dst_ptr[m + k * Iterations::kRow] = -+ *((ref_.data() + ref_.offset({m * Policy::WarpShape::kRow / kInterleave, -+ k*Policy::LaneMmaShape::kK}) + pointer_offset / Policy::LaneMmaShape::kM)); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ Array const *src_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kN; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kM; ++m) { -+ *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) = -+ src_ptr[m + k * Iterations::kM]; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for B operands of row-major k-interleaved layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Number of KGroups per kPartition -+ int PartitionGroupSize -+> -+class MmaSimtTileIterator, Policy_, PartitionsK, PartitionGroupSize> { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kB; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::RowMajorInterleaved<4>; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Interleave factor -+ static const int kInterleave = 4; -+ -+ /// Number of partitions along K dimension -+ static const int kPartitionsK = PartitionsK; -+ -+ /// Number of KGroups per kPartition -+ static const int kGroupPerTile = PartitionGroupSize / Shape::kRow; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn), -+ "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); -+ static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow, -+ Shape::kColumn / Policy::WarpShape::kColumn -+ >; -+ -+ static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN) && !(ThreadShape::kRow % Policy::LaneMmaShape::kK), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow / Policy::LaneMmaShape::kK, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kN -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+ -+private: -+ -+ /// Internal reference -+ cutlass::TensorRef, layout::RowMajorInterleaved<4>> ref_; -+ -+ /// group index within tile -+ int k_group_idx_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(0, Policy::LaneMmaShape::kN); -+ -+ ref.add_coord_offset(lane_offset); -+ -+ k_group_idx_ = 0; -+ -+ ref_.reset( -+ reinterpret_cast *>(ref.data()), -+ ref.stride(0) / Policy::LaneMmaShape::kKN); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ ref_.add_coord_offset({ -+ coord.row() * Shape::kRow, -+ coord.column() * Shape::kColumn / Policy::LaneMmaShape::kKN}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ add_tile_offset({1, 0}); -+ -+ if (kPartitionsK > 1) { -+ ++k_group_idx_; -+ // Jump to next stage -+ if (k_group_idx_ == kGroupPerTile) { -+ k_group_idx_ = 0; -+ add_tile_offset({kGroupPerTile * (kPartitionsK-1), 0}); -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({-Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kRow; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ dst_ptr[n + k * Iterations::kColumn] = -+ *(ref_.data() + ref_.offset({k * Policy::LaneMmaShape::kK, -+ n * Policy::WarpShape::kColumn / kInterleave}) + pointer_offset / Policy::LaneMmaShape::kN); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ Array const *src_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kM; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kN; ++n) { -+ *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) = -+ src_ptr[n + k * Iterations::kN]; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, Index pointer_offset) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h -new file mode 100644 -index 0000000..e049f4f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h -@@ -0,0 +1,339 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate -+ operations targeting sparse Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class SparseMmaTensorOp { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Equivalant base dense mma -+ using Base = MmaTensorOp; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Base::ArchMmaOperator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Architecture tag from underlying instruction -+ using ArchTag = typename Base::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = typename Base::OperatorClass; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename Base::InstructionShape; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Base::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Base::kTransformB; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// Sparsity in Operand A -+ static int const kSparse = Policy::Operator::kSparse; -+ -+ /// Meta data size in bits -+ static int const kMetaSizeInBits = Policy::Operator::kMetaSizeInBits; -+ -+ /// Max ID2 -+ static int const kMaxID2 = Policy::Operator::kMaxID2; -+ -+ /// Data type of meta E that is moved at the same time -+ using ElementE = -+ typename cutlass::platform::conditional::type; -+ -+ /// Number of ElementA that is associated with one ElementE -+ static int const kElementsPerElementE = -+ 128 / cutlass::sizeof_bits::value; -+ -+ /// Meta data is essentially interleaved but mapped to ColumnMajor internally -+ static int const kInterleaved = 2; -+ -+ /// Layout of meta E -+ using LayoutE = cutlass::layout::ColumnMajor; -+ -+ public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, Operand::kA, ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = -+ Array; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = typename Base::IteratorB; -+ -+ /// Storage for B tile -+ using FragmentB = typename Base::FragmentB; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = typename Base::TransformedFragmentB; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = typename Base::IteratorC; -+ -+ /// Storage for C tile -+ using FragmentC = typename Base::FragmentC; -+ -+ /// Iterates over the E operand in memory -+ using IteratorE = SparseMmaTensorOpMetaTileIterator< -+ MatrixShape, -+ ElementE, LayoutE, -+ MatrixShape, -+ Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; -+ -+ /// Storage for E tile -+ using FragmentE = typename IteratorE::Fragment; -+ -+ /// Number of mma operations performed -+ using MmaIterations = typename Base::MmaIterations; -+ -+public: -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ SparseMmaTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ TransformedFragmentA const &A, -+ TransformedFragmentB const &B, -+ FragmentC const &C, -+ FragmentE const &E -+ ) const { -+ -+ using MmaOperandA = typename Policy::Operator::FragmentA; -+ using MmaOperandB = typename Policy::Operator::FragmentB; -+ using MmaOperandC = typename Policy::Operator::FragmentC; -+ using MmaOperandE = typename Policy::Operator::FragmentE; -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ D = C; -+ -+ MmaOperandA const *ptr_A = reinterpret_cast(&A); -+ MmaOperandB const *ptr_B = reinterpret_cast(&B); -+ MmaOperandC *ptr_D = reinterpret_cast(&D); -+ MmaOperandE const *ptr_E = reinterpret_cast(&E); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ int id2 = m % kMaxID2; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); -+ -+ if (AccumulatorsInRowMajor) { // matrix B is reordered -+ mma( -+ ptr_D[n_serpentine + m * MmaIterations::kColumn], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[n_serpentine + m * MmaIterations::kColumn], -+ ptr_E[(m / kMaxID2)], -+ id2); -+ } else { -+ mma(ptr_D[m + n_serpentine * MmaIterations::kRow], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[m + n_serpentine * MmaIterations::kRow], -+ ptr_E[(m / kMaxID2)], -+ id2); -+ } -+ } -+ } -+ #else -+ assert(0); -+ #endif -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ // -+ // Define conversions from source type to instruction type -+ // -+ FloatRoundStyle const kRoundA = -+ PreferredRoundingMode::kRound; -+ FloatRoundStyle const kRoundB = -+ PreferredRoundingMode::kRound; -+ detail::ConvertAndPack -+ convert_A; -+ NumericArrayConverter -+ convert_B; -+ Array const *ptr_A = -+ reinterpret_cast const *>(&A); -+ Array * -+ ptr_dst_A = reinterpret_cast *>(&dst_A); -+ -+ dst_B = convert_B(B); -+ -+ ptr_dst_A[0] = convert_A(ptr_A[0]); -+ ptr_dst_A[1] = convert_A(ptr_A[1]); -+ #else -+ assert(0); -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h -new file mode 100644 -index 0000000..3124618 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h -@@ -0,0 +1,431 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct ConvertAndPack { -+ -+ using Converter = NumericArrayConverter; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &source) { -+ Converter converter; -+ -+ return converter(source); -+ } -+}; -+ -+template -+struct ConvertAndPack { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &source) { -+ return source; -+ } -+}; -+ -+template -+struct ConvertAndPack { -+ -+ using Converter = NumericArrayConverter; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &source) { -+ Converter converter; -+ -+ Array tmp; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); -+ tmp[i] = source[idx]; -+ } -+ -+ return converter(tmp); -+ } -+}; -+ -+template -+struct ConvertAndPack { -+ -+ using Converter = NumericArrayConverter; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &source) { -+ Converter converter; -+ -+ Array tmp; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); -+ tmp[i] = source[idx]; -+ } -+ -+ return converter(tmp); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaTensorOp { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Architecture tag from underlying instruction -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, Operand::kA, ElementA, LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = -+ Array; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, Operand::kB, ElementB, LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = -+ Array; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, ElementC, LayoutC, -+ typename ArchMmaOperator::Shape, typename Policy::OpDelta>; -+ -+ /// Storage for C tile -+ using FragmentC = typename IteratorC::Fragment; -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, -+ (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN -+ >; -+ -+public: -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ TransformedFragmentA const &A, -+ TransformedFragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ D = C; -+ -+ MmaOperandA const *ptr_A = reinterpret_cast(&A); -+ MmaOperandB const *ptr_B = reinterpret_cast(&B); -+ MmaOperandC *ptr_D = reinterpret_cast(&D); -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -+ // Serpentine visitation order maximizing reuse of Rb -+ // The visitation order is like -+ // _ -+ // | | | | -+ // | | | | -+ // |_| |_| -+ // -+ // Down Up Down Up -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); -+ -+ if (AccumulatorsInRowMajor) { // matrix B is reordered -+ mma( -+ ptr_D[n + m_serpentine * MmaIterations::kColumn], -+ ptr_A[m_serpentine], -+ ptr_B[n], -+ ptr_D[n + m_serpentine * MmaIterations::kColumn]); -+ } else { -+ mma( -+ ptr_D[m_serpentine + n * MmaIterations::kRow], -+ ptr_A[m_serpentine], -+ ptr_B[n], -+ ptr_D[m_serpentine + n * MmaIterations::kRow]); -+ } -+ } -+ } -+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ // Serpentine visitation order maximizing reuse of Ra -+ // The visitation order is like -+ // _________ -+ // _________| -+ // |_________ -+ // __________| -+ // -+ // Right Left Right Left -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); -+ -+ if (AccumulatorsInRowMajor) { // matrix B is reordered -+ mma( -+ ptr_D[n_serpentine + m * MmaIterations::kColumn], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[n_serpentine + m * MmaIterations::kColumn]); -+ } else { -+ mma(ptr_D[m + n_serpentine * MmaIterations::kRow], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[m + n_serpentine * MmaIterations::kRow]); -+ } -+ } -+ } -+ #else -+ assert(0); -+ #endif -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ -+ // -+ // Define conversions from source type to instruction type -+ // -+ FloatRoundStyle const kRoundA = -+ PreferredRoundingMode::kRound; -+ FloatRoundStyle const kRoundB = -+ PreferredRoundingMode::kRound; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -+ detail::ConvertAndPack -+ convert_A; -+ NumericArrayConverter -+ convert_B; -+ Array const *ptr_B = -+ reinterpret_cast const *>(&B); -+ Array * -+ ptr_dst_B = reinterpret_cast *>(&dst_B); -+ -+ dst_A = convert_A(A); -+ -+ ptr_dst_B[0] = convert_B(ptr_B[0]); -+ ptr_dst_B[1] = convert_B(ptr_B[1]); -+ -+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ detail::ConvertAndPack -+ convert_A; -+ NumericArrayConverter -+ convert_B; -+ Array const *ptr_A = -+ reinterpret_cast const *>(&A); -+ Array * -+ ptr_dst_A = reinterpret_cast *>(&dst_A); -+ -+ dst_B = convert_B(B); -+ -+ ptr_dst_A[0] = convert_A(ptr_A[0]); -+ ptr_dst_A[1] = convert_A(ptr_A[1]); -+ #else -+ assert(0); -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h -new file mode 100644 -index 0000000..d17edc1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h -@@ -0,0 +1,471 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+enum class TensorFloat32Op { -+ k3xTF32, -+ k4xTF32 -+}; -+ -+template < -+ /// Floating-point rounding style -+ FloatRoundStyle RoundBigA_, -+ /// Floating-point rounding style -+ FloatRoundStyle RoundSmallA_, -+ /// Floating-point rounding style -+ FloatRoundStyle RoundBigB_ = RoundBigA_, -+ /// Floating-point rounding style -+ FloatRoundStyle RoundSmallB_ = RoundSmallA_, -+ /// Precision for TensorFloat32Op -+ // (k3xTF32: BigxBig, BigxSmall, SmallxBig) -+ // (k4xTF32: BigxBig, BigxSmall, SmallxBig, SmallxSmall) -+ TensorFloat32Op Precision_ = TensorFloat32Op::k3xTF32 -+ > -+struct FastF32 { -+ -+ static FloatRoundStyle const kRoundBigA = RoundBigA_; -+ static FloatRoundStyle const kRoundSmallA = RoundSmallA_; -+ static FloatRoundStyle const kRoundBigB = RoundBigB_; -+ static FloatRoundStyle const kRoundSmallB = RoundSmallB_; -+ static TensorFloat32Op const kPrecision = Precision_; -+}; -+ -+ -+namespace detail { -+ -+ template< -+ int N, -+ FloatRoundStyle RoundBig = FloatRoundStyle::round_toward_zero, -+ FloatRoundStyle RoundSmall = FloatRoundStyle::round_half_ulp_truncate -+ > -+ struct ConvertAndPackAccurateF32 { -+ -+ /// Rounding styles for big and small part -+ static FloatRoundStyle const kRoundBig = RoundBig; -+ static FloatRoundStyle const kRoundSmall = RoundSmall; -+ -+ /// Converter type -+ using Converter = NumericConverterFastF32; -+ -+ /// Source fragement -+ using SourceFragment = Array; -+ -+ /// Destination fragment -+ using DestinationFragment = Array; -+ -+ /// Converter Fragment holding two tfloat32_t elements for every float -+ using ConverterFragment = Array; -+ -+ /// Index in fargments for the big and small part -+ static int const kBigIndex = 0; -+ static int const kSmallIndex = 1; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(SourceFragment const &source, -+ DestinationFragment &dst_big, -+ DestinationFragment &dst_small) { -+ -+ Converter convert_; -+ ConverterFragment result_; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ // convert source to result fragment -+ result_ = convert_(source[i]); -+ -+ // store converted result fragments to destination fragment -+ dst_big[i] = result_[kBigIndex]; -+ dst_small[i] = result_[kSmallIndex]; -+ } -+ } -+ }; -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaTensorOpFastF32; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for float*float+float => float using TF32 TensorOps -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK_, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor, -+ /// Used for partial specialization -+ typename Enable -+> -+class MmaTensorOpFastF32< -+ Shape_, -+ float, LayoutA_, -+ float, LayoutB_, -+ float, LayoutC_, -+ Policy_, PartitionsK_, -+ AccumulatorsInRowMajor, Enable> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = float; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = float; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = float; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Indicates math operator -+ using MathOperator = arch::OpMultiplyAddFastF32; -+ -+ /// Architecture tag from underlying instruction -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// Tune F32 to TF32 big small conversion for float operation -+ /// Different combination of big small conversin can cause different tradeoff -+ /// between speed and accuracy. Generally, use round_half_ulp_truncate can -+ /// improve the performance but hur the accuracy. -+ using MmaFastF32 = FastF32 < -+ FloatRoundStyle::round_toward_zero, // kRoundBigA -+ FloatRoundStyle::round_half_ulp_truncate, // kRoundSmallA -+ FloatRoundStyle::round_toward_zero, // kRoundBigB -+ FloatRoundStyle::round_half_ulp_truncate, // kRoundSmallB -+ TensorFloat32Op::k3xTF32 // Number of TF32 operations -+ >; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ kThreadCount, -+ kPartitionsK -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = -+ Array; -+ -+ /// Fragment bisecting big and small sections -+ using AccessTypeFragmentA = -+ Array; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ kThreadCount, -+ kPartitionsK -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = -+ Array; -+ -+ /// Fragment bisecting big and small sections -+ using AccessTypeFragmentB = -+ Array; -+ -+ /// Index in fargments for the big and small part -+ static int const kBigIndex = 0; -+ static int const kSmallIndex = 1; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, ElementC, LayoutC, -+ typename ArchMmaOperator::Shape, typename Policy::OpDelta>; -+ -+ /// Storage for C tile -+ using FragmentC = typename IteratorC::Fragment; -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, -+ (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN -+ >; -+ -+public: -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaTensorOpFastF32() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ TransformedFragmentA const &A, -+ TransformedFragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ AccessTypeFragmentA const *ptr_A = reinterpret_cast(&A); -+ AccessTypeFragmentB const *ptr_B = reinterpret_cast(&B); -+ -+ // -+ // Accumulate in place -+ // -+ D = C; -+ -+ mma_operator(D, ptr_A[kSmallIndex], ptr_B[kBigIndex], D); -+ -+ mma_operator(D, ptr_A[kBigIndex], ptr_B[kSmallIndex], D); -+ -+ mma_operator(D, ptr_A[kBigIndex], ptr_B[kBigIndex], D); -+ -+ if (MmaFastF32::kPrecision == TensorFloat32Op::k4xTF32) -+ mma_operator(D, ptr_A[kSmallIndex], ptr_B[kSmallIndex], D); -+ } -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void mma_operator( -+ FragmentC &D, -+ AccessTypeFragmentA const &A, -+ AccessTypeFragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ MmaOperandA const *ptr_A = reinterpret_cast(&A); -+ MmaOperandB const *ptr_B = reinterpret_cast(&B); -+ MmaOperandC *ptr_D = reinterpret_cast(&D); -+ -+ // Serpentine visitation order maximizing reuse of Ra -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // This allows to reuse of Rb when at serpentine turns -+ int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); -+ -+ if (AccumulatorsInRowMajor) { // matrix B is reordered -+ mma( -+ ptr_D[n_serpentine + m * MmaIterations::kColumn], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[n_serpentine + m * MmaIterations::kColumn]); -+ } else { -+ mma( -+ ptr_D[m + n_serpentine * MmaIterations::kRow], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[m + n_serpentine * MmaIterations::kRow]); -+ } -+ } // end n loop -+ } // end m loop -+ #else -+ assert(0); -+ #endif -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ -+ // -+ // Define conversions from source type to instruction type -+ // -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ detail::ConvertAndPackAccurateF32< -+ FragmentA::kElements / 2, -+ MmaFastF32::kRoundBigA, -+ MmaFastF32::kRoundSmallA> convert_A; -+ -+ detail::ConvertAndPackAccurateF32< -+ FragmentB::kElements, -+ MmaFastF32::kRoundBigB, -+ MmaFastF32::kRoundSmallB> convert_B; -+ -+ Array *ptr_dst_B = -+ reinterpret_cast *>(&dst_B); -+ -+ convert_B(B, ptr_dst_B[0], ptr_dst_B[1]); -+ -+ Array *ptr_dst_A = -+ reinterpret_cast *>(&dst_A); -+ -+ Array const *ptr_A = -+ reinterpret_cast const *>(&A); -+ -+ convert_A(ptr_A[0], ptr_dst_A[0], ptr_dst_A[2]); -+ -+ convert_A(ptr_A[1], ptr_dst_A[1], ptr_dst_A[3]); -+ #else -+ assert(0); -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h -new file mode 100644 -index 0000000..aa2806d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h -@@ -0,0 +1,528 @@ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of a warp tile -+ that participate in one warp-level mma operation. -+ -+ Typically, this is used to access the accumulator tile/fragement of a warp-level mma operation. -+ The accumulator tile is then partitioned into smaller tiles/fragments that can be fed into -+ next warp-level mma operation. -+ -+ This iterator is necessary to accomplish warp-level mma fusion where the accumulator tile is -+ reused as multiplicand tile for the next mma. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_conversion.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Size of the accumulation tile shape (concept: MatrixShape) -+ typename AccumulatorShape_, -+ /// KBlocks columns to compute residual -+ int KBlocksColumn_, -+ /// Accumulator Element type -+ typename ElementAccumulator_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Output operation on the fragment -+ typename OutputOp_> -+class MmaTensorOpFragmentIterator; -+ -+ -+// Partial specialization for col-major accumulator tile -+ -+template < -+ /// Shape of warp tile to load (concept: MatrixShape) -+ typename Shape_, -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ typename AccumulatorShape_, -+ /// KBlocks columns to compute residual -+ int KBlocksColumn_, -+ /// Accumulator Element type -+ typename ElementAccumulator_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Output operation on fragment -+ typename OutputOp_> -+class MmaTensorOpFragmentIterator { -+ public: -+ -+ /// Shape of warp tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ using AccumulatorShape = AccumulatorShape_; -+ -+ /// KBlocks columns to compute residual -+ static int const kKBlockColumn = KBlocksColumn_; -+ -+ /// Accumulator Element type -+ using ElementAccumulator = ElementAccumulator_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Output operation on fragment -+ using OutputOp = OutputOp_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ static_assert( -+ AccumulatorShape::kRow == Shape::kRow, -+ "Rows of Warp Accumulator must be the same as rows of warp"); -+ static_assert( -+ !(AccumulatorShape::kColumn % Shape::kColumn), -+ "Shape of Warp Accumulator must be divisible by warp shape."); -+ static_assert( -+ !(kKBlockColumn % Shape::kColumn), -+ "KBlock size must be divisible by warp shape."); -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = AccumulatorShape::kCount / Shape::kCount; -+ }; -+ -+private: -+ -+ static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; -+ -+ /// Number of mma operations performed by a warp -+ using MmaIterations = MatrixShape; -+ /// Number of mma operations performed by the entire accumulator -+ using AccumulatorIterations = MatrixShape; -+ -+ /// Number of K iterations -+ static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; -+ static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; -+ static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn -+ * (AccumulatorShape::kRow / Shape::kRow); -+ static int const kResidualIndex = kResidualColumn / Shape::kColumn -+ * (AccumulatorShape::kRow / Shape::kRow); -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array; -+ -+ /// Accumulator Fragment object -+ using AccumulatorFragment = Array; -+ -+ /// Scale Bias Element Type -+ using ElementScaleBias = typename OutputOp::ElementCompute; -+ -+ /// Scale Bias Fragment object -+ using ScaleBiasFragment = Array; -+ -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ using FragmentAccessType = Array; -+ -+ using ScaleBiasAccessType = Array; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+ /// Used to access residual tile first -+ bool is_residual_tile_; -+ -+public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpFragmentIterator(AccumulatorFragment const &accum) -+ : accumulators_(reinterpret_cast(&accum)), -+ index_(0), is_residual_tile_(true) {} -+ -+ /// Add offset -+ CUTLASS_HOST_DEVICE -+ void add_offset(int index_offset) { -+ index_ += index_offset; -+ if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { -+ index_ = index_ - kKBlockColumnIterations + kResidualIndex; -+ is_residual_tile_ = false; -+ } -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpFragmentIterator &operator++() { -+ add_offset(1); -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpFragmentIterator &operator--() { -+ add_offset(-1); -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, OutputOp output_op) const { -+ -+ if (output_op.is_source_needed()) //beta must be zero -+ assert(0); -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ int index = index_ * MmaIterations::kCount; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; n++) { -+ for (int m = 0; m < MmaIterations::kRow; m++) { -+ int accumulator_access_offset = -+ n * AccumulatorIterations::kRow + m + index; -+ -+ frag_ptr[m * MmaIterations::kColumn + n].clear(); -+ if(!(is_residual_tile_ && index_ >= kResidualIndex)) -+ frag_ptr[m * MmaIterations::kColumn + n] = output_op(accumulators_[accumulator_access_offset]); -+ } -+ } -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ /// Then apply per-channel scale and bias -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, ScaleBiasFragment &scale, -+ ScaleBiasFragment &bias, OutputOp output_op) const { -+ -+ if (output_op.is_source_needed()) //beta must be zero -+ assert(0); -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ ScaleBiasAccessType * scale_ptr = reinterpret_cast(&scale); -+ ScaleBiasAccessType * bias_ptr = reinterpret_cast(&bias); -+ -+ int index = index_ * MmaIterations::kCount; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; n++) { -+ for (int m = 0; m < MmaIterations::kRow; m++) { -+ int accumulator_access_offset = -+ n * AccumulatorIterations::kRow + m + index; -+ -+ frag_ptr[m * MmaIterations::kColumn + n].clear(); -+ if(!(is_residual_tile_ && index_ >= kResidualIndex)) -+ frag_ptr[m * MmaIterations::kColumn + n] = -+ output_op(accumulators_[accumulator_access_offset], -+ scale_ptr[n] /*scale*/, bias_ptr[n] /*bias*/); -+ } -+ } -+ } -+ -+ -+ -+}; -+ -+// Partial specialization for row-major accumulator tile -+ -+template < -+ /// Shape of warp tile to load (concept: MatrixShape) -+ typename Shape_, -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ typename AccumulatorShape_, -+ /// KBlocks columns to compute residual -+ int KBlocksColumn_, -+ /// Accumulator Element type -+ typename ElementAccumulator_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Output operation on fragment -+ typename OutputOp_> -+class MmaTensorOpFragmentIterator { -+ public: -+ -+ /// Shape of warp tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ using AccumulatorShape = AccumulatorShape_; -+ -+ /// KBlocks columns to compute residual -+ static int const kKBlockColumn = KBlocksColumn_; -+ -+ /// Accumulator Element type -+ using ElementAccumulator = ElementAccumulator_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Output operation on fragment -+ using OutputOp = OutputOp_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ static_assert( -+ AccumulatorShape::kRow == Shape::kRow, -+ "Rows of Warp Accumulator must be the same as rows of warp"); -+ static_assert( -+ !(AccumulatorShape::kColumn % Shape::kColumn), -+ "Shape of Warp Accumulator must be divisible by warp shape."); -+ static_assert( -+ !(kKBlockColumn % Shape::kColumn), -+ "KBlock size must be divisible by warp shape."); -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = AccumulatorShape::kCount / Shape::kCount; -+ }; -+ -+private: -+ -+ static int const kRowsPerIteration = 8; -+ static int const kColumnsPerIteration = 16; -+ static int const kElementsPerIteration = kRowsPerIteration * InstructionShape::kN / kThreads; -+ static int const kElementsPerAccess = kRowsPerIteration * kColumnsPerIteration / kThreads; -+ static int const kIterationsPerAccess = kElementsPerAccess / kElementsPerIteration; -+ -+ // Number of iterations per actual instruction -+ static int const kIterationsPerInstruction = InstructionShape::kM / kRowsPerIteration; -+ -+ static int const kAccessStride = kIterationsPerInstruction; -+ -+ /// Number of mma operations performed by a warp -+ using MmaIterations = MatrixShape; -+ /// Number of mma operations performed by the entire accumulator -+ using AccumulatorIterations = MatrixShape; -+ -+ /// Number of Accesses in a warp -+ using AccessIterations = MatrixShape; -+ -+ /// Number of K iterations -+ static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; -+ static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; -+ static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn; -+ static int const kResidualIndex = kResidualColumn / Shape::kColumn; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array; -+ -+ /// Accumulator Fragment object -+ using AccumulatorFragment = Array; -+ -+ /// Scale Bias Element Type -+ using ElementScaleBias = typename OutputOp::ElementCompute; -+ -+ /// Scale Bias Fragment object -+ using ScaleBiasFragment = Array; -+ -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ using FragmentAccessType = Array; -+ using ScaleBiasAccessType = Array; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+ /// Used to access residual tile first -+ bool is_residual_tile_; -+ -+public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpFragmentIterator(AccumulatorFragment const &accum) -+ : accumulators_(reinterpret_cast(&accum)), -+ index_(0), is_residual_tile_(true) {} -+ -+ /// Add offset -+ CUTLASS_HOST_DEVICE -+ void add_offset(int index_offset) { -+ index_ += index_offset; -+ if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { -+ index_ = index_ - kKBlockColumnIterations + kResidualIndex; -+ is_residual_tile_ = false; -+ } -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpFragmentIterator &operator++() { -+ add_offset(1); -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpFragmentIterator &operator--() { -+ add_offset(-1); -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_index(int idx) { -+ index_ = idx; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, OutputOp output_op) const { -+ -+ if (output_op.is_source_needed()) //beta must be zero -+ assert(0); -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ int index = index_ * AccessIterations::kCount; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < AccessIterations::kCount; i++) { -+ -+ int accumulator_access_offset = index / AccessIterations::kCount * (MmaIterations::kColumn * kIterationsPerInstruction) + -+ (index % AccessIterations::kCount) / (AccessIterations::kColumn * kIterationsPerInstruction) * -+ AccumulatorIterations::kColumn * kIterationsPerInstruction + -+ (index % (AccessIterations::kColumn * kIterationsPerInstruction)) / kIterationsPerInstruction * -+ (kIterationsPerInstruction * kIterationsPerAccess) + -+ (index % kIterationsPerInstruction); -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kIterationsPerAccess; j++) { -+ -+ frag_ptr[i*kIterationsPerAccess + j].clear(); -+ if(!(is_residual_tile_ && index_ >= kResidualIndex)) -+ frag_ptr[i*kIterationsPerAccess + j] = output_op(accumulators_[accumulator_access_offset + j * kAccessStride]); -+ } -+ index++; -+ } -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ /// Then apply per-channel scale and bias -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, ScaleBiasFragment &scale, -+ ScaleBiasFragment & bias, OutputOp output_op) const { -+ -+ if (output_op.is_source_needed()) //beta must be zero -+ assert(0); -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ ScaleBiasAccessType * scale_ptr = reinterpret_cast(&scale); -+ ScaleBiasAccessType * bias_ptr = reinterpret_cast(&bias); -+ -+ int index = index_ * AccessIterations::kCount; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < AccessIterations::kCount; i++) { -+ -+ int accumulator_access_offset = index / AccessIterations::kCount * (MmaIterations::kColumn * kIterationsPerInstruction) + -+ (index % AccessIterations::kCount) / (AccessIterations::kColumn * kIterationsPerInstruction) * -+ AccumulatorIterations::kColumn * kIterationsPerInstruction + -+ (index % (AccessIterations::kColumn * kIterationsPerInstruction)) / kIterationsPerInstruction * -+ (kIterationsPerInstruction * kIterationsPerAccess) + -+ (index % kIterationsPerInstruction); -+ -+ int scale_bias_offset = (index -+ % (kIterationsPerInstruction * AccessIterations::kColumn)) -+ * kIterationsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kIterationsPerAccess; j++) { -+ -+ -+ frag_ptr[i*kIterationsPerAccess + j].clear(); -+ if(!(is_residual_tile_ && index_ >= kResidualIndex)) -+ frag_ptr[i*kIterationsPerAccess + j] = output_op( -+ accumulators_[accumulator_access_offset + j * kAccessStride], -+ scale_ptr[scale_bias_offset + j], bias_ptr[scale_bias_offset + j]); -+ } -+ index++; -+ } -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h -new file mode 100644 -index 0000000..f73ede6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h -@@ -0,0 +1,65 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Policy describing implementation details of warp-level GEMM targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy -+template < -+ typename Operator_, ///< hardware instruction(s) performing TensorOp (concept: arch::Mma) -+ typename OpDelta_ ///< distance between operations (concept: MatrixShape) -+> -+struct MmaTensorOpPolicy { -+ -+ using Operator = Operator_; ///< hardware instruction(s) performing TensorOp (concept: arch::Mma) -+ using OpDelta = OpDelta_; ///< distance between operations (concept: MatrixShape) -+ using MmaShape = typename Operator::Shape; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h -new file mode 100644 -index 0000000..0a2449d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h -@@ -0,0 +1,280 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+ -+ This is a work in progress. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/mma.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaVoltaTensorOp { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Architecture tag -+ using ArchTag = arch::Sm70; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Underlying instruction shape -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ /// interleaved 32x32 tiles -+ using InterleavedTileShape = GemmShape<32, 32, 4>; -+ -+ static_assert(!(Shape::kM % InterleavedTileShape::kM) && -+ !(Shape::kN % InterleavedTileShape::kN), -+ "Shape must be a multiple of InterleavedTileShape."); -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaVoltaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape< -+ ArchMmaOperator::Shape::kM, -+ ArchMmaOperator::Shape::kK -+ >, -+ Policy::OpDelta::kRow, -+ kThreadCount -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaVoltaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape< -+ ArchMmaOperator::Shape::kK, -+ ArchMmaOperator::Shape::kN -+ >, -+ Policy::OpDelta::kRow, -+ kThreadCount -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaVoltaTensorOpAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta -+ >; -+ -+ /// Storage for C tile -+ using FragmentC = typename IteratorC::Fragment; -+ -+private: -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ InterleavedTileShape::kM / ArchMmaOperator::Shape::kM, -+ InterleavedTileShape::kN / ArchMmaOperator::Shape::kN -+ >; -+ using TileIterations = MatrixShape< -+ Shape::kM / InterleavedTileShape::kM, -+ Shape::kN / InterleavedTileShape::kN -+ >; -+ -+ // Whether matrix B is reordered -+ bool reorder_B_; -+ -+public: -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) { -+ -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ D = C; -+ -+ MmaOperandA const *ptr_A = reinterpret_cast(&A); -+ MmaOperandB const *ptr_B = reinterpret_cast(&B); -+ MmaOperandC *ptr_D = reinterpret_cast(&D); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int outer_col = 0; outer_col < TileIterations::kColumn; ++outer_col) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int inner_col = 0; inner_col < MmaIterations::kColumn; ++inner_col) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int outer_row = 0; outer_row < TileIterations::kRow; ++outer_row) { -+ CUTLASS_PRAGMA_UNROLL -+ -+ for (int inner_row = 0; inner_row < MmaIterations::kRow; ++inner_row) { -+ -+ int op_col = inner_col + MmaIterations::kColumn * outer_col; -+ -+ // Column-major serpentine sequence to maximize reuse of A operand. -+ int inner_row_serp = inner_row; -+ int outer_row_serp = outer_row; -+ if (op_col & 1) { -+ inner_row_serp = MmaIterations::kRow - inner_row - 1; -+ outer_row_serp = TileIterations::kRow - outer_row - 1; -+ } -+ int op_row = inner_row_serp + MmaIterations::kRow * outer_row_serp; -+ int op_idx = inner_row_serp + MmaIterations::kRow * -+ (inner_col + MmaIterations::kColumn * -+ (outer_row_serp + TileIterations::kRow * outer_col)); -+ mma( -+ ptr_D[op_idx], -+ ptr_A[op_row], -+ ptr_B[op_col], -+ ptr_D[op_idx]); -+ -+ } -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h -new file mode 100644 -index 0000000..5e4de60 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h -@@ -0,0 +1,362 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+ -+/// Tile access iterator -+/// Each iteration acess in the tile is -+/// used as multiplicand for one -+/// warp-level matrix multiplication -+template < -+ /// Size of the tile (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand_, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads = 32, -+ /// Enable Residual Support -+ bool EnableResidual = false, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1 -+> -+class MmaTensorOpMultiplicandTileAccessIterator { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ /// Basic check -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Number of elements accessed per Shared Memory load -+ static int const kElementsPerAccess = -+ (sizeof_bits::value >= 32 ? 1 : 32 / sizeof_bits::value); -+ -+ using InstructionCount = MatrixShape< -+ Shape::kRow / InstructionShape::kRow, -+ Shape::kColumn / InstructionShape::kColumn -+ >; -+ -+ static int const kIterations = (kOperand == Operand::kA) ? -+ InstructionCount::kColumn : InstructionCount::kRow; -+ -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array< -+ Element, -+ (kOperand == Operand::kA) ? -+ (Shape::kRow * InstructionShape::kColumn / kThreads) : -+ (Shape::kColumn * InstructionShape::kRow / kThreads) -+ >; -+ -+ /// Memory access type -+ using AccessType = AlignedArray; -+ -+private: -+ -+ /// Underlying tensor reference -+ TensorRef ref_; -+ -+ /// Extent of tensor -+ MatrixCoord extent_; -+ -+ /// Origin -+ MatrixCoord origin_; -+ -+ /// Used to load residual tile -+ bool is_residual_; -+ -+ /// residual offset of each thread -+ TensorCoord residual_offset_; -+ -+ /// Iterations in a tile -+ int iterations_; -+ -+public: -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileAccessIterator( -+ TensorRef const &ref, -+ TensorCoord extent, -+ int lane_id -+ ): ref_(ref), extent_(extent), is_residual_(false), iterations_(0) { -+ -+ if (kOperand == Operand::kA) { -+ origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess); -+ } -+ else { -+ origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4); -+ } -+ -+ ref_.add_coord_offset(origin_); -+ -+ if(EnableResidual) { -+ // compute residual offset -+ if (kOperand == Operand::kA) { -+ typename TensorCoord::Index residual_size = -+ extent_.column() % Shape::kColumn; -+ if(residual_size) { -+ is_residual_ = true; -+ residual_offset_ = make_Coord(0, residual_size); -+ } -+ } -+ else { -+ typename TensorCoord::Index residual_size = -+ extent_.row() % Shape::kRow; -+ if(residual_size) { -+ is_residual_ = true; -+ residual_offset_ = make_Coord(residual_size, 0); -+ } -+ } -+ } -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileAccessIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): MmaTensorOpMultiplicandTileAccessIterator(ref, -+ {Shape::kRow, Shape::kColumn}, lane_id) { -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileAccessIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ origin_ += coord_offset; -+ -+ ref_.add_coord_offset(coord_offset); -+ -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ void advance() { -+ -+ if(EnableResidual && is_residual_) { -+ is_residual_ = false; -+ -+ origin_ += residual_offset_; -+ ref_.add_coord_offset(residual_offset_); -+ -+ } -+ -+ else { -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, 1}); -+ } -+ else { -+ add_tile_offset({1, 0}); -+ } -+ } -+ -+ iterations_ = 0; -+ } -+ -+ /// increase iterations in a tile -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileAccessIterator & operator++() { -+ -+ iterations_++; -+ -+ if(iterations_ >= kIterations) -+ advance(); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ int const kWarpShapeDivisibleInner = -+ (kOperand == Operand::kA ? InstructionShape::kColumn : InstructionShape::kRow); -+ -+ // Take advantage of Tensor Op's 8 x 4T access pattern -+ int const kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; -+ -+ AccessType *access_ptr = reinterpret_cast(&frag); -+ -+ if (kOperand == Operand::kA) { -+ int const kTilesPerInstruction = InstructionShape::kRow / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; ++access_m_idx) { -+ int access_idx = -+ access_m_idx + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx); -+ -+ MatrixCoord offset( -+ access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, -+ inner_idx * 4 * kElementsPerAccess + iterations_ * InstructionShape::kColumn); -+ -+ MatrixCoord access_coord = origin_ + offset; -+ -+// if(access_coord.row() < extent_.row() && access_coord.column() < extent_.column()) { -+ -+ access_ptr[access_idx] = *reinterpret_cast( -+ ref_.data() + ref_.offset(offset)); -+// } -+// else { -+// AccessType zero; -+// zero.clear(); -+// access_ptr[access_idx] = zero; -+// } -+ } -+ } -+ } -+ } -+ else { -+ CUTLASS_PRAGMA_UNROLL -+ for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { -+ int access_idx = inner_idx + kAccessesInner * inst_n_idx; -+ -+ MatrixCoord offset( -+ inner_idx * 4 * kElementsPerAccess + iterations_ * InstructionShape::kRow, -+ inst_n_idx * 8); -+ -+ MatrixCoord access_coord = origin_ + offset; -+ -+// if(access_coord.row() < extent_.row() && access_coord.column() < extent_.column()) { -+ -+ access_ptr[access_idx] = *reinterpret_cast( -+ ref_.data() + ref_.offset(offset)); -+// } -+// else { -+// AccessType zero; -+// zero.clear(); -+// access_ptr[access_idx] = zero; -+// } -+ } -+ } -+ } -+ } -+ -+}; -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h -new file mode 100644 -index 0000000..54f194f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h -@@ -0,0 +1,3982 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1> -+class MmaTensorOpMultiplicandTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicandCongruous::value, -+ 64>, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCongruous< -+ sizeof_bits::value, 64>; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kContiguous % InstructionShape::kContiguous), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ // Determine number of elements along outer dimension per individual LDSM op -+ static int const kLdsmOpOuter = Layout::kElementsPerAccess; -+ static int const kLdsmOpInner = 8; -+ -+ static_assert(!(Shape::kContiguous % kLdsmOpOuter), -+ "Shape of warp-level mma must be divisible by LDSM's fundamental tile size."); -+ -+ static_assert(!(Shape::kStrided % kLdsmOpInner), -+ "Shape of warp-level mma must be divisible by LDSM's fundamental tile size."); -+ -+ /// Shape of one individual LDSM instruction -+ static int const LdsmShapeStrided = -+ InstructionShape::kStrided / kLdsmOpInner; -+ static int const LdsmShapeContiguous = 4 / LdsmShapeStrided; -+ using LdsmShape = -+ layout::PitchLinearShape; -+ -+ /// Number and arrangement of LDSM instructions -+ using LdsmIterations = layout::PitchLinearShape< -+ Shape::kContiguous / Layout::kElementsPerAccess / LdsmShapeContiguous, -+ 1>; -+ -+ /// Number of groups for each tile -+ static int const kGroupsPerTile = -+ Shape::kStrided / InstructionShape::kStrided; -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Number of internal pointers needed to reference shared memory -+ static int const kPointerCount = -+ Layout::TileShape::kContiguous / Policy::LdsmShape::kContiguous; -+ -+ /// Pointer type used for accesses -+ using AccessType = Array; -+ -+ /// Internal counter used to jump to next K partition -+ int k_group_idx_; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_[kPointerCount]; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0), -+ k_group_idx_(0) { -+ -+ int quad_pair = (lane_id >> 3); -+ int quad_quad = (lane_id >> 4); -+ int lane_in_quad = (lane_id & 3); -+ int lane_in_quad_pair = (lane_id & 7); -+ int lane_in_quad_quad = (lane_id & 15); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPointerCount; ++i) { -+ int partition_contiguous_idx = -1; -+ int access_contiguous_idx = -1; -+ int access_strided_idx = -1; -+ -+ if (Policy::LdsmShape::kContiguous == 4) { -+ // Matrix multiply 1688 A/B -+ // Q0 Q1 Q2 Q3 (Q stands for 1 8x128bit block). -+ // Four blocks are next to each other in the contiguous dimension. -+ partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ i); -+ access_contiguous_idx = (quad_pair ^ lane_in_quad); -+ access_strided_idx = lane_in_quad_pair; -+ } -+ else if (Policy::LdsmShape::kContiguous == 2 && -+ kOperand == Operand::kA) { -+ // Matrix multiply 16816 A -+ // Q0 Q1 -+ // Q2 Q3 -+ partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 1)); -+ access_contiguous_idx = -+ (((quad_pair & 1) + ((i & 1) << 1)) ^ lane_in_quad); -+ access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3); -+ } else if (Policy::LdsmShape::kContiguous == 2 && -+ kOperand == Operand::kB) { -+ // Matrix multiply 16816 B -+ // Q0 Q2 -+ // Q1 Q3 -+ partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 1)); -+ access_contiguous_idx = ((quad_quad + ((i & 1) << 1)) ^ lane_in_quad); -+ access_strided_idx = lane_in_quad_quad; -+ } else if (Policy::LdsmShape::kContiguous == 1) { -+ // Matrix multiply 16832.SP B -+ // Q0 -+ // Q1 -+ // Q2 -+ // Q3 -+ partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 2)); -+ access_contiguous_idx = ((i & 3) ^ lane_in_quad); -+ access_strided_idx = lane_id; -+ } -+ -+ int access_contiguous = -+ partition_contiguous_idx * Layout::PartitionShape::kContiguous + -+ access_contiguous_idx; -+ -+ int access_strided = access_strided_idx; -+ -+ pointer_[i] = reinterpret_cast(ref.data()) + -+ access_contiguous + access_strided * stride_; -+ } -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int contiguous_offset = tile_offset.contiguous(); -+ if (Shape::kContiguous == -+ Layout::PartitionShape::kContiguous * Layout::kElementsPerAccess) { -+ if (tile_offset.contiguous() % 2) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPointerCount / 2; ++i) { -+ AccessType const *tmp_pointer = pointer_[i]; -+ pointer_[i] = pointer_[i + kPointerCount / 2]; -+ pointer_[i + kPointerCount / 2] = tmp_pointer; -+ } -+ } -+ contiguous_offset = (tile_offset.contiguous() >> 1) << 1; -+ } -+ -+ int offset = (tile_offset.strided() * InstructionShape::kStrided) * -+ stride_ * Layout::kElementsPerAccess + -+ contiguous_offset * Shape::kContiguous; -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ add_tile_offset({0, 1}); -+ -+ if (kPartitionsK > 1) { -+ ++k_group_idx_; -+ // Jump to next stage -+ if (k_group_idx_ == Policy::kGroupsPerTile) { -+ k_group_idx_ = 0; -+ add_tile_offset( -+ {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) * -+ Layout::kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ Array *fetch_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::LdsmIterations::kContiguous; -+ -+ AccessType const *source_ptr = -+ pointer_[c % kPointerCount] + -+ Layout::TileShape::kContiguous * (c / kPointerCount) + -+ Policy::kLdsmOpInner * Policy::LdsmShape::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ cutlass::arch::ldsm( -+ fetch_ptr[access_idx], -+ source_byte_ptr -+ ); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no op -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread MMA.TF32 NT TensorOps. It -+/// uses LDS.32 to load from shared memory and therefore must be initialized -+/// with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicandCongruous<32, 32>, InstructionShape_, -+ OpDelta_, 32, PartitionsK_> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand == Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for " -+ "A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCongruous<32, 32>; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kContiguous % InstructionShape::kContiguous), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ // Determine number of elements along outer dimension per individual 32bit -+ // shared memory load op. Every one warp of 32bit shared memory load loads -+ // 8x4 elements -+ static int const kLdsOpInner = Layout::TileShape::kStrided; -+ static int const kLdsOpOuter = kThreads / kLdsOpInner; -+ -+ static_assert(!(Shape::kContiguous % kLdsOpOuter), -+ "Shape of warp-level mma must be divisible by 32bit " -+ "fundamental tile size."); -+ -+ static_assert(!(Shape::kStrided % kLdsOpInner), -+ "Shape of warp-level mma must be divisible by 32bit " -+ "fundamental tile size."); -+ -+ /// Number of 32 bit shared memory load instructions needed by one MMA instruction -+ /// 1688 A 2x2 -+ /// 1688 B 1x2 -+ /// 16816 B 1x4 -+ static int const LdsShapeContiguous = -+ InstructionShape::kContiguous / kLdsOpOuter; -+ static int const LdsShapeStrided = InstructionShape::kStrided / kLdsOpInner; -+ using LdsShape = -+ layout::PitchLinearShape; -+ -+ /// Number and arrangement of LDS instructions -+ using LdsIterations = layout::PitchLinearShape< -+ Shape::kContiguous / LdsShapeContiguous / kLdsOpOuter, 1>; -+ -+ /// Number of groups for each tile -+ static int const kGroupsPerTile = -+ Shape::kStrided / InstructionShape::kStrided; -+ }; -+ -+ private: -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Number of internal pointers needed to reference shared memory -+ static int const kPointerCount = Layout::TileShape::kContiguous * -+ Layout::kElementsPerAccess / -+ Policy::kLdsOpOuter; -+ -+ /// Vectorized access is not used -+ static int const kElementsPerAccess = 1; -+ -+ /// Pointer type used for accesses -+ using AccessType = Element; -+ -+ /// Internal counter used to jump to next K partition -+ int k_group_idx_; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+ private: -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_[kPointerCount]; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() : stride_(0), byte_offset_(0) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : stride_(ref.stride(0)), byte_offset_(0), k_group_idx_(0) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPointerCount; ++i) { -+ int access_strided = lane_id % Policy::kLdsOpInner; -+ int access_contiguous = (lane_id / Policy::kLdsOpInner) + -+ (access_strided ^ i) * Policy::kLdsOpOuter; -+ -+ pointer_[i] = reinterpret_cast(ref.data()) + -+ access_contiguous + access_strided * stride_; -+ } -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ int contiguous_offset = tile_offset.contiguous(); -+ if (Shape::kContiguous == -+ Layout::TileShape::kContiguous * Layout::kElementsPerAccess / 2) { -+ if (tile_offset.contiguous() % 2) { -+ // Matrix multiply 1688 pointer_[0] <=> pointer_[4] pointer_[1] <=> pointer_[5] -+ // pointer_[2] <=> pointer_[6] pointer_[3] <=> pointer_[7] -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPointerCount / 2; ++i) { -+ AccessType const *tmp_pointer = pointer_[i]; -+ pointer_[i] = pointer_[i + kPointerCount / 2]; -+ pointer_[i + kPointerCount / 2] = tmp_pointer; -+ } -+ } -+ contiguous_offset = (tile_offset.contiguous() >> 1) << 1; -+ } -+ -+ int offset = (tile_offset.strided() * InstructionShape::kStrided) * stride_ + -+ contiguous_offset * Shape::kContiguous; -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator++() { -+ add_tile_offset({0, 1}); -+ -+ if (kPartitionsK > 1) { -+ ++k_group_idx_; -+ // Jump to next stage -+ if (k_group_idx_ == Policy::kGroupsPerTile) { -+ k_group_idx_ = 0; -+ add_tile_offset( -+ {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator--() { -+ byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) * -+ kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ Element *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int ss = 0; ss < Policy::LdsShape::kStrided; ++ss) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int cc = 0; cc < Policy::LdsShape::kContiguous; ++cc) { -+ int access_idx = -+ cc + (ss + (c + s * Policy::LdsIterations::kContiguous) * -+ Policy::LdsShape::kStrided) * -+ Policy::LdsShape::kContiguous; -+ int access_idx_contiguous = cc + c * Policy::LdsShape::kContiguous; -+ int access_idx_strided = -+ (ss + s * Policy::LdsShape::kStrided) * Policy::kLdsOpInner; -+ -+ AccessType const *source_ptr = -+ pointer_[access_idx_contiguous % kPointerCount] + -+ Layout::TileShape::kContiguous * Layout::kElementsPerAccess * -+ (access_idx_contiguous / kPointerCount) + -+ access_idx_strided * stride_; -+ -+ char const *source_byte_ptr = -+ reinterpret_cast(source_ptr) + byte_offset + -+ byte_offset_; -+ -+ fetch_ptr[access_idx] = -+ *reinterpret_cast(source_byte_ptr); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous / -+ Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no op -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA, -+ "MmaTensorOpMultiplicandIterator for ColumnMajor Congruous may " -+ "only be instantiated for A operand to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kB, -+ "MmaTensorOpMultiplicandIterator for RowMajor Congruous may " -+ "only be instantiated for B operand to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.strided(), tile_offset.contiguous()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Element number when the layout crosses (in units of elements) -+ int Crosswise, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand == Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for " -+ "A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Element number when the layout crosses -+ static int const kCrosswise = Crosswise; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kCrosswise>; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kContiguous % InstructionShape::kContiguous), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ // Determine number of elements along outer dimension per individual LDSM op -+ static int const kLdsmOpOuter = Layout::kElementsPerAccess; -+ static int const kLdsmOpInner = 8; -+ -+ static_assert(!(Shape::kContiguous % kLdsmOpOuter), -+ "Shape of warp-level mma must be divisible by LDSM's " -+ "fundamental tile size."); -+ -+ static_assert(!(Shape::kStrided % kLdsmOpInner), -+ "Shape of warp-level mma must be divisible by LDSM's " -+ "fundamental tile size."); -+ -+ /// Shape of one individual LDSM instruction -+ static int const LdsmShapeContiguous = -+ InstructionShape::kContiguous / kLdsmOpOuter; -+ static int const LdsmShapeStrided = -+ ((4 / LdsmShapeContiguous * kLdsmOpInner) > Shape::kStrided) -+ ? (Shape::kStrided / kLdsmOpInner) -+ : (4 / LdsmShapeContiguous); -+ using LdsmShape = -+ layout::PitchLinearShape; -+ -+ /// Number and arrangement of LDSM instructions -+ using LdsmIterations = -+ layout::PitchLinearShape<1, Shape::kStrided / kLdsmOpInner / -+ LdsmShape::kStrided>; -+ -+ /// -+ static int const kGroupsPerTile = Layout::TileShape::kContiguous / -+ Layout::kFactor / LdsmShape::kContiguous; -+ }; -+ -+ private: -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = Array; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+ private: -+ -+ /// Total number of sections. The memory is divided into stages. One stage -+ /// can store one tile. Stage is divided into sections. Interleaved layout -+ /// can have multiple sections in a stage. The rest layout only has one section -+ /// in a stage. -+ int sections_; -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ /// Internal counter used to determine when to increment byte offset and when -+ /// to XOR it -+ int k_group_idx_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() -+ : pointer_(nullptr), -+ sections_(0), -+ stride_(0), -+ byte_offset_(0), -+ k_group_idx_(0) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : pointer_(reinterpret_cast(ref.data())), -+ sections_(ref.stride(0) / kCrosswise), -+ // stride_ = kCrosswise x sections_ x kFactor -+ stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), -+ byte_offset_(0), -+ k_group_idx_(0) { -+ // Warp level iterator at most use double buffer to hide latency. If there -+ // are more than 2 sections, every stage should have more than 1 section. -+ -+ // Turing silicon requires all 32 threads in a warp provide valid addresses -+ // even for LDSM.1 and LDSM.2 -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 750)) -+ lane_id = lane_id % (Policy::LdsmShape::kCount * Policy::kLdsmOpInner); -+#endif -+ -+ int quad_quad = (lane_id >> 4); -+ int quad_pair = (lane_id >> 3); -+ int lane_in_pair = (lane_id & 1); -+ int lane_in_quad = (lane_id & 3); -+ int lane_in_quad_pair = (lane_id & 7); -+ int lane_in_quad_quad = (lane_id & 15); -+ -+ int partition_contiguous_idx = -1; -+ int access_contiguous_idx = -1; -+ int access_strided_idx = -1; -+ -+ if (Layout::kFactor == 4) { -+ // Super Integer matrix multiply Interleaved-32 -+ -+ int factor_in_partition = -+ (Layout::PartitionShape::kContiguous * Layout::kFactor / -+ Layout::TileShape::kContiguous); -+ -+ if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { -+ // Integer matrix multiply 8816 A/B -+ partition_contiguous_idx = lane_in_quad / factor_in_partition; -+ access_contiguous_idx = ((lane_in_pair * factor_in_partition) ^ -+ (lane_in_quad_quad / Layout::kFactor)); -+ access_strided_idx = lane_id / Layout::kFactor; -+ } -+ else if (Policy::LdsmShape::kStrided == -+ (Policy::LdsmShape::kCount / 2) && -+ kOperand == Operand::kA) { -+ // Integer matrix multiply 16832 A -+ partition_contiguous_idx = lane_in_quad / factor_in_partition; -+ access_strided_idx = lane_in_quad_quad / Layout::kFactor; -+ access_contiguous_idx = -+ ((lane_in_pair * factor_in_partition + quad_quad) ^ -+ access_strided_idx); -+ } -+ else if (Policy::LdsmShape::kStrided == -+ (Policy::LdsmShape::kCount / 2) && -+ kOperand == Operand::kB) { -+ // Integer matrix multiply 16832 B -+ partition_contiguous_idx = lane_in_quad / factor_in_partition; -+ access_strided_idx = lane_in_quad_pair / Layout::kFactor + quad_quad * 2; -+ access_contiguous_idx = -+ ((lane_in_pair * factor_in_partition + ((lane_id & 8) >> 3)) ^ -+ access_strided_idx); -+ } -+ } else if (Layout::kFactor == 2) { -+ // Super Matrix multiply kBlock = 32 -+ if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { -+ // Matrix multiply 1688 A/B -+ // (Q stands for 1 8x128bit block). -+ // Q0 -+ // Q1 -+ // Q2 -+ // Q3 -+ // Four blocks are next to each other in the strided dimension. -+ partition_contiguous_idx = (lane_id % Layout::kFactor); -+ access_contiguous_idx = (lane_in_quad_pair / Layout::kFactor); -+ access_strided_idx = lane_id / Layout::kFactor; -+ } -+ else if (Policy::LdsmShape::kStrided == -+ (Policy::LdsmShape::kCount / 2) && -+ kOperand == Operand::kA) { -+ // Matrix multiply 16816|1688.TF32 A -+ // Q0 Q2 -+ // Q1 Q3 -+ partition_contiguous_idx = (lane_id % Layout::kFactor); -+ access_contiguous_idx = -+ (quad_quad ^ (lane_in_quad_pair / Layout::kFactor)); -+ access_strided_idx = (lane_in_quad_quad / Layout::kFactor); -+ } else if (Policy::LdsmShape::kStrided == -+ (Policy::LdsmShape::kCount / 2) && -+ kOperand == Operand::kB) { -+ // Matrix multiply 16816|1688.TF32 B -+ // Q0 Q1 -+ // Q2 Q3 -+ partition_contiguous_idx = (lane_id % Layout::kFactor); -+ access_contiguous_idx = -+ ((quad_pair & 1) ^ (lane_in_quad_pair / Layout::kFactor)); -+ access_strided_idx = -+ (lane_in_quad_pair + (lane_id >> 4 << 3)) / Layout::kFactor; -+ } -+ else if (Policy::LdsmShape::kContiguous == Policy::LdsmShape::kCount) { -+ // Matrix multiply 16832.SP B -+ // Q0 Q1 Q2 Q3 -+ partition_contiguous_idx = (lane_id % Layout::kFactor); -+ access_contiguous_idx = -+ (quad_pair ^ (lane_in_quad_pair / Layout::kFactor)); -+ access_strided_idx = lane_in_quad_pair / Layout::kFactor; -+ } -+ } else if (Layout::kFactor == 1) { -+ // Super Matrix multiply kBlock = 64 -+ if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { -+ // Q0 -+ // Q1 -+ // Q2 -+ // Q3 -+ partition_contiguous_idx = (lane_in_quad_pair >> 2); -+ access_contiguous_idx = lane_in_quad; -+ access_strided_idx = lane_id; -+ } -+ else if (Policy::LdsmShape::kStrided == -+ (Policy::LdsmShape::kCount / 2) && -+ kOperand == Operand::kA) { -+ // Matrix multiply 16816|1688.TF32 A -+ // Q0 Q2 -+ // Q1 Q3 -+ partition_contiguous_idx = (lane_in_quad_pair >> 2); -+ access_contiguous_idx = (quad_quad ^ lane_in_quad); -+ access_strided_idx = lane_in_quad_quad; -+ } else if (Policy::LdsmShape::kStrided == -+ (Policy::LdsmShape::kCount / 2) && -+ kOperand == Operand::kB) { -+ // Matrix multiply 16816|1688.TF32 B -+ // Q0 Q1 -+ // Q2 Q3 -+ partition_contiguous_idx = (lane_in_quad_pair >> 2); -+ access_contiguous_idx = ((quad_pair & 1) ^ lane_in_quad); -+ access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3); -+ } -+ else if (Policy::LdsmShape::kContiguous == Policy::LdsmShape::kCount) { -+ // Matrix multiply 16832.SP B -+ // Q0 Q1 Q2 Q3 -+ partition_contiguous_idx = (lane_in_quad_pair >> 2); -+ access_contiguous_idx = (quad_pair ^ lane_in_quad); -+ access_strided_idx = lane_in_quad_pair; -+ } -+ } -+ -+ int access_contiguous = -+ partition_contiguous_idx * Layout::PartitionShape::kContiguous + -+ access_contiguous_idx; -+ -+ int access_strided = access_strided_idx; -+ -+ byte_offset_ = (access_contiguous + access_strided * stride_) * -+ sizeof_bits::value * Layout::kElementsPerAccess / 8; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ byte_offset_ += offset * sizeof_bits::value / 8; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ int whole_tiles = tile_offset.contiguous() / Policy::kGroupsPerTile; -+ int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile; -+ -+ byte_offset_ ^= k_groups_delta * sizeof_bits::value * -+ Layout::kElementsPerAccess * -+ Policy::LdsmShape::kContiguous / 8; -+ pointer_ += -+ tile_offset.strided() * stride_ * Shape::kStrided / Layout::kFactor + -+ whole_tiles * stride_ / sections_; -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative( -+ TensorCoord const &tile_offset) { -+ -+ int whole_tiles = tile_offset.contiguous() / Policy::kGroupsPerTile; -+ int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile; -+ if (k_groups_delta < 0) { -+ whole_tiles -= 1; -+ k_groups_delta += Policy::kGroupsPerTile; -+ } -+ -+ if ((Policy::kGroupsPerTile / kPartitionsK) >= 2) { -+ byte_offset_ ^= (k_groups_delta & 1) * Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * -+ Layout::kElementsPerAccess / 8; -+ } -+ if ((Policy::kGroupsPerTile / kPartitionsK) >= 4) { -+ byte_offset_ ^= ((k_groups_delta + (k_group_idx_ & 1)) & 2) * -+ Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * -+ Layout::kElementsPerAccess / 8; -+ } -+ if ((Policy::kGroupsPerTile / kPartitionsK) == 8) { -+ byte_offset_ ^= ((k_groups_delta + (k_group_idx_ & 3)) & 4) * -+ Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * -+ Layout::kElementsPerAccess / 8; -+ } -+ -+ k_group_idx_ += k_groups_delta; -+ whole_tiles += k_group_idx_ / (Policy::kGroupsPerTile / kPartitionsK); -+ k_group_idx_ = k_group_idx_ % (Policy::kGroupsPerTile / kPartitionsK); -+ -+ pointer_ += -+ tile_offset.strided() * stride_ * Shape::kStrided / Layout::kFactor + -+ whole_tiles * stride_ / sections_; -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator++() { -+ -+ // Integer matrix multiply 16832 Interleaved-32 -+ // NONE -+ // Integer matrix multiply 16816 Interleaved-32 || Integer matrix multiply 16816 kblock=32 -+ -+ // Integer matrix multiply 8816 Interleaved-32 -+ // ^1 ^1 -+ // Matrix multiply 1684.TF32 kblock=16 || Integer matrix multiply 16816 kblock=64 -+ // Matrix multiply 1688 kblock=32 || Integer matrix multiply 8816 kblock=64 -+ // ^1 ^3 ^1 ^3 -+ // Matrix multiply 1688 kblock=64 -+ // ^1 ^3 ^1 ^7 ^1 ^3 ^1 ^7 -+ -+ // Matrix multiply 16816 kblock=32 | 1688.TF32 kblock=16 || Integer matrix multiply 16832 kblock=64 -+ // ^2 ^2 -+ // Matrix multiply 16816 kblock=64 | 1688.TF32 kblock=32 || Integer matrix multiply 16832 kblock=128 -+ // ^2 ^6 ^2 ^6 -+ -+ if ((Policy::kGroupsPerTile / kPartitionsK) > 1) { -+ int mask = ((Policy::kGroupsPerTile / kPartitionsK) == 8) -+ ? 3 -+ : (((Policy::kGroupsPerTile / kPartitionsK) == 4) ? 1 : 0); -+ -+ if (((k_group_idx_ & mask) % 2) == 0) -+ byte_offset_ ^= 1 * Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * -+ Layout::kElementsPerAccess / 8; -+ else if ((k_group_idx_ & mask) == 1) -+ byte_offset_ ^= 3 * Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * -+ Layout::kElementsPerAccess / 8; -+ else if ((k_group_idx_ & mask) == 3) -+ byte_offset_ ^= 7 * Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * -+ Layout::kElementsPerAccess / 8; -+ } -+ -+ k_group_idx_++; -+ -+ if (k_group_idx_ == (Policy::kGroupsPerTile / kPartitionsK)) { -+ k_group_idx_ = 0; -+ add_tile_offset({Policy::kGroupsPerTile, 0}); -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator--() { assert(0); } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ Array *fetch_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { -+ int access_idx = c + s * Policy::LdsmIterations::kContiguous; -+ -+ AccessType const *source_ptr = -+ pointer_ + Policy::LdsmShape::kContiguous * c + -+ Policy::kLdsmOpInner / Layout::kFactor * -+ Policy::LdsmShape::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = -+ reinterpret_cast(source_ptr) + byte_offset + -+ byte_offset_; -+ -+ cutlass::arch::ldsm( -+ fetch_ptr[access_idx], source_byte_ptr); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = tile_offset.contiguous() * -+ InstructionShape::kContiguous / -+ Layout::kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided * stride_; -+ -+ byte_offset += sizeof_bits::value * pointer_offset / 8; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ k_group_idx_ = k_group % (Policy::kGroupsPerTile / kPartitionsK); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Element number when the layout crosses (in units of elements) -+ int Crosswise, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kB, -+ "MmaTensorOpMultiplicandIterator for ColumnMajor Crosswise may " -+ "only be instantiated for B operand to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// KBlock size -+ static int const kCrosswise = Crosswise; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kCrosswise>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCrosswise::value, -+ kCrosswise>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+ private: -+ /// Underlying tile iterator -+ Base iterator_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : iterator_({ref.data(), ref.stride()}, lane_id) {} -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset_negative({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator++() { -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator--() { -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { iterator_.load(frag); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, {tile_offset.contiguous(), tile_offset.strided()}, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Element number when the layout crosses (in units of elements) -+ int Crosswise, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA, -+ "MmaTensorOpMultiplicandIterator for RowMajor Crosswise may " -+ "only be instantiated for A operand to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Element number when the layout crosses -+ static int const kCrosswise = Crosswise; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kCrosswise>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCrosswise::value, -+ kCrosswise>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+ private: -+ /// Underlying tile iterator -+ Base iterator_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : iterator_({ref.data(), ref.stride()}, lane_id) {} -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset_negative({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator++() { -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator--() { -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { iterator_.load(frag); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, {tile_offset.strided(), tile_offset.contiguous()}, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpAccumulatorTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -+/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -+/// accumulator layout. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpAccumulatorTileIterator< -+ Shape_, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static bool const kDivisible = -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ (Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM, -+ (Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN -+ >; -+ }; -+ -+private: -+ -+ // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire -+ // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements -+ // of that row. The accumulators within one row are assumed to be consecutive. -+ static int const kElementsPerAccess = InstructionShape::kN / 4; -+ static int const kRowsPerTile = 8; -+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array< -+ Element, -+ Policy::MmaIterations::kCount * InstructionShape::kMN / kThreads>; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ -+ frag[mma_accum_start + row * kElementsPerAccess + col] = offset_ref.at({accum_m, accum_n}); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ -+ offset_ref.at({accum_m, accum_n}) = frag[idx]; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -+/// accumulators from memory and is agnostic to layout. -+/// -+/// This iterator is not tested. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpAccumulatorTileIterator< -+ Shape_, Element_, cutlass::layout::AffineRankN<2>, InstructionShape_, OpDelta_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static bool const kDivisible = -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ (Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM, -+ (Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN -+ >; -+ }; -+ -+private: -+ -+ // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire -+ // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements -+ // of that row. The accumulators within one row are assumed to be consecutive. -+ static int const kElementsPerAccess = InstructionShape::kN / 4; -+ static int const kRowsPerTile = 8; -+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array< -+ Element, -+ Policy::MmaIterations::kCount * InstructionShape::kMN / kThreads>; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ -+ frag[mma_accum_start + row * kElementsPerAccess + col] = offset_ref.at({accum_m, accum_n}); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ -+ offset_ref.at({accum_m, accum_n}) = frag[idx]; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -+/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -+/// accumulator layout. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpAccumulatorTileIterator { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static bool const kDivisible = -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ (Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM, -+ (Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN -+ >; -+ }; -+ -+private: -+ -+ // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire -+ // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements -+ // of that row. The accumulators within one row are assumed to be consecutive. -+ static int const kElementsPerAccess = InstructionShape::kN / 4; -+ static int const kRowsPerTile = 8; -+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ -+ frag[idx] = offset_ref.at({accum_m, accum_n}); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ -+ offset_ref.at({accum_m, accum_n}) = frag[idx]; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -+/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -+/// accumulator layout. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element typ -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_, -+ /// Interleaved N -+ int InterleavedN> -+class MmaTensorOpAccumulatorTileIterator< -+ Shape_, Element_, cutlass::layout::ColumnMajorInterleaved, -+ InstructionShape_, OpDelta_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorInterleaved; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape; -+ }; -+ -+private: -+ -+ static int const kElementsPerAccess = 2; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ using AccessType = Array; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ int accum_m = mma_m * InstructionShape::kM; -+ int accum_n = mma_n * InstructionShape::kN; -+ -+ int idx = mma_m + mma_n * Policy::MmaIterations::kRow; -+ -+ AccessType* access_ptr = reinterpret_cast(offset_ref.data() + -+ offset_ref.offset(TensorCoord(accum_m, accum_n))); -+ -+ frag_ptr[idx] = access_ptr[0]; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ int accum_m = mma_m * InstructionShape::kM; -+ int accum_n = mma_n * InstructionShape::kN; -+ -+ int idx = mma_m + mma_n * Policy::MmaIterations::kRow; -+ -+ AccessType* access_ptr = reinterpret_cast(offset_ref.data() + -+ offset_ref.offset(TensorCoord(accum_m, accum_n))); -+ -+ access_ptr[0] = frag_ptr[idx]; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -+/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -+/// accumulator layout. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element typ -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_, -+ /// Interleaved N -+ int InterleavedN> -+class MmaTensorOpAccumulatorTileIterator< -+ Shape_, Element_, cutlass::layout::TensorNCxHWx, -+ InstructionShape_, OpDelta_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = int8_t; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorNCxHWx; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of elements in strided dimension that each STG writes -+ static int const kStridedPerSTG = 8; -+ -+ /// Factor to calculate reorder index to pack accumulator. -+ static int const kPackedFactor = Shape::kColumn / 32; -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape; -+ }; -+ -+private: -+ -+ static int const kElementsPerAccess = InterleavedN / 4; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ struct alignas((kElementsPerAccess * sizeof_bits::value / 8)) AccessType { -+ Array storage; -+ }; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+ /// Row offset index globally -+ LongIndex global_offset_row_; -+ -+ /// Column offset index globally -+ LongIndex global_offset_col_; -+ -+ /// Output tensor size -+ TensorCoord extent_; -+ -+ /// Alpha -+ float alpha_; -+ -+ /// Beta -+ float beta_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int const lane_id, -+ TensorCoord extent, -+ float alpha = 1.0f, -+ float beta = 0.0f -+ ): -+ ref_(ref), -+ extent_(extent), -+ alpha_(alpha), -+ beta_(beta) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ global_offset_row_ = quad; -+ -+ global_offset_col_ = lane_in_quad * kElementsPerAccess; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_tile_offset(MatrixCoord const &tile_offset) { -+ -+ global_offset_row_ += tile_offset.row() * Shape::kRow; -+ -+ global_offset_col_ += tile_offset.column() * Shape::kColumn; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kN; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kM; ++mma_m) { -+ int accum_m = mma_m * InstructionShape::kM; -+ int accum_n = mma_n * InstructionShape::kN; -+ -+ int idx = mma_m + mma_n * Policy::MmaIterations::kM; -+ -+ AccessType* access_ptr = reinterpret_cast(offset_ref.data() + -+ accum_m * offset_ref.stride(0) + accum_n); -+ -+ frag_ptr[idx] = access_ptr[0]; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ Array output_frag_f; -+ Array output_frag; -+ -+ LongIndex pq = extent_.h() * extent_.w(); -+ -+ LongIndex extent_row = extent_.n() * pq; -+ LongIndex extent_col = extent_.c(); -+ -+ LongIndex k_major = (global_offset_col_ / InterleavedN) * pq; -+ Index k_minor = global_offset_col_ % InterleavedN; -+ LongIndex k_offset = k_major * InterleavedN + k_minor; -+ LongIndex k_offset_delta = pq * InterleavedN; -+ -+ LongIndex stride_n = pq * extent_.c(); -+ -+ Index n; -+ LongIndex pq_rem; -+ -+ unsigned int pq_mul, pq_shr; -+ find_divisor(pq_mul, pq_shr, pq); -+ -+ if(beta_ == 0.0f) { -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < frag.size(); ++i) { -+ output_frag_f[i] = frag[i]; -+ } -+ -+ if(InstructionShape::kM == Policy::kStridedPerSTG) { -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < frag.size(); ++i) { -+ output_frag[i] = (Element)(output_frag_f[i] * alpha_); -+ } -+ } else { -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < frag.size(); ++i) { -+ int map_i = (i / (16 * Policy::kPackedFactor)) * (16 * Policy::kPackedFactor) -+ + (i % (8 * Policy::kPackedFactor)) / 2 * 4 -+ + (i % (8 * Policy::kPackedFactor)) % 2 -+ + (i / (8 * Policy::kPackedFactor)) % 2 * 2; -+ output_frag[i] = (Element)(output_frag_f[map_i] * alpha_); -+ } -+ } -+ -+ AccessType const *frag_ptr = reinterpret_cast(&output_frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ int accum_m = mma_m * Policy::kStridedPerSTG; -+ -+ fast_divmod(n, pq_rem, global_offset_row_ + accum_m, pq, pq_mul, pq_shr); -+ LongIndex offset_m = n * stride_n + k_offset + pq_rem * InterleavedN; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ -+ int accum_n = mma_n * InterleavedN; -+ -+ int idx = mma_n + mma_m * Policy::MmaIterations::kColumn; -+ -+ if((global_offset_row_ + accum_m < extent_row) && (global_offset_col_ + accum_n < extent_col)) { -+ AccessType* access_ptr = reinterpret_cast(offset_ref.data() + -+ offset_m + mma_n * k_offset_delta); -+ -+ access_ptr[0] = frag_ptr[idx]; -+ } -+ } -+ } -+ } else { -+ if(InstructionShape::kM == Policy::kStridedPerSTG) { -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < frag.size(); ++i) { -+ output_frag_f[i] = frag[i]; -+ } -+ } else { -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < frag.size(); ++i) { -+ int map_i = (i / (16 * Policy::kPackedFactor)) * (16 * Policy::kPackedFactor) -+ + (i % (8 * Policy::kPackedFactor)) / 2 * 4 -+ + (i % (8 * Policy::kPackedFactor)) % 2 -+ + (i / (8 * Policy::kPackedFactor)) % 2 * 2; -+ output_frag_f[i] = frag[map_i]; -+ } -+ } -+ -+ AccessType const *frag_ptr = reinterpret_cast(&output_frag); -+ -+ Array ref_frag; -+ AccessType *ref_frag_ptr = reinterpret_cast(&ref_frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ int accum_m = mma_m * Policy::kStridedPerSTG; -+ -+ fast_divmod(n, pq_rem, global_offset_row_ + accum_m, pq, pq_mul, pq_shr); -+ LongIndex offset_m = n * stride_n + k_offset + pq_rem * InterleavedN; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ -+ int accum_n = mma_n * InterleavedN; -+ -+ int idx = mma_n + mma_m * Policy::MmaIterations::kColumn; -+ -+ if((global_offset_row_ + accum_m < extent_row) && (global_offset_col_ + accum_n < extent_col)) { -+ AccessType* access_ptr = reinterpret_cast(offset_ref.data() + -+ offset_m + mma_n * k_offset_delta); -+ -+ ref_frag_ptr[0] = access_ptr[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < kElementsPerAccess; ++i) { -+ output_frag[idx * kElementsPerAccess + i] = Element(alpha_ * output_frag_f[idx * kElementsPerAccess + i] -+ + beta_ * ref_frag[i]); -+ } -+ -+ access_ptr[0] = frag_ptr[idx]; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h -new file mode 100644 -index 0000000..bf192e6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h -@@ -0,0 +1,3106 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm70.h" -+ -+#include "cutlass/platform/platform.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads> -+class MmaVoltaTensorOpMultiplicandTileIterator; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand::kA, Element_, -+ cutlass::layout::VoltaTensorOpMultiplicandCongruous< -+ sizeof_bits::value>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::VoltaTensorOpMultiplicandCongruous::value>; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kContiguous % InstructionShape::kContiguous), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ // Shape of one individual LDS.128 -+ // TODO: 32 and 4 are hardcoded, 32-by-4 is logical shape -+ using LdsShape = layout::PitchLinearShape< -+ 32, -+ 4 -+ >; -+ -+ // LdsShapes are arranged in the strided direction in SMEM -+ using LdsIterations = layout::PitchLinearShape< -+ InstructionShape::kStrided / LdsShape::kStrided, -+ Shape::kContiguous / LdsShape::kContiguous -+ >; -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Number of internal pointers needed to reference shared memory -+ static int const kPointerCount = 2; -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_[kPointerCount]; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { -+ // swizzle patterns for operandA LDS are -+ // 1. (tid[4] << 3) | (tid[2:0] ^ tid[4]) -+ // 2. (tid[4] << 3) | (tid[2:0] ^ tid[4] ^ 0b10010) -+ -+ int vec_row = (lane_id >> 4); // tid[4] -+ int vec_col = ((lane_id & 4) >> 2); // tid[2] -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPointerCount; ++i) { -+ -+ if(i == 1) { -+ vec_row |= 2; -+ } -+ int access_contiguous_idx = (vec_col << 2) | ((lane_id & 3) ^ vec_row); -+ int access_contiguous = access_contiguous_idx; -+ -+ int access_strided = vec_row; -+ pointer_[i] = reinterpret_cast(ref.data()) + -+ access_contiguous + access_strided * stride_; -+ } -+ -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int contiguous_offset = tile_offset.contiguous(); -+ int strided_offset = tile_offset.strided(); -+ -+ // To support 32x32 tile size -+ if (Shape::kContiguous == Policy::LdsShape::kContiguous) { -+ if (contiguous_offset % 2) { -+ AccessType const *tmp_pointer = pointer_[0]; -+ pointer_[0] = pointer_[1]; -+ pointer_[1] = tmp_pointer; -+ } -+ contiguous_offset = contiguous_offset / 2 * 2; -+ } -+ -+ int offset = (strided_offset * InstructionShape::kStrided) * stride_ * -+ Layout::kElementsPerAccess + -+ contiguous_offset * Shape::kContiguous; -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator++() { -+ byte_offset_ += stride_ * InstructionShape::kStrided * sizeof(Element) * -+ Layout::kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator--() { -+ byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) * -+ Layout::kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType * fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::LdsIterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_[s & 1] + -+ Policy::LdsShape::kContiguous * c + -+ Policy::LdsShape::kStrided * (s / 2) * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ fetch_ptr[access_idx] = *(reinterpret_cast (source_byte_ptr)); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous / -+ Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+////////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+ -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand::kB, Element_, -+ cutlass::layout::VoltaTensorOpMultiplicandBCongruous< -+ sizeof_bits::value>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kB; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::VoltaTensorOpMultiplicandBCongruous::value>; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kContiguous % InstructionShape::kContiguous), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ // Shape of one individual LDS -+ // TODO: remove hardcoded 32 and 4 -+ using LdsShape = layout::PitchLinearShape< -+ 32, -+ 4 -+ >; -+ -+ using LdsIterations = layout::PitchLinearShape< -+ Shape::kContiguous / LdsShape::kContiguous, -+ InstructionShape::kStrided / LdsShape::kStrided -+ >; -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile, needs on more time number of registers -+ using Fragment = Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { -+ -+ // swizzle pattern is (tid & (3 << 3) | (tid[1:0] ^ tid[4:3])) -+ int access_strided = (lane_id >> 3) & 0x3; -+ int access_contiguous = ((lane_id ^ (lane_id >> 3)) & 0x3); -+ -+ pointer_ = reinterpret_cast(ref.data()) + -+ access_contiguous + access_strided * stride_; -+ -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int contiguous_offset = tile_offset.contiguous(); -+ int strided_offset = tile_offset.strided(); -+ -+ int offset = (strided_offset * InstructionShape::kStrided) * stride_ * -+ Layout::kElementsPerAccess + -+ contiguous_offset * Shape::kContiguous; -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator++() { -+ byte_offset_ += stride_ * InstructionShape::kStrided * sizeof(Element) * -+ Layout::kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator--() { -+ byte_offset_ += stride_ * InstructionShape::kStrided * sizeof(Element) * -+ Layout::kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType * fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::LdsIterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::LdsShape::kContiguous / Layout::kElementsPerAccess * c + -+ Policy::LdsShape::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ fetch_ptr[access_idx] = *(reinterpret_cast (source_byte_ptr)); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous / -+ Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+////////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand::kA, Element_, -+ cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous< -+ sizeof_bits::value>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaVoltaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::VoltaTensorOpMultiplicandCongruous::value>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand::kB, Element_, -+ cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous< -+ sizeof_bits::value>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kB; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaVoltaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::VoltaTensorOpMultiplicandBCongruous::value>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.strided(), tile_offset.contiguous()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -+/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -+/// accumulator layout. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaVoltaTensorOpAccumulatorTileIterator { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ -+ /// Volta Tensor Op uses 32x32 interleaved tile -+ using InterleavedTile = MatrixShape<32, 32>; -+ -+ static_assert(!(Shape::kRow % InterleavedTile::kRow) && !(Shape::kColumn % InterleavedTile::kColumn), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using TileIterations = MatrixShape< -+ Shape::kRow / InterleavedTile::kRow, -+ Shape::kColumn / InterleavedTile::kColumn -+ >; -+ -+ using MmaIterations = -+ MatrixShape; -+ }; -+ -+private: -+ -+ // Assume accumulator tile is multipile interleaved 32x32 tile. -+ static int const kElementsPerPartial = 4; -+ using EleShapePerPatial = typename platform::conditional< -+ platform::is_same::value, -+ MatrixShape<2, 2>, -+ MatrixShape<1, 4> >::type; -+ static int const kElementsPerMma = 8; -+ static int const kAccumulatorPatials = 2; -+ using QuadShapePerPatialMma = MatrixShape<4, 4>; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ int accum_m, accum_n; -+ -+ if (platform::is_same::value) { -+ // (quad[2],quad[0])+lane_in_quad[0] -+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); -+ // (quad[1])+lane_in_quad[1] -+ accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + -+ (lane_in_quad & 2); -+ } else { -+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + lane_in_quad; // (quad[2],quad[0]) -+ accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; -+ } -+ MatrixCoord lane_offset(accum_m, accum_n); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = -+ (((tile_n * Policy::TileIterations::kRow + tile_m) * -+ Policy::MmaIterations::kColumn + mma_n) * -+ Policy::MmaIterations::kRow + mma_m) * -+ kElementsPerMma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < kAccumulatorPatials; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < EleShapePerPatial::kRow; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { -+ int accum_m = tile_m * Policy::InterleavedTile::kRow + -+ mma_m * QuadShapePerPatialMma::kRow + m * 2; -+ int accum_n = tile_n * Policy::InterleavedTile::kColumn + -+ mma_n * QuadShapePerPatialMma::kColumn + -+ p * Policy::InterleavedTile::kColumn/2 + n; -+ int idx = mma_accum_start + p * kElementsPerPartial + -+ m * EleShapePerPatial::kColumn + n; -+ frag[idx] = offset_ref.at({accum_m, accum_n}); -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_HOST_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_HOST_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = -+ (((tile_n * Policy::TileIterations::kRow + tile_m) * -+ Policy::MmaIterations::kColumn + mma_n) * -+ Policy::MmaIterations::kRow + mma_m) * -+ kElementsPerMma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < kAccumulatorPatials; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < EleShapePerPatial::kRow; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { -+ int accum_m = tile_m * Policy::InterleavedTile::kRow + -+ mma_m * QuadShapePerPatialMma::kRow + m * 2; -+ int accum_n = tile_n * Policy::InterleavedTile::kColumn + -+ mma_n * QuadShapePerPatialMma::kColumn + -+ p * Policy::InterleavedTile::kColumn/2 + n; -+ int idx = mma_accum_start + p * kElementsPerPartial + -+ m * EleShapePerPatial::kColumn + n; -+ offset_ref.at({accum_m, accum_n}) = frag[idx]; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_HOST_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_HOST_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_HOST_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDS to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// KBlock size (in units of elements) -+ int KBlock> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::VoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, KBlock>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand == Operand::kB, -+ "MmaVoltaTensorOpMultiplicandIterator may only be instantiated for " -+ "A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// KBlock size -+ static int const kKBlock = KBlock; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::VoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kKBlock>; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ -+ /// Shape of one individual LDS instruction -+ using LdsShape = layout::PitchLinearShape<1, 32>; -+ -+ /// Number and arrangement of LDSM instructions -+ using LdsIterations = layout::PitchLinearShape<1, Shape::kStrided / 32>; -+ -+ /// Using LDS.128 -+ static int const kElementsPerAccess = 8; -+ -+ /// Contiguous elements per line -+ static int const kContiguousElementsPerLine = 4; -+ }; -+ -+ private: -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+ private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ /// Crosswised elements are arranged in a SMEM line -+ /// in units of AccessType -+ Index line_size; -+ -+ /// Internal counter used to determine load addr offset -+ /// and when to swap higher 64bit with lower 64bit -+ int k_group_idx_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator() -+ : pointer_(nullptr), -+ stride_(0), -+ line_size(0), -+ byte_offset_(0), -+ k_group_idx_(0) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : pointer_(reinterpret_cast(ref.data())), -+ stride_(ref.stride(0) * Policy::kElementsPerAccess), -+ line_size((ref.stride(0) * Policy::kContiguousElementsPerLine) / -+ Policy::kElementsPerAccess), -+ k_group_idx_(0), -+ byte_offset_(0) { -+ -+ int quad = (lane_id / 4); -+ int lane_in_quad = (lane_id % 4); -+ int access_contiguous; -+ -+ if(kOperand == Operand::kA) { -+ -+ // swizzle id: tid[4]|tid[1:0]|(tid[2]^tid[4]) -+ access_contiguous = ((quad & 0x4) << 1) + ((lane_in_quad) << 1) + -+ ((quad & 0x1) ^ ((quad & 0x4) >> 2)); -+ } else { -+ -+ // swizzle id: tid[4]|tid[1:0]|tid[3] -+ access_contiguous = ((quad & 0x4) << 1) + (lane_in_quad << 1) + -+ ((quad & 0x2) >> 1 ^ ((quad & 0x4) >> 2)); -+ } -+ -+ byte_offset_ = access_contiguous * -+ sizeof(Element) * Policy::kElementsPerAccess; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ -+ int contiguous_offset = tile_offset.contiguous(); -+ int strided_offset = tile_offset.strided(); -+ k_group_idx_ = 0; -+ -+ pointer_ += contiguous_offset * -+ (InstructionShape::kContiguous / -+ Policy::kContiguousElementsPerLine) * -+ line_size + -+ strided_offset * Shape::kStrided / 2; -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator++() { -+ k_group_idx_ = (k_group_idx_ + 1) % 8; -+ -+ if (k_group_idx_ == 4 || k_group_idx_ == 0) { -+ byte_offset_ ^= 1 * sizeof(Element) * Policy::kElementsPerAccess; -+ } -+ -+ pointer_ += line_size; -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator--() { assert(0); } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType * fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::LdsIterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::LdsShape::kContiguous * c * line_size + -+ Policy::LdsShape::kStrided * s / 2; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ fetch_ptr[access_idx] = *(reinterpret_cast (source_byte_ptr)); -+ -+ // swap higher 64bit and lower 64bit -+ if (k_group_idx_ & 0x2) { -+ uint64_t *low = reinterpret_cast(&frag) + access_idx * 2; -+ uint64_t *high = reinterpret_cast(&frag) + access_idx * 2 + 1; -+ uint64_t tmp = *low; -+ *low = *high; -+ *high = tmp; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = tile_offset.contiguous() * -+ InstructionShape::kContiguous / -+ Policy::kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ k_group_idx_ = k_group; -+ } -+}; -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDS to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// KBlock size (in units of elements) -+ int KBlock> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, KBlock>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand == Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for " -+ "A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// KBlock size -+ static int const kKBlock = KBlock; -+ -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kKBlock>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaVoltaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::VoltaTensorOpMultiplicandCrosswise::value, -+ kKBlock>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads>; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+ private: -+ /// Underlying tile iterator -+ Base iterator_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator() {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : iterator_({ref.data(), ref.stride()}, lane_id) {} -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator++() { -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator--() { -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { iterator_.load(frag); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, {tile_offset.contiguous(), tile_offset.strided()}, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDS to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// KBlock size (in units of elements) -+ int KBlock> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, KBlock>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand == Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for " -+ "A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// KBlock size -+ static int const kKBlock = KBlock; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kKBlock>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaVoltaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::VoltaTensorOpMultiplicandCrosswise::value, -+ kKBlock>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads>; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+ private: -+ /// Underlying tile iterator -+ Base iterator_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator() {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : iterator_({ref.data(), ref.stride()}, lane_id) {} -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator++() { -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator--() { -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { iterator_.load(frag); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, {tile_offset.strided(), tile_offset.contiguous()}, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for 'TN' arrangement -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand_, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of matrix operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads = 32, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1> -+class MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ /// Basic check -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaVoltaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Number of elements accessed per Shared Memory load -+ static int const kElementsPerAccess = 4; -+ -+private: -+ -+ static int const kInterleavedTileRows = 32; -+ static int const kInterleavedTileColumns = 32; -+ static int const kInstructionsPerTile = 2; -+ -+ /// Rounded up instruction counts -+ using TileCount = MatrixShape< -+ Shape::kRow / kInterleavedTileRows, -+ Shape::kColumn / kInterleavedTileColumns -+ >; -+ -+ using FragmentCount = MatrixShape< -+ TileCount::kRow * kInstructionsPerTile, -+ TileCount::kColumn * kInstructionsPerTile -+ >; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array< -+ Element, -+ (kOperand == Operand::kA ? FragmentCount::kRow : FragmentCount::kColumn) * kElementsPerAccess -+ >; -+ -+ /// Memory access type -+ using AccessType = AlignedArray; -+ -+private: -+ -+ /// Underlying tensor reference -+ TensorRef ref_; -+ -+ /// Extent of tensor -+ MatrixCoord extent_; -+ -+ /// Origin -+ MatrixCoord origin_; -+ -+ /// Used to conditionally enable extents checking -+ bool divisible_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner(): divisible_(true) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) { -+ -+ int quad_id = lane_id / 4; -+ int lane_in_quad = (lane_id % 4); -+ -+ if (kOperand == Operand::kA) { -+ -+ int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile + lane_in_quad; -+ int col_idx = 0; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ else { -+ -+ int row_idx = 0; -+ int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile + lane_in_quad; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ -+ ref_.add_coord_offset(origin_); -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner( -+ TensorRef const &ref, -+ TensorCoord extent, -+ int lane_id -+ ): ref_(ref), extent_(extent), divisible_(false) { -+ -+ int quad_id = lane_id / 4; -+ int lane_in_quad = (lane_id % 4); -+ -+ if (kOperand == Operand::kA) { -+ -+ int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile + lane_in_quad; -+ int col_idx = 0; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ else { -+ -+ int row_idx = 0; -+ int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile + lane_in_quad; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ -+ #if defined(__CUDA_ARCH__) -+ __syncthreads(); -+ #endif -+ -+ ref_.add_coord_offset(origin_); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner &add_pointer_offset(LongIndex offset) { -+ -+ ref_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ origin_ += coord_offset; -+ -+ ref_.add_coord_offset(coord_offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator++() { -+ -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, 1}); -+ } -+ else { -+ add_tile_offset({1, 0}); -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator--() { -+ -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, -1}); -+ } -+ else { -+ add_tile_offset({-1, 0}); -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ AccessType const *access_ptr = reinterpret_cast(ref_.data()); -+ int ldm = ref_.stride()[0]; -+ -+ if (kOperand == Operand::kA) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < FragmentCount::kRow; ++idx) { -+ -+ int tile_idx = idx / 2; -+ int quad_idx = idx % 2; -+ -+ int row_offset = tile_idx * kInterleavedTileRows + quad_idx * 4; -+ frag_ptr[idx] = access_ptr[row_offset * ldm / kElementsPerAccess]; -+ } -+ } -+ else { -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < FragmentCount::kColumn; ++idx) { -+ -+ int tile_idx = idx / 2; -+ int quad_idx = idx % 2; -+ -+ int col_offset = tile_idx * kInterleavedTileColumns + quad_idx * 4; -+ frag_ptr[idx] = access_ptr[col_offset * ldm / kElementsPerAccess]; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ -+ load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits::value); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits::value); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation -+ } -+}; -+ -+ -+/// Tile iterator specialized for 'NT' arrangement -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand_, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of matrix operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads = 32, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1> -+class MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ /// Basic check -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaVoltaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Number of elements accessed per Shared Memory load -+ static int const kElementsPerAccess = 4; -+ -+private: -+ -+ static int const kInterleavedTileRows = 32; -+ static int const kInterleavedTileColumns = 32; -+ static int const kInstructionsPerTile = 2; -+ -+ /// Rounded up instruction counts -+ using TileCount = MatrixShape< -+ Shape::kRow / kInterleavedTileRows, -+ Shape::kColumn / kInterleavedTileColumns -+ >; -+ -+ using FragmentCount = MatrixShape< -+ TileCount::kRow * kInstructionsPerTile, -+ TileCount::kColumn * kInstructionsPerTile -+ >; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array< -+ Element, -+ (kOperand == Operand::kA ? FragmentCount::kRow : FragmentCount::kColumn) * kElementsPerAccess -+ >; -+ -+ /// Memory access type -+ using AccessType = AlignedArray; -+ -+private: -+ -+ /// Underlying tensor reference -+ TensorRef ref_; -+ -+ /// Extent of tensor -+ MatrixCoord extent_; -+ -+ /// Origin -+ MatrixCoord origin_; -+ -+ /// Used to conditionally enable extents checking -+ bool divisible_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter(): divisible_(true) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) { -+ -+ int quad_id = lane_id / 4; -+ int lane_in_quad = (lane_id % 4); -+ -+ if (kOperand == Operand::kA) { -+ -+ int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile; -+ int col_idx = lane_in_quad; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ else { -+ -+ int row_idx = lane_in_quad; -+ int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ -+ ref_.add_coord_offset(origin_); -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter( -+ TensorRef const &ref, -+ TensorCoord extent, -+ int lane_id -+ ): ref_(ref), extent_(extent), divisible_(false) { -+ -+ int quad_id = lane_id / 4; -+ int lane_in_quad = (lane_id % 4); -+ -+ if (kOperand == Operand::kA) { -+ -+ int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile; -+ int col_idx = lane_in_quad; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ else { -+ -+ int row_idx = lane_in_quad; -+ int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ -+ #if defined(__CUDA_ARCH__) -+ __syncthreads(); -+ #endif -+ -+ ref_.add_coord_offset(origin_); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter &add_pointer_offset(LongIndex offset) { -+ -+ ref_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ origin_ += coord_offset; -+ -+ ref_.add_coord_offset(coord_offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator++() { -+ -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, 1}); -+ } -+ else { -+ add_tile_offset({1, 0}); -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator--() { -+ -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, -1}); -+ } -+ else { -+ add_tile_offset({-1, 0}); -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ AccessType const *access_ptr = reinterpret_cast(ref_.data()); -+ int ldm = ref_.stride()[0]; -+ -+ if (kOperand == Operand::kA) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < FragmentCount::kRow; ++idx) { -+ -+ int tile_idx = idx / 2; -+ int quad_idx = idx % 2; -+ -+ int row_offset = tile_idx * kInterleavedTileRows; -+ frag_ptr[idx] = access_ptr[row_offset / kElementsPerAccess + quad_idx]; -+ } -+ } -+ else { -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < FragmentCount::kColumn; ++idx) { -+ -+ int tile_idx = idx / 2; -+ int quad_idx = idx % 2; -+ -+ int col_offset = tile_idx * kInterleavedTileColumns; -+ frag_ptr[idx] = access_ptr[col_offset / kElementsPerAccess + quad_idx]; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ -+ load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits::value); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits::value); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, -+ Operand::kA, -+ Element_, -+ cutlass::layout::RowMajor, -+ InstructionShape_, -+ OpDelta_, -+ 32 -+> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner< -+ Shape_, Operand::kA, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { -+ -+public: -+ using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner< -+ Shape_, Operand::kA, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> ; -+ -+ using TensorRef = typename Base::TensorRef; -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): Base(ref, lane_id) { } -+ -+}; -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, -+ Operand::kA, -+ Element_, -+ cutlass::layout::ColumnMajor, -+ InstructionShape_, -+ OpDelta_, -+ 32 -+> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter< -+ Shape_, Operand::kA, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_> { -+ -+public: -+ using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter< -+ Shape_, Operand::kA, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_> ; -+ -+ using TensorRef = typename Base::TensorRef; -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): Base(ref, lane_id) { } -+ -+}; -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand::kB, Element_, -+ cutlass::layout::ColumnMajor, -+ InstructionShape_, OpDelta_, 32 -+> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner< -+ Shape_, Operand::kB, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_> { -+ -+public: -+ using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner< -+ Shape_, Operand::kB, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_>; -+ -+ using TensorRef = typename Base::TensorRef; -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): Base(ref, lane_id) { } -+}; -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand::kB, Element_, -+ cutlass::layout::RowMajor, -+ InstructionShape_, OpDelta_, 32 -+> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter< -+ Shape_, Operand::kB, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { -+ -+public: -+ using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter< -+ Shape_, Operand::kB, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_>; -+ -+ using TensorRef = typename Base::TensorRef; -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): Base(ref, lane_id) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h -new file mode 100644 -index 0000000..29cc3d9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h -@@ -0,0 +1,2452 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for loading 128b vectors of 64b elements. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicandCongruous64b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ static_assert(!(Shape::kContiguous % 16) && !(Shape::kStrided % 4), "Divisibility."); -+ -+ static_assert(sizeof_bits::value == 64, "This is specialized for 64b accesses."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCongruous64b; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Load two elements per access -+ static int const kElementsPerAccess = 2; -+ -+ /// Policy defining internal details of tile iterator -+ struct Policy { -+ -+ /// Shape of one access -+ using Delta = layout::PitchLinearShape<8, 4>; -+ -+ /// Number of iterations to load -+ using Iterations = layout::PitchLinearShape< -+ Shape::kContiguous / kElementsPerAccess / Delta::kContiguous, -+ InstructionShape::kStrided / Delta::kStrided -+ >; -+ -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+ /// Internal counter used to jump to next K partition -+ int k_group_idx_; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), -+ k_group_idx_(0) { -+ -+ int access_strided = lane_id / Policy::Delta::kContiguous; -+ int access_contiguous = (lane_id % Policy::Delta::kContiguous) ^ access_strided; -+ -+ pointer_= reinterpret_cast(ref.data()) + -+ access_contiguous + access_strided * stride_; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int offset = -+ (tile_offset.strided() * InstructionShape::kStrided) * stride_ * kElementsPerAccess + -+ tile_offset.contiguous() * Shape::kContiguous; -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ add_tile_offset({0, 1}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ add_tile_offset({0, -1}); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::Iterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::Iterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::Delta::kContiguous * c + -+ Policy::Delta::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ AccessType const *source = reinterpret_cast(source_byte_ptr); -+ -+ fetch_ptr[access_idx] = *source; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCongruous64b, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.strided(), tile_offset.contiguous()}, -+ byte_offset); -+ } -+ -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCongruous64b, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for loading 128b vectors of 64b elements. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicand64bCrosswise, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 16), "Divisibility."); -+ -+ static_assert(sizeof_bits::value == 64, "This is specialized for 64b accesses."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicand64bCrosswise; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Load two elements per access -+ static int const kElementsPerAccess = 2; -+ -+ /// Policy defining internal details of tile iterator -+ struct Policy { -+ -+ /// Shape of one access -+ using Delta = layout::PitchLinearShape<4, 16>; -+ -+ /// Number of iterations to load -+ using Iterations = layout::PitchLinearShape< -+ InstructionShape::kContiguous / Delta::kContiguous, -+ Shape::kStrided / Delta::kStrided -+ >; -+ -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ /// Internal counter for tracking K-group -+ Index k_group_idx_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), -+ k_group_idx_(0) { -+ -+ int access_strided = lane_id / 8; -+ int access_contiguous = (lane_id % 8); -+ -+ byte_offset_ = (access_contiguous + access_strided * stride_) * sizeof(AccessType); -+ -+ pointer_= reinterpret_cast(ref.data()); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ pointer_ += offset / kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ int offset = (tile_offset.contiguous() * InstructionShape::kContiguous) * -+ stride_ * kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided; -+ -+ add_pointer_offset(offset); -+ -+ int old_k_group_idx = k_group_idx_; -+ -+ k_group_idx_ += tile_offset.contiguous(); -+ -+ if ((k_group_idx_ & 2) ^ (old_k_group_idx & 2)) { -+ byte_offset_ ^= 0x40; -+ } -+ -+ return *this; -+ } -+ -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { -+ -+ add_tile_offset(tile_offset); // TODO fix this if it becomes an issue during warp it reset -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ pointer_ += stride_ * InstructionShape::kContiguous; -+ -+ if (k_group_idx_ & 0x1) { -+ // xor ptr -+ byte_offset_ ^= 0x40; -+ } -+ -+ ++k_group_idx_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::Iterations::kStrided; ++s) { -+ -+ int access_idx = c + s * Policy::Iterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::Delta::kContiguous * c * stride_ + -+ Policy::Delta::kStrided * s / kElementsPerAccess; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ AccessType const *source = reinterpret_cast(source_byte_ptr); -+ -+ fetch_ptr[access_idx] = *source; -+ } -+ } -+ -+ Element *exchange_ptr = reinterpret_cast(&frag); -+ -+ if (k_group_idx_ & 1) { -+ // exchange on 64b granularity -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Fragment::kElements; i += 2) { -+ Element tmp = exchange_ptr[i]; -+ exchange_ptr[i] = exchange_ptr[i + 1]; -+ exchange_ptr[i + 1] = tmp; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = tile_offset.contiguous() * -+ InstructionShape::kContiguous / -+ Layout::kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ k_group_idx_ = k_group; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicand64bCrosswise, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset_negative({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.strided(), tile_offset.contiguous()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicand64bCrosswise, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset_negative({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Tile iterator specialized for canonical matrix layouts -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand_, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads = 32, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1> -+class MmaTensorOpMultiplicandTileIteratorCanonical { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ /// Basic check -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Number of elements accessed per Shared Memory load -+ static int const kElementsPerAccess = -+ (sizeof_bits::value >= 32 ? 1 : 32 / sizeof_bits::value); -+ -+private: -+ -+ static int const kWarpShapeOuter = -+ (kOperand == Operand::kA ? Shape::kRow : Shape::kColumn); -+ -+ static int const kWarpShapeInner = -+ (kOperand == Operand::kA ? Shape::kColumn : Shape::kRow); -+ -+ -+ /// Rounded up instruction counts -+ using InstructionCount = MatrixShape< -+ Shape::kRow / InstructionShape::kRow, -+ Shape::kColumn / InstructionShape::kColumn -+ >; -+ -+ /// Rounded up tile dimensions -+ using WarpShapeDivisible = MatrixShape< -+ InstructionCount::kRow * InstructionShape::kRow, -+ InstructionCount::kColumn * InstructionShape::kColumn -+ >; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array< -+ Element, -+ WarpShapeDivisible::kRow * WarpShapeDivisible::kColumn / kThreads -+ >; -+ -+ /// Memory access type -+ using AccessType = AlignedArray; -+ -+private: -+ -+ /// Underlying tensor reference -+ TensorRef ref_; -+ -+ /// Extent of tensor -+ MatrixCoord extent_; -+ -+ /// Origin -+ MatrixCoord origin_; -+ -+ /// Used to conditionally enable extents checking -+ bool divisible_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical(): divisible_(true) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical( -+ TensorRef const &ref, -+ int lane_id -+ ): ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) { -+ -+ if (kOperand == Operand::kA) { -+ origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess); -+ } -+ else { -+ origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4); -+ } -+ -+ ref_.add_coord_offset(origin_); -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical( -+ TensorRef const &ref, -+ TensorCoord extent, -+ int lane_id -+ ): ref_(ref), extent_(extent), divisible_(false) { -+ -+ if (kOperand == Operand::kA) { -+ origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess); -+ } -+ else { -+ origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4); -+ } -+ -+ ref_.add_coord_offset(origin_); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical &add_pointer_offset(LongIndex offset) { -+ -+ ref_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ origin_ += coord_offset; -+ -+ ref_.add_coord_offset(coord_offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical & operator++() { -+ -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, 1}); -+ } -+ else { -+ add_tile_offset({1, 0}); -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical & operator--() { -+ -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, -1}); -+ } -+ else { -+ add_tile_offset({-1, 0}); -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ int const kWarpShapeDivisibleInner = -+ (kOperand == Operand::kA ? WarpShapeDivisible::kColumn : WarpShapeDivisible::kRow); -+ -+ // Take advantage of Tensor Op's 8 x 4T access pattern -+ int const kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; -+ -+ AccessType *access_ptr = reinterpret_cast(&frag); -+ -+ if (kOperand == Operand::kA) { -+ int const kTilesPerInstruction = InstructionShape::kRow / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; ++access_m_idx) { -+ int access_idx = -+ access_m_idx + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx); -+ -+ MatrixCoord offset( -+ access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, -+ inner_idx * 4 * kElementsPerAccess); -+ -+ MatrixCoord access_coord = origin_ + offset; -+ -+ if (divisible_ || -+ (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { -+ -+ access_ptr[access_idx] = *reinterpret_cast( -+ ref_.data() + ref_.offset(offset)); -+ } -+ else { -+ AccessType zero; -+ zero.clear(); -+ access_ptr[access_idx] = zero; -+ } -+ } -+ } -+ } -+ } -+ else { -+ CUTLASS_PRAGMA_UNROLL -+ for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { -+ int access_idx = inner_idx + kAccessesInner * inst_n_idx; -+ -+ MatrixCoord offset( -+ inner_idx * 4 * kElementsPerAccess, -+ inst_n_idx * 8); -+ -+ MatrixCoord access_coord = origin_ + offset; -+ -+ if (divisible_ || -+ (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { -+ -+ access_ptr[access_idx] = *reinterpret_cast( -+ ref_.data() + ref_.offset(offset)); -+ } -+ else { -+ AccessType zero; -+ zero.clear(); -+ access_ptr[access_idx] = zero; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ -+ load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits::value); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits::value); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation -+ } -+}; -+ -+/// Wrapper for ColumnMajor -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajor, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIteratorCanonical< -+ Shape, kOperand, Element, -+ layout::ColumnMajor, -+ InstructionShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ TensorCoord const & extent, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, extent, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+ -+/// Wrapper for RowMajor -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajor, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIteratorCanonical< -+ Shape, kOperand, Element, -+ layout::RowMajor, -+ InstructionShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ TensorCoord const &extent, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, extent, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h -new file mode 100644 -index 0000000..f7370a6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h -@@ -0,0 +1,380 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators to load sparse meta data used by warp-level matrix multiply operations -+ targeting Sparse Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1> -+class SparseMmaTensorOpMetaTileIterator { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ static int const kSparse = 2; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kColumn % InstructionShape::kColumn), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ static int const kElementsPerAccess = 128 / sizeof_bits::value; -+ -+ // Determine number of elements along outer dimension per individual LDSM op -+ static int const kLdsmOpOuter = InstructionShape::kColumn; -+ static int const kLdsmOpInner = 8 * kElementsPerAccess / kLdsmOpOuter; -+ -+ static_assert(!(Shape::kColumn % kLdsmOpOuter), -+ "Shape of warp-level mma must be divisible by LDSM's " -+ "fundamental tile size."); -+ -+ static_assert(!(Shape::kRow % kLdsmOpInner), -+ "Shape of warp-level mma must be divisible by LDSM's " -+ "fundamental tile size."); -+ -+ /// Shape of one individual LDSM instruction -+ static int const LdsmShapeColumn = -+ InstructionShape::kColumn / kLdsmOpOuter; -+ static int const LdsmShapeRow = -+ ((4 / LdsmShapeColumn * kLdsmOpInner) > Shape::kRow) -+ ? (Shape::kRow / kLdsmOpInner) -+ : (4 / LdsmShapeColumn); -+ using LdsmShape = -+ layout::PitchLinearShape; -+ -+ /// Number and arrangement of LDSM instructions -+ using LdsmIterations = layout::PitchLinearShape< -+ Shape::kRow / kLdsmOpInner / LdsmShapeRow, -+ 1>; -+ -+ /// Number of groups for each tile -+ static int const kGroupsPerTile = -+ Shape::kColumn / InstructionShape::kColumn; -+ }; -+ -+ private: -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = Array; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+ private: -+ -+ /// Layout object storing stride values -+ Index stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ /// Internal counter used to determine when to increment byte offset and when -+ /// to XOR it -+ int k_group_idx_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ SparseMmaTensorOpMetaTileIterator() -+ : pointer_(nullptr), -+ stride_(0), -+ byte_offset_(0), -+ k_group_idx_(0) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ SparseMmaTensorOpMetaTileIterator(TensorRef const &ref, int lane_id) -+ : pointer_(reinterpret_cast(ref.data())), -+ stride_(ref.stride(0) / Policy::kElementsPerAccess), -+ byte_offset_(0), -+ k_group_idx_(0) { -+ -+ int access_contiguous = (lane_id % (Shape::kRow / Policy::kElementsPerAccess)); -+ int access_strided = (lane_id / (Shape::kRow / Policy::kElementsPerAccess)); -+ -+ byte_offset_ = (access_contiguous + access_strided * stride_) * -+ sizeof_bits::value * Policy::kElementsPerAccess / 8; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ SparseMmaTensorOpMetaTileIterator &add_pointer_offset(LongIndex offset) { -+ byte_offset_ += offset * sizeof_bits::value / 8; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ SparseMmaTensorOpMetaTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ int offset = tile_offset.row() * Shape::kRow + -+ tile_offset.column() * InstructionShape::kColumn * stride_ * -+ Policy::kElementsPerAccess; -+ -+ add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ SparseMmaTensorOpMetaTileIterator &operator++() { -+ add_tile_offset({0, 1}); -+ -+ if (kPartitionsK > 1) { -+ ++k_group_idx_; -+ // Jump to next stage -+ if (k_group_idx_ == Policy::kGroupsPerTile) { -+ k_group_idx_ = 0; -+ add_tile_offset( -+ {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ SparseMmaTensorOpMetaTileIterator &operator--(){ -+ byte_offset_ -= stride_ * InstructionShape::kColumn * -+ sizeof_bits::value * Policy::kElementsPerAccess / -+ 8; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE SparseMmaTensorOpMetaTileIterator & -+ operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ SparseMmaTensorOpMetaTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ Array *fetch_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::LdsmIterations::kContiguous; -+ -+ AccessType const *source_ptr = -+ pointer_ + -+ Policy::LdsmShape::kContiguous * Policy::kLdsmOpInner * c + -+ Policy::LdsmShape::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + -+ byte_offset + byte_offset_; -+ -+ cutlass::arch::ldsm( -+ fetch_ptr[access_idx], source_byte_ptr); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kRow / Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kColumn * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no op -+ } -+}; -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h -new file mode 100644 -index 0000000..d841d2b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h -@@ -0,0 +1,805 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#include "cutlass/wmma_array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+template < -+ ///< Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity (A or B) -+ Operand Operand, -+ /// Data type of operand -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Delta between *MMA operations (in units of *WMMA operations, concept:MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads, -+ /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) -+ typename Policy_> -+class MmaTensorOpWmmaMultiplicandTileIterator; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// This tile iterator is specialized for 32-thread WMMA operation. -+/// It uses nvcuda::wmma::load_matrix_sync to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+//////////////////////////////////////////////////////////////////////////////// -+template < -+ ///< Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Interval between adjacent *WMMA instructions (in units of WMMA instructions) -+ int OpDelta_, -+ /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) -+ typename Policy_> -+class MmaTensorOpWmmaMultiplicandTileIterator< -+ Shape_, Operand::kA, Element_, Layout_, -+ OpDelta_, 32, Policy_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Delta between *WMMA operations -+ static int const kOpDelta = OpDelta_; -+ -+ /// Wmma Operator information and operation delta -+ using Policy = Policy_; -+ -+ -+ // -+ // Derived quantities -+ // -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Stride Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Native Wmma shape for operand A (concept MatrixShape) -+ using WmmaShape = MatrixShape< -+ Policy::Operator::Shape::kM, -+ Policy::Operator::Shape::kK -+ >; -+ -+ /// Map cutlass dataype to nvcuda::wmma datatype -+ using WmmaDataType = typename cutlass::arch::CutlassToWmmaDataType::Type; -+ -+ /// Shape of individual WMMA load / stores for operand A -+ using Iterations = MatrixShape< -+ Shape::kRow / WmmaShape::kRow, -+ 1 -+ >; -+ -+ /// Fragment object holding a warps part -+ using Fragment = WmmaFragmentArray; -+ -+ -+ ////////////////////////////////////////////////////////////////////////////////////////////////////// -+ /// statically assert this specialization -+ ///////////////////////////////////////////////////////////////////////////////////////////////////// -+ /// This iterator is specalized for Operand A -+ static_assert(kOperand == Operand::kA, -+ "MmaTensorOpWmmaMultiplicandTileIterator may only be instantiated for A operands to warp-level Mma."); -+ -+ /// Supported memory layouts -+ static_assert( -+ platform::is_same::value || -+ platform::is_same::value, -+ "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor"); -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ ///////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+private: -+ -+ /// Shared memory base pointers - not advanced -+ char const *pointer_; -+ -+ /// Byte offset into shared memory - advanced -+ Index byte_offset_; -+ -+ /// Stride in units of number of elements -+ StrideIndex stride_; -+ -+ /// Layout of shared memory -+ Layout layout_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): pointer_(reinterpret_cast(ref.data())), byte_offset_(0), stride_(ref.stride(0)), layout_(ref.stride(0)) { -+ -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ byte_offset_ += (offset * sizeof_bits::value) / 8; -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ Index elements_offset = layout_({tile_offset.row() * Shape::kRow, tile_offset.column() * WmmaShape::kColumn}); -+ -+ byte_offset_ += (elements_offset * sizeof_bits::value) / 8; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator++() { -+ -+ Index elements_offset = layout_({0, WmmaShape::kColumn}); -+ -+ byte_offset_ += (elements_offset * sizeof_bits::value) / 8; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator--() { -+ -+ Index elements_offset = layout_({0, WmmaShape::kColumn}); -+ -+ byte_offset_ -= (elements_offset * sizeof_bits::value) / 8; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load_with_byte_offset(Fragment &frag, Index byte_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kColumn; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ -+ Index load_byte_offset = layout_({m * WmmaShape::kRow, k * WmmaShape::kColumn}) * sizeof_bits::value / 8; -+ -+ const WmmaDataType *ptr = reinterpret_cast(pointer_ + byte_offset_ + load_byte_offset + byte_offset); -+ -+ nvcuda::wmma::load_matrix_sync(frag[m], ptr, stride_); -+ -+ } -+ } -+ } -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_byte_offset(Fragment const &frag, Index byte_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kColumn; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ -+ Index store_byte_offset = layout_({m * WmmaShape::kRow, k * WmmaShape::kColumn}) * sizeof_bits::value / 8; -+ -+ WmmaDataType *ptr = reinterpret_cast(pointer_ + byte_offset_ + store_byte_offset + byte_offset); -+ -+ nvcuda::wmma::store_matrix_sync(ptr, frag[m], stride_); -+ -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_byte_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// This tile iterator is specialized for 32-thread WMMA operation. -+/// It uses nvcuda::wmma::load_matrix_sync to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ ///< Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Interval between adjacent *WMMA instructions (in units of WMMA instructions) -+ int OpDelta_, -+ /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) -+ typename Policy_> -+class MmaTensorOpWmmaMultiplicandTileIterator< -+ Shape_, Operand::kB, Element_, Layout_, -+ OpDelta_, 32, Policy_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kB; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Delta between *WMMA operations -+ static int const kOpDelta = OpDelta_; -+ -+ /// Wmma Operator information and operation delta -+ using Policy = Policy_; -+ -+ -+ // -+ // Derived quantities -+ // -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Stride Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Native Wmma shape (concept MatrixShape) -+ using WmmaShape = MatrixShape< -+ Policy::Operator::Shape::kK, -+ Policy::Operator::Shape::kN -+ >; -+ -+ /// Map cutlass dataype to nvcuda::wmma datatype -+ using WmmaDataType = typename cutlass::arch::CutlassToWmmaDataType::Type; -+ -+ /// Shape of individual WMMA load / stores for operand B -+ using Iterations = MatrixShape< -+ 1, -+ Shape::kColumn / WmmaShape::kColumn -+ >; -+ -+ /// Fragment object holding a warps part -+ using Fragment = WmmaFragmentArray; -+ -+ -+ ////////////////////////////////////////////////////////////////////////////////////////////////////// -+ /// statically asserts this specialization -+ ///////////////////////////////////////////////////////////////////////////////////////////////////// -+ /// This iterator is specalized for Operand B -+ static_assert(kOperand == Operand::kB, -+ "MmaTensorOpWmmaMultiplicandTileIterator may only be instantiated for B operands to warp-level Mma."); -+ -+ /// Supported memory layouts -+ static_assert( -+ platform::is_same::value || -+ platform::is_same::value, -+ "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor"); -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ ///////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+private: -+ -+ /// Shared memory base pointers - not advanced -+ char const *pointer_; -+ -+ /// Byte offset into shared memory - advanced -+ Index byte_offset_; -+ -+ /// Stride in units of number of elements -+ StrideIndex stride_; -+ -+ /// Layout of shared memory -+ Layout layout_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): pointer_(reinterpret_cast(ref.data())), byte_offset_(0), stride_(ref.stride(0)), layout_(ref.stride(0)) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ byte_offset_ += (offset * sizeof_bits::value) / 8; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ Index elements_offset = layout_({tile_offset.row() * WmmaShape::kRow, tile_offset.column() * Shape::kColumn}); -+ -+ byte_offset_ += (elements_offset * sizeof_bits::value) / 8; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator++() { -+ -+ Index elements_offset = layout_({WmmaShape::kRow, 0}); -+ -+ byte_offset_ += (elements_offset * sizeof_bits::value) / 8; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator--() { -+ -+ Index elements_offset = layout_({WmmaShape::kRow, 0}); -+ -+ byte_offset_ -= (elements_offset * sizeof_bits::value) / 8; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load_with_byte_offset(Fragment &frag, Index byte_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kRow; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ -+ Index load_byte_offset = layout_({k * WmmaShape::kRow, n * WmmaShape::kColumn}) * sizeof_bits::value / 8; -+ -+ const WmmaDataType *ptr = reinterpret_cast(pointer_ + byte_offset_ + load_byte_offset + byte_offset); -+ -+ nvcuda::wmma::load_matrix_sync(frag[n], ptr, stride_); -+ } -+ } -+ } -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_byte_offset(Fragment const &frag, Index byte_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kRow; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ -+ Index store_byte_offset = layout_({k * WmmaShape::kRow, n * WmmaShape::kColumn}) * sizeof_bits::value / 8; -+ -+ WmmaDataType *ptr = reinterpret_cast(pointer_ + byte_offset_ + store_byte_offset + byte_offset); -+ -+ nvcuda::wmma::store_matrix_sync(ptr, frag[n], stride_); -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_byte_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+template < -+ ///< Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Interval between adjacent *WMMA instructions (in units of WMMA instructions, concept: MatrixShape) -+ typename OpDelta_, -+ /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) -+ typename Policy_> -+class MmaTensorOpWmmaAccumulatorTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// This tile iterator is specialized for 32-thread WMMA operation. -+/// It uses nvcuda::wmma::store_matrix_sync to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ ///< Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Interval between adjacent *WMMA instructions (in units of WMMA instructions) -+ typename OpDelta_, -+ /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) -+ typename Policy_> -+class MmaTensorOpWmmaAccumulatorTileIterator -+{ -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Wmma Operator information and operation delta -+ using Policy = Policy_; -+ -+ -+ // -+ // Derived quantities -+ // -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Native Wmma shape (concept MatrixShape) -+ using WmmaShape = MatrixShape< -+ Policy::Operator::Shape::kM, -+ Policy::Operator::Shape::kN -+ >; -+ -+ /// Map cutlass dataype to nvcuda::wmma datatype -+ using WmmaDataType = typename cutlass::arch::CutlassToWmmaDataType::Type; -+ -+ /// Map cutlass::layout to nvuda::wmma::layout_t enum -+ static nvcuda::wmma::layout_t const WmmaLayout = cutlass::arch::CutlassToWmmaLayout::value; -+ -+ /// Shape of individual WMMA load / stores for accumulator -+ using Iterations = MatrixShape< -+ Shape::kRow / WmmaShape::kRow, -+ Shape::kColumn / WmmaShape::kColumn -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = WmmaFragmentArray; -+ -+ ////////////////////////////////////////////////////////////////////////////////////////////////////// -+ /// statically asserts this specialization -+ ///////////////////////////////////////////////////////////////////////////////////////////////////// -+ /// Supported layouts -+ static_assert( -+ platform::is_same::value || -+ platform::is_same::value, -+ "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor"); -+ -+private: -+ -+ /// Internal reference -+ cutlass::TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): ref_(ref) { } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ ref_.add_coord_offset({tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn}); -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator & operator++() { -+ ref_.add_coord_offset({Shape::kRow, 0}); -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator & operator--() { -+ ref_.add_coord_offset({-Shape::kRow, 0}); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ -+ const WmmaDataType * ptr = reinterpret_cast (ref_.data() + ref_.offset({m * WmmaShape::kRow, n * WmmaShape::kColumn}) + pointer_offset); -+ -+ nvcuda::wmma::load_matrix_sync(frag[m * Iterations::kColumn + n], ptr, ref_.stride()[0], WmmaLayout); -+ -+ } -+ } -+ } -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ -+ WmmaDataType * ptr = reinterpret_cast (ref_.data() + ref_.offset({m * WmmaShape::kRow, n * WmmaShape::kColumn}) + pointer_offset); -+ -+ nvcuda::wmma::store_matrix_sync(ptr, frag[m * Iterations::kColumn + n], ref_.stride()[0], WmmaLayout); -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+ -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h -new file mode 100644 -index 0000000..c3954f3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#include "cutlass/wmma_array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///< Structure to compute the matrix product targeting CUDA cores via WMMA. -+template < -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ ///< Data type of A elements -+ typename ElementA_, -+ ///< Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ ///< Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ ///< Element type of C matrix -+ typename ElementC_, -+ ///< Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ ///< Policy describing warp-level Wmma operation (concept: MmaTensorOpPolicy) -+ typename Policy_, -+ ///< Number of partitions along K dimension -+ int PartitionsK_ = 1, -+ ///< Used for partial specialization -+ typename Enable = bool -+> -+class MmaTensorOpWmma { -+public: -+ ///< Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ ///< Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ ///< Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ ///< Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ ///< Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ ///< Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ ///< Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) -+ using Policy = Policy_; -+ -+ /// Underlying instruction shape -+ using InstructionShape = typename Policy::Operator::Shape; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Underlying architecture tag -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassWmmaTensorOp; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpWmmaMultiplicandTileIterator< -+ MatrixShape, Operand::kA, ElementA, LayoutA, -+ Policy::OpDelta::kRow, kThreadCount, Policy>; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpWmmaMultiplicandTileIterator< -+ MatrixShape, Operand::kB, ElementB, LayoutB, -+ Policy::OpDelta::kRow, kThreadCount, Policy>; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpWmmaAccumulatorTileIterator< -+ MatrixShape, ElementC, LayoutC, -+ typename Policy::OpDelta, Policy>; -+ -+ /// Storage for C tile -+ using FragmentC = typename IteratorC::Fragment; -+ -+private: -+ -+ static_assert( -+ !(Shape::kM % Policy::Operator::Shape::kM) && -+ !(Shape::kN % Policy::Operator::Shape::kN), -+ "Shape of warp-level Wmma must be divisible by operator shape (wmma native size)"); -+ -+ /// Number of wmma operations performed -+ using WmmaIterations = MatrixShape< -+ Shape::kM / Policy::Operator::Shape::kM, -+ Shape::kN / Policy::Operator::Shape::kN -+ >; -+ -+public: -+ -+ /// Underlying matrix multiply operator (concept: cutlass::arch::Wmma) -+ typename Policy::Operator wmma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaTensorOpWmma() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < WmmaIterations::kColumn; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < WmmaIterations::kRow; ++m) { -+ -+ // accumulate wmma mma -+ wmma(D[m * WmmaIterations::kColumn + n], A[m], B[n], C[m * WmmaIterations::kColumn + n]); -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h -new file mode 100644 -index 0000000..9957967 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h -@@ -0,0 +1,449 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Reduce operand A or B along K dimension -+ bool ReduceKForA_, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaWithReductionTensorOp { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Architecture tag from underlying instruction -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ static bool const kReduceKForA = ReduceKForA_; -+ -+ static_assert(platform::is_same::value || -+ platform::is_same::value, -+ "ElementA needs to be fp16 or bf16."); -+ -+ static_assert(platform::is_same::value || -+ platform::is_same::value, -+ "ElementB needs to be fp16 or bf16."); -+ -+ static_assert(platform::is_same>::value, -+ "Only supports 16x8x16 tensor core instruction."); -+ -+ static_assert(!AccumulatorsInRowMajor, -+ "Only calls tensor core instructions in column major."); -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, Operand::kA, ElementA, LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = -+ Array; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, Operand::kB, ElementB, LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = -+ Array; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, ElementC, LayoutC, -+ typename ArchMmaOperator::Shape, typename Policy::OpDelta>; -+ -+ /// Storage for C tile -+ using FragmentC = typename IteratorC::Fragment; -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, -+ (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN -+ >; -+ -+ using FragmentReduction = Array; -+ -+public: -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaWithReductionTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ TransformedFragmentA const &A, -+ TransformedFragmentB const &B, -+ FragmentC const &C, -+ FragmentReduction &gemm_k_reduction -+ ) const { -+ -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ D = C; -+ -+ [[maybe_unused]] MmaOperandA const *ptr_A = reinterpret_cast(&A); -+ [[maybe_unused]] MmaOperandB const *ptr_B = reinterpret_cast(&B); -+ [[maybe_unused]] MmaOperandC *ptr_D = reinterpret_cast(&D); -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -+ assert(0); -+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ // Serpentine visitation order maximizing reuse of Ra -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); -+ -+ mma(ptr_D[m + n_serpentine * MmaIterations::kRow], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[m + n_serpentine * MmaIterations::kRow]); -+ -+ if (!kReduceKForA && m == 0) { -+ #if 0 -+ gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4]); -+ gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 1]); -+ gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 2]); -+ gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 3]); -+ #else -+ uint32_t const *tmp = reinterpret_cast(&B); -+ -+ if (platform::is_same::value) { -+ asm volatile( -+ "{\n\t" -+ " .reg .f16 low, high;\n\t" -+ " .reg .f32 tmp;\n\t" -+ " mov.b32 {low, high}, %1;\n\t" -+ " cvt.f32.f16 tmp, low;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " cvt.f32.f16 tmp, high;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " mov.b32 {low, high}, %2;\n\t" -+ " cvt.f32.f16 tmp, low;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " cvt.f32.f16 tmp, high;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ "}\n\t" -+ : "+f"(gemm_k_reduction[n_serpentine]) -+ : "r"(tmp[n_serpentine * 2]), "r"(tmp[n_serpentine * 2 + 1])); -+ } else if (platform::is_same::value) { -+ asm volatile( -+ "{\n\t" -+ " .reg .f32 tmp;\n\t" -+ " shl.b32 tmp, %1, 16;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " and.b32 tmp, %1, 0xffff0000;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " shl.b32 tmp, %2, 16;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " and.b32 tmp, %2, 0xffff0000;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ "}\n\t" -+ : "+f"(gemm_k_reduction[n_serpentine]) -+ : "r"(tmp[n_serpentine * 2]), "r"(tmp[n_serpentine * 2 + 1])); -+ } else { -+ assert(0); -+ } -+ #endif -+ } -+ -+ if (kReduceKForA && (n == 0)) { -+ #if 0 -+ gemm_k_reduction[m * 2] += float(A[m * 8]); -+ gemm_k_reduction[m * 2] += float(A[m * 8 + 1]); -+ gemm_k_reduction[m * 2] += float(A[m * 8 + 4]); -+ gemm_k_reduction[m * 2] += float(A[m * 8 + 5]); -+ -+ gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 2]); -+ gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 3]); -+ gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 6]); -+ gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 7]); -+ #else -+ uint32_t const *tmp = reinterpret_cast(&A); -+ -+ if (platform::is_same::value) { -+ asm volatile( -+ "{\n\t" -+ " .reg .f16 low, high;\n\t" -+ " .reg .f32 tmp;\n\t" -+ " mov.b32 {low, high}, %2;\n\t" -+ " cvt.f32.f16 tmp, low;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " cvt.f32.f16 tmp, high;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " mov.b32 {low, high}, %3;\n\t" -+ " cvt.f32.f16 tmp, low;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ " cvt.f32.f16 tmp, high;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ " mov.b32 {low, high}, %4;\n\t" -+ " cvt.f32.f16 tmp, low;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " cvt.f32.f16 tmp, high;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " mov.b32 {low, high}, %5;\n\t" -+ " cvt.f32.f16 tmp, low;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ " cvt.f32.f16 tmp, high;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ "}\n\t" -+ : "+f"(gemm_k_reduction[m * 2]), "+f"(gemm_k_reduction[m * 2 + 1]) -+ : "r"(tmp[m * 4]), "r"(tmp[m * 4 + 1]),"r"(tmp[m * 4 + 2]), "r"(tmp[m * 4 + 3])); -+ -+ } else if (platform::is_same::value) { -+ -+ asm volatile( -+ "{\n\t" -+ " .reg .f32 tmp;\n\t" -+ " shl.b32 tmp, %2, 16;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " and.b32 tmp, %2, 0xffff0000;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " shl.b32 tmp, %3, 16;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ " and.b32 tmp, %3, 0xffff0000;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ " shl.b32 tmp, %4, 16;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " and.b32 tmp, %4, 0xffff0000;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " shl.b32 tmp, %5, 16;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ " and.b32 tmp, %5, 0xffff0000;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ "}\n\t" -+ : "+f"(gemm_k_reduction[m * 2]), "+f"(gemm_k_reduction[m * 2 + 1]) -+ : "r"(tmp[m * 4]), "r"(tmp[m * 4 + 1]),"r"(tmp[m * 4 + 2]), "r"(tmp[m * 4 + 3])); -+ -+ } else { -+ assert(0); -+ } -+ #endif -+ } -+ } -+ } -+ #else -+ assert(0); -+ #endif -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ -+ // -+ // Define conversions from source type to instruction type -+ // -+ FloatRoundStyle const kRoundA = -+ PreferredRoundingMode::kRound; -+ FloatRoundStyle const kRoundB = -+ PreferredRoundingMode::kRound; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -+ detail::ConvertAndPack -+ convert_A; -+ NumericArrayConverter -+ convert_B; -+ Array const *ptr_B = -+ reinterpret_cast const *>(&B); -+ Array * -+ ptr_dst_B = reinterpret_cast *>(&dst_B); -+ -+ dst_A = convert_A(A); -+ -+ ptr_dst_B[0] = convert_B(ptr_B[0]); -+ ptr_dst_B[1] = convert_B(ptr_B[1]); -+ -+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ detail::ConvertAndPack -+ convert_A; -+ NumericArrayConverter -+ convert_B; -+ Array const *ptr_A = -+ reinterpret_cast const *>(&A); -+ Array * -+ ptr_dst_A = reinterpret_cast *>(&dst_A); -+ -+ dst_B = convert_B(B); -+ -+ ptr_dst_A[0] = convert_A(ptr_A[0]); -+ ptr_dst_A[1] = convert_A(ptr_A[1]); -+ #else -+ assert(0); -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h b/3rdparty/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h -new file mode 100644 -index 0000000..9c9b90b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h -@@ -0,0 +1,574 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Defines iterators used by warp-level loading scale and bias vectors. -+ Every scale/bias data only needs to be loaded once for every channel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Policy of the details of LDSM shape and iterations -+ typename Policy_, -+ /// Number of threads participating in one matrix operation -+ int Threads, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1> -+class ScaleBiasTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Policy of the details of LDSM shape and iterations -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class ScaleBiasTileIterator { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::PitchLinear; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// Number of partitions along K dimension -+ static int const kElementsPerAccess = 128 / sizeof_bits::value; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ using Policy = Policy_; -+ -+ private: -+ -+ /// Pointer type used for accesses -+ using AccessType = Array; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+ private: -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ /// Internal counter used to determine when to increment byte offset and when -+ /// to XOR it -+ int k_group_idx_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator() -+ : pointer_(nullptr), -+ byte_offset_(0), -+ k_group_idx_(0) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator(TensorRef const &ref_scale_bias, -+ int lane_id) -+ : byte_offset_(0), k_group_idx_(0) { -+ /// 16816 only -+ pointer_ = reinterpret_cast(ref_scale_bias.data()) + -+ ((lane_id >> 3) & 1) * Shape::kContiguous / kElementsPerAccess + -+ (lane_id >> 4); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &add_pointer_offset(LongIndex offset) { -+ byte_offset_ += offset * sizeof_bits::value / 8; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ int whole_tiles = tile_offset.contiguous() / Policy::kGroupsPerTile; -+ int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile; -+ -+ byte_offset_ += k_groups_delta * sizeof_bits::value * -+ kElementsPerAccess * Policy::LdsmShape::kContiguous / 8; -+ -+ // Multiply by 2 because scale and bias belonging to the same stage are next -+ // to each other in the shared memory. -+ pointer_ += (2 * whole_tiles * Shape::kContiguous / kElementsPerAccess); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &operator++() { -+ byte_offset_ += Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * kElementsPerAccess / 8; -+ -+ k_group_idx_++; -+ -+ if (k_group_idx_ == (Policy::kGroupsPerTile / kPartitionsK)) { -+ k_group_idx_ = 0; -+ byte_offset_ -= (Policy::kGroupsPerTile / kPartitionsK) * -+ Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * kElementsPerAccess / 8; -+ add_tile_offset({Policy::kGroupsPerTile, 0}); -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator &operator--() { assert(0); } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ Array *fetch_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < 1; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { -+ int access_idx = c + s * Policy::LdsmIterations::kContiguous; -+ -+ AccessType const *source_ptr = -+ pointer_ + Policy::LdsmShape::kContiguous * c; -+ -+ char const *source_byte_ptr = -+ reinterpret_cast(source_ptr) + byte_offset + -+ byte_offset_; -+ -+ cutlass::arch::ldsm( -+ fetch_ptr[access_idx], source_byte_ptr); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = tile_offset.contiguous() * -+ InstructionShape::kContiguous / -+ kElementsPerAccess; -+ -+ byte_offset += sizeof_bits::value * pointer_offset / 8; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ k_group_idx_ = k_group % (Policy::kGroupsPerTile / kPartitionsK); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Policy of the details of LDSM shape and iterations -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class ScaleBiasTileIterator { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ using Policy = Policy_; -+ -+ /// Underlying tile iterator implementation -+ using Base = ScaleBiasTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, -+ layout::PitchLinearShape, -+ Policy, kThreads, PartitionsK_>; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+ private: -+ /// Underlying tile iterator -+ Base iterator_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator() {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator(TensorRef const &ref_scale_bias, int lane_id) -+ : iterator_({ref_scale_bias.data(), ref_scale_bias.stride()}, lane_id) {} -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator &add_pointer_offset(LongIndex offset) { -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &add_tile_offset_negative( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset_negative({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator &operator++() { -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator &operator--() { -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { iterator_.load(frag); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, {tile_offset.strided(), tile_offset.contiguous()}, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h b/3rdparty/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h -new file mode 100644 -index 0000000..bf8efe9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h -@@ -0,0 +1,117 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level per-channel softmax before -+ matrix multiply-accumulate operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct SoftmaxScaleBiasTransform { -+ -+ using T = typename FragmentActivations::Element; -+ -+ static int const NumActivations = FragmentActivations::kElements; -+ static int const NumNormSum = FragmentNormSum::kElements; -+ static int const MmaElements = 2; -+ // One element has one scale and one bias -+ static int const MmaScaleBiasPair = 2; -+ // 16816 has 2 columns and 2 rows -+ static int const MmaCols = 2; -+ static int const MmaRows = 2; -+ -+ using MmaOperand = Array; -+ using NormSumOperand = Array<__half2, MmaScaleBiasPair>; -+ -+ CUTLASS_DEVICE -+ void transform(MmaOperand &activations, -+ NormSumOperand const &norm_sum) { -+ -+ __half2* packed_activations = reinterpret_cast<__half2*>(&activations); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < MmaElements / 2; ++i) { -+ __half2 out = ::h2exp(__hsub2(packed_activations[i], norm_sum[2*i])); -+ packed_activations[i] = __hmul2(out, norm_sum[2*i + 1]); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void operator()(FragmentActivations &activations, -+ FragmentNormSum const &norm_sum) { -+ MmaOperand *ptr_activations = reinterpret_cast(&activations); -+ NormSumOperand const *ptr_norm_sum = -+ reinterpret_cast(&norm_sum); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < (NumActivations / MmaElements); ++i) { -+ transform(ptr_activations[i], -+ ptr_norm_sum[i / (MmaCols * MmaRows) * MmaRows + i % MmaRows]); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h b/3rdparty/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h -new file mode 100644 -index 0000000..1633dd2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h -@@ -0,0 +1,250 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/array_planar_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TileIteratorPlanarComplex { -+public: -+ -+ /// Underlying iterator over real-valued tiles -+ using TileIterator = TileIterator_; -+ -+ /// Underlying element type -+ using Element = typename TileIterator::Element; -+ -+ /// Underlying layout type -+ using Layout = typename TileIterator::Layout; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = typename TileIterator::TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Planar complex fragment -+ using Fragment = ArrayPlanarComplex; -+ -+public: -+ -+ /// Underlying tile iterator -+ TileIterator tile_iterator_; -+ -+ /// Offset (in units of bytes) to the imaginary part of the planar complex matrix -+ LongIndex imaginary_offset_; -+ -+public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ TileIteratorPlanarComplex(): imaginary_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ TileIteratorPlanarComplex( -+ TensorRef const &ref, -+ int lane_id, -+ LongIndex imaginary_offset -+ ): -+ tile_iterator_(ref, lane_id), -+ imaginary_offset_((imaginary_offset * sizeof_bits::value) / 8) { } -+ -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ TileIteratorPlanarComplex &add_pointer_offset(LongIndex offset) { -+ -+ tile_iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ TileIteratorPlanarComplex &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ tile_iterator_.add_tile_offset(tile_offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ TileIteratorPlanarComplex & operator++() { -+ ++tile_iterator_; -+ return *this; -+ } -+ -+ // -+ // WIP -+ // -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ TileIteratorPlanarComplex & operator--() { -+ --tile_iterator_; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ TileIteratorPlanarComplex & operator+=(TensorCoord const &tile_offset) { -+ tile_iterator_.add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ TileIteratorPlanarComplex & operator-=(TensorCoord const &tile_offset) { -+ tile_iterator_.add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ tile_iterator_.load_with_byte_offset(frag.real, 0); -+ tile_iterator_.load_with_byte_offset(frag.imag, imaginary_offset_); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ tile_iterator_.load_with_byte_offset(frag.real, byte_offset); -+ tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ Index byte_offset = (pointer_offset * sizeof_bits::value)/8; -+ -+ tile_iterator_.load_with_byte_offset(frag.real, byte_offset); -+ tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ tile_iterator_.load_with_byte_offset(frag.real, tile_offset, 0); -+ tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, imaginary_offset_); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ Index byte_offset = (pointer_offset * sizeof_bits::value)/8; -+ -+ tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset); -+ tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset + imaginary_offset_); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ -+ tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset); -+ tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, byte_offset + imaginary_offset_); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ tile_iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/half.h b/3rdparty/cutlass/include/cutlass/half.h -new file mode 100644 -index 0000000..8d90b26 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/half.h -@@ -0,0 +1,919 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines a class for using IEEE half-precision floating-point types in host or -+ device code. -+*/ -+#pragma once -+ -+#ifndef CUTLASS_ENABLE_F16C -+#define CUTLASS_ENABLE_F16C 0 -+#endif -+ -+#if defined(__CUDACC_RTC__) -+ -+#include "cutlass/floating_point_nvrtc.h" -+ -+// F16C extensions are not meaningful when compiling for NVRTC which only accommodates device code. -+#undef CUTLASS_ENABLE_F16C -+#define CUTLASS_ENABLE_F16C 0 -+ -+#else -+// -+// Standard Library headers belong here to avoid conflicts with NVRTC. -+// -+#include -+#include -+#include -+#include -+#endif -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/float8.h" -+#include "cutlass/platform/platform.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Optionally target F16C extentions to accelerate half-precision conversion. -+#if !defined(__CUDA_ARCH__) && (CUTLASS_ENABLE_F16C) -+#if defined(_MSC_VER) -+ -+#include -+ -+#if defined(__i386__) || defined(__x86_64__) -+#include -+#endif -+ -+#define F16C_ROUND_NEAREST 0 -+ -+#if !defined(__CUDA_ARCH__) -+extern __inline float _cvtsh_ss (unsigned short __S) { -+ __m128i packed; -+ std::memcpy(&packed, &__S, sizeof(__S)); -+ -+ __m128 result = _mm_cvtph_ps(packed); -+ -+ float flt; -+ std::memcpy(&flt, &result, sizeof(flt)); -+ -+ return flt; -+} -+ -+__inline unsigned short _cvtss_sh (float __F, const int) { -+ __m128 packed; -+ std::memcpy(&packed, &__F, sizeof(__F)); -+ -+ __m128i result = _mm_cvtps_ph(packed, F16C_ROUND_NEAREST); -+ -+ unsigned short u; -+ std::memcpy(&u, &result, sizeof(u)); -+ -+ return u; -+} -+#endif -+ -+#else -+ -+// Linux -+#include -+ -+#if defined(__i386__) || defined(__x86_64__) -+#include -+#endif -+ -+#define F16C_ROUND_NEAREST (_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC) -+ -+#endif // _MSC_VER -+ -+class CpuId { -+ -+ bool f16c_enabled; -+ -+ CpuId() { -+ #if defined(__i386__) || defined(__x86_64__) -+ #if defined(_MSC_VER) -+ int exx[4]; -+ -+ __cpuid (exx, 1); -+ f16c_enabled = exx[2] & 0x20000000; -+ -+ #else -+ // GCC / Clang -+ int eax, ebx, ecx, edx; -+ -+ __cpuid (1 , eax, ebx, ecx, edx); -+ f16c_enabled = ecx & 0x20000000; -+ #endif -+ #else -+ // Arm / PowerPC etc. -+ f16c_enabled = false; -+ #endif -+ } -+ -+public: -+ -+ bool is_f16c_supported() const { -+ return f16c_enabled; -+ } -+ -+ static const CpuId& instance() { -+ static CpuId cpu; -+ return cpu; -+ } -+}; -+#endif // !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// IEEE half-precision floating-point type -+struct alignas(2) half_t { -+ -+ // -+ // Data members -+ // -+ -+ /// Storage type -+ uint16_t storage; -+ -+ // -+ // Static conversion operators -+ // -+ -+ /// Constructs from an unsigned short -+ CUTLASS_HOST_DEVICE -+ static half_t bitcast(uint16_t x) { -+ half_t h; -+ h.storage = x; -+ return h; -+ } -+ -+ /// FP32 -> FP16 conversion - rounds to nearest even -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) -+ // Avoid inlining in device code if no hardware support -+ __device__ __noinline__ -+ #else -+ CUTLASS_HOST_DEVICE -+ #endif -+ static half_t convert(float const& flt) { -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__float2half_rn(flt)); -+ #else -+ -+ #if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C -+ if( CpuId::instance().is_f16c_supported() ) { -+ unsigned short u = _cvtss_sh(flt, F16C_ROUND_NEAREST); -+ return bitcast(u); -+ } -+ #endif -+ -+ // software implementation rounds toward nearest even -+ unsigned s; -+ -+ #if defined(__CUDA_ARCH__) -+ s = reinterpret_cast(flt); -+ #else -+ std::memcpy(&s, &flt, sizeof(s)); -+ #endif -+ -+ uint16_t sign = uint16_t((s >> 16) & 0x8000); -+ int16_t exp = uint16_t(((s >> 23) & 0xff) - 127); -+ int mantissa = s & 0x7fffff; -+ uint16_t u = 0; -+ -+ if ((s & 0x7fffffff) == 0) { -+ // sign-preserving zero -+ return bitcast(sign); -+ } -+ -+ if (exp > 15) { -+ if (exp == 128 && mantissa) { -+ // not a number -+ u = 0x7fff; -+ } else { -+ // overflow to infinity -+ u = sign | 0x7c00; -+ } -+ return bitcast(u); -+ } -+ -+ int sticky_bit = 0; -+ -+ if (exp >= -14) { -+ // normal fp32 to normal fp16 -+ exp = uint16_t(exp + uint16_t(15)); -+ u = uint16_t(((exp & 0x1f) << 10)); -+ u = uint16_t(u | (mantissa >> 13)); -+ } else { -+ // normal single-precision to subnormal half_t-precision representation -+ int rshift = (-14 - exp); -+ if (rshift < 32) { -+ mantissa |= (1 << 23); -+ -+ sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0); -+ -+ mantissa = (mantissa >> rshift); -+ u = (uint16_t(mantissa >> 13) & 0x3ff); -+ } else { -+ mantissa = 0; -+ u = 0; -+ } -+ } -+ -+ // round to nearest even -+ int round_bit = ((mantissa >> 12) & 1); -+ sticky_bit |= ((mantissa & ((1 << 12) - 1)) != 0); -+ -+ if ((round_bit && sticky_bit) || (round_bit && (u & 1))) { -+ u = uint16_t(u + 1); -+ } -+ -+ u |= sign; -+ -+ return bitcast(u); -+ #endif -+ } -+ -+ /// FP32 -> FP16 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static half_t convert(int const& n) { -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__int2half_rn(n)); -+ #else -+ return convert(float(n)); -+ #endif -+ } -+ -+ /// FP32 -> FP16 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static half_t convert(unsigned const& n) { -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__uint2half_rn(n)); -+ #else -+ return convert(float(n)); -+ #endif -+ } -+ -+ /// Converts a half-precision value stored as a uint16_t to a float -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) -+ // Avoid inlining in device code if no hardware support -+ __device__ __noinline__ -+ #else -+ CUTLASS_HOST_DEVICE -+ #endif -+ static float convert(half_t const& x) { -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __half2float(x.to_half()); -+ #else -+ -+ #if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C -+ if( CpuId::instance().is_f16c_supported() ) { -+ unsigned short u = x.storage; -+ return _cvtsh_ss(u); -+ } -+ #endif -+ -+ uint16_t const &h = x.storage; -+ int sign = ((h >> 15) & 1); -+ int exp = ((h >> 10) & 0x1f); -+ int mantissa = (h & 0x3ff); -+ unsigned f = 0; -+ -+ if (exp > 0 && exp < 31) { -+ // normal -+ exp += 112; -+ f = (sign << 31) | (exp << 23) | (mantissa << 13); -+ } else if (exp == 0) { -+ if (mantissa) { -+ // subnormal -+ exp += 113; -+ while ((mantissa & (1 << 10)) == 0) { -+ mantissa <<= 1; -+ exp--; -+ } -+ mantissa &= 0x3ff; -+ f = (sign << 31) | (exp << 23) | (mantissa << 13); -+ } else { -+ // sign-preserving zero -+ f = (sign << 31); -+ } -+ } else if (exp == 31) { -+ if (mantissa) { -+ f = 0x7fffffff; // not a number -+ } else { -+ f = (0xff << 23) | (sign << 31); // inf -+ } -+ } -+ #if defined(__CUDA_ARCH__) -+ return reinterpret_cast(f); -+ #else -+ float flt; -+ std::memcpy(&flt, &f, sizeof(flt)); -+ return flt; -+ #endif -+ #endif -+ } -+ -+ // -+ // Methods -+ // -+ -+ /// Default constructor -+ half_t() = default; -+ -+ /// Reinterpret cast from CUDA's half type -+ CUTLASS_HOST_DEVICE -+ explicit half_t(half const & x) { -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(x); -+ #else -+ __half_raw raw(x); -+ std::memcpy(&storage, &raw.x, sizeof(storage)); -+ #endif -+ } -+ -+ /// Floating point conversion -+ CUTLASS_HOST_DEVICE -+ explicit half_t(float x) { -+ storage = convert(x).storage; -+ } -+ -+ /// Floating point conversion -+ CUTLASS_HOST_DEVICE -+ explicit half_t(double x): half_t(float(x)) { -+ -+ } -+ -+ /// float_e4m3_t conversion -+ CUTLASS_HOST_DEVICE -+ explicit half_t(float_e4m3_t x): half_t(float(x)) { -+ -+ } -+ -+ /// float_e5m2_t conversion -+ CUTLASS_HOST_DEVICE -+ explicit half_t(float_e5m2_t x): half_t(float(x)) { -+ -+ } -+ -+ /// Integer conversion - round to nearest even -+ CUTLASS_HOST_DEVICE -+ explicit half_t(int x) { -+ storage = convert(x).storage; -+ } -+ -+ /// Integer conversion - round toward zero -+ CUTLASS_HOST_DEVICE -+ explicit half_t(unsigned x) { -+ storage = convert(x).storage; -+ } -+ -+ /// Assignment -+ CUTLASS_HOST_DEVICE -+ half_t & operator=(half const &x) { -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(x); -+ #else -+ __half_raw raw(x); -+ std::memcpy(&storage, &raw.x, sizeof(storage)); -+ #endif -+ return *this; -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ operator float() const { -+ return convert(*this); -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(convert(*this)); -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(convert(*this)); -+ } -+ -+ /// Casts to bool -+ CUTLASS_HOST_DEVICE -+ explicit operator bool() const { -+ return (convert(*this) != 0.0f); -+ } -+ -+ /// Bitcasts to CUDA's half type -+ CUTLASS_HOST_DEVICE -+ half to_half() const { -+ #if defined(__CUDA_ARCH__) -+ return reinterpret_cast(storage); -+ #else -+ __half_raw raw; -+ std::memcpy(&raw.x, &storage, sizeof(raw.x)); -+ return half(raw); -+ #endif -+ } -+ -+ /// Accesses raw internal state -+ CUTLASS_HOST_DEVICE -+ uint16_t& raw() { -+ return storage; -+ } -+ -+ /// Accesses raw internal state -+ CUTLASS_HOST_DEVICE -+ uint16_t raw() const { -+ return storage; -+ } -+ -+ /// Returns the sign bit -+ CUTLASS_HOST_DEVICE -+ bool signbit() const { -+ return ((storage & 0x8000) != 0); -+ } -+ -+ /// Returns the biased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent_biased() const { -+ return int((storage >> 10) & 0x1f); -+ } -+ -+ /// Returns the unbiased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent() const { -+ return exponent_biased() - 15; -+ } -+ -+ /// Returns the mantissa -+ CUTLASS_HOST_DEVICE -+ int mantissa() const { -+ return int(storage & 0x3ff); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool signbit(cutlass::half_t const& h) { -+ return ((h.raw() & 0x8000) != 0); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::half_t abs(cutlass::half_t const& h) { -+ return cutlass::half_t::bitcast(h.raw() & 0x7fff); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isnan(cutlass::half_t const& h) { -+ return (h.exponent_biased() == 0x1f) && h.mantissa(); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isfinite(cutlass::half_t const& h) { -+ return (h.exponent_biased() != 0x1f); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::half_t nanh(const char*) { -+ // NVIDIA canonical NaN -+ return cutlass::half_t::bitcast(0x7fff); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isinf(cutlass::half_t const& h) { -+ return (h.exponent_biased() == 0x1f) && !h.mantissa(); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isnormal(cutlass::half_t const& h) { -+ return h.exponent_biased() && h.exponent_biased() != 0x1f; -+} -+ -+CUTLASS_HOST_DEVICE -+int fpclassify(cutlass::half_t const& h) { -+ int exp = h.exponent_biased(); -+ int mantissa = h.mantissa(); -+ if (exp == 0x1f) { -+ if (mantissa) { -+ return FP_NAN; -+ } -+ else { -+ return FP_INFINITE; -+ } -+ } -+ else if (!exp) { -+ if (mantissa) { -+ return FP_SUBNORMAL; -+ } -+ else { -+ return FP_ZERO; -+ } -+ } -+ return FP_NORMAL; -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::half_t sqrt(cutlass::half_t const& h) { -+#if defined(__CUDACC_RTC__) -+ return cutlass::half_t(sqrtf(float(h))); -+#else -+ return cutlass::half_t(std::sqrt(float(h))); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t copysign(half_t const& a, half_t const& b) { -+ -+ uint16_t a_mag = (a.raw() & 0x7fff); -+ uint16_t b_sign = (b.raw() & 0x8000); -+ uint16_t result = (a_mag | b_sign); -+ -+ return half_t::bitcast(result); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Standard Library operations and definitions -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if !defined(__CUDACC_RTC__) -+namespace std { -+ -+/// Numeric limits -+template <> -+struct numeric_limits { -+ static bool const is_specialized = true; -+ static bool const is_signed = true; -+ static bool const is_integer = false; -+ static bool const is_exact = false; -+ static bool const has_infinity = true; -+ static bool const has_quiet_NaN = true; -+ static bool const has_signaling_NaN = false; -+ static std::float_denorm_style const has_denorm = std::denorm_present; -+ static bool const has_denorm_loss = true; -+ static std::float_round_style const round_style = std::round_to_nearest; -+ static bool const is_iec559 = true; -+ static bool const is_bounded = true; -+ static bool const is_modulo = false; -+ static int const digits = 10; -+ -+ /// Least positive value -+ static cutlass::half_t min() { return cutlass::half_t::bitcast(0x0001); } -+ -+ /// Minimum finite value -+ static cutlass::half_t lowest() { return cutlass::half_t::bitcast(0xfbff); } -+ -+ /// Maximum finite value -+ static cutlass::half_t max() { return cutlass::half_t::bitcast(0x7bff); } -+ -+ /// Returns smallest finite value -+ static cutlass::half_t epsilon() { return cutlass::half_t::bitcast(0x1800); } -+ -+ /// Returns maximum rounding error -+ static cutlass::half_t round_error() { return cutlass::half_t(0.5f); } -+ -+ /// Returns positive infinity value -+ static cutlass::half_t infinity() { return cutlass::half_t::bitcast(0x7c00); } -+ -+ /// Returns quiet NaN value -+ static cutlass::half_t quiet_NaN() { return cutlass::half_t::bitcast(0x7fff); } -+ -+ /// Returns signaling NaN value -+ static cutlass::half_t signaling_NaN() { return cutlass::half_t::bitcast(0x7fff); } -+ -+ /// Returns smallest positive subnormal value -+ static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } -+}; -+} // namespace std -+#endif -+ -+namespace platform { -+ -+/// std::numeric_limits -+template -+struct numeric_limits; -+ -+/// Numeric limits -+template <> -+struct numeric_limits { -+ static bool const is_specialized = true; -+ static bool const is_signed = true; -+ static bool const is_integer = false; -+ static bool const is_exact = false; -+ static bool const has_infinity = true; -+ static bool const has_quiet_NaN = true; -+ static bool const has_signaling_NaN = false; -+#if !defined(__CUDACC_RTC__) -+ static std::float_denorm_style const has_denorm = std::denorm_present; -+#endif -+ static bool const has_denorm_loss = true; -+#if !defined(__CUDACC_RTC__) -+ static std::float_round_style const round_style = std::round_to_nearest; -+#endif -+ static bool const is_iec559 = true; -+ static bool const is_bounded = true; -+ static bool const is_modulo = false; -+ static int const digits = 10; -+ -+ /// Least positive value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t min() { return cutlass::half_t::bitcast(0x0001); } -+ -+ /// Minimum finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t lowest() { return cutlass::half_t::bitcast(0xfbff); } -+ -+ /// Maximum finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t max() { return cutlass::half_t::bitcast(0x7bff); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t epsilon() { return cutlass::half_t::bitcast(0x1800); } -+ -+ /// Returns maximum rounding error -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t round_error() { return cutlass::half_t(0.5f); } -+ -+ /// Returns positive infinity value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t infinity() { return cutlass::half_t::bitcast(0x7c00); } -+ -+ /// Returns quiet NaN value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t quiet_NaN() { return cutlass::half_t::bitcast(0x7fff); } -+ -+ /// Returns signaling NaN value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t signaling_NaN() { return cutlass::half_t::bitcast(0x7fff); } -+ -+ /// Returns smallest positive subnormal value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } -+}; -+} // namespace platform -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Arithmetic operators -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool operator==(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __heq(lhs.to_half(), rhs.to_half()); -+#else -+ return float(lhs) == float(rhs); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator!=(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __hne(lhs.to_half(), rhs.to_half()); -+#else -+ return float(lhs) != float(rhs); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __hlt(lhs.to_half(), rhs.to_half()); -+#else -+ return float(lhs) < float(rhs); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<=(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __hle(lhs.to_half(), rhs.to_half()); -+#else -+ return float(lhs) <= float(rhs); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __hgt(lhs.to_half(), rhs.to_half()); -+#else -+ return float(lhs) > float(rhs); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>=(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __hge(lhs.to_half(), rhs.to_half()); -+#else -+ return float(lhs) >= float(rhs); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator+(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__hadd(lhs.to_half(), rhs.to_half())); -+#else -+ return half_t(float(lhs) + float(rhs)); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator-(half_t const& lhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__hneg(lhs.to_half())); -+#else -+ return half_t(-float(lhs)); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator-(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__hsub(lhs.to_half(), rhs.to_half())); -+#else -+ return half_t(float(lhs) - float(rhs)); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator*(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__hmul(lhs.to_half(), rhs.to_half())); -+#else -+ return half_t(float(lhs) * float(rhs)); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator/(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__hdiv(lhs.to_half(), rhs.to_half())); -+#else -+ return half_t(float(lhs) / float(rhs)); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t& operator+=(half_t & lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hadd(lhs.to_half(), rhs.to_half())); -+#else -+ lhs = half_t(float(lhs) + float(rhs)); -+#endif -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t& operator-=(half_t & lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hsub(lhs.to_half(), rhs.to_half())); -+#else -+ lhs = half_t(float(lhs) - float(rhs)); -+#endif -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t& operator*=(half_t & lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hmul(lhs.to_half(), rhs.to_half())); -+#else -+ lhs = half_t(float(lhs) * float(rhs)); -+#endif -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t& operator/=(half_t & lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hdiv(lhs.to_half(), rhs.to_half())); -+#else -+ lhs = half_t(float(lhs) / float(rhs)); -+#endif -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t& operator++(half_t & lhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hadd(lhs.to_half(), half_t(1.0f).to_half())); -+#else -+ float tmp(lhs); -+ ++tmp; -+ lhs = half_t(tmp); -+#endif -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t& operator--(half_t & lhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hsub(lhs.to_half(), half_t(1.0f).to_half())); -+#else -+ float tmp(lhs); -+ --tmp; -+ lhs = half_t(tmp); -+#endif -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator++(half_t & lhs, int) { -+ half_t ret(lhs); -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hadd(lhs.to_half(), half_t(1.0f).to_half())); -+#else -+ float tmp(lhs); -+ tmp++; -+ lhs = half_t(tmp); -+#endif -+ return ret; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator--(half_t & lhs, int) { -+ half_t ret(lhs); -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hsub(lhs.to_half(), half_t(1.0f).to_half())); -+#else -+ float tmp(lhs); -+ tmp--; -+ lhs = half_t(tmp); -+#endif -+ return ret; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// User-defined literals -+// -+ -+CUTLASS_HOST_DEVICE -+cutlass::half_t operator "" _hf(long double x) { -+ return cutlass::half_t(float(x)); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::half_t operator "" _hf(unsigned long long int x) { -+ return cutlass::half_t(int(x)); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/integer_subbyte.h b/3rdparty/cutlass/include/cutlass/integer_subbyte.h -new file mode 100644 -index 0000000..f02a7d3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/integer_subbyte.h -@@ -0,0 +1,240 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines a class for using integer types smaller than one byte in host or -+ device code. -+*/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/platform/platform.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 4-bit signed integer type -+template -+struct integer_subbyte { -+ -+ /// Number of bits -+ static int const kBits = Bits; -+ -+ /// Whether type is signed -+ static bool const kSigned = Signed; -+ -+ /// External type -+ using T = typename platform::conditional::type; -+ -+ /// Storage type -+ using Storage = uint8_t; -+ -+ /// Bitmask used to truncate from larger integers -+ static Storage const kMask = Storage((1 << kBits) - 1); -+ -+ // -+ // Data members -+ // -+ -+ Storage storage; -+ -+ // -+ // Methods -+ // -+ -+ /// No operation -+ integer_subbyte() = default; -+ -+ /// Conversion from integer type -+ CUTLASS_HOST_DEVICE -+ integer_subbyte(int value) -+ : storage(reinterpret_cast(value) & kMask) {} -+ -+ CUTLASS_HOST_DEVICE -+ integer_subbyte(unsigned value) -+ : storage(reinterpret_cast(value) & kMask) {} -+ -+ CUTLASS_HOST_DEVICE -+ integer_subbyte(double value) { -+ T tmp = static_cast(value); -+ storage = Storage(reinterpret_cast(tmp) & kMask); -+ } -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ operator T() const { -+ if (kSigned) { -+ // Sign extend -+ if (storage & Storage(1 << (kBits - 1))) { -+ return T(storage) | ~T(kMask); -+ } -+ } -+ return T(storage); -+ } -+ -+ /// Equality -+ CUTLASS_HOST_DEVICE -+ bool operator==(integer_subbyte const &rhs) const { -+ return storage == rhs.storage; -+ } -+ -+ /// Inequality -+ CUTLASS_HOST_DEVICE -+ bool operator!=(integer_subbyte const &rhs) const { -+ return storage != rhs.storage; -+ } -+ -+ /// Less than or equal -+ CUTLASS_HOST_DEVICE -+ bool operator<=(integer_subbyte const &rhs) const { -+ if (kSigned) { -+ if (storage & (1 << (kBits - 1))) { -+ return !(rhs.storage < storage); -+ } -+ } -+ return storage < rhs.storage; -+ } -+ -+ /// Less than -+ CUTLASS_HOST_DEVICE -+ bool operator<(integer_subbyte const &rhs) const { -+ if (kSigned) { -+ if (storage & (1 << (kBits - 1))) { -+ return !(rhs.storage <= storage); -+ } -+ } -+ return storage < rhs.storage; -+ } -+ -+ /// Greater than or equal -+ CUTLASS_HOST_DEVICE -+ bool operator>=(integer_subbyte const &rhs) const { -+ return !(*this < rhs); -+ } -+ -+ /// Greater than -+ CUTLASS_HOST_DEVICE -+ bool operator>(integer_subbyte const &rhs) const { -+ return !(*this <= rhs); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// 1-bit Unsigned integer type -+using uint1b_t = integer_subbyte<1, false>; -+ -+/// 2-bit Integer type -+using int2b_t = integer_subbyte<2, true>; -+ -+/// 2-bit Unsigned integer type -+using uint2b_t = integer_subbyte<2, false>; -+ -+/// 4-bit Integer type -+using int4b_t = integer_subbyte<4, true>; -+ -+/// 4-bit Unsigned integer type -+using uint4b_t = integer_subbyte<4, false>; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the size of an element in bits - specialized for uint1b_t -+template <> -+struct sizeof_bits { -+ static int const value = 1; -+}; -+ -+/// Defines the size of an element in bits - specialized for int2b_t -+template <> -+struct sizeof_bits { -+ static int const value = 2; -+}; -+ -+/// Defines the size of an element in bits - specialized for uint2b_t -+template <> -+struct sizeof_bits { -+ static int const value = 2; -+}; -+ -+/// Defines the size of an element in bits - specialized for int4b_t -+template <> -+struct sizeof_bits { -+ static int const value = 4; -+}; -+ -+/// Defines the size of an element in bits - specialized for uint4b_t -+template <> -+struct sizeof_bits { -+ static int const value = 4; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace platform { -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static cutlass::int4b_t const lowest() noexcept { return -8;} -+ CUTLASS_HOST_DEVICE -+ static cutlass::int4b_t const max() noexcept { return 7;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static cutlass::uint4b_t const lowest() noexcept { return 0;} -+ CUTLASS_HOST_DEVICE -+ static cutlass::uint4b_t const max() noexcept { return 15;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static cutlass::uint1b_t const lowest() noexcept { return 0;} -+ CUTLASS_HOST_DEVICE -+ static cutlass::uint1b_t const max() noexcept { return 1;} -+ static constexpr bool is_integer = true; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace platform -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/kernel_hardware_info.hpp b/3rdparty/cutlass/include/cutlass/kernel_hardware_info.hpp -new file mode 100644 -index 0000000..3ae0932 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/kernel_hardware_info.hpp -@@ -0,0 +1,71 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cuda_runtime.h" -+ -+#include "cutlass/trace.h" -+ -+namespace cutlass { -+ -+struct KernelHardwareInfo { -+ // -+ // Data members -+ // -+ int device_id = 0; -+ int sm_count = 0; -+ -+ // -+ // Methods -+ // -+ -+ static int -+ query_device_multiprocessor_count(int device_id = 0) { -+ cudaError_t result = cudaGetDevice(&device_id); -+ if (result != cudaSuccess) { -+ CUTLASS_TRACE_HOST( -+ " cudaGetDevice() returned error " -+ << cudaGetErrorString(result)); -+ return 0; -+ } -+ cudaDeviceProp properties; -+ result = cudaGetDeviceProperties(&properties, device_id); -+ if (result != cudaSuccess) { -+ CUTLASS_TRACE_HOST( -+ " cudaGetDeviceProperties() returned error " -+ << cudaGetErrorString(result)); -+ return 0; -+ } -+ return properties.multiProcessorCount; -+ } -+}; -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/kernel_launch.h b/3rdparty/cutlass/include/cutlass/kernel_launch.h -new file mode 100644 -index 0000000..c54f1fa ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/kernel_launch.h -@@ -0,0 +1,73 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines structures and helpers to launch CUDA kernels within CUTLASS. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure containing the basic launch configuration of a CUDA kernel. -+struct KernelLaunchConfiguration { -+ -+ /// CUDA grid dimensions -+ dim3 grid; -+ -+ /// CUDA threablock dimensions -+ dim3 block; -+ -+ /// Bytes of dynamically allocated SMEM in addition to static SMEM -+ size_t dynamic_smem; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a KernellaunchConfiguration object -+ CUTLASS_HOST_DEVICE -+ KernelLaunchConfiguration( -+ dim3 _grid = dim3(1,1,1), -+ dim3 _block = dim3(1,1,1), -+ size_t _dynamic_smem = 0 -+ ): -+ grid(_grid), -+ block(_block), -+ dynamic_smem(_dynamic_smem) { } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/layout/layout.h b/3rdparty/cutlass/include/cutlass/layout/layout.h -new file mode 100644 -index 0000000..6f638eb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/layout.h -@@ -0,0 +1,64 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used by TensorRef and derived classes. -+ -+ Layout functions map logical coordinates to linear memory. They often require additional -+ data to describe strides between elements. -+ -+ Layout functions must implement all members in the public interface of IdentityTensorLayout<> -+ defined in cutlass/tensor_ref.h. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/vector.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm70.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace layout { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/layout/matrix.h b/3rdparty/cutlass/include/cutlass/layout/matrix.h -new file mode 100644 -index 0000000..fe7a848 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/matrix.h -@@ -0,0 +1,1371 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used by TensorRef and derived classes. -+ -+ Layout functions map logical coordinates to linear memory. They often require additional -+ data to describe strides between elements. -+ -+ Layout functions must implement all members in the public interface of IdentityTensorLayout<> -+ defined in cutlass/tensor_ref.h. -+*/ -+#pragma once -+ -+#include "cute/layout.hpp" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/pitch_linear_coord.h" -+ -+namespace cutlass { -+namespace layout { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Defines data layouts of various matrix formats usable by TensorRef and other classes. -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for row-major matrices. -+class RowMajor { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ RowMajor(LongIndex ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajor(Stride stride): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajor packed(MatrixCoord const &extent) { -+ return RowMajor(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ return LongIndex(coord.row()) * LongIndex(stride_[0]) + coord.column(); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ return MatrixCoord(Index(offset / stride_[0]), Index(offset % stride_[0])); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return LongIndex(extent.row()) * LongIndex(stride_[0]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ cute::Layout, cute::Stride > > -+ to_cute_layout(MatrixCoord const &extent) const { -+ return cute::Layout, cute::Stride > >{ -+ {extent[0], extent[1]}, -+ {stride(0), cute::Int<1>{}} -+ }; -+ } -+}; -+ -+/// Mapping function for column-major matrices. -+class ColumnMajor { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajor(LongIndex ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajor(Stride stride): stride_(stride) { } -+ -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajor packed(MatrixCoord const &extent) { -+ return ColumnMajor(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ return LongIndex(coord.column()) * LongIndex(stride_[0]) + coord.row(); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ return MatrixCoord(Index(offset % stride_[0]), Index(offset / stride_[0])); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return LongIndex(extent.column()) * LongIndex(stride_[0]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ cute::Layout, cute::Stride< cute::Int<1>, int64_t> > -+ to_cute_layout(MatrixCoord const &extent) const { -+ return cute::Layout, cute::Stride, int64_t> >{ -+ {extent[0], extent[1]}, -+ {cute::Int<1>{}, stride(0)} -+ }; -+ } -+}; -+ -+/// Mapping function for interleaved matrices. Matrix is structured -+/// as row-major arrangement of fixed-size columns. -+template -+struct RowMajorInterleaved { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ /// Size of interleaved columns -+ static int const kInterleave = Interleave; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorInterleaved(LongIndex ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorInterleaved(Stride stride): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorInterleaved packed(MatrixCoord const &extent) { -+ return RowMajorInterleaved(extent.column() * kInterleave); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ Index row_major = coord.row() / kInterleave; -+ Index row_minor = coord.row() % kInterleave; -+ return LongIndex(row_major) * LongIndex(stride_[0]) + LongIndex(coord.column()) * kInterleave + row_minor; -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ -+ Index row_major = Index(offset / stride_[0]); -+ Index residual = Index(offset % stride_[0]); -+ -+ Index column = residual / kInterleave; -+ Index row_minor = residual % kInterleave; -+ -+ return MatrixCoord(row_major * kInterleave + row_minor, column); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return (extent.row() + kInterleave - 1) / kInterleave * stride_[0]; -+ } -+}; -+ -+/// Mapping function for interleaved matrices. Matrix is structured -+/// as column-major arrangement of fixed-size rows. -+template -+struct ColumnMajorInterleaved { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ /// Size of interleaved columns -+ static int const kInterleave = Interleave; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorInterleaved(LongIndex ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorInterleaved(Stride stride): stride_(stride) { } -+ -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorInterleaved packed(MatrixCoord const &extent) { -+ return ColumnMajorInterleaved(extent.row() * kInterleave); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ Index column_major = coord.column() / kInterleave; -+ Index column_minor = coord.column() % kInterleave; -+ return LongIndex(column_major) * LongIndex(stride_[0]) + LongIndex(coord.row()) * kInterleave + column_minor; -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ -+ Index column_major = Index(offset / stride_[0]); -+ Index residual = Index(offset % stride_[0]); -+ -+ Index row = residual / kInterleave; -+ Index column_minor = residual % kInterleave; -+ -+ return MatrixCoord(row, column_major * kInterleave + column_minor); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return (extent.column() + kInterleave - 1) / kInterleave * stride_[0]; -+ } -+}; -+ -+/// Enumerated type for canonical pitch-linear matrix layouts -+enum class Matrix { -+ kColumnMajor, ///< leading dimension refers to stride between columns; stride along rows is 1 -+ kRowMajor ///< leading dimension refers to stride between rows; stride along columns is 1 -+}; -+ -+/// Mapping function for scenario in which layout is row-major or column-major but this information -+/// is only available at runtime. -+struct ContiguousMatrix { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+ /// Enumerated type indicating canonical matrix layout -+ Matrix layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ContiguousMatrix( -+ Index ldm = 0, -+ Matrix layout = Matrix::kColumnMajor -+ ): -+ stride_(ldm), layout_(layout) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ContiguousMatrix packed( -+ MatrixCoord const &extent, -+ Matrix layout = Matrix::kColumnMajor) { -+ -+ Index ldm = 0; -+ if (layout == Matrix::kColumnMajor) { -+ ldm = extent.row(); -+ } -+ else if (layout == Matrix::kRowMajor) { -+ ldm = extent.column(); -+ } -+ return ContiguousMatrix(ldm, layout); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ if (layout_ == Matrix::kColumnMajor) { -+ return coord.row() + coord.column() * stride_[0]; -+ } -+ else if (layout_ == Matrix::kRowMajor) { -+ return coord.row() * stride_[0] + coord.column(); -+ } -+ else { -+ // degenerate case -+ return 0; -+ } -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ // TODO -+ return MatrixCoord(0, 0); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ if (layout_ == Matrix::kColumnMajor) { -+ return stride_[0] * extent.column(); -+ } -+ else if (layout_ == Matrix::kRowMajor) { -+ return stride_[0] * extent.row(); -+ } -+ else { -+ // degenerate case -+ return 0; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for scenario in which both rows and columns are separated by a stride. -+template -+struct AffineRankN { -+ -+ /// Logical rank of tensor -+ static int const kRank = Rank; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = kRank; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = Coord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRankN( -+ Stride const &stride = Stride() -+ ): -+ stride_(stride) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRankN( -+ Coord const &stride_m, -+ Coord const &stride_n -+ ) { -+ -+ // Concatenate the strides -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kRank/2; ++m) { -+ stride_[m] = stride_m[m]; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kRank/2; ++n) { -+ stride_[n + kRank/2] = stride_n[n]; -+ } -+ } -+ -+ /// Ctor for N = 2 -+ CUTLASS_HOST_DEVICE -+ AffineRankN( -+ LongIndex const &stride_m, -+ LongIndex const &stride_n -+ ) { -+ stride_[0] = stride_m; -+ stride_[1] = stride_n; -+ } -+ -+ /// Ctor for N = 2 -+ CUTLASS_HOST_DEVICE -+ AffineRankN( -+ LongIndex const &stride -+ ) { -+ stride_[0] = stride; -+ stride_[1] = 1; -+ } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static AffineRankN packed(TensorCoord const &extent) { -+ -+ AffineRankN layout; -+ layout.stride_[kRank - 1] = 1; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = kRank - 1; i > 0; --i) { -+ layout.stride_[i - 1] = layout.stride_[i] * extent[i]; -+ } -+ -+ return layout; -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return dot(coord, stride_); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ // TODO -+ return TensorCoord(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ int idx = stride_.max_dim_index(); -+ return extent[idx] * stride_[idx]; -+ } -+}; -+ -+/// Mapping function for scenario in which both rows and columns are separated by a stride. -+/// Row stride is smaller than column stride in AffineRank2ColumnMajor. -+struct AffineRank2ColumnMajor { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 2; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRank2ColumnMajor( -+ Stride const &stride = Stride() -+ ): -+ stride_(stride) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRank2ColumnMajor( -+ LongIndex row_stride, ///< stride between elements in consecutive rows -+ LongIndex column_stride ///< stride between elements in consecutive columns -+ ) -+ { stride_[0] = row_stride; stride_[1] = column_stride;} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRank2ColumnMajor( -+ LongIndex stride -+ ) -+ { stride_[0] = 1; stride_[1] = stride;} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static AffineRank2ColumnMajor packed(MatrixCoord const &extent) { -+ return AffineRank2ColumnMajor(extent.column(), 1); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ return dot(coord, stride_); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ // TODO -+ return MatrixCoord(0, 0); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return extent.column() * stride_[1]; -+ } -+}; -+ -+/// Mapping function for scenario in which both rows and columns are separated by a stride. -+/// Column stride is smaller than row stride in AffineRank2RowMajor. -+struct AffineRank2RowMajor { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 2; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRank2RowMajor( -+ Stride const &stride = Stride() -+ ): -+ stride_(stride) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRank2RowMajor( -+ LongIndex row_stride, ///< stride between elements in consecutive rows -+ LongIndex column_stride ///< stride between elements in consecutive columns -+ ) { stride_[0] = row_stride; stride_[1] = column_stride;} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRank2RowMajor( -+ LongIndex stride -+ ) { stride_[0] = stride; stride_[1] = 1;} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static AffineRank2RowMajor packed(MatrixCoord const &extent) { -+ return AffineRank2RowMajor(extent.column(), 1); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ return dot(coord, stride_); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ // TODO -+ return MatrixCoord(0, 0); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return extent.row() * stride_[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Utility functions to convert stride_factor to the strides used by the Affine2 layout. -+// -+// stride_factor is the logical distance between two coorinates. -+// -+// All Coodinates used here are matrix coordinates. stride[0] and extent[0] are for the -+// rows. stride[1] and extent[1] are for the columns. -+template -+ struct Affine2Layout_Factory { -+ CUTLASS_HOST_DEVICE -+ static Affine2Layout layout_factory(cutlass::Coord<2> const &extent, typename Affine2Layout::Stride stride_factor) { -+ return Affine2Layout::packed(extent); -+ } -+}; -+ -+template <> -+struct Affine2Layout_Factory { -+CUTLASS_HOST_DEVICE -+static cutlass::layout::AffineRank2ColumnMajor layout_factory( -+ cutlass::Coord<2> const &extent, -+ typename cutlass::layout::AffineRank2ColumnMajor::Stride stride_factor) { -+ return cutlass::layout::AffineRank2ColumnMajor({ stride_factor[0], stride_factor[0] * stride_factor[1] * extent[0] }); -+ } -+}; -+ -+template <> -+struct Affine2Layout_Factory { -+CUTLASS_HOST_DEVICE -+static cutlass::layout::AffineRank2RowMajor layout_factory( -+ cutlass::Coord<2> const &extent, -+ typename cutlass::layout::AffineRank2RowMajor::Stride stride_factor) { -+ return cutlass::layout::AffineRank2RowMajor({ stride_factor[0] * stride_factor[1] * extent[1], stride_factor[1] }); -+ } -+}; -+ -+// The base layout cutlass::layout::AffineRankN<2> is similar to AffineRank2ColumnMajor -+template <> -+struct Affine2Layout_Factory> { -+CUTLASS_HOST_DEVICE -+static cutlass::layout::AffineRankN<2> layout_factory( -+ cutlass::Coord<2> const &extent, -+ typename cutlass::layout::AffineRankN<2>::Stride stride_factor) { -+ return cutlass::layout::AffineRankN<2>({ stride_factor[0], stride_factor[0] * stride_factor[1] * extent[0] }); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for block-linear matrices. Matrix is structured -+/// as column-major arrangement of 2D tiles (that are column-major). -+template -+struct ColumnMajorBlockLinear { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ /// Size of a block in rows -+ static int const kBlockRows = BlockRows; -+ -+ /// Size of a block in columns -+ static int const kBlockColumns = BlockColumns; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorBlockLinear(Index ldm = 0): stride_(ldm) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorBlockLinear packed(MatrixCoord const &extent) { -+ return ColumnMajorBlockLinear(extent.row() * kBlockRows * kBlockColumns); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ return -+ (coord.row() % kBlockRows) + -+ (coord.column() % kBlockColumns) * kBlockRows + -+ (coord.row() / kBlockRows) * kBlockRows * kBlockColumns + -+ (coord.column() / kBlockColumns) * stride_[0]; -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ -+ // TODO -+ return MatrixCoord(0, 0); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return (extent.column() + kBlockColumns - 1) / kBlockColumns * stride_[0]; -+ } -+}; -+ -+/// Mapping function for block-linear matrices. Matrix is structured -+/// as row-major arrangement of 2D tiles (that are row-major) -+template -+struct RowMajorBlockLinear { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ /// Size of a block in rows -+ static int const kBlockRows = BlockRows; -+ -+ /// Size of a block in columns -+ static int const kBlockColumns = BlockColumns; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorBlockLinear(Index ldm = 0): stride_(ldm) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorBlockLinear packed(MatrixCoord const &extent) { -+ return RowMajorBlockLinear(extent.column() * kBlockRows * kBlockColumns); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ return -+ (coord.column() % kBlockColumns) + -+ (coord.row() % kBlockRows) * kBlockColumns + -+ (coord.column() / kBlockColumns) * kBlockRows * kBlockColumns + -+ (coord.row() / kBlockRows) * stride_[0]; -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ // TODO -+ return MatrixCoord(0, 0); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return (extent.row() + kBlockRows - 1) / kBlockRows * stride_[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct GeneralMatrix { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 2; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ Matrix layout_id_; -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ GeneralMatrix(): layout_id_(Matrix::kColumnMajor), stride_(make_Coord(0, 1)) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ GeneralMatrix( -+ Matrix layout_id, -+ Index ldm, -+ Index interleave): layout_id_(layout_id), stride_(make_Coord(ldm, interleave)) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static GeneralMatrix packed( -+ MatrixCoord const &extent, -+ Matrix layout_id = Matrix::kColumnMajor, -+ Index interleave = 1) { -+ -+ Index c; -+ if (layout_id == Matrix::kRowMajor) { -+ c = extent.column(); -+ } -+ else { -+ c = extent.row(); -+ } -+ -+ Index ldm = c * interleave; -+ -+ return GeneralMatrix(layout_id, ldm, interleave); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ Index c, s; -+ if (layout_id_ == Matrix::kRowMajor) { -+ c = coord.column(); -+ s = coord.row(); -+ } -+ else { -+ s = coord.column(); -+ c = coord.row(); -+ } -+ -+ Index v = s / stride_[1]; -+ Index residual = (s % stride_[1]); -+ -+ return LongIndex(c) * LongIndex(stride_[1]) + LongIndex(v) * LongIndex(stride_[0]) + residual; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix layout_id() const { -+ return layout_id_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix & layout_id() { -+ return layout_id_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ Index s; -+ if (layout_id_ == Matrix::kRowMajor) { -+ s = extent.row(); -+ } -+ else { -+ s = extent.column(); -+ } -+ -+ Index v = Index((s + stride_[1] - 1) / stride_[1]); -+ return LongIndex(v) * LongIndex(stride_[0]); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines transposes of matrix layouts -+template -+struct LayoutTranspose; -+ -+/// Transpose of row-major is column-major -+template <> -+struct LayoutTranspose { -+ using type = layout::ColumnMajor; -+}; -+ -+/// Transpose of column-major is row-major -+template <> -+struct LayoutTranspose { -+ using type = layout::RowMajor; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/layout/permute.h b/3rdparty/cutlass/include/cutlass/layout/permute.h -new file mode 100644 -index 0000000..693425b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/permute.h -@@ -0,0 +1,314 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used by GEMM+permute path for common tensor or matrix formats. -+ -+ Like Layout functions, permute layout functions map logical coordinates to linear memory. They often require additional -+ data to describe strides between elements. -+ -+ Permute layout functions must implement all members in the interface of NoPermute<> defined in this file. Address offset -+ computation lies in operator() with private member variables {col_permute_, row_permute_ and stride_permute_} as new addresses after permute op. -+*/ -+#pragma once -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include "assert.h" -+#endif -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/coord.h" -+#include "cutlass/tensor_coord.h" -+ -+namespace cutlass { -+namespace layout { -+ -+class NoPermute { -+public: -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+private: -+ // -+ // Data members -+ // -+ -+ MatrixCoord extent_; -+ -+ Index stride_unit_; // sizeof(AccessType) / kElementsPerAccess in epilogue's predicated_tile_iterator -+ -+ Index stride_permute_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ NoPermute() { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ NoPermute(MatrixCoord extent, Index stride_init): extent_(extent) { } -+ -+ /// Computes the address offset after Permute Op in Bytes -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord offset_init) { return 0; } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Defines permute layouts of various tensor formats. -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Permute layout function for 4-D permuted tensors with output matrix (dimension as [M, N]) reshaped -+/// as [M/D1, D1, D2, N/D2]. Then perform permute([0, 2, 1, 3]) on the corresponding output tensor. -+template -+class Tensor4DPermute0213 { -+public: -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+private: -+ // -+ // Data members -+ // -+ -+ MatrixCoord extent_; -+ -+ Index stride_permute_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Tensor4DPermute0213() { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Tensor4DPermute0213(MatrixCoord extent, Index stride_init): extent_(extent) { -+ -+ /// Update stride_permute with stride_init -+ stride_permute_ = stride_init / D2 * D1; // stride in Elements -+ -+ } -+ -+ /// Computes the address offset after Permute Op in Bytes -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord offset_init) { -+ // Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i,j,k,l], the dimension of X -+ // is [D0, D1, D2, D3], after permutation the dim of X1 is [D0, D2, D1, D3]. -+ assert(extent_.row() % D1 == 0); -+ assert(extent_.column() % D2 == 0); -+ -+ int D3 = extent_.column() / D2; -+ -+ Index col_init = offset_init.column(); -+ Index row_init = offset_init.row(); -+ -+ int l = col_init % D3; -+ int k = col_init / D3; -+ int j = row_init % D1; -+ int i = row_init / D1; -+ -+ // After the Permute Op -+ Index col_permute = l + j * D3; -+ Index row_permute = k + i * D2; -+ -+ return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute); -+ } -+ -+ /// Return D1 -+ CUTLASS_HOST_DEVICE -+ Index d1() const { -+ return D1; -+ } -+ -+ /// Return D2 -+ CUTLASS_HOST_DEVICE -+ Index d2() const { -+ return D2; -+ } -+}; -+ -+/// Permute layout function for 4-D permuted tensors for BMM with BMM output tensor (dimension as [B, M, N]) reshaped -+/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM output tensor. -+template -+class Tensor4DPermuteBMM0213 { -+public: -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+private: -+ // -+ // Data members -+ // -+ -+ MatrixCoord extent_; -+ -+ Index stride_permute_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Tensor4DPermuteBMM0213() { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Tensor4DPermuteBMM0213(MatrixCoord extent, Index stride_init): extent_(extent) { -+ -+ /// Update stride_permute with stride_init -+ stride_permute_ = stride_init * D1; // stride in Elements -+ -+ } -+ -+ /// Computes the address offset after Permute Op in Bytes -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord offset_init) { -+ -+ // The batch index for BMM -+ Index BMM_batch_idx = blockIdx.z; -+ -+ // Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i,j,k,l], the dimension of X -+ // is [D0, D1, D2, D3], after permutation the dim of X1 is [D0, D2, D1, D3]. -+ int D2 = extent_.row(); -+ int D3 = extent_.column(); -+ -+ Index col_init = offset_init.column(); -+ Index row_init = offset_init.row(); -+ -+ int l = col_init; -+ int k = row_init; -+ int j = BMM_batch_idx % D1; -+ int i = BMM_batch_idx / D1; -+ -+ // After the Permute Op -+ Index col_permute = l + j * D3; -+ Index row_permute = k + i * D2; -+ -+ return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute); -+ } -+ -+ /// Return D1 -+ CUTLASS_HOST_DEVICE -+ Index d1() const { -+ return D1; -+ } -+}; -+ -+/// Permute layout function for 5-D permuted tensors with output matrix (dimension as [M, N]) reshaped -+/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding output tensor. -+template -+class Tensor5DPermute20314 { -+public: -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+private: -+ // -+ // Data members -+ // -+ -+ MatrixCoord extent_; -+ -+ Index stride_permute_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Tensor5DPermute20314() { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Tensor5DPermute20314(MatrixCoord extent, Index stride_init): extent_(extent) { -+ -+ /// Update stride_permute with stride_init -+ stride_permute_ = stride_init / T2 * T1; // stride in Elements -+ -+ } -+ -+ /// Computes the address offset after Permute Op in Bytes -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord offset_init) { -+ -+ // Permute as torch.permute(X1, [2, 0, 3, 1, 4]) -> 5D Tensor indices as [i,j,k,l,m], the dimension of X -+ // is [T0, T1, T2, T3, T4], after permutation the dim of X1 is [T2, T0, T3, T1, T4]. -+ int T0 = extent_.row() / T1; -+ int T4 = extent_.column() / T2 / T3; -+ -+ Index col_init = offset_init.column(); -+ Index row_init = offset_init.row(); -+ -+ int m = col_init % T4; -+ int l = int(col_init / T4) % T3; -+ int k = int(col_init / T4) / T3; -+ int j = row_init % T1; -+ int i = row_init / T1; -+ -+ // After the Permute Op -+ Index col_permute = m + j * T4 + l * T1 * T4; -+ Index row_permute = i + k * T0; -+ -+ return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/layout/pitch_linear.h b/3rdparty/cutlass/include/cutlass/layout/pitch_linear.h -new file mode 100644 -index 0000000..b49ab95 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/pitch_linear.h -@@ -0,0 +1,148 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used by TensorRef and derived classes for pitch-linear memory. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+#include "cutlass/pitch_linear_coord.h" -+ -+namespace cutlass { -+namespace layout { -+ -+template -+ using PitchLinearShape = cutlass::PitchLinearShape < Contiguous, Strided >; -+ using PitchLinearCoord = PitchLinearCoord; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for pitch-linear memory -+class PitchLinear { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ PitchLinear(LongIndex ldm = 0): stride_(ldm) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ PitchLinear(Stride _stride): stride_(_stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static PitchLinear packed(TensorCoord const &extent) { -+ return PitchLinear(extent.contiguous()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return LongIndex(coord.contiguous()) + LongIndex(coord.strided()) * LongIndex(stride_[0]); -+ } -+ -+ /// Returns the logical coordinate given an offset. -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex index) const { -+ return make_Coord( -+ TensorCoord::Index(index % stride_[0]), -+ TensorCoord::Index(index / stride_[0]) -+ ); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ LongIndex stride(int rank) const { -+ return stride_[rank]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ LongIndex & stride(int rank) { -+ return stride_[rank]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent.strided() * stride_[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/layout/tensor.h b/3rdparty/cutlass/include/cutlass/layout/tensor.h -new file mode 100644 -index 0000000..29ac570 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/tensor.h -@@ -0,0 +1,636 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D -+ tensor formats. -+ -+ Layout functions map logical coordinates to linear memory. They often require additional -+ data to describe strides between elements. -+ -+ Layout functions must implement all members in the public interface of IdentityTensorLayout<> -+ defined in cutlass/tensor_ref.h. -+*/ -+#pragma once -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include "assert.h" -+#endif -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/coord.h" -+#include "cutlass/tensor_coord.h" -+ -+namespace cutlass { -+namespace layout { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Defines data layouts of various tensor formats usable by TensorRef and other classes. -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for 4-D NHWC tensors. -+class TensorNHWC { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 4; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 3; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate (n, h, w, c) -+ using TensorCoord = Tensor4DCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member - [stride_w, stride_h, stride_n] -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNHWC(Stride const &stride = Stride(0)): stride_(stride) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNHWC( -+ typename Stride::Index stride_w, ///< number of elements between adjacent W coordinates -+ typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates -+ typename Stride::Index stride_n ///< number of elements between adjacent N coordinates -+ ): -+ stride_(make_Coord(stride_w, stride_h, stride_n)) { } -+ -+ /// Constructor -+ // Once convolutions implement 64b stride this ctor can be deleted -+ CUTLASS_HOST_DEVICE -+ TensorNHWC(Coord const &stride): -+ stride_(make_Coord( -+ static_cast(stride[0]), -+ static_cast(stride[1]), -+ static_cast(stride[2])) -+ ) { } -+ -+ /// Helper returns a layout to a tightly packed NHWC tensor. -+ CUTLASS_HOST_DEVICE -+ static TensorNHWC packed(TensorCoord const &extent) { -+ return TensorNHWC( -+ make_Coord( -+ extent.c(), -+ extent.w() * extent.c(), -+ extent.h() * extent.w() * extent.c() -+ ) -+ ); -+ } -+ -+ /// Returns the offset of a coordinate (n, h, w, c) in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return coord.c() + -+ LongIndex(stride_[0] * coord.w()) + -+ LongIndex(stride_[1] * coord.h()) + -+ LongIndex(stride_[2] * coord.n()); -+ } -+ -+ /// Returns the offset of a pitchlinear coordinate in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(PitchLinearCoord coord) const { -+ return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); -+ } -+ -+ /// Returns the logical coordinate (n, h, w, c) from a given offset in linear memory. -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex index) const { -+ -+ int n = 0, h = 0, w = 0, c = 0; -+ -+ #if defined(__CUDA_ARCH__) -+ int tmp = 0; -+ c = int(index % static_cast(stride_[0])); -+ -+ unsigned int hw_mul, hw_shr, w_mul, w_shr, c_mul, c_shr; -+ -+ find_divisor(hw_mul, hw_shr, stride_[2]); -+ find_divisor(w_mul, w_shr, stride_[1]); -+ find_divisor(c_mul, c_shr, stride_[0]); -+ -+ fast_divmod(n, tmp, index, int(stride_[2]), hw_mul, hw_shr); -+ fast_divmod(h, w, tmp, int(stride_[1]), w_mul, w_shr); -+ fast_divmod(w, tmp, w, int(stride_[0]), c_mul, c_shr); -+ #else -+ -+ n = int(index / stride_[2]); -+ LongIndex residual = index % stride_[2]; -+ -+ h = int(residual / stride_[1]); -+ residual = (residual % stride_[1]); -+ -+ w = int(residual / stride_[0]); -+ c = int(residual % stride_[0]); -+ -+ #endif -+ return TensorCoord(n, h, w, c); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ // it does not make sense if the extent is larger than stride -+ // and we could not rely on the capacity calculation in such cases -+ // we could move this checkers to debug code only -+ if ((extent.c() > stride_[0]) -+ || (extent.w() * stride_[0] > stride_[1]) -+ || (extent.h() * stride_[1] > stride_[2])) { -+ assert(0); -+ } -+ return extent.n() * stride_[2]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for 4-D NCHW tensors. -+class TensorNCHW { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 4; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 3; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = Tensor4DCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member - [w, hw, chw] -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNCHW(Stride const &stride = Stride(0)): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorNCHW packed(TensorCoord const &extent) { -+ return TensorNCHW( -+ make_Coord( -+ extent.w(), -+ extent.w() * extent.h(), -+ extent.h() * extent.w() * extent.c() -+ ) -+ ); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return coord.w() + -+ LongIndex(stride_[0] * coord.h()) + -+ LongIndex(stride_[1] * coord.c()) + -+ LongIndex(stride_[2] * coord.n()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent.n() * stride_[2]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for 4-D NC/xHWx tensors. -+template -+class TensorNCxHWx { -+public: -+ -+ /// Interleaving quantity -+ static int const kInterleave = Interleave; -+ -+ /// Logical rank of tensor -+ static int const kRank = 4; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 3; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = Tensor4DCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member - [Interleave x w, Interleave x wh, hwc] -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNCxHWx(Stride const &stride = Stride(0)): stride_(stride) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNCxHWx( -+ typename Stride::Index stride_w, ///< number of elements between adjacent W coordinates -+ typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates -+ typename Stride::Index stride_n ///< number of elements between adjacent N coordinates -+ ): -+ stride_(make_Coord(stride_w, stride_h, stride_n)) { } -+ -+ /// Constructor -+ // Once convolutions implement 64b stride this ctor can be deleted -+ CUTLASS_HOST_DEVICE -+ TensorNCxHWx(Coord const &stride): -+ stride_(make_Coord( -+ static_cast(stride[0]), -+ static_cast(stride[1]), -+ static_cast(stride[2])) -+ ) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorNCxHWx packed(TensorCoord const &extent) { -+ return TensorNCxHWx( -+ make_Coord( -+ kInterleave * extent.w(), -+ kInterleave * extent.w() * extent.h(), -+ extent.h() * extent.w() * extent.c() -+ ) -+ ); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ Index c_minor = (coord.c() % kInterleave); -+ Index c_major = (coord.c() / kInterleave); -+ -+ return c_minor + -+ LongIndex(kInterleave * coord.w()) + -+ LongIndex(stride_[0] * coord.h()) + -+ LongIndex(stride_[1] * c_major) + -+ LongIndex(stride_[2] * coord.n()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent.n() * stride_[2]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for 4-D CxRSKx tensors. -+template -+class TensorCxRSKx { -+public: -+ -+ /// Interleaving quantity -+ static int const kInterleave = Interleave; -+ -+ /// Logical rank of tensor -+ static int const kRank = 4; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 3; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = Tensor4DCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member - [Interleave x n, Interleave x nw, Interleave x nwh] -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorCxRSKx(Stride const &stride = Stride(0)): stride_(stride) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorCxRSKx( -+ typename Stride::Index stride_w, ///< number of elements between adjacent W coordinates -+ typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates -+ typename Stride::Index stride_n ///< number of elements between adjacent N coordinates -+ ): -+ stride_(make_Coord(stride_w, stride_h, stride_n)) { } -+ -+ /// Constructor -+ // Once convolutions implement 64b stride this ctor can be deleted -+ CUTLASS_HOST_DEVICE -+ TensorCxRSKx(Coord const &stride): -+ stride_(make_Coord( -+ static_cast(stride[0]), -+ static_cast(stride[1]), -+ static_cast(stride[2])) -+ ) { } -+ -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorCxRSKx packed(TensorCoord const &extent) { -+ return TensorCxRSKx( -+ make_Coord( -+ kInterleave * extent.n(), -+ kInterleave * extent.n() * extent.w(), -+ kInterleave * extent.n() * extent.w() * extent.h() -+ ) -+ ); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ Index c_minor = (coord.c() % kInterleave); -+ Index c_major = (coord.c() / kInterleave); -+ -+ return c_minor + -+ LongIndex(kInterleave * coord.n()) + -+ LongIndex(stride_[0] * coord.w()) + -+ LongIndex(stride_[1] * coord.h()) + -+ LongIndex(stride_[2] * c_major); -+ } -+ -+ /// Returns the offset of a pitchlinear coordinate in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(PitchLinearCoord const &coord) const { -+ return (coord.contiguous() % kInterleave) + -+ LongIndex((coord.contiguous() / kInterleave) * stride_[2]) + -+ LongIndex(coord.strided() * kInterleave); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return (extent.c() / kInterleave * stride_[2]); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for 5-D NDHWC tensors. -+class TensorNDHWC { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 5; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 4; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate (n, d, h, w, c) -+ using TensorCoord = Tensor5DCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member - [c, wc, hwc, dhwc] -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNDHWC(Stride const &stride = Stride(0)): stride_(stride) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNDHWC( -+ typename Stride::Index c, -+ typename Stride::Index wc, -+ typename Stride::Index hwc, -+ typename Stride::Index dhwc): -+ stride_(make_Coord(c, wc, hwc, dhwc)) { } -+ -+ /// Constructor -+ // Once convolutions implement 64b stride this ctor can be deleted -+ CUTLASS_HOST_DEVICE -+ TensorNDHWC(Coord const &stride): -+ stride_(make_Coord( -+ static_cast(stride[0]), -+ static_cast(stride[1]), -+ static_cast(stride[2]), -+ static_cast(stride[3])) -+ ) { } -+ -+ /// Helper returns a layout to a tightly packed NHWC tensor. -+ CUTLASS_HOST_DEVICE -+ static TensorNDHWC packed(TensorCoord const &extent) { -+ return TensorNDHWC( -+ make_Coord( -+ extent.c(), -+ extent.w() * extent.c(), -+ extent.h() * extent.w() * extent.c(), -+ extent.d() * extent.h() * extent.w() * extent.c() -+ ) -+ ); -+ } -+ -+ /// Returns the offset of a coordinate (n, d, h, w, c) in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return coord.c() + -+ LongIndex(stride_[0] * coord.w()) + -+ LongIndex(stride_[1] * coord.h()) + -+ LongIndex(stride_[2] * coord.d()) + -+ LongIndex(stride_[3] * coord.n()); -+ } -+ -+ /// Returns the offset of a pitchlinear coordinate in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(PitchLinearCoord coord) const { -+ return coord.contiguous() + LongIndex(coord.strided() * stride_[3]); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ // it does not make sense if the extent is larger than stride -+ // and we could not rely on the capacity calculation in such cases -+ // we could move this checkers to debug code only -+ if ((extent.c() > stride_[0]) -+ || (extent.w() * stride_[0] > stride_[1]) -+ || (extent.h() * stride_[1] > stride_[2]) -+ || (extent.d() * stride_[2] > stride_[3])) { -+ assert(0); -+ } -+ return extent.n() * stride_[3]; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h b/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h -new file mode 100644 -index 0000000..b127bff ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h -@@ -0,0 +1,1044 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace layout { -+ -+// template < -+// int ElementSize, -+// gemm::Operand Operand -+// > -+// struct VoltaTensorOpMultiplicandCongruous; -+ -+// template < -+// int ElementSize, -+// gemm::Operand Operand -+// > -+// struct ColumnMajorVoltaTensorOpMultiplicandCongruous; -+// template < -+// int ElementSize, -+// gemm::Operand Operand -+// > -+// struct RowMajorVoltaTensorOpMultiplicandCongruous; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear memory. -+template -+struct VoltaTensorOpMultiplicandCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = 128; -+ -+ /// Fundamental tile shape in units of vectors -+ using TileShape = PitchLinearShape<8, 4>; -+ -+ /// Fundamental partition shape in units of vectors -+ using PartitionShape = PitchLinearShape<8, 2>; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = ElementSize; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ -+ using PartitionCount = PitchLinearShape< -+ TileShape::kContiguous / PartitionShape::kContiguous, -+ TileShape::kStrided / PartitionShape::kStrided -+ >; -+ -+ using AccessCount = PitchLinearShape< -+ PartitionShape::kContiguous, -+ PartitionShape::kStrided -+ >; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ VoltaTensorOpMultiplicandCongruous(Index ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ VoltaTensorOpMultiplicandCongruous(Stride stride): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static VoltaTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return VoltaTensorOpMultiplicandCongruous(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ // First, compute c and s of vector within source (in units of vector accesses) -+ int vec_contiguous_idx = coord.contiguous() / kElementsPerAccess; -+ int vec_strided_idx = coord.strided(); -+ -+ // Compute the fundamental tile being accessed -+ int tile_contiguous_idx = vec_contiguous_idx / TileShape::kContiguous; -+ int tile_strided_idx = vec_strided_idx / TileShape::kStrided; -+ -+ int tile_contiguous_residual = vec_contiguous_idx % TileShape::kContiguous; -+ int tile_strided_residual = vec_strided_idx % TileShape::kStrided; -+ -+ // Then swizzle in a tile -+ // Swizzle pattern is (tid[2:0] << 2)|(tid[4:3] ^ tid[2:1]) -+ int permuted_strided_within_tile = (tile_contiguous_residual >> 1); -+ int permuted_contiguous_within_tile = (tile_strided_residual ^ permuted_strided_within_tile) | -+ ((tile_contiguous_residual & 1) << 2); -+ // Compute final element location -+ int element_contiguous = (tile_contiguous_idx * TileShape::kContiguous + -+ permuted_contiguous_within_tile) * kElementsPerAccess + (coord.contiguous() % kElementsPerAccess); -+ -+ int element_strided = tile_strided_idx * TileShape::kStrided + permuted_strided_within_tile; -+ -+ return element_contiguous + element_strided * stride_[0]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous -+template -+struct ColumnMajorVoltaTensorOpMultiplicandCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = VoltaTensorOpMultiplicandCongruous; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorVoltaTensorOpMultiplicandCongruous(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorVoltaTensorOpMultiplicandCongruous(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorVoltaTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return ColumnMajorVoltaTensorOpMultiplicandCongruous(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+/// Template mapping a row-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous -+template -+struct RowMajorVoltaTensorOpMultiplicandCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = VoltaTensorOpMultiplicandCongruous; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorVoltaTensorOpMultiplicandCongruous(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorVoltaTensorOpMultiplicandCongruous(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorVoltaTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return RowMajorVoltaTensorOpMultiplicandCongruous(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear memory. -+// template -+template -+struct VoltaTensorOpMultiplicandBCongruous { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = 128; -+ -+ /// Fundamental tile shape in units of vectors -+ using TileShape = PitchLinearShape<8, 4>; -+ -+ /// Fundamental partition shape in units of vectors -+ using PartitionShape = PitchLinearShape<4, 4>; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = ElementSize; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ -+ using PartitionCount = PitchLinearShape< -+ TileShape::kContiguous / PartitionShape::kContiguous, -+ TileShape::kStrided / PartitionShape::kStrided -+ >; -+ -+ using AccessCount = PitchLinearShape< -+ PartitionShape::kContiguous, -+ PartitionShape::kStrided -+ >; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ VoltaTensorOpMultiplicandBCongruous(Index ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ VoltaTensorOpMultiplicandBCongruous(Stride stride): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static VoltaTensorOpMultiplicandBCongruous packed(TensorCoord const &extent) { -+ return VoltaTensorOpMultiplicandBCongruous(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ // First, compute c and s of vector within source (in units of vector accesses) -+ int vec_contiguous_idx = coord.contiguous() / kElementsPerAccess; -+ int vec_strided_idx = coord.strided(); -+ -+ // Compute the fundamental tile being accessed -+ int tile_contiguous_idx = vec_contiguous_idx / TileShape::kContiguous; -+ int tile_strided_idx = vec_strided_idx / TileShape::kStrided; -+ -+ int tile_contiguous_residual = vec_contiguous_idx % TileShape::kContiguous; -+ int tile_strided_residual = vec_strided_idx % TileShape::kStrided; -+ -+ // Then swizzle in a tile -+ // Swizzle pattern is (tid[1:0] << 3)|(tid & 0x4)|(tid[1:0]) -+ int permuted_strided_within_tile = (tile_contiguous_residual & 0x3); -+ int permuted_contiguous_within_tile = (tile_strided_residual ^ permuted_strided_within_tile) | -+ (tile_contiguous_residual & 0x4); -+ -+ // Compute final element location -+ int element_contiguous = (tile_contiguous_idx * TileShape::kContiguous + -+ permuted_contiguous_within_tile) * kElementsPerAccess + (coord.contiguous() % kElementsPerAccess); -+ -+ int element_strided = tile_strided_idx * TileShape::kStrided + permuted_strided_within_tile; -+ -+ return element_contiguous + element_strided * stride_[0]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous -+template -+struct ColumnMajorVoltaTensorOpMultiplicandBCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = VoltaTensorOpMultiplicandBCongruous; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorVoltaTensorOpMultiplicandBCongruous(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorVoltaTensorOpMultiplicandBCongruous(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorVoltaTensorOpMultiplicandBCongruous packed(TensorCoord const &extent) { -+ return ColumnMajorVoltaTensorOpMultiplicandBCongruous(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+/// Template mapping a row-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous -+template -+struct RowMajorVoltaTensorOpMultiplicandBCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = VoltaTensorOpMultiplicandBCongruous; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorVoltaTensorOpMultiplicandBCongruous(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorVoltaTensorOpMultiplicandBCongruous(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorVoltaTensorOpMultiplicandBCongruous packed(TensorCoord const &extent) { -+ return RowMajorVoltaTensorOpMultiplicandBCongruous(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and KBlock size (in elements). -+template -+struct VoltaTensorOpMultiplicandCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ /// This layout is optimized for 64b accesses -+ static int const kAccessSize = 64; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = ElementSize; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ static int const kKBlock = KBlock; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member. For GEMM, it equals to KBlock x stage. -+ Stride stride_; -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ VoltaTensorOpMultiplicandCrosswise(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ VoltaTensorOpMultiplicandCrosswise(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static VoltaTensorOpMultiplicandCrosswise packed(TensorCoord const &extent) { -+ return VoltaTensorOpMultiplicandCrosswise(extent[1]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ // -+ // First, compute c and s of vector within source (in units of vector -+ // accesses) -+ // -+ int vec_contiguous_idx = coord.contiguous() / kElementsPerAccess; -+ int vec_strided_idx = coord.strided(); -+ -+ // -+ // Then swizzle -+ // The mapping is like this: -+ // id[1:0]|(id[3]^id[4])|id[2] -+ -+ int vec_strided_within_tile = vec_contiguous_idx & 0x7; -+ int permuted_vec_contiguous = -+ (vec_strided_idx & (~0xF)) + (vec_strided_idx & 0x3) * 4 + -+ (((vec_strided_idx >> 2) ^ ((vec_strided_idx & 0x10) >> 3)) & 0x3); -+ -+ permuted_vec_contiguous ^= ((vec_strided_within_tile >> 1) & 0x3); -+ -+ int permuted_vec_strided = vec_contiguous_idx; -+ -+ // -+ // Compute final element location -+ // -+ -+ int element_contiguous = permuted_vec_contiguous * kElementsPerAccess + -+ (coord.contiguous() % kElementsPerAccess); -+ -+ return element_contiguous + permuted_vec_strided * (stride_[0] * kElementsPerAccess); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[0] * stride_[0]; -+ } -+}; -+ -+/// Template mapping a column-major view of pitch-linear memory to -+/// VoltaTensorOpMultiplicandCrosswise -+template -+struct ColumnMajorVoltaTensorOpMultiplicandCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = VoltaTensorOpMultiplicandCrosswise; -+ -+ /// This layout is optimized for 64b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorVoltaTensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorVoltaTensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorVoltaTensorOpMultiplicandCrosswise packed( -+ TensorCoord const &extent) { -+ return ColumnMajorVoltaTensorOpMultiplicandCrosswise(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return layout_.stride(); } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return layout_.stride(); } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+/// Template mapping a row-major view of pitch-linear memory to -+/// TensorOpMultiplicandCrosswise -+template -+struct RowMajorVoltaTensorOpMultiplicandCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = VoltaTensorOpMultiplicandCrosswise; -+ -+ /// This layout is optimized for 64b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorVoltaTensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorVoltaTensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorVoltaTensorOpMultiplicandCrosswise packed( -+ TensorCoord const &extent) { -+ return RowMajorVoltaTensorOpMultiplicandCrosswise(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return layout_.stride(); } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return layout_.stride(); } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+} // namespace layout -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h b/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h -new file mode 100644 -index 0000000..14148b7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h -@@ -0,0 +1,1161 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace layout { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+/// This one is the base class of all Ampere/Turing fp16/bf16/int8/int4/int1 -+/// tensor core kernels. tf32 TN uses this too. -+template -+struct TensorOpMultiplicand { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Static constants -+ // -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = 128; -+ -+ static int const kElementSize = ElementSize; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ static int const kCrosswise = Crosswise; -+ -+ /// Contiguous dimension of the tile shape matches one shared memory cache -+ /// line - 128B. For 128bit access size, it equals to 8 accesses. -+ static int const kTileShapeContiguous = 128 / (kAccessSize / 8); -+ -+ /// Number of kblocks to store PartitionShape::kContiguous Elements -+ static int const kFactor = -+ kTileShapeContiguous * kElementsPerAccess / kCrosswise; -+ -+ static_assert( -+ (kFactor > 0), -+ "kCrosswise should be no large than one shared memory cache line."); -+ -+ /// The strided dimension needs to be at least (WarpSize(32) / -+ /// kTileShapeContiguous) for a warp to access. To ensure conflict free -+ /// access, it also needs to be at least (kTileShapeContiguous / kFactor). -+ /// See comments below -+ static int const kTileShapeStride = -+ ((kTileShapeContiguous / kFactor) > (32 / kTileShapeContiguous)) -+ ? (kTileShapeContiguous / kFactor) -+ : (32 / kTileShapeContiguous); -+ -+ /// Fundamental tile shape in units of vectors to guarantee bank conflict free -+ /// shared memory load/store. -+ /// For kFactor = 1, TileShape = <8, 8> -+ /// For kFactor > 1, TileShape = <8, 4> -+ using TileShape = PitchLinearShape; -+ -+ /// Fundamental partition shape in units of vectors -+ using PartitionShape = PitchLinearShape<4, 4>; -+ -+ using PartitionCount = -+ PitchLinearShape; -+ -+ using AccessCount = -+ PitchLinearShape; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member. For GEMM, it equals to kCrosswise x stage. -+ Stride stride_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicand(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicand(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicand packed(TensorCoord const &extent) { -+ return TensorOpMultiplicand(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ // -+ // First, compute c and s of vector within source (in units of vector -+ // accesses) -+ // -+ -+ int vec_contiguous_idx = coord.contiguous() / kElementsPerAccess; -+ int vec_strided_idx = coord.strided() / kFactor; -+ -+ // Compute the fundamental tile being accessed -+ int tile_contiguous_idx = -+ vec_contiguous_idx / (TileShape::kContiguous / kFactor); -+ -+ int tile_contiguous_residual = -+ vec_contiguous_idx % (TileShape::kContiguous / kFactor) + -+ ((coord.strided() % kFactor) * (TileShape::kContiguous / kFactor)); -+ int tile_strided_residual = vec_strided_idx % TileShape::kStrided; -+ -+ // Compute the 'partition' within the fundamental tile -+ int partition_contiguous_idx = -+ tile_contiguous_residual / PartitionShape::kContiguous; -+ int partition_strided_idx = -+ tile_strided_residual / PartitionShape::kStrided; -+ -+ int partition_contiguous_residual = -+ tile_contiguous_residual % PartitionShape::kContiguous; -+ int partition_strided_residual = -+ tile_strided_residual % PartitionShape::kStrided; -+ -+ // -+ // Then swizzle -+ // -+ -+ int permuted_vec_contiguous_within_partition = -+ partition_contiguous_residual ^ (partition_strided_residual % 4); -+ -+ int permuted_partition_contiguous_within_tile = -+ partition_contiguous_idx ^ (partition_strided_idx % 2); -+ -+ // -+ // Compute final element location -+ // -+ -+ int element_contiguous = (tile_contiguous_idx * TileShape::kContiguous + -+ permuted_partition_contiguous_within_tile * -+ PartitionShape::kContiguous + -+ permuted_vec_contiguous_within_partition) * -+ kElementsPerAccess + -+ (coord.contiguous() % kElementsPerAccess); -+ -+ int element_strided = vec_strided_idx; -+ -+ return element_contiguous + element_strided * stride_[0] * kFactor; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+template -+struct TensorOpMultiplicandCongruous { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicand; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous(Index ldm = 0) : layout_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous(Stride stride) : layout_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandCongruous(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(coord); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return coord; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return layout_.stride(); } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return layout_.stride(); } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(extent); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+/// This one is just for TF32 NT kernel. -+template -+struct TensorOpMultiplicandCongruous<32, Crosswise> { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = 128; -+ -+ /// Fundamental tile shape in units of vectors -+ using TileShape = PitchLinearShape<8, 4>; -+ -+ /// Partitionshape is the same as TileShape for this layout -+ using PartitionShape = PitchLinearShape<8, 4>; -+ -+ using PartitionCount = -+ PitchLinearShape; -+ -+ using AccessCount = -+ PitchLinearShape; -+ -+ // -+ // Static constants -+ // -+ static int const kElementSize = 32; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member. -+ Stride stride_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandCongruous(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ int tc = coord.contiguous() / 32; -+ int ts = coord.strided() / 4; -+ -+ int c = (coord.contiguous() % 32) / kElementsPerAccess; -+ int s = coord.strided() % 4; -+ -+ LongIndex offset = (c ^ (2 * s)) * kElementsPerAccess + s * stride_[0] + -+ tc * 32 + ts * stride_[0] * 4 + coord.contiguous() % 4; -+ -+ return offset; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+template -+struct ColumnMajorTensorOpMultiplicandCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCongruous; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCongruous(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCongruous(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return ColumnMajorTensorOpMultiplicandCongruous(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a row-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+template -+struct RowMajorTensorOpMultiplicandCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCongruous; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCongruous(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCongruous(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return RowMajorTensorOpMultiplicandCongruous(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+template -+struct TensorOpMultiplicandCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicand; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ static int const kCrosswise = Base::kCrosswise; -+ static int const kFactor = Base::kFactor; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandCrosswise packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandCrosswise(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(coord); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return coord; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return layout_.stride(); } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return layout_.stride(); } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(extent); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to -+/// TensorOpMultiplicandCrosswise -+template -+struct ColumnMajorTensorOpMultiplicandCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCrosswise; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorTensorOpMultiplicandCrosswise packed( -+ TensorCoord const &extent) { -+ return ColumnMajorTensorOpMultiplicandCrosswise(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return layout_.stride(); } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return layout_.stride(); } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a row-major view of pitch-linear memory to -+/// TensorOpMultiplicandCrosswise -+template -+struct RowMajorTensorOpMultiplicandCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCrosswise; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorTensorOpMultiplicandCrosswise packed( -+ TensorCoord const &extent) { -+ return RowMajorTensorOpMultiplicandCrosswise(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return layout_.stride(); } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return layout_.stride(); } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear memory. -+template -+struct TensorOpMultiplicandColumnMajorInterleaved { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = 128; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = ElementSize; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ -+ //static int const kThreadBlockStrided = ThreadBlockStrided; -+ static int const kInterleavedK = InterleavedK; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandColumnMajorInterleaved(Index ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandColumnMajorInterleaved(Stride stride): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandColumnMajorInterleaved packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandColumnMajorInterleaved(extent[0] * kInterleavedK); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ int const rows_per_smem_cache_line = 128 / kInterleavedK; -+ -+ int row_id = coord.strided() / rows_per_smem_cache_line; -+ int col_id = (coord.strided() % rows_per_smem_cache_line) * kInterleavedK + coord.contiguous(); -+ -+ int access_block_id = col_id >> 4; -+ int swizzle_access_block_id = access_block_id ^ (row_id & 1); -+ -+ int swizzle_col_id = swizzle_access_block_id << 4; -+ -+ return row_id * 128 + swizzle_col_id; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return (extent[1] / kInterleavedK) * stride_[0]; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear memory. -+template -+struct TensorOpMultiplicandRowMajorInterleaved { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = 128; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = ElementSize; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ -+ //static int const kThreadBlockStrided = ThreadBlockStrided; -+ static int const kInterleavedK = InterleavedK; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandRowMajorInterleaved(Index ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandRowMajorInterleaved(Stride stride): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandRowMajorInterleaved packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandRowMajorInterleaved(extent[1] * kInterleavedK); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ int const rows_per_smem_cache_line = 128 / kInterleavedK; -+ -+ int row_id = coord.strided() / rows_per_smem_cache_line; -+ int col_id = (coord.strided() % rows_per_smem_cache_line) * kInterleavedK + coord.contiguous(); -+ -+ int access_block_id = col_id >> 4; -+ int swizzle_access_block_id = access_block_id ^ (row_id & 1); -+ -+ int swizzle_col_id = swizzle_access_block_id << 4; -+ -+ return row_id * 128 + swizzle_col_id; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return (extent[0] / kInterleavedK) * stride_[0]; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h b/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h -new file mode 100644 -index 0000000..f75c2a8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h -@@ -0,0 +1,1139 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief layouts needed by Ampere fp64 tensor core kernels. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace layout { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+struct TensorOpMultiplicandCongruous64b { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = 64; -+ static int const kElementsPerAccess = 1; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member. -+ Stride stride_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous64b(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous64b(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandCongruous64b packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandCongruous64b(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ int tc = coord.contiguous() / 16; -+ int ts = coord.strided() / 4; -+ -+ int c = coord.contiguous() % 16; -+ int s = coord.strided() % 4; -+ -+ -+ int bank = ((((c & 1) * 4 + (c & 6) / 2)) ^ (s & 1)) * 2 + (c / 8); -+ int row = (c & 6) / 2; -+ -+ bank ^= ((s & 2) * 2); -+ -+ LongIndex offset = tc * 16 + bank + (ts * 4 + row) * stride_[0]; -+ -+ return offset; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ return TensorCoord(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+struct ColumnMajorTensorOpMultiplicandCongruous64b { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCongruous64b; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCongruous64b(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCongruous64b(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorTensorOpMultiplicandCongruous64b packed(TensorCoord const &extent) { -+ return ColumnMajorTensorOpMultiplicandCongruous64b(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a row-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+struct RowMajorTensorOpMultiplicandCongruous64b { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCongruous64b; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCongruous64b(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCongruous64b(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorTensorOpMultiplicandCongruous64b packed(TensorCoord const &extent) { -+ return RowMajorTensorOpMultiplicandCongruous64b(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+struct TensorOpMultiplicand64bCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = 64; -+ static int const kElementsPerAccess = 1; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member. -+ Stride stride_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicand64bCrosswise(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicand64bCrosswise(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicand64bCrosswise packed(TensorCoord const &extent) { -+ return TensorOpMultiplicand64bCrosswise(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ int tc = coord.contiguous() / 16; -+ int ts = coord.strided() / 16; -+ -+ int c = coord.contiguous() % 16; -+ int s = coord.strided() % 16; -+ -+ int k_group = c / 4; -+ int access_s = s / 2; -+ -+ int row = access_s % 4; -+ int bank = ((k_group & 2) << 2) ^ ((s % 2) << 3) + (c % 4) * 2 + (access_s / 4) ^ (k_group & 1); -+ -+ int smem_row = (k_group * 4 + row) + tc * 16; -+ int smem_col = ts * 16 + bank; -+ -+ LongIndex offset = smem_row * stride_[0] + smem_col; -+ -+ return offset; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+struct ColumnMajorTensorOpMultiplicand64bCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicand64bCrosswise; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicand64bCrosswise(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicand64bCrosswise(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorTensorOpMultiplicand64bCrosswise packed(TensorCoord const &extent) { -+ return ColumnMajorTensorOpMultiplicand64bCrosswise(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+struct RowMajorTensorOpMultiplicand64bCrosswise { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicand64bCrosswise; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicand64bCrosswise(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicand64bCrosswise(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorTensorOpMultiplicand64bCrosswise packed(TensorCoord const &extent) { -+ return RowMajorTensorOpMultiplicand64bCrosswise(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+struct TensorOpMultiplicandCongruous128b { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = 128; -+ static int const kElementsPerAccess = 1; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member. -+ Stride stride_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous128b(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous128b(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandCongruous128b packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandCongruous128b(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ Index tc = coord.contiguous() / 8; -+ Index ts = coord.strided() / 4; -+ -+ Index c = coord.contiguous() % 8; -+ Index s = coord.strided() % 4; -+ -+ Index k_index = (c / 2); -+ -+ Index bank = (((c & 1) * 4) | (s ^ k_index)); -+ -+ LongIndex offset = tc * 8 + bank + (ts * 4 + k_index) * stride_[0]; -+ -+ return offset; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ return TensorCoord(); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+struct ColumnMajorTensorOpMultiplicandCongruous128b { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCongruous128b; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCongruous128b(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCongruous128b(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorTensorOpMultiplicandCongruous128b packed(TensorCoord const &extent) { -+ return ColumnMajorTensorOpMultiplicandCongruous128b(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a row-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+struct RowMajorTensorOpMultiplicandCongruous128b { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCongruous128b; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCongruous128b(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCongruous128b(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorTensorOpMultiplicandCongruous128b packed(TensorCoord const &extent) { -+ return RowMajorTensorOpMultiplicandCongruous128b(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+struct TensorOpMultiplicandCrosswise128x4 { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = 128; -+ static int const kElementsPerAccess = 1; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member. -+ Stride stride_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCrosswise128x4(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCrosswise128x4(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandCrosswise128x4 packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandCrosswise128x4(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ Index tc = coord.contiguous() / 8; -+ Index ts = coord.strided() / 8; -+ -+ Index c = coord.contiguous() % 8; -+ Index s = coord.strided() % 8; -+ -+ Index liq = c % 4; -+ -+ Index bank = liq + ((s & 1) * 4) ^ (c & 4); -+ -+ Index k_index = (c & 4) + (s / 4) * 2 + ((s & 2) / 2); -+ -+ LongIndex offset = (tc * 8 + k_index) * stride_[0] + ts * 8 + bank; -+ -+ return offset; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+struct ColumnMajorTensorOpMultiplicandCrosswise128x4 { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCrosswise128x4; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCrosswise128x4(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCrosswise128x4(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorTensorOpMultiplicandCrosswise128x4 packed(TensorCoord const &extent) { -+ return ColumnMajorTensorOpMultiplicandCrosswise128x4(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a row-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+struct RowMajorTensorOpMultiplicandCrosswise128x4 { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCrosswise128x4; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCrosswise128x4(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCrosswise128x4(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorTensorOpMultiplicandCrosswise128x4 packed(TensorCoord const &extent) { -+ return RowMajorTensorOpMultiplicandCrosswise128x4(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/layout/vector.h b/3rdparty/cutlass/include/cutlass/layout/vector.h -new file mode 100644 -index 0000000..e9ad6da ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/vector.h -@@ -0,0 +1,104 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used for rank=1 vectors. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+ -+namespace cutlass { -+namespace layout { -+ -+/// Tensor layout for densely packed vectors. -+class PackedVectorLayout { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 1; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = Coord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ -+ // -+ // No actual stride vector stored -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ PackedVectorLayout() { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static PackedVectorLayout packed(TensorCoord const &size) { -+ return PackedVectorLayout(); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return coord[0]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return make_Coord(1); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &size) const { -+ return size[0]; -+ } -+}; -+ -+} // namespace layout -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/matrix.h b/3rdparty/cutlass/include/cutlass/matrix.h -new file mode 100644 -index 0000000..ba9ffbb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/matrix.h -@@ -0,0 +1,14129 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* -+ \file -+ \brief Matrix classes with value semantics. -+*/ -+ -+#pragma once -+ -+#if !defined(__CUDACC_RTC__) -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/matrix.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Primary template with partial specializations to follow -+template struct Matrix; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 1-by-2 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 1; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 2; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 2; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 1-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 1-by-2 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[1] = data[1]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 1 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 1 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x2(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x2(v, i, 0); -+ } -+ -+ /// Forms a 1-by-2 matrix by horizontally concatenating an Element with an Element -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Element lhs, Element rhs) { -+ return Matrix( -+ lhs, rhs); -+ } -+ -+ /// Concatenates this matrix with a an Element to form a 1-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Element rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-2 matrix to form a 1-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-2 matrix to form a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-2 matrix to form a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-2 matrix to form a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Elementwise add operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 1-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Element product(Matrix const &rhs, Element accum = Element()) const { -+ -+ // k=0 -+ accum += data[0] * rhs.data[0]; -+ -+ // k=1 -+ accum += data[1] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Element operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 1-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Dot product of vectors with extent 2 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ return accum; -+ } -+ -+ /// Dot product of vectors with extent 2 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ return accum; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 1-by-2 matrix -+template -+using Matrix1x2 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix1x2 make_Matrix1x2( -+ Element _0_0, Element _0_1 -+) { -+ return Matrix1x2( -+ _0_0, _0_1 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 1-by-3 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 1; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 3; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 3; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 1-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 1-by-3 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[1] = data[1]; -+ mt.data[2] = data[2]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 1 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 1 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x3(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x3(v, i, 0); -+ } -+ -+ /// Forms a 1-by-3 matrix by horizontally concatenating an Element with a 1-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Element lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs, rhs.at(0, 0), rhs.at(0, 1)); -+ } -+ -+ /// Forms a 1-by-3 matrix by horizontally concatenating a 1-by-2 matrix with an Element -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Element rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs); -+ } -+ -+ /// Concatenates this matrix with a an Element to form a 1-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Element rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-3 matrix to form a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-3 matrix to form a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-3 matrix to form a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Elementwise add operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 1-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Element product(Matrix const &rhs, Element accum = Element()) const { -+ -+ // k=0 -+ accum += data[0] * rhs.data[0]; -+ -+ // k=1 -+ accum += data[1] * rhs.data[1]; -+ -+ // k=2 -+ accum += data[2] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Element operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 1-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Dot product of vectors with extent 3 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ return accum; -+ } -+ -+ /// Dot product of vectors with extent 3 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ return accum; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ -+ return accum; -+ } -+ -+ /// Cross product -+ CUTLASS_HOST_DEVICE -+ Matrix cross(Matrix const &rhs) const { -+ return Matrix( -+ data[1] * rhs.data[2] - data[2] * rhs.data[1], -+ data[0] * rhs.data[2] - data[2] * rhs.data[1], -+ data[0] * rhs.data[1] - data[1] * rhs.data[0] -+ ); -+ } -+ -+}; -+ -+/// Template alias for 1-by-3 matrix -+template -+using Matrix1x3 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix1x3 make_Matrix1x3( -+ Element _0_0, Element _0_1, Element _0_2 -+) { -+ return Matrix1x3( -+ _0_0, _0_1, _0_2 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 1-by-4 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 1; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 4; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 4; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 1-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 1-by-4 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[1] = data[1]; -+ mt.data[2] = data[2]; -+ mt.data[3] = data[3]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 1 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 1 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x4(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x4(v, i, 0); -+ } -+ -+ /// Forms a 1-by-4 matrix by horizontally concatenating an Element with a 1-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Element lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs, rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2)); -+ } -+ -+ /// Forms a 1-by-4 matrix by horizontally concatenating a 1-by-2 matrix with a 1-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1)); -+ } -+ -+ /// Forms a 1-by-4 matrix by horizontally concatenating a 1-by-3 matrix with an Element -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Element rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-4 matrix to form a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-4 matrix to form a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-4 matrix to form a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Elementwise add operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ data[3] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ data[3] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 1-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Element product(Matrix const &rhs, Element accum = Element()) const { -+ -+ // k=0 -+ accum += data[0] * rhs.data[0]; -+ -+ // k=1 -+ accum += data[1] * rhs.data[1]; -+ -+ // k=2 -+ accum += data[2] * rhs.data[2]; -+ -+ // k=3 -+ accum += data[3] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Element operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[6]; -+ accum.data[1] += data[3] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[9]; -+ accum.data[1] += data[3] * rhs.data[10]; -+ accum.data[2] += data[3] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[12]; -+ accum.data[1] += data[3] * rhs.data[13]; -+ accum.data[2] += data[3] * rhs.data[14]; -+ accum.data[3] += data[3] * rhs.data[15]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Dot product of vectors with extent 4 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ accum += data[3] * rhs.data[3]; -+ return accum; -+ } -+ -+ /// Dot product of vectors with extent 4 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ accum += data[3] * rhs.data[3]; -+ return accum; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 1-by-4 matrix -+template -+using Matrix1x4 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix1x4 make_Matrix1x4( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3 -+) { -+ return Matrix1x4( -+ _0_0, _0_1, _0_2, _0_3 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 2-by-1 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 2; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 1; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 2; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 2-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 2-by-1 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, -+ Element _1_0 -+ ) { -+ -+ data[0] = _0_0; -+ data[1] = _1_0; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[1] = data[1]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 1 + j + 0]; -+ m.data[1] = data[i * 1 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 1 + j + 0] = m.data[0]; -+ data[i * 1 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_2x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_2x1(v, 0, j); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-2 matrix to form a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-3 matrix to form a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 2-by-1 matrix by vertically concatenating an Element with an Element -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Element upper, Element lower) { -+ return Matrix( -+ upper -+ , lower); -+ } -+ -+ /// Concatenates this matrix with a an Element to form a 3-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Element rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-1 matrix to form a 4-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Elementwise add operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ -+ data[1] += rhs.data[1]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ -+ data[1] -= rhs.data[1]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ -+ result.data[1] = data[1] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ -+ data[1] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ -+ result.data[1] = data[1] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ -+ data[1] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ -+ data[1] /= rhs.data[1]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 2-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[1] * rhs.data[0]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 2-by-2-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[1] * rhs.data[0]; -+ accum.data[3] += data[1] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-2-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-3-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[1] * rhs.data[0]; -+ accum.data[4] += data[1] * rhs.data[1]; -+ accum.data[5] += data[1] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-3-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-4-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[1] * rhs.data[0]; -+ accum.data[5] += data[1] * rhs.data[1]; -+ accum.data[6] += data[1] * rhs.data[2]; -+ accum.data[7] += data[1] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-4-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Dot product of vectors with extent 2 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ return accum; -+ } -+ -+ /// Dot product of vectors with extent 2 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ return accum; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 2-by-1 matrix -+template -+using Matrix2x1 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix2x1 make_Matrix2x1( -+ Element _0_0, -+ Element _1_0 -+) { -+ return Matrix2x1( -+ _0_0, -+ _1_0 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 2-by-2 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 2; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 2; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 4; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 2-by-2 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, -+ Element _1_0, Element _1_1 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; -+ data[2] = _1_0; data[3] = _1_1; -+ } -+ -+ /// Constucts a 2-by-2 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_1.data[0]; -+ data[3] = row_1.data[1]; -+ } -+ -+ /// Static method to construct a 2-by-2 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_0.data[1]; -+ result.data[3] = column_1.data[1]; -+ return result; -+ } -+ -+ /// Constructs an identity matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix identity() { -+ Matrix m; -+ -+ m.data[0] = Element(1); -+ m.data[3] = Element(1); -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[3] = diag.data[1]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[3] = diag.data[1]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[3]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[2] = data[1]; -+ mt.data[1] = data[2]; -+ mt.data[3] = data[3]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x2(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x2(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 2] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_2x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_2x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ m.data[2] = data[i * 2 + j + 2]; -+ m.data[3] = data[i * 2 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ data[i * 2 + j + 2] = m.data[2]; -+ data[i * 2 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Forms a 2-by-2 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0) -+ , lhs.at(1, 0), rhs.at(1, 0)); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-2 matrix to form a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 2-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 1-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1) -+ , lower.at(0, 0), lower.at(0, 1)); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-2 matrix to form a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-2 matrix to form a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Forms a 2-by-2 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Element B, -+ Element C, Element D) { -+ return Matrix( -+ A, B -+ , C, D -+ ); -+ } -+ -+ /// Elementwise add operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ -+ data[2] *= s; -+ data[3] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ -+ data[2] /= s; -+ data[3] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 2-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[2] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[3] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[2] * rhs.data[0]; -+ accum.data[3] += data[2] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[3] * rhs.data[2]; -+ accum.data[3] += data[3] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 2-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[2] * rhs.data[0]; -+ accum.data[4] += data[2] * rhs.data[1]; -+ accum.data[5] += data[2] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[3] * rhs.data[3]; -+ accum.data[4] += data[3] * rhs.data[4]; -+ accum.data[5] += data[3] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[2] * rhs.data[0]; -+ accum.data[5] += data[2] * rhs.data[1]; -+ accum.data[6] += data[2] * rhs.data[2]; -+ accum.data[7] += data[2] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[3] * rhs.data[4]; -+ accum.data[5] += data[3] * rhs.data[5]; -+ accum.data[6] += data[3] * rhs.data[6]; -+ accum.data[7] += data[3] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns 2-by-2 rotation matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation(Element theta) { -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ return Matrix( -+ c, -s, -+ s, c -+ ); -+ } -+ -+ /// Computes the determinant of a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Element determinant(Element accum = Element()) const { -+ accum += data[0] * data[3] - data[1] * data[2]; -+ -+ return accum; -+ } -+ -+ /// Computes the inverse of a 2-by-2 matrix given -+ /// the matrix's determinant -+ CUTLASS_HOST_DEVICE -+ Matrix inverse(Element det) const { -+ return Matrix( -+ data[3], -data[1], -+ -data[2], data[0] -+ ) * (Element(1) / det); -+ } -+ -+ /// Computes the inverse of a 2-by-2 matrix. -+ CUTLASS_HOST_DEVICE -+ Matrix inverse() const { -+ return inverse(determinant()); -+ } -+ -+}; -+ -+/// Template alias for 2-by-2 matrix -+template -+using Matrix2x2 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix2x2 make_Matrix2x2( -+ Element _0_0, Element _0_1, -+ Element _1_0, Element _1_1 -+) { -+ return Matrix2x2( -+ _0_0, _0_1, -+ _1_0, _1_1 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 2-by-3 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 2; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 3; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 6; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 2-by-3 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, -+ Element _1_0, Element _1_1, Element _1_2 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; -+ data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; -+ } -+ -+ /// Constucts a 2-by-3 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_0.data[2]; -+ data[3] = row_1.data[0]; -+ data[4] = row_1.data[1]; -+ data[5] = row_1.data[2]; -+ } -+ -+ /// Static method to construct a 2-by-3 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1, -+ Matrix const &column_2 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_2.data[0]; -+ result.data[3] = column_0.data[1]; -+ result.data[4] = column_1.data[1]; -+ result.data[5] = column_2.data[1]; -+ return result; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[3] = diag.data[1]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[3] = diag.data[1]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[3]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[2] = data[1]; -+ mt.data[4] = data[2]; -+ mt.data[1] = data[3]; -+ mt.data[3] = data[4]; -+ mt.data[5] = data[5]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x3(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x3(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 3] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_2x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_2x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 3]; -+ m.data[3] = data[i * 3 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 3] = m.data[2]; -+ data[i * 3 + j + 4] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ m.data[3] = data[i * 3 + j + 3]; -+ m.data[4] = data[i * 3 + j + 4]; -+ m.data[5] = data[i * 3 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ data[i * 3 + j + 3] = m.data[3]; -+ data[i * 3 + j + 4] = m.data[4]; -+ data[i * 3 + j + 5] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Forms a 2-by-3 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1) -+ , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1)); -+ } -+ -+ /// Forms a 2-by-3 matrix by horizontally concatenating a 2-by-2 matrix with a 2-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0) -+ , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0)); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 2-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 1-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-3 matrix to form a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-3 matrix to form a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Forms a 2-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Matrix const & B, -+ Element C, Matrix const & D) { -+ return Matrix( -+ A, B.at(0, 0), B.at(0, 1) -+ , C, D.at(0, 0), D.at(0, 1) -+ ); -+ } -+ -+ /// Forms a 2-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Element B, -+ Matrix const & C, Element D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B -+ , C.at(0, 0), C.at(0, 1), D -+ ); -+ } -+ -+ /// Elementwise add operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ -+ result.data[3] = data[3] + rhs.data[3]; -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ -+ data[3] += rhs.data[3]; -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ -+ result.data[3] = data[3] - rhs.data[3]; -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ -+ data[3] -= rhs.data[3]; -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ -+ result.data[3] = data[3] * rhs.data[3]; -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ -+ result.data[3] = data[3] * s; -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ -+ data[3] *= s; -+ data[4] *= s; -+ data[5] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ -+ result.data[3] = data[3] / rhs.data[3]; -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ -+ result.data[3] = data[3] / s; -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ -+ data[3] /= s; -+ data[4] /= s; -+ data[5] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ -+ data[3] /= rhs.data[3]; -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 2-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[3] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[4] * rhs.data[1]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[2]; -+ accum.data[1] += data[5] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[3] * rhs.data[0]; -+ accum.data[3] += data[3] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[4] * rhs.data[2]; -+ accum.data[3] += data[4] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ accum.data[2] += data[5] * rhs.data[4]; -+ accum.data[3] += data[5] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[3] * rhs.data[0]; -+ accum.data[4] += data[3] * rhs.data[1]; -+ accum.data[5] += data[3] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[4] * rhs.data[3]; -+ accum.data[4] += data[4] * rhs.data[4]; -+ accum.data[5] += data[4] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ accum.data[3] += data[5] * rhs.data[6]; -+ accum.data[4] += data[5] * rhs.data[7]; -+ accum.data[5] += data[5] * rhs.data[8]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 2-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[3] * rhs.data[0]; -+ accum.data[5] += data[3] * rhs.data[1]; -+ accum.data[6] += data[3] * rhs.data[2]; -+ accum.data[7] += data[3] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[4] * rhs.data[4]; -+ accum.data[5] += data[4] * rhs.data[5]; -+ accum.data[6] += data[4] * rhs.data[6]; -+ accum.data[7] += data[4] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ accum.data[4] += data[5] * rhs.data[8]; -+ accum.data[5] += data[5] * rhs.data[9]; -+ accum.data[6] += data[5] * rhs.data[10]; -+ accum.data[7] += data[5] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[4]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 2-by-3 matrix -+template -+using Matrix2x3 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix2x3 make_Matrix2x3( -+ Element _0_0, Element _0_1, Element _0_2, -+ Element _1_0, Element _1_1, Element _1_2 -+) { -+ return Matrix2x3( -+ _0_0, _0_1, _0_2, -+ _1_0, _1_1, _1_2 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 2-by-4 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 2; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 4; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 8; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 2-by-4 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3, -+ Element _1_0, Element _1_1, Element _1_2, Element _1_3 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; -+ data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; -+ } -+ -+ /// Constucts a 2-by-4 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_0.data[2]; -+ data[3] = row_0.data[3]; -+ data[4] = row_1.data[0]; -+ data[5] = row_1.data[1]; -+ data[6] = row_1.data[2]; -+ data[7] = row_1.data[3]; -+ } -+ -+ /// Static method to construct a 2-by-4 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1, -+ Matrix const &column_2, -+ Matrix const &column_3 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_2.data[0]; -+ result.data[3] = column_3.data[0]; -+ result.data[4] = column_0.data[1]; -+ result.data[5] = column_1.data[1]; -+ result.data[6] = column_2.data[1]; -+ result.data[7] = column_3.data[1]; -+ return result; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ m.data[6] = s; -+ m.data[7] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[3] = diag.data[1]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[3] = diag.data[1]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[3]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[2] = data[1]; -+ mt.data[4] = data[2]; -+ mt.data[6] = data[3]; -+ mt.data[1] = data[4]; -+ mt.data[3] = data[5]; -+ mt.data[5] = data[6]; -+ mt.data[7] = data[7]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x4(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x4(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 4] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_2x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_2x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 4]; -+ m.data[3] = data[i * 4 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 4] = m.data[2]; -+ data[i * 4 + j + 5] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 4]; -+ m.data[4] = data[i * 4 + j + 5]; -+ m.data[5] = data[i * 4 + j + 6]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 4] = m.data[3]; -+ data[i * 4 + j + 5] = m.data[4]; -+ data[i * 4 + j + 6] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ m.data[4] = data[i * 4 + j + 4]; -+ m.data[5] = data[i * 4 + j + 5]; -+ m.data[6] = data[i * 4 + j + 6]; -+ m.data[7] = data[i * 4 + j + 7]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ data[i * 4 + j + 4] = m.data[4]; -+ data[i * 4 + j + 5] = m.data[5]; -+ data[i * 4 + j + 6] = m.data[6]; -+ data[i * 4 + j + 7] = m.data[7]; -+ -+ return *this; -+ } -+ -+ /// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2) -+ , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2)); -+ } -+ -+ /// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-2 matrix with a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1) -+ , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1)); -+ } -+ -+ /// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-3 matrix with a 2-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0) -+ , lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0)); -+ } -+ -+ /// Forms a 2-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 1-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-4 matrix to form a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-4 matrix to form a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Forms a 2-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Matrix const & B, -+ Element C, Matrix const & D) { -+ return Matrix( -+ A, B.at(0, 0), B.at(0, 1), B.at(0, 2) -+ , C, D.at(0, 0), D.at(0, 1), D.at(0, 2) -+ ); -+ } -+ -+ /// Forms a 2-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) -+ ); -+ } -+ -+ /// Forms a 2-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Element B, -+ Matrix const & C, Element D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), A.at(0, 2), B -+ , C.at(0, 0), C.at(0, 1), C.at(0, 2), D -+ ); -+ } -+ -+ /// Elementwise add operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ result.data[6] = data[6] + rhs.data[6]; -+ result.data[7] = data[7] + rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ data[6] += rhs.data[6]; -+ data[7] += rhs.data[7]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ result.data[6] = data[6] - rhs.data[6]; -+ result.data[7] = data[7] - rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ data[6] -= rhs.data[6]; -+ data[7] -= rhs.data[7]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ result.data[6] = data[6] * rhs.data[6]; -+ result.data[7] = data[7] * rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ result.data[6] = data[6] * s; -+ result.data[7] = data[7] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ data[3] *= s; -+ -+ data[4] *= s; -+ data[5] *= s; -+ data[6] *= s; -+ data[7] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ result.data[6] = data[6] / rhs.data[6]; -+ result.data[7] = data[7] / rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ result.data[6] = data[6] / s; -+ result.data[7] = data[7] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ data[3] /= s; -+ -+ data[4] /= s; -+ data[5] /= s; -+ data[6] /= s; -+ data[7] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ data[6] /= rhs.data[6]; -+ data[7] /= rhs.data[7]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ m.data[6] = -m.data[6]; -+ m.data[7] = -m.data[7]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 2-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[4] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[5] * rhs.data[1]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[2]; -+ accum.data[1] += data[6] * rhs.data[2]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[3]; -+ accum.data[1] += data[7] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[4] * rhs.data[0]; -+ accum.data[3] += data[4] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[5] * rhs.data[2]; -+ accum.data[3] += data[5] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ accum.data[2] += data[6] * rhs.data[4]; -+ accum.data[3] += data[6] * rhs.data[5]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[6]; -+ accum.data[1] += data[3] * rhs.data[7]; -+ accum.data[2] += data[7] * rhs.data[6]; -+ accum.data[3] += data[7] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[4] * rhs.data[0]; -+ accum.data[4] += data[4] * rhs.data[1]; -+ accum.data[5] += data[4] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[5] * rhs.data[3]; -+ accum.data[4] += data[5] * rhs.data[4]; -+ accum.data[5] += data[5] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ accum.data[3] += data[6] * rhs.data[6]; -+ accum.data[4] += data[6] * rhs.data[7]; -+ accum.data[5] += data[6] * rhs.data[8]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[9]; -+ accum.data[1] += data[3] * rhs.data[10]; -+ accum.data[2] += data[3] * rhs.data[11]; -+ accum.data[3] += data[7] * rhs.data[9]; -+ accum.data[4] += data[7] * rhs.data[10]; -+ accum.data[5] += data[7] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[4] * rhs.data[0]; -+ accum.data[5] += data[4] * rhs.data[1]; -+ accum.data[6] += data[4] * rhs.data[2]; -+ accum.data[7] += data[4] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[5] * rhs.data[4]; -+ accum.data[5] += data[5] * rhs.data[5]; -+ accum.data[6] += data[5] * rhs.data[6]; -+ accum.data[7] += data[5] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ accum.data[4] += data[6] * rhs.data[8]; -+ accum.data[5] += data[6] * rhs.data[9]; -+ accum.data[6] += data[6] * rhs.data[10]; -+ accum.data[7] += data[6] * rhs.data[11]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[12]; -+ accum.data[1] += data[3] * rhs.data[13]; -+ accum.data[2] += data[3] * rhs.data[14]; -+ accum.data[3] += data[3] * rhs.data[15]; -+ accum.data[4] += data[7] * rhs.data[12]; -+ accum.data[5] += data[7] * rhs.data[13]; -+ accum.data[6] += data[7] * rhs.data[14]; -+ accum.data[7] += data[7] * rhs.data[15]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ accum += data[6]; -+ accum += data[7]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ accum += data[6] * data[6]; -+ accum += data[7] * data[7]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[5]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 2-by-4 matrix -+template -+using Matrix2x4 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix2x4 make_Matrix2x4( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3, -+ Element _1_0, Element _1_1, Element _1_2, Element _1_3 -+) { -+ return Matrix2x4( -+ _0_0, _0_1, _0_2, _0_3, -+ _1_0, _1_1, _1_2, _1_3 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 3-by-1 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 3; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 1; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 3; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 3-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 3-by-1 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, -+ Element _1_0, -+ Element _2_0 -+ ) { -+ -+ data[0] = _0_0; -+ data[1] = _1_0; -+ data[2] = _2_0; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[1] = data[1]; -+ mt.data[2] = data[2]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 1 + j + 0]; -+ m.data[1] = data[i * 1 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 1 + j + 0] = m.data[0]; -+ data[i * 1 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 1 + j + 0]; -+ m.data[1] = data[i * 1 + j + 1]; -+ m.data[2] = data[i * 1 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 1 + j + 0] = m.data[0]; -+ data[i * 1 + j + 1] = m.data[1]; -+ data[i * 1 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_3x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_3x1(v, 0, j); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-2 matrix to form a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-3 matrix to form a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 3-by-1 matrix by vertically concatenating an Element with a 2-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Element upper, Matrix const & lower) { -+ return Matrix( -+ upper -+ , lower.at(0, 0) -+ , lower.at(1, 0)); -+ } -+ -+ /// Forms a 3-by-1 matrix by vertically concatenating a 2-by-1 matrix with an Element -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Element lower) { -+ return Matrix( -+ upper.at(0, 0) -+ , upper.at(1, 0) -+ , lower); -+ } -+ -+ /// Concatenates this matrix with a an Element to form a 4-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Element rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Elementwise add operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ result.data[2] = data[2] + rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ -+ data[1] += rhs.data[1]; -+ -+ data[2] += rhs.data[2]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ result.data[2] = data[2] - rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ -+ data[1] -= rhs.data[1]; -+ -+ data[2] -= rhs.data[2]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ result.data[2] = data[2] * rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ -+ result.data[1] = data[1] * s; -+ -+ result.data[2] = data[2] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ -+ data[1] *= s; -+ -+ data[2] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ result.data[2] = data[2] / rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ -+ result.data[1] = data[1] / s; -+ -+ result.data[2] = data[2] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ -+ data[1] /= s; -+ -+ data[2] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ -+ data[1] /= rhs.data[1]; -+ -+ data[2] /= rhs.data[2]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 3-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[1] * rhs.data[0]; -+ accum.data[2] += data[2] * rhs.data[0]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 3-by-2-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[1] * rhs.data[0]; -+ accum.data[3] += data[1] * rhs.data[1]; -+ accum.data[4] += data[2] * rhs.data[0]; -+ accum.data[5] += data[2] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-2-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-3-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[1] * rhs.data[0]; -+ accum.data[4] += data[1] * rhs.data[1]; -+ accum.data[5] += data[1] * rhs.data[2]; -+ accum.data[6] += data[2] * rhs.data[0]; -+ accum.data[7] += data[2] * rhs.data[1]; -+ accum.data[8] += data[2] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-3-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-4-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[1] * rhs.data[0]; -+ accum.data[5] += data[1] * rhs.data[1]; -+ accum.data[6] += data[1] * rhs.data[2]; -+ accum.data[7] += data[1] * rhs.data[3]; -+ accum.data[8] += data[2] * rhs.data[0]; -+ accum.data[9] += data[2] * rhs.data[1]; -+ accum.data[10] += data[2] * rhs.data[2]; -+ accum.data[11] += data[2] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-4-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Dot product of vectors with extent 3 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ return accum; -+ } -+ -+ /// Dot product of vectors with extent 3 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ return accum; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ -+ return accum; -+ } -+ -+ /// Cross product -+ CUTLASS_HOST_DEVICE -+ Matrix cross(Matrix const &rhs) const { -+ return Matrix( -+ data[1] * rhs.data[2] - data[2] * rhs.data[1], -+ data[0] * rhs.data[2] - data[2] * rhs.data[1], -+ data[0] * rhs.data[1] - data[1] * rhs.data[0] -+ ); -+ } -+ -+}; -+ -+/// Template alias for 3-by-1 matrix -+template -+using Matrix3x1 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix3x1 make_Matrix3x1( -+ Element _0_0, -+ Element _1_0, -+ Element _2_0 -+) { -+ return Matrix3x1( -+ _0_0, -+ _1_0, -+ _2_0 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 3-by-2 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 3; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 2; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 6; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 3-by-2 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, -+ Element _1_0, Element _1_1, -+ Element _2_0, Element _2_1 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; -+ data[2] = _1_0; data[3] = _1_1; -+ data[4] = _2_0; data[5] = _2_1; -+ } -+ -+ /// Constucts a 3-by-2 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1, -+ Matrix const &row_2 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_1.data[0]; -+ data[3] = row_1.data[1]; -+ data[4] = row_2.data[0]; -+ data[5] = row_2.data[1]; -+ } -+ -+ /// Static method to construct a 3-by-2 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_0.data[1]; -+ result.data[3] = column_1.data[1]; -+ result.data[4] = column_0.data[2]; -+ result.data[5] = column_1.data[2]; -+ return result; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[4] = diag.data[1]; -+ m.data[8] = diag.data[2]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[4] = diag.data[1]; -+ m.data[8] = diag.data[2]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[4]; -+ diag.data[2] = data[8]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[3] = data[1]; -+ mt.data[1] = data[2]; -+ mt.data[4] = data[3]; -+ mt.data[2] = data[4]; -+ mt.data[5] = data[5]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x2(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x2(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 2] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ m.data[2] = data[i * 2 + j + 2]; -+ m.data[3] = data[i * 2 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ data[i * 2 + j + 2] = m.data[2]; -+ data[i * 2 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 2]; -+ m.data[2] = data[i * 2 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 2] = m.data[1]; -+ data[i * 2 + j + 4] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_3x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_3x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ m.data[2] = data[i * 2 + j + 2]; -+ m.data[3] = data[i * 2 + j + 3]; -+ m.data[4] = data[i * 2 + j + 4]; -+ m.data[5] = data[i * 2 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ data[i * 2 + j + 2] = m.data[2]; -+ data[i * 2 + j + 3] = m.data[3]; -+ data[i * 2 + j + 4] = m.data[4]; -+ data[i * 2 + j + 5] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Forms a 3-by-2 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0) -+ , lhs.at(1, 0), rhs.at(1, 0) -+ , lhs.at(2, 0), rhs.at(2, 0)); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-2 matrix to form a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 3-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1) -+ , lower.at(0, 0), lower.at(0, 1) -+ , lower.at(1, 0), lower.at(1, 1)); -+ } -+ -+ /// Forms a 3-by-2 matrix by vertically concatenating a 2-by-2 matrix with a 1-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1) -+ , upper.at(1, 0), upper.at(1, 1) -+ , lower.at(0, 0), lower.at(0, 1)); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-2 matrix to form a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Forms a 3-by-2 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Element B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A, B -+ , C.at(0, 0), D.at(0, 0) -+ , C.at(1, 0), D.at(1, 0) -+ ); -+ } -+ -+ /// Forms a 3-by-2 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Element C, Element D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0) -+ , A.at(1, 0), B.at(1, 0) -+ , C, D -+ ); -+ } -+ -+ /// Elementwise add operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ -+ data[2] *= s; -+ data[3] *= s; -+ -+ data[4] *= s; -+ data[5] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ -+ data[2] /= s; -+ data[3] /= s; -+ -+ data[4] /= s; -+ data[5] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 3-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[2] * rhs.data[0]; -+ accum.data[2] += data[4] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[3] * rhs.data[1]; -+ accum.data[2] += data[5] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[2] * rhs.data[0]; -+ accum.data[3] += data[2] * rhs.data[1]; -+ accum.data[4] += data[4] * rhs.data[0]; -+ accum.data[5] += data[4] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[3] * rhs.data[2]; -+ accum.data[3] += data[3] * rhs.data[3]; -+ accum.data[4] += data[5] * rhs.data[2]; -+ accum.data[5] += data[5] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 3-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[2] * rhs.data[0]; -+ accum.data[4] += data[2] * rhs.data[1]; -+ accum.data[5] += data[2] * rhs.data[2]; -+ accum.data[6] += data[4] * rhs.data[0]; -+ accum.data[7] += data[4] * rhs.data[1]; -+ accum.data[8] += data[4] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[3] * rhs.data[3]; -+ accum.data[4] += data[3] * rhs.data[4]; -+ accum.data[5] += data[3] * rhs.data[5]; -+ accum.data[6] += data[5] * rhs.data[3]; -+ accum.data[7] += data[5] * rhs.data[4]; -+ accum.data[8] += data[5] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[2] * rhs.data[0]; -+ accum.data[5] += data[2] * rhs.data[1]; -+ accum.data[6] += data[2] * rhs.data[2]; -+ accum.data[7] += data[2] * rhs.data[3]; -+ accum.data[8] += data[4] * rhs.data[0]; -+ accum.data[9] += data[4] * rhs.data[1]; -+ accum.data[10] += data[4] * rhs.data[2]; -+ accum.data[11] += data[4] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[3] * rhs.data[4]; -+ accum.data[5] += data[3] * rhs.data[5]; -+ accum.data[6] += data[3] * rhs.data[6]; -+ accum.data[7] += data[3] * rhs.data[7]; -+ accum.data[8] += data[5] * rhs.data[4]; -+ accum.data[9] += data[5] * rhs.data[5]; -+ accum.data[10] += data[5] * rhs.data[6]; -+ accum.data[11] += data[5] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[3]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 3-by-2 matrix -+template -+using Matrix3x2 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix3x2 make_Matrix3x2( -+ Element _0_0, Element _0_1, -+ Element _1_0, Element _1_1, -+ Element _2_0, Element _2_1 -+) { -+ return Matrix3x2( -+ _0_0, _0_1, -+ _1_0, _1_1, -+ _2_0, _2_1 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 3-by-3 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 3; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 3; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 9; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 3-by-3 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, -+ Element _1_0, Element _1_1, Element _1_2, -+ Element _2_0, Element _2_1, Element _2_2 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; -+ data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; -+ data[6] = _2_0; data[7] = _2_1; data[8] = _2_2; -+ } -+ -+ /// Constucts a 3-by-3 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1, -+ Matrix const &row_2 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_0.data[2]; -+ data[3] = row_1.data[0]; -+ data[4] = row_1.data[1]; -+ data[5] = row_1.data[2]; -+ data[6] = row_2.data[0]; -+ data[7] = row_2.data[1]; -+ data[8] = row_2.data[2]; -+ } -+ -+ /// Static method to construct a 3-by-3 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1, -+ Matrix const &column_2 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_2.data[0]; -+ result.data[3] = column_0.data[1]; -+ result.data[4] = column_1.data[1]; -+ result.data[5] = column_2.data[1]; -+ result.data[6] = column_0.data[2]; -+ result.data[7] = column_1.data[2]; -+ result.data[8] = column_2.data[2]; -+ return result; -+ } -+ -+ /// Constructs an identity matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix identity() { -+ Matrix m; -+ -+ m.data[0] = Element(1); -+ m.data[4] = Element(1); -+ m.data[8] = Element(1); -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ m.data[6] = s; -+ m.data[7] = s; -+ m.data[8] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[4] = diag.data[1]; -+ m.data[8] = diag.data[2]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[4] = diag.data[1]; -+ m.data[8] = diag.data[2]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[4]; -+ diag.data[2] = data[8]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[3] = data[1]; -+ mt.data[6] = data[2]; -+ mt.data[1] = data[3]; -+ mt.data[4] = data[4]; -+ mt.data[7] = data[5]; -+ mt.data[2] = data[6]; -+ mt.data[5] = data[7]; -+ mt.data[8] = data[8]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x3(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x3(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 3] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 3]; -+ m.data[3] = data[i * 3 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 3] = m.data[2]; -+ data[i * 3 + j + 4] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ m.data[3] = data[i * 3 + j + 3]; -+ m.data[4] = data[i * 3 + j + 4]; -+ m.data[5] = data[i * 3 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ data[i * 3 + j + 3] = m.data[3]; -+ data[i * 3 + j + 4] = m.data[4]; -+ data[i * 3 + j + 5] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 3]; -+ m.data[2] = data[i * 3 + j + 6]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 3] = m.data[1]; -+ data[i * 3 + j + 6] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_3x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_3x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 3]; -+ m.data[3] = data[i * 3 + j + 4]; -+ m.data[4] = data[i * 3 + j + 6]; -+ m.data[5] = data[i * 3 + j + 7]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 3] = m.data[2]; -+ data[i * 3 + j + 4] = m.data[3]; -+ data[i * 3 + j + 6] = m.data[4]; -+ data[i * 3 + j + 7] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ m.data[3] = data[i * 3 + j + 3]; -+ m.data[4] = data[i * 3 + j + 4]; -+ m.data[5] = data[i * 3 + j + 5]; -+ m.data[6] = data[i * 3 + j + 6]; -+ m.data[7] = data[i * 3 + j + 7]; -+ m.data[8] = data[i * 3 + j + 8]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ data[i * 3 + j + 3] = m.data[3]; -+ data[i * 3 + j + 4] = m.data[4]; -+ data[i * 3 + j + 5] = m.data[5]; -+ data[i * 3 + j + 6] = m.data[6]; -+ data[i * 3 + j + 7] = m.data[7]; -+ data[i * 3 + j + 8] = m.data[8]; -+ -+ return *this; -+ } -+ -+ /// Forms a 3-by-3 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1) -+ , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1) -+ , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1)); -+ } -+ -+ /// Forms a 3-by-3 matrix by horizontally concatenating a 3-by-2 matrix with a 3-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0) -+ , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0) -+ , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0)); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 3-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2) -+ , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2)); -+ } -+ -+ /// Forms a 3-by-3 matrix by vertically concatenating a 2-by-3 matrix with a 1-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) -+ , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-3 matrix to form a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Forms a 3-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A, B.at(0, 0), B.at(0, 1) -+ , C.at(0, 0), D.at(0, 0), D.at(0, 1) -+ , C.at(1, 0), D.at(1, 0), D.at(1, 1) -+ ); -+ } -+ -+ /// Forms a 3-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Element B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0) -+ , C.at(1, 0), C.at(1, 1), D.at(1, 0) -+ ); -+ } -+ -+ /// Forms a 3-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Element C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0), B.at(0, 1) -+ , A.at(1, 0), B.at(1, 0), B.at(1, 1) -+ , C, D.at(0, 0), D.at(0, 1) -+ ); -+ } -+ -+ /// Forms a 3-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Element D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0) -+ , A.at(1, 0), A.at(1, 1), B.at(1, 0) -+ , C.at(0, 0), C.at(0, 1), D -+ ); -+ } -+ -+ /// Elementwise add operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ -+ result.data[3] = data[3] + rhs.data[3]; -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ -+ result.data[6] = data[6] + rhs.data[6]; -+ result.data[7] = data[7] + rhs.data[7]; -+ result.data[8] = data[8] + rhs.data[8]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ -+ data[3] += rhs.data[3]; -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ -+ data[6] += rhs.data[6]; -+ data[7] += rhs.data[7]; -+ data[8] += rhs.data[8]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ -+ result.data[3] = data[3] - rhs.data[3]; -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ -+ result.data[6] = data[6] - rhs.data[6]; -+ result.data[7] = data[7] - rhs.data[7]; -+ result.data[8] = data[8] - rhs.data[8]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ -+ data[3] -= rhs.data[3]; -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ -+ data[6] -= rhs.data[6]; -+ data[7] -= rhs.data[7]; -+ data[8] -= rhs.data[8]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ -+ result.data[3] = data[3] * rhs.data[3]; -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ -+ result.data[6] = data[6] * rhs.data[6]; -+ result.data[7] = data[7] * rhs.data[7]; -+ result.data[8] = data[8] * rhs.data[8]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ -+ result.data[3] = data[3] * s; -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ -+ result.data[6] = data[6] * s; -+ result.data[7] = data[7] * s; -+ result.data[8] = data[8] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ -+ data[3] *= s; -+ data[4] *= s; -+ data[5] *= s; -+ -+ data[6] *= s; -+ data[7] *= s; -+ data[8] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ -+ result.data[3] = data[3] / rhs.data[3]; -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ -+ result.data[6] = data[6] / rhs.data[6]; -+ result.data[7] = data[7] / rhs.data[7]; -+ result.data[8] = data[8] / rhs.data[8]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ -+ result.data[3] = data[3] / s; -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ -+ result.data[6] = data[6] / s; -+ result.data[7] = data[7] / s; -+ result.data[8] = data[8] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ -+ data[3] /= s; -+ data[4] /= s; -+ data[5] /= s; -+ -+ data[6] /= s; -+ data[7] /= s; -+ data[8] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ -+ data[3] /= rhs.data[3]; -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ -+ data[6] /= rhs.data[6]; -+ data[7] /= rhs.data[7]; -+ data[8] /= rhs.data[8]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ m.data[6] = -m.data[6]; -+ m.data[7] = -m.data[7]; -+ m.data[8] = -m.data[8]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 3-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[3] * rhs.data[0]; -+ accum.data[2] += data[6] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[4] * rhs.data[1]; -+ accum.data[2] += data[7] * rhs.data[1]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[2]; -+ accum.data[1] += data[5] * rhs.data[2]; -+ accum.data[2] += data[8] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[3] * rhs.data[0]; -+ accum.data[3] += data[3] * rhs.data[1]; -+ accum.data[4] += data[6] * rhs.data[0]; -+ accum.data[5] += data[6] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[4] * rhs.data[2]; -+ accum.data[3] += data[4] * rhs.data[3]; -+ accum.data[4] += data[7] * rhs.data[2]; -+ accum.data[5] += data[7] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ accum.data[2] += data[5] * rhs.data[4]; -+ accum.data[3] += data[5] * rhs.data[5]; -+ accum.data[4] += data[8] * rhs.data[4]; -+ accum.data[5] += data[8] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[3] * rhs.data[0]; -+ accum.data[4] += data[3] * rhs.data[1]; -+ accum.data[5] += data[3] * rhs.data[2]; -+ accum.data[6] += data[6] * rhs.data[0]; -+ accum.data[7] += data[6] * rhs.data[1]; -+ accum.data[8] += data[6] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[4] * rhs.data[3]; -+ accum.data[4] += data[4] * rhs.data[4]; -+ accum.data[5] += data[4] * rhs.data[5]; -+ accum.data[6] += data[7] * rhs.data[3]; -+ accum.data[7] += data[7] * rhs.data[4]; -+ accum.data[8] += data[7] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ accum.data[3] += data[5] * rhs.data[6]; -+ accum.data[4] += data[5] * rhs.data[7]; -+ accum.data[5] += data[5] * rhs.data[8]; -+ accum.data[6] += data[8] * rhs.data[6]; -+ accum.data[7] += data[8] * rhs.data[7]; -+ accum.data[8] += data[8] * rhs.data[8]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 3-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[3] * rhs.data[0]; -+ accum.data[5] += data[3] * rhs.data[1]; -+ accum.data[6] += data[3] * rhs.data[2]; -+ accum.data[7] += data[3] * rhs.data[3]; -+ accum.data[8] += data[6] * rhs.data[0]; -+ accum.data[9] += data[6] * rhs.data[1]; -+ accum.data[10] += data[6] * rhs.data[2]; -+ accum.data[11] += data[6] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[4] * rhs.data[4]; -+ accum.data[5] += data[4] * rhs.data[5]; -+ accum.data[6] += data[4] * rhs.data[6]; -+ accum.data[7] += data[4] * rhs.data[7]; -+ accum.data[8] += data[7] * rhs.data[4]; -+ accum.data[9] += data[7] * rhs.data[5]; -+ accum.data[10] += data[7] * rhs.data[6]; -+ accum.data[11] += data[7] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ accum.data[4] += data[5] * rhs.data[8]; -+ accum.data[5] += data[5] * rhs.data[9]; -+ accum.data[6] += data[5] * rhs.data[10]; -+ accum.data[7] += data[5] * rhs.data[11]; -+ accum.data[8] += data[8] * rhs.data[8]; -+ accum.data[9] += data[8] * rhs.data[9]; -+ accum.data[10] += data[8] * rhs.data[10]; -+ accum.data[11] += data[8] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ accum += data[6]; -+ accum += data[7]; -+ accum += data[8]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ accum += data[6] * data[6]; -+ accum += data[7] * data[7]; -+ accum += data[8] * data[8]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[4]; -+ accum += data[8]; -+ -+ return accum; -+ } -+ -+ /// Returns 3-by-3 rotation matrix around the X axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation_X(Element theta) { -+ Matrix m = identity(); -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ m.at(1, 1) = c; -+ m.at(1, 2) = -s; -+ m.at(2, 1) = s; -+ m.at(2, 2) = c; -+ -+ return m; -+ } -+ -+ /// Returns 3-by-3 rotation matrix around the Y axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation_Y(Element theta) { -+ Matrix m = identity(); -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ m.at(0, 0) = c; -+ m.at(2, 0) = -s; -+ m.at(0, 2) = s; -+ m.at(2, 2) = c; -+ -+ return m; -+ } -+ -+ /// Returns 3-by-3 rotation matrix around the Z axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation_Z(Element theta) { -+ Matrix m = Matrix::identity(); -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ m.at(0, 0) = c; -+ m.at(0, 1) = -s; -+ m.at(1, 0) = s; -+ m.at(1, 1) = c; -+ -+ return m; -+ } -+ -+ /// Returns a 3-by-3 rotation matrix around a unit-length axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation(Element theta, Matrix const &u) { -+ Element x = u.data[0]; -+ Element y = u.data[1]; -+ Element z = u.data[2]; -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ Element one_minus_cos = Element(1) - fast_cos(theta); -+ -+ Matrix m; -+ -+ m.set_slice3x3({ -+ c + x * x * one_minus_cos, x * y * one_minus_cos - z * s, x * z * one_minus_cos + y * s, -+ y * x * one_minus_cos * z * s, c + y * y * one_minus_cos, y * z * one_minus_cos - x * s, -+ z * x * one_minus_cos - y * s, z * y * one_minus_cos + x * s, c + z * z * one_minus_cos -+ }); -+ -+ return m; -+ } -+ -+ /// Returns a 3-by-3 reflection about the plane specified by the -+ /// unit-length normal vector n_unit -+ CUTLASS_HOST_DEVICE -+ static Matrix reflection(Matrix const &n_unit) { -+ -+ Element a = n_unit.data[0]; -+ Element b = n_unit.data[1]; -+ Element c = n_unit.data[2]; -+ -+ Matrix m = Matrix::identity(); -+ -+ m.set_slice3x3({ -+ Element(1) - Element(2) * a * a, Element(-2) * a * b, Element(-2) * a * c, -+ Element(-2) * a * b, Element(1) - Element(2) * b * b, Element(-2) * b * c, -+ Element(-2) * a * c, Element(-2) * b * c, Element(1) - Element(2) * c * c -+ }); -+ -+ return m; -+ } -+ -+ /// Computes the determinant of a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Element determinant(Element accum = Element()) const { -+ -+ accum += at(0, 0) * Matrix({ at(1, 1), at(1, 2), at(2, 1), at(2, 2) }).determinant(); -+ accum -= at(0, 1) * Matrix({ at(1, 0), at(1, 2), at(2, 0), at(2, 2) }).determinant(); -+ accum += at(0, 2) * Matrix({ at(1, 0), at(1, 1), at(2, 0), at(2, 1) }).determinant(); -+ -+ return accum; -+ } -+ -+ /// Computes the inverse of a 3-by-3 matrix given -+ /// the matrix's determinant -+ CUTLASS_HOST_DEVICE -+ Matrix inverse(Element det) const { -+ return Matrix( -+ at(1, 1) * at(2, 2) - at(1, 2) * at(2, 1), -+ at(0, 2) * at(2, 1) - at(0, 1) * at(2, 2), -+ at(0, 1) * at(1, 2) - at(0, 2) * at(1, 1), -+ -+ at(1, 2) * at(2, 0) - at(1, 0) * at(2, 2), -+ at(0, 0) * at(2, 2) - at(0, 2) * at(2, 0), -+ at(0, 2) * at(1, 0) - at(0, 0) * at(1, 2), -+ -+ at(1, 0) * at(2, 1) - at(1, 1) * at(2, 0), -+ at(0, 1) * at(2, 0) - at(0, 0) * at(2, 1), -+ at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0) -+ ) * (Element(1) / det); -+ } -+ /// Computes the inverse of a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix inverse() const { -+ return inverse(determinant()); -+ } -+ -+}; -+ -+/// Template alias for 3-by-3 matrix -+template -+using Matrix3x3 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix3x3 make_Matrix3x3( -+ Element _0_0, Element _0_1, Element _0_2, -+ Element _1_0, Element _1_1, Element _1_2, -+ Element _2_0, Element _2_1, Element _2_2 -+) { -+ return Matrix3x3( -+ _0_0, _0_1, _0_2, -+ _1_0, _1_1, _1_2, -+ _2_0, _2_1, _2_2 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 3-by-4 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 3; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 4; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 12; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 3-by-4 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3, -+ Element _1_0, Element _1_1, Element _1_2, Element _1_3, -+ Element _2_0, Element _2_1, Element _2_2, Element _2_3 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; -+ data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; -+ data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3; -+ } -+ -+ /// Constucts a 3-by-4 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1, -+ Matrix const &row_2 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_0.data[2]; -+ data[3] = row_0.data[3]; -+ data[4] = row_1.data[0]; -+ data[5] = row_1.data[1]; -+ data[6] = row_1.data[2]; -+ data[7] = row_1.data[3]; -+ data[8] = row_2.data[0]; -+ data[9] = row_2.data[1]; -+ data[10] = row_2.data[2]; -+ data[11] = row_2.data[3]; -+ } -+ -+ /// Static method to construct a 3-by-4 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1, -+ Matrix const &column_2, -+ Matrix const &column_3 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_2.data[0]; -+ result.data[3] = column_3.data[0]; -+ result.data[4] = column_0.data[1]; -+ result.data[5] = column_1.data[1]; -+ result.data[6] = column_2.data[1]; -+ result.data[7] = column_3.data[1]; -+ result.data[8] = column_0.data[2]; -+ result.data[9] = column_1.data[2]; -+ result.data[10] = column_2.data[2]; -+ result.data[11] = column_3.data[2]; -+ return result; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ m.data[6] = s; -+ m.data[7] = s; -+ m.data[8] = s; -+ m.data[9] = s; -+ m.data[10] = s; -+ m.data[11] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[4] = diag.data[1]; -+ m.data[8] = diag.data[2]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[4] = diag.data[1]; -+ m.data[8] = diag.data[2]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[4]; -+ diag.data[2] = data[8]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[3] = data[1]; -+ mt.data[6] = data[2]; -+ mt.data[9] = data[3]; -+ mt.data[1] = data[4]; -+ mt.data[4] = data[5]; -+ mt.data[7] = data[6]; -+ mt.data[10] = data[7]; -+ mt.data[2] = data[8]; -+ mt.data[5] = data[9]; -+ mt.data[8] = data[10]; -+ mt.data[11] = data[11]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x4(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x4(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 4] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 4]; -+ m.data[3] = data[i * 4 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 4] = m.data[2]; -+ data[i * 4 + j + 5] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 4]; -+ m.data[4] = data[i * 4 + j + 5]; -+ m.data[5] = data[i * 4 + j + 6]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 4] = m.data[3]; -+ data[i * 4 + j + 5] = m.data[4]; -+ data[i * 4 + j + 6] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ m.data[4] = data[i * 4 + j + 4]; -+ m.data[5] = data[i * 4 + j + 5]; -+ m.data[6] = data[i * 4 + j + 6]; -+ m.data[7] = data[i * 4 + j + 7]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ data[i * 4 + j + 4] = m.data[4]; -+ data[i * 4 + j + 5] = m.data[5]; -+ data[i * 4 + j + 6] = m.data[6]; -+ data[i * 4 + j + 7] = m.data[7]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 4]; -+ m.data[2] = data[i * 4 + j + 8]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 4] = m.data[1]; -+ data[i * 4 + j + 8] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_3x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_3x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 4]; -+ m.data[3] = data[i * 4 + j + 5]; -+ m.data[4] = data[i * 4 + j + 8]; -+ m.data[5] = data[i * 4 + j + 9]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 4] = m.data[2]; -+ data[i * 4 + j + 5] = m.data[3]; -+ data[i * 4 + j + 8] = m.data[4]; -+ data[i * 4 + j + 9] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 4]; -+ m.data[4] = data[i * 4 + j + 5]; -+ m.data[5] = data[i * 4 + j + 6]; -+ m.data[6] = data[i * 4 + j + 8]; -+ m.data[7] = data[i * 4 + j + 9]; -+ m.data[8] = data[i * 4 + j + 10]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 4] = m.data[3]; -+ data[i * 4 + j + 5] = m.data[4]; -+ data[i * 4 + j + 6] = m.data[5]; -+ data[i * 4 + j + 8] = m.data[6]; -+ data[i * 4 + j + 9] = m.data[7]; -+ data[i * 4 + j + 10] = m.data[8]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ m.data[4] = data[i * 4 + j + 4]; -+ m.data[5] = data[i * 4 + j + 5]; -+ m.data[6] = data[i * 4 + j + 6]; -+ m.data[7] = data[i * 4 + j + 7]; -+ m.data[8] = data[i * 4 + j + 8]; -+ m.data[9] = data[i * 4 + j + 9]; -+ m.data[10] = data[i * 4 + j + 10]; -+ m.data[11] = data[i * 4 + j + 11]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ data[i * 4 + j + 4] = m.data[4]; -+ data[i * 4 + j + 5] = m.data[5]; -+ data[i * 4 + j + 6] = m.data[6]; -+ data[i * 4 + j + 7] = m.data[7]; -+ data[i * 4 + j + 8] = m.data[8]; -+ data[i * 4 + j + 9] = m.data[9]; -+ data[i * 4 + j + 10] = m.data[10]; -+ data[i * 4 + j + 11] = m.data[11]; -+ -+ return *this; -+ } -+ -+ /// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2) -+ , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2) -+ , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1), rhs.at(2, 2)); -+ } -+ -+ /// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-2 matrix with a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1) -+ , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1) -+ , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0), rhs.at(2, 1)); -+ } -+ -+ /// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-3 matrix with a 3-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0) -+ , lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0) -+ , lhs.at(2, 0), lhs.at(2, 1), lhs.at(2, 2), rhs.at(2, 0)); -+ } -+ -+ /// Forms a 3-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3) -+ , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3)); -+ } -+ -+ /// Forms a 3-by-4 matrix by vertically concatenating a 2-by-4 matrix with a 1-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) -+ , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-4 matrix to form a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Forms a 3-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A, B.at(0, 0), B.at(0, 1), B.at(0, 2) -+ , C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2) -+ , C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2) -+ ); -+ } -+ -+ /// Forms a 3-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) -+ , C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1) -+ ); -+ } -+ -+ /// Forms a 3-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Element B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), A.at(0, 2), B -+ , C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0) -+ , C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0) -+ ); -+ } -+ -+ /// Forms a 3-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Element C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2) -+ , A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2) -+ , C, D.at(0, 0), D.at(0, 1), D.at(0, 2) -+ ); -+ } -+ -+ /// Forms a 3-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) -+ , A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) -+ ); -+ } -+ -+ /// Forms a 3-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Element D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0) -+ , A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0) -+ , C.at(0, 0), C.at(0, 1), C.at(0, 2), D -+ ); -+ } -+ -+ /// Elementwise add operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ result.data[6] = data[6] + rhs.data[6]; -+ result.data[7] = data[7] + rhs.data[7]; -+ -+ result.data[8] = data[8] + rhs.data[8]; -+ result.data[9] = data[9] + rhs.data[9]; -+ result.data[10] = data[10] + rhs.data[10]; -+ result.data[11] = data[11] + rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ data[6] += rhs.data[6]; -+ data[7] += rhs.data[7]; -+ -+ data[8] += rhs.data[8]; -+ data[9] += rhs.data[9]; -+ data[10] += rhs.data[10]; -+ data[11] += rhs.data[11]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ result.data[6] = data[6] - rhs.data[6]; -+ result.data[7] = data[7] - rhs.data[7]; -+ -+ result.data[8] = data[8] - rhs.data[8]; -+ result.data[9] = data[9] - rhs.data[9]; -+ result.data[10] = data[10] - rhs.data[10]; -+ result.data[11] = data[11] - rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ data[6] -= rhs.data[6]; -+ data[7] -= rhs.data[7]; -+ -+ data[8] -= rhs.data[8]; -+ data[9] -= rhs.data[9]; -+ data[10] -= rhs.data[10]; -+ data[11] -= rhs.data[11]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ result.data[6] = data[6] * rhs.data[6]; -+ result.data[7] = data[7] * rhs.data[7]; -+ -+ result.data[8] = data[8] * rhs.data[8]; -+ result.data[9] = data[9] * rhs.data[9]; -+ result.data[10] = data[10] * rhs.data[10]; -+ result.data[11] = data[11] * rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ result.data[6] = data[6] * s; -+ result.data[7] = data[7] * s; -+ -+ result.data[8] = data[8] * s; -+ result.data[9] = data[9] * s; -+ result.data[10] = data[10] * s; -+ result.data[11] = data[11] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ data[3] *= s; -+ -+ data[4] *= s; -+ data[5] *= s; -+ data[6] *= s; -+ data[7] *= s; -+ -+ data[8] *= s; -+ data[9] *= s; -+ data[10] *= s; -+ data[11] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ result.data[6] = data[6] / rhs.data[6]; -+ result.data[7] = data[7] / rhs.data[7]; -+ -+ result.data[8] = data[8] / rhs.data[8]; -+ result.data[9] = data[9] / rhs.data[9]; -+ result.data[10] = data[10] / rhs.data[10]; -+ result.data[11] = data[11] / rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ result.data[6] = data[6] / s; -+ result.data[7] = data[7] / s; -+ -+ result.data[8] = data[8] / s; -+ result.data[9] = data[9] / s; -+ result.data[10] = data[10] / s; -+ result.data[11] = data[11] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ data[3] /= s; -+ -+ data[4] /= s; -+ data[5] /= s; -+ data[6] /= s; -+ data[7] /= s; -+ -+ data[8] /= s; -+ data[9] /= s; -+ data[10] /= s; -+ data[11] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ data[6] /= rhs.data[6]; -+ data[7] /= rhs.data[7]; -+ -+ data[8] /= rhs.data[8]; -+ data[9] /= rhs.data[9]; -+ data[10] /= rhs.data[10]; -+ data[11] /= rhs.data[11]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ m.data[6] = -m.data[6]; -+ m.data[7] = -m.data[7]; -+ m.data[8] = -m.data[8]; -+ m.data[9] = -m.data[9]; -+ m.data[10] = -m.data[10]; -+ m.data[11] = -m.data[11]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 3-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[4] * rhs.data[0]; -+ accum.data[2] += data[8] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[5] * rhs.data[1]; -+ accum.data[2] += data[9] * rhs.data[1]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[2]; -+ accum.data[1] += data[6] * rhs.data[2]; -+ accum.data[2] += data[10] * rhs.data[2]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[3]; -+ accum.data[1] += data[7] * rhs.data[3]; -+ accum.data[2] += data[11] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[4] * rhs.data[0]; -+ accum.data[3] += data[4] * rhs.data[1]; -+ accum.data[4] += data[8] * rhs.data[0]; -+ accum.data[5] += data[8] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[5] * rhs.data[2]; -+ accum.data[3] += data[5] * rhs.data[3]; -+ accum.data[4] += data[9] * rhs.data[2]; -+ accum.data[5] += data[9] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ accum.data[2] += data[6] * rhs.data[4]; -+ accum.data[3] += data[6] * rhs.data[5]; -+ accum.data[4] += data[10] * rhs.data[4]; -+ accum.data[5] += data[10] * rhs.data[5]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[6]; -+ accum.data[1] += data[3] * rhs.data[7]; -+ accum.data[2] += data[7] * rhs.data[6]; -+ accum.data[3] += data[7] * rhs.data[7]; -+ accum.data[4] += data[11] * rhs.data[6]; -+ accum.data[5] += data[11] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[4] * rhs.data[0]; -+ accum.data[4] += data[4] * rhs.data[1]; -+ accum.data[5] += data[4] * rhs.data[2]; -+ accum.data[6] += data[8] * rhs.data[0]; -+ accum.data[7] += data[8] * rhs.data[1]; -+ accum.data[8] += data[8] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[5] * rhs.data[3]; -+ accum.data[4] += data[5] * rhs.data[4]; -+ accum.data[5] += data[5] * rhs.data[5]; -+ accum.data[6] += data[9] * rhs.data[3]; -+ accum.data[7] += data[9] * rhs.data[4]; -+ accum.data[8] += data[9] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ accum.data[3] += data[6] * rhs.data[6]; -+ accum.data[4] += data[6] * rhs.data[7]; -+ accum.data[5] += data[6] * rhs.data[8]; -+ accum.data[6] += data[10] * rhs.data[6]; -+ accum.data[7] += data[10] * rhs.data[7]; -+ accum.data[8] += data[10] * rhs.data[8]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[9]; -+ accum.data[1] += data[3] * rhs.data[10]; -+ accum.data[2] += data[3] * rhs.data[11]; -+ accum.data[3] += data[7] * rhs.data[9]; -+ accum.data[4] += data[7] * rhs.data[10]; -+ accum.data[5] += data[7] * rhs.data[11]; -+ accum.data[6] += data[11] * rhs.data[9]; -+ accum.data[7] += data[11] * rhs.data[10]; -+ accum.data[8] += data[11] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[4] * rhs.data[0]; -+ accum.data[5] += data[4] * rhs.data[1]; -+ accum.data[6] += data[4] * rhs.data[2]; -+ accum.data[7] += data[4] * rhs.data[3]; -+ accum.data[8] += data[8] * rhs.data[0]; -+ accum.data[9] += data[8] * rhs.data[1]; -+ accum.data[10] += data[8] * rhs.data[2]; -+ accum.data[11] += data[8] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[5] * rhs.data[4]; -+ accum.data[5] += data[5] * rhs.data[5]; -+ accum.data[6] += data[5] * rhs.data[6]; -+ accum.data[7] += data[5] * rhs.data[7]; -+ accum.data[8] += data[9] * rhs.data[4]; -+ accum.data[9] += data[9] * rhs.data[5]; -+ accum.data[10] += data[9] * rhs.data[6]; -+ accum.data[11] += data[9] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ accum.data[4] += data[6] * rhs.data[8]; -+ accum.data[5] += data[6] * rhs.data[9]; -+ accum.data[6] += data[6] * rhs.data[10]; -+ accum.data[7] += data[6] * rhs.data[11]; -+ accum.data[8] += data[10] * rhs.data[8]; -+ accum.data[9] += data[10] * rhs.data[9]; -+ accum.data[10] += data[10] * rhs.data[10]; -+ accum.data[11] += data[10] * rhs.data[11]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[12]; -+ accum.data[1] += data[3] * rhs.data[13]; -+ accum.data[2] += data[3] * rhs.data[14]; -+ accum.data[3] += data[3] * rhs.data[15]; -+ accum.data[4] += data[7] * rhs.data[12]; -+ accum.data[5] += data[7] * rhs.data[13]; -+ accum.data[6] += data[7] * rhs.data[14]; -+ accum.data[7] += data[7] * rhs.data[15]; -+ accum.data[8] += data[11] * rhs.data[12]; -+ accum.data[9] += data[11] * rhs.data[13]; -+ accum.data[10] += data[11] * rhs.data[14]; -+ accum.data[11] += data[11] * rhs.data[15]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ accum += data[6]; -+ accum += data[7]; -+ accum += data[8]; -+ accum += data[9]; -+ accum += data[10]; -+ accum += data[11]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ accum += data[6] * data[6]; -+ accum += data[7] * data[7]; -+ accum += data[8] * data[8]; -+ accum += data[9] * data[9]; -+ accum += data[10] * data[10]; -+ accum += data[11] * data[11]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[5]; -+ accum += data[10]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 3-by-4 matrix -+template -+using Matrix3x4 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix3x4 make_Matrix3x4( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3, -+ Element _1_0, Element _1_1, Element _1_2, Element _1_3, -+ Element _2_0, Element _2_1, Element _2_2, Element _2_3 -+) { -+ return Matrix3x4( -+ _0_0, _0_1, _0_2, _0_3, -+ _1_0, _1_1, _1_2, _1_3, -+ _2_0, _2_1, _2_2, _2_3 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 4-by-1 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 4; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 1; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 4; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 4-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 4-by-1 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, -+ Element _1_0, -+ Element _2_0, -+ Element _3_0 -+ ) { -+ -+ data[0] = _0_0; -+ data[1] = _1_0; -+ data[2] = _2_0; -+ data[3] = _3_0; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[1] = data[1]; -+ mt.data[2] = data[2]; -+ mt.data[3] = data[3]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 1 + j + 0]; -+ m.data[1] = data[i * 1 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 1 + j + 0] = m.data[0]; -+ data[i * 1 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 1 + j + 0]; -+ m.data[1] = data[i * 1 + j + 1]; -+ m.data[2] = data[i * 1 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 1 + j + 0] = m.data[0]; -+ data[i * 1 + j + 1] = m.data[1]; -+ data[i * 1 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 1 + j + 0]; -+ m.data[1] = data[i * 1 + j + 1]; -+ m.data[2] = data[i * 1 + j + 2]; -+ m.data[3] = data[i * 1 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 1 + j + 0] = m.data[0]; -+ data[i * 1 + j + 1] = m.data[1]; -+ data[i * 1 + j + 2] = m.data[2]; -+ data[i * 1 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_4x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_4x1(v, 0, j); -+ } -+ -+ /// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 4-by-2 matrix to form a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 4-by-3 matrix to form a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 4-by-1 matrix by vertically concatenating an Element with a 3-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Element upper, Matrix const & lower) { -+ return Matrix( -+ upper -+ , lower.at(0, 0) -+ , lower.at(1, 0) -+ , lower.at(2, 0)); -+ } -+ -+ /// Forms a 4-by-1 matrix by vertically concatenating a 2-by-1 matrix with a 2-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0) -+ , upper.at(1, 0) -+ , lower.at(0, 0) -+ , lower.at(1, 0)); -+ } -+ -+ /// Forms a 4-by-1 matrix by vertically concatenating a 3-by-1 matrix with an Element -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Element lower) { -+ return Matrix( -+ upper.at(0, 0) -+ , upper.at(1, 0) -+ , upper.at(2, 0) -+ , lower); -+ } -+ -+ /// Elementwise add operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ result.data[2] = data[2] + rhs.data[2]; -+ -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ -+ data[1] += rhs.data[1]; -+ -+ data[2] += rhs.data[2]; -+ -+ data[3] += rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ result.data[2] = data[2] - rhs.data[2]; -+ -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ -+ data[1] -= rhs.data[1]; -+ -+ data[2] -= rhs.data[2]; -+ -+ data[3] -= rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ result.data[2] = data[2] * rhs.data[2]; -+ -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ -+ result.data[1] = data[1] * s; -+ -+ result.data[2] = data[2] * s; -+ -+ result.data[3] = data[3] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ -+ data[1] *= s; -+ -+ data[2] *= s; -+ -+ data[3] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ result.data[2] = data[2] / rhs.data[2]; -+ -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ -+ result.data[1] = data[1] / s; -+ -+ result.data[2] = data[2] / s; -+ -+ result.data[3] = data[3] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ -+ data[1] /= s; -+ -+ data[2] /= s; -+ -+ data[3] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ -+ data[1] /= rhs.data[1]; -+ -+ data[2] /= rhs.data[2]; -+ -+ data[3] /= rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 4-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[1] * rhs.data[0]; -+ accum.data[2] += data[2] * rhs.data[0]; -+ accum.data[3] += data[3] * rhs.data[0]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 4-by-2-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[1] * rhs.data[0]; -+ accum.data[3] += data[1] * rhs.data[1]; -+ accum.data[4] += data[2] * rhs.data[0]; -+ accum.data[5] += data[2] * rhs.data[1]; -+ accum.data[6] += data[3] * rhs.data[0]; -+ accum.data[7] += data[3] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-2-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-3-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[1] * rhs.data[0]; -+ accum.data[4] += data[1] * rhs.data[1]; -+ accum.data[5] += data[1] * rhs.data[2]; -+ accum.data[6] += data[2] * rhs.data[0]; -+ accum.data[7] += data[2] * rhs.data[1]; -+ accum.data[8] += data[2] * rhs.data[2]; -+ accum.data[9] += data[3] * rhs.data[0]; -+ accum.data[10] += data[3] * rhs.data[1]; -+ accum.data[11] += data[3] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-3-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-4-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[1] * rhs.data[0]; -+ accum.data[5] += data[1] * rhs.data[1]; -+ accum.data[6] += data[1] * rhs.data[2]; -+ accum.data[7] += data[1] * rhs.data[3]; -+ accum.data[8] += data[2] * rhs.data[0]; -+ accum.data[9] += data[2] * rhs.data[1]; -+ accum.data[10] += data[2] * rhs.data[2]; -+ accum.data[11] += data[2] * rhs.data[3]; -+ accum.data[12] += data[3] * rhs.data[0]; -+ accum.data[13] += data[3] * rhs.data[1]; -+ accum.data[14] += data[3] * rhs.data[2]; -+ accum.data[15] += data[3] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-4-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Dot product of vectors with extent 4 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ accum += data[3] * rhs.data[3]; -+ return accum; -+ } -+ -+ /// Dot product of vectors with extent 4 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ accum += data[3] * rhs.data[3]; -+ return accum; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 4-by-1 matrix -+template -+using Matrix4x1 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix4x1 make_Matrix4x1( -+ Element _0_0, -+ Element _1_0, -+ Element _2_0, -+ Element _3_0 -+) { -+ return Matrix4x1( -+ _0_0, -+ _1_0, -+ _2_0, -+ _3_0 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 4-by-2 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 4; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 2; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 8; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 4-by-2 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, -+ Element _1_0, Element _1_1, -+ Element _2_0, Element _2_1, -+ Element _3_0, Element _3_1 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; -+ data[2] = _1_0; data[3] = _1_1; -+ data[4] = _2_0; data[5] = _2_1; -+ data[6] = _3_0; data[7] = _3_1; -+ } -+ -+ /// Constucts a 4-by-2 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1, -+ Matrix const &row_2, -+ Matrix const &row_3 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_1.data[0]; -+ data[3] = row_1.data[1]; -+ data[4] = row_2.data[0]; -+ data[5] = row_2.data[1]; -+ data[6] = row_3.data[0]; -+ data[7] = row_3.data[1]; -+ } -+ -+ /// Static method to construct a 4-by-2 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_0.data[1]; -+ result.data[3] = column_1.data[1]; -+ result.data[4] = column_0.data[2]; -+ result.data[5] = column_1.data[2]; -+ result.data[6] = column_0.data[3]; -+ result.data[7] = column_1.data[3]; -+ return result; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ m.data[6] = s; -+ m.data[7] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[5] = diag.data[1]; -+ m.data[10] = diag.data[2]; -+ m.data[15] = diag.data[3]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[5] = diag.data[1]; -+ m.data[10] = diag.data[2]; -+ m.data[15] = diag.data[3]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[5]; -+ diag.data[2] = data[10]; -+ diag.data[3] = data[15]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[4] = data[1]; -+ mt.data[1] = data[2]; -+ mt.data[5] = data[3]; -+ mt.data[2] = data[4]; -+ mt.data[6] = data[5]; -+ mt.data[3] = data[6]; -+ mt.data[7] = data[7]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x2(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x2(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 2] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ m.data[2] = data[i * 2 + j + 2]; -+ m.data[3] = data[i * 2 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ data[i * 2 + j + 2] = m.data[2]; -+ data[i * 2 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 2]; -+ m.data[2] = data[i * 2 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 2] = m.data[1]; -+ data[i * 2 + j + 4] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ m.data[2] = data[i * 2 + j + 2]; -+ m.data[3] = data[i * 2 + j + 3]; -+ m.data[4] = data[i * 2 + j + 4]; -+ m.data[5] = data[i * 2 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ data[i * 2 + j + 2] = m.data[2]; -+ data[i * 2 + j + 3] = m.data[3]; -+ data[i * 2 + j + 4] = m.data[4]; -+ data[i * 2 + j + 5] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 2]; -+ m.data[2] = data[i * 2 + j + 4]; -+ m.data[3] = data[i * 2 + j + 6]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 2] = m.data[1]; -+ data[i * 2 + j + 4] = m.data[2]; -+ data[i * 2 + j + 6] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_4x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_4x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ m.data[2] = data[i * 2 + j + 2]; -+ m.data[3] = data[i * 2 + j + 3]; -+ m.data[4] = data[i * 2 + j + 4]; -+ m.data[5] = data[i * 2 + j + 5]; -+ m.data[6] = data[i * 2 + j + 6]; -+ m.data[7] = data[i * 2 + j + 7]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ data[i * 2 + j + 2] = m.data[2]; -+ data[i * 2 + j + 3] = m.data[3]; -+ data[i * 2 + j + 4] = m.data[4]; -+ data[i * 2 + j + 5] = m.data[5]; -+ data[i * 2 + j + 6] = m.data[6]; -+ data[i * 2 + j + 7] = m.data[7]; -+ -+ return *this; -+ } -+ -+ /// Forms a 4-by-2 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0) -+ , lhs.at(1, 0), rhs.at(1, 0) -+ , lhs.at(2, 0), rhs.at(2, 0) -+ , lhs.at(3, 0), rhs.at(3, 0)); -+ } -+ -+ /// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 4-by-2 matrix to form a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 4-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1) -+ , lower.at(0, 0), lower.at(0, 1) -+ , lower.at(1, 0), lower.at(1, 1) -+ , lower.at(2, 0), lower.at(2, 1)); -+ } -+ -+ /// Forms a 4-by-2 matrix by vertically concatenating a 2-by-2 matrix with a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1) -+ , upper.at(1, 0), upper.at(1, 1) -+ , lower.at(0, 0), lower.at(0, 1) -+ , lower.at(1, 0), lower.at(1, 1)); -+ } -+ -+ /// Forms a 4-by-2 matrix by vertically concatenating a 3-by-2 matrix with a 1-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1) -+ , upper.at(1, 0), upper.at(1, 1) -+ , upper.at(2, 0), upper.at(2, 1) -+ , lower.at(0, 0), lower.at(0, 1)); -+ } -+ -+ /// Forms a 4-by-2 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Element B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A, B -+ , C.at(0, 0), D.at(0, 0) -+ , C.at(1, 0), D.at(1, 0) -+ , C.at(2, 0), D.at(2, 0) -+ ); -+ } -+ -+ /// Forms a 4-by-2 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0) -+ , A.at(1, 0), B.at(1, 0) -+ , C.at(0, 0), D.at(0, 0) -+ , C.at(1, 0), D.at(1, 0) -+ ); -+ } -+ -+ /// Forms a 4-by-2 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Element C, Element D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0) -+ , A.at(1, 0), B.at(1, 0) -+ , A.at(2, 0), B.at(2, 0) -+ , C, D -+ ); -+ } -+ -+ /// Elementwise add operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ -+ result.data[6] = data[6] + rhs.data[6]; -+ result.data[7] = data[7] + rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ -+ data[6] += rhs.data[6]; -+ data[7] += rhs.data[7]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ -+ result.data[6] = data[6] - rhs.data[6]; -+ result.data[7] = data[7] - rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ -+ data[6] -= rhs.data[6]; -+ data[7] -= rhs.data[7]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ -+ result.data[6] = data[6] * rhs.data[6]; -+ result.data[7] = data[7] * rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ -+ result.data[6] = data[6] * s; -+ result.data[7] = data[7] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ -+ data[2] *= s; -+ data[3] *= s; -+ -+ data[4] *= s; -+ data[5] *= s; -+ -+ data[6] *= s; -+ data[7] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ -+ result.data[6] = data[6] / rhs.data[6]; -+ result.data[7] = data[7] / rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ -+ result.data[6] = data[6] / s; -+ result.data[7] = data[7] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ -+ data[2] /= s; -+ data[3] /= s; -+ -+ data[4] /= s; -+ data[5] /= s; -+ -+ data[6] /= s; -+ data[7] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ -+ data[6] /= rhs.data[6]; -+ data[7] /= rhs.data[7]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ m.data[6] = -m.data[6]; -+ m.data[7] = -m.data[7]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 4-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[2] * rhs.data[0]; -+ accum.data[2] += data[4] * rhs.data[0]; -+ accum.data[3] += data[6] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[3] * rhs.data[1]; -+ accum.data[2] += data[5] * rhs.data[1]; -+ accum.data[3] += data[7] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[2] * rhs.data[0]; -+ accum.data[3] += data[2] * rhs.data[1]; -+ accum.data[4] += data[4] * rhs.data[0]; -+ accum.data[5] += data[4] * rhs.data[1]; -+ accum.data[6] += data[6] * rhs.data[0]; -+ accum.data[7] += data[6] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[3] * rhs.data[2]; -+ accum.data[3] += data[3] * rhs.data[3]; -+ accum.data[4] += data[5] * rhs.data[2]; -+ accum.data[5] += data[5] * rhs.data[3]; -+ accum.data[6] += data[7] * rhs.data[2]; -+ accum.data[7] += data[7] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 4-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[2] * rhs.data[0]; -+ accum.data[4] += data[2] * rhs.data[1]; -+ accum.data[5] += data[2] * rhs.data[2]; -+ accum.data[6] += data[4] * rhs.data[0]; -+ accum.data[7] += data[4] * rhs.data[1]; -+ accum.data[8] += data[4] * rhs.data[2]; -+ accum.data[9] += data[6] * rhs.data[0]; -+ accum.data[10] += data[6] * rhs.data[1]; -+ accum.data[11] += data[6] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[3] * rhs.data[3]; -+ accum.data[4] += data[3] * rhs.data[4]; -+ accum.data[5] += data[3] * rhs.data[5]; -+ accum.data[6] += data[5] * rhs.data[3]; -+ accum.data[7] += data[5] * rhs.data[4]; -+ accum.data[8] += data[5] * rhs.data[5]; -+ accum.data[9] += data[7] * rhs.data[3]; -+ accum.data[10] += data[7] * rhs.data[4]; -+ accum.data[11] += data[7] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[2] * rhs.data[0]; -+ accum.data[5] += data[2] * rhs.data[1]; -+ accum.data[6] += data[2] * rhs.data[2]; -+ accum.data[7] += data[2] * rhs.data[3]; -+ accum.data[8] += data[4] * rhs.data[0]; -+ accum.data[9] += data[4] * rhs.data[1]; -+ accum.data[10] += data[4] * rhs.data[2]; -+ accum.data[11] += data[4] * rhs.data[3]; -+ accum.data[12] += data[6] * rhs.data[0]; -+ accum.data[13] += data[6] * rhs.data[1]; -+ accum.data[14] += data[6] * rhs.data[2]; -+ accum.data[15] += data[6] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[3] * rhs.data[4]; -+ accum.data[5] += data[3] * rhs.data[5]; -+ accum.data[6] += data[3] * rhs.data[6]; -+ accum.data[7] += data[3] * rhs.data[7]; -+ accum.data[8] += data[5] * rhs.data[4]; -+ accum.data[9] += data[5] * rhs.data[5]; -+ accum.data[10] += data[5] * rhs.data[6]; -+ accum.data[11] += data[5] * rhs.data[7]; -+ accum.data[12] += data[7] * rhs.data[4]; -+ accum.data[13] += data[7] * rhs.data[5]; -+ accum.data[14] += data[7] * rhs.data[6]; -+ accum.data[15] += data[7] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ accum += data[6]; -+ accum += data[7]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ accum += data[6] * data[6]; -+ accum += data[7] * data[7]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[3]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 4-by-2 matrix -+template -+using Matrix4x2 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix4x2 make_Matrix4x2( -+ Element _0_0, Element _0_1, -+ Element _1_0, Element _1_1, -+ Element _2_0, Element _2_1, -+ Element _3_0, Element _3_1 -+) { -+ return Matrix4x2( -+ _0_0, _0_1, -+ _1_0, _1_1, -+ _2_0, _2_1, -+ _3_0, _3_1 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 4-by-3 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 4; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 3; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 12; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 4-by-3 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, -+ Element _1_0, Element _1_1, Element _1_2, -+ Element _2_0, Element _2_1, Element _2_2, -+ Element _3_0, Element _3_1, Element _3_2 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; -+ data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; -+ data[6] = _2_0; data[7] = _2_1; data[8] = _2_2; -+ data[9] = _3_0; data[10] = _3_1; data[11] = _3_2; -+ } -+ -+ /// Constucts a 4-by-3 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1, -+ Matrix const &row_2, -+ Matrix const &row_3 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_0.data[2]; -+ data[3] = row_1.data[0]; -+ data[4] = row_1.data[1]; -+ data[5] = row_1.data[2]; -+ data[6] = row_2.data[0]; -+ data[7] = row_2.data[1]; -+ data[8] = row_2.data[2]; -+ data[9] = row_3.data[0]; -+ data[10] = row_3.data[1]; -+ data[11] = row_3.data[2]; -+ } -+ -+ /// Static method to construct a 4-by-3 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1, -+ Matrix const &column_2 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_2.data[0]; -+ result.data[3] = column_0.data[1]; -+ result.data[4] = column_1.data[1]; -+ result.data[5] = column_2.data[1]; -+ result.data[6] = column_0.data[2]; -+ result.data[7] = column_1.data[2]; -+ result.data[8] = column_2.data[2]; -+ result.data[9] = column_0.data[3]; -+ result.data[10] = column_1.data[3]; -+ result.data[11] = column_2.data[3]; -+ return result; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ m.data[6] = s; -+ m.data[7] = s; -+ m.data[8] = s; -+ m.data[9] = s; -+ m.data[10] = s; -+ m.data[11] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[5] = diag.data[1]; -+ m.data[10] = diag.data[2]; -+ m.data[15] = diag.data[3]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[5] = diag.data[1]; -+ m.data[10] = diag.data[2]; -+ m.data[15] = diag.data[3]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[5]; -+ diag.data[2] = data[10]; -+ diag.data[3] = data[15]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[4] = data[1]; -+ mt.data[8] = data[2]; -+ mt.data[1] = data[3]; -+ mt.data[5] = data[4]; -+ mt.data[9] = data[5]; -+ mt.data[2] = data[6]; -+ mt.data[6] = data[7]; -+ mt.data[10] = data[8]; -+ mt.data[3] = data[9]; -+ mt.data[7] = data[10]; -+ mt.data[11] = data[11]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x3(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x3(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 3] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 3]; -+ m.data[3] = data[i * 3 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 3] = m.data[2]; -+ data[i * 3 + j + 4] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ m.data[3] = data[i * 3 + j + 3]; -+ m.data[4] = data[i * 3 + j + 4]; -+ m.data[5] = data[i * 3 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ data[i * 3 + j + 3] = m.data[3]; -+ data[i * 3 + j + 4] = m.data[4]; -+ data[i * 3 + j + 5] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 3]; -+ m.data[2] = data[i * 3 + j + 6]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 3] = m.data[1]; -+ data[i * 3 + j + 6] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 3]; -+ m.data[3] = data[i * 3 + j + 4]; -+ m.data[4] = data[i * 3 + j + 6]; -+ m.data[5] = data[i * 3 + j + 7]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 3] = m.data[2]; -+ data[i * 3 + j + 4] = m.data[3]; -+ data[i * 3 + j + 6] = m.data[4]; -+ data[i * 3 + j + 7] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ m.data[3] = data[i * 3 + j + 3]; -+ m.data[4] = data[i * 3 + j + 4]; -+ m.data[5] = data[i * 3 + j + 5]; -+ m.data[6] = data[i * 3 + j + 6]; -+ m.data[7] = data[i * 3 + j + 7]; -+ m.data[8] = data[i * 3 + j + 8]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ data[i * 3 + j + 3] = m.data[3]; -+ data[i * 3 + j + 4] = m.data[4]; -+ data[i * 3 + j + 5] = m.data[5]; -+ data[i * 3 + j + 6] = m.data[6]; -+ data[i * 3 + j + 7] = m.data[7]; -+ data[i * 3 + j + 8] = m.data[8]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 3]; -+ m.data[2] = data[i * 3 + j + 6]; -+ m.data[3] = data[i * 3 + j + 9]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 3] = m.data[1]; -+ data[i * 3 + j + 6] = m.data[2]; -+ data[i * 3 + j + 9] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_4x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_4x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 3]; -+ m.data[3] = data[i * 3 + j + 4]; -+ m.data[4] = data[i * 3 + j + 6]; -+ m.data[5] = data[i * 3 + j + 7]; -+ m.data[6] = data[i * 3 + j + 9]; -+ m.data[7] = data[i * 3 + j + 10]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 3] = m.data[2]; -+ data[i * 3 + j + 4] = m.data[3]; -+ data[i * 3 + j + 6] = m.data[4]; -+ data[i * 3 + j + 7] = m.data[5]; -+ data[i * 3 + j + 9] = m.data[6]; -+ data[i * 3 + j + 10] = m.data[7]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ m.data[3] = data[i * 3 + j + 3]; -+ m.data[4] = data[i * 3 + j + 4]; -+ m.data[5] = data[i * 3 + j + 5]; -+ m.data[6] = data[i * 3 + j + 6]; -+ m.data[7] = data[i * 3 + j + 7]; -+ m.data[8] = data[i * 3 + j + 8]; -+ m.data[9] = data[i * 3 + j + 9]; -+ m.data[10] = data[i * 3 + j + 10]; -+ m.data[11] = data[i * 3 + j + 11]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ data[i * 3 + j + 3] = m.data[3]; -+ data[i * 3 + j + 4] = m.data[4]; -+ data[i * 3 + j + 5] = m.data[5]; -+ data[i * 3 + j + 6] = m.data[6]; -+ data[i * 3 + j + 7] = m.data[7]; -+ data[i * 3 + j + 8] = m.data[8]; -+ data[i * 3 + j + 9] = m.data[9]; -+ data[i * 3 + j + 10] = m.data[10]; -+ data[i * 3 + j + 11] = m.data[11]; -+ -+ return *this; -+ } -+ -+ /// Forms a 4-by-3 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1) -+ , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1) -+ , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1) -+ , lhs.at(3, 0), rhs.at(3, 0), rhs.at(3, 1)); -+ } -+ -+ /// Forms a 4-by-3 matrix by horizontally concatenating a 4-by-2 matrix with a 4-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0) -+ , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0) -+ , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0) -+ , lhs.at(3, 0), lhs.at(3, 1), rhs.at(3, 0)); -+ } -+ -+ /// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 4-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2) -+ , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2) -+ , lower.at(2, 0), lower.at(2, 1), lower.at(2, 2)); -+ } -+ -+ /// Forms a 4-by-3 matrix by vertically concatenating a 2-by-3 matrix with a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) -+ , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2) -+ , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2)); -+ } -+ -+ /// Forms a 4-by-3 matrix by vertically concatenating a 3-by-3 matrix with a 1-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) -+ , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2) -+ , upper.at(2, 0), upper.at(2, 1), upper.at(2, 2) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)); -+ } -+ -+ /// Forms a 4-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A, B.at(0, 0), B.at(0, 1) -+ , C.at(0, 0), D.at(0, 0), D.at(0, 1) -+ , C.at(1, 0), D.at(1, 0), D.at(1, 1) -+ , C.at(2, 0), D.at(2, 0), D.at(2, 1) -+ ); -+ } -+ -+ /// Forms a 4-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Element B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0) -+ , C.at(1, 0), C.at(1, 1), D.at(1, 0) -+ , C.at(2, 0), C.at(2, 1), D.at(2, 0) -+ ); -+ } -+ -+ /// Forms a 4-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0), B.at(0, 1) -+ , A.at(1, 0), B.at(1, 0), B.at(1, 1) -+ , C.at(0, 0), D.at(0, 0), D.at(0, 1) -+ , C.at(1, 0), D.at(1, 0), D.at(1, 1) -+ ); -+ } -+ -+ /// Forms a 4-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0) -+ , A.at(1, 0), A.at(1, 1), B.at(1, 0) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0) -+ , C.at(1, 0), C.at(1, 1), D.at(1, 0) -+ ); -+ } -+ -+ /// Forms a 4-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Element C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0), B.at(0, 1) -+ , A.at(1, 0), B.at(1, 0), B.at(1, 1) -+ , A.at(2, 0), B.at(2, 0), B.at(2, 1) -+ , C, D.at(0, 0), D.at(0, 1) -+ ); -+ } -+ -+ /// Forms a 4-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Element D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0) -+ , A.at(1, 0), A.at(1, 1), B.at(1, 0) -+ , A.at(2, 0), A.at(2, 1), B.at(2, 0) -+ , C.at(0, 0), C.at(0, 1), D -+ ); -+ } -+ -+ /// Elementwise add operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ -+ result.data[3] = data[3] + rhs.data[3]; -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ -+ result.data[6] = data[6] + rhs.data[6]; -+ result.data[7] = data[7] + rhs.data[7]; -+ result.data[8] = data[8] + rhs.data[8]; -+ -+ result.data[9] = data[9] + rhs.data[9]; -+ result.data[10] = data[10] + rhs.data[10]; -+ result.data[11] = data[11] + rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ -+ data[3] += rhs.data[3]; -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ -+ data[6] += rhs.data[6]; -+ data[7] += rhs.data[7]; -+ data[8] += rhs.data[8]; -+ -+ data[9] += rhs.data[9]; -+ data[10] += rhs.data[10]; -+ data[11] += rhs.data[11]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ -+ result.data[3] = data[3] - rhs.data[3]; -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ -+ result.data[6] = data[6] - rhs.data[6]; -+ result.data[7] = data[7] - rhs.data[7]; -+ result.data[8] = data[8] - rhs.data[8]; -+ -+ result.data[9] = data[9] - rhs.data[9]; -+ result.data[10] = data[10] - rhs.data[10]; -+ result.data[11] = data[11] - rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ -+ data[3] -= rhs.data[3]; -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ -+ data[6] -= rhs.data[6]; -+ data[7] -= rhs.data[7]; -+ data[8] -= rhs.data[8]; -+ -+ data[9] -= rhs.data[9]; -+ data[10] -= rhs.data[10]; -+ data[11] -= rhs.data[11]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ -+ result.data[3] = data[3] * rhs.data[3]; -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ -+ result.data[6] = data[6] * rhs.data[6]; -+ result.data[7] = data[7] * rhs.data[7]; -+ result.data[8] = data[8] * rhs.data[8]; -+ -+ result.data[9] = data[9] * rhs.data[9]; -+ result.data[10] = data[10] * rhs.data[10]; -+ result.data[11] = data[11] * rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ -+ result.data[3] = data[3] * s; -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ -+ result.data[6] = data[6] * s; -+ result.data[7] = data[7] * s; -+ result.data[8] = data[8] * s; -+ -+ result.data[9] = data[9] * s; -+ result.data[10] = data[10] * s; -+ result.data[11] = data[11] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ -+ data[3] *= s; -+ data[4] *= s; -+ data[5] *= s; -+ -+ data[6] *= s; -+ data[7] *= s; -+ data[8] *= s; -+ -+ data[9] *= s; -+ data[10] *= s; -+ data[11] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ -+ result.data[3] = data[3] / rhs.data[3]; -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ -+ result.data[6] = data[6] / rhs.data[6]; -+ result.data[7] = data[7] / rhs.data[7]; -+ result.data[8] = data[8] / rhs.data[8]; -+ -+ result.data[9] = data[9] / rhs.data[9]; -+ result.data[10] = data[10] / rhs.data[10]; -+ result.data[11] = data[11] / rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ -+ result.data[3] = data[3] / s; -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ -+ result.data[6] = data[6] / s; -+ result.data[7] = data[7] / s; -+ result.data[8] = data[8] / s; -+ -+ result.data[9] = data[9] / s; -+ result.data[10] = data[10] / s; -+ result.data[11] = data[11] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ -+ data[3] /= s; -+ data[4] /= s; -+ data[5] /= s; -+ -+ data[6] /= s; -+ data[7] /= s; -+ data[8] /= s; -+ -+ data[9] /= s; -+ data[10] /= s; -+ data[11] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ -+ data[3] /= rhs.data[3]; -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ -+ data[6] /= rhs.data[6]; -+ data[7] /= rhs.data[7]; -+ data[8] /= rhs.data[8]; -+ -+ data[9] /= rhs.data[9]; -+ data[10] /= rhs.data[10]; -+ data[11] /= rhs.data[11]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ m.data[6] = -m.data[6]; -+ m.data[7] = -m.data[7]; -+ m.data[8] = -m.data[8]; -+ m.data[9] = -m.data[9]; -+ m.data[10] = -m.data[10]; -+ m.data[11] = -m.data[11]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 4-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[3] * rhs.data[0]; -+ accum.data[2] += data[6] * rhs.data[0]; -+ accum.data[3] += data[9] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[4] * rhs.data[1]; -+ accum.data[2] += data[7] * rhs.data[1]; -+ accum.data[3] += data[10] * rhs.data[1]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[2]; -+ accum.data[1] += data[5] * rhs.data[2]; -+ accum.data[2] += data[8] * rhs.data[2]; -+ accum.data[3] += data[11] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[3] * rhs.data[0]; -+ accum.data[3] += data[3] * rhs.data[1]; -+ accum.data[4] += data[6] * rhs.data[0]; -+ accum.data[5] += data[6] * rhs.data[1]; -+ accum.data[6] += data[9] * rhs.data[0]; -+ accum.data[7] += data[9] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[4] * rhs.data[2]; -+ accum.data[3] += data[4] * rhs.data[3]; -+ accum.data[4] += data[7] * rhs.data[2]; -+ accum.data[5] += data[7] * rhs.data[3]; -+ accum.data[6] += data[10] * rhs.data[2]; -+ accum.data[7] += data[10] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ accum.data[2] += data[5] * rhs.data[4]; -+ accum.data[3] += data[5] * rhs.data[5]; -+ accum.data[4] += data[8] * rhs.data[4]; -+ accum.data[5] += data[8] * rhs.data[5]; -+ accum.data[6] += data[11] * rhs.data[4]; -+ accum.data[7] += data[11] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[3] * rhs.data[0]; -+ accum.data[4] += data[3] * rhs.data[1]; -+ accum.data[5] += data[3] * rhs.data[2]; -+ accum.data[6] += data[6] * rhs.data[0]; -+ accum.data[7] += data[6] * rhs.data[1]; -+ accum.data[8] += data[6] * rhs.data[2]; -+ accum.data[9] += data[9] * rhs.data[0]; -+ accum.data[10] += data[9] * rhs.data[1]; -+ accum.data[11] += data[9] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[4] * rhs.data[3]; -+ accum.data[4] += data[4] * rhs.data[4]; -+ accum.data[5] += data[4] * rhs.data[5]; -+ accum.data[6] += data[7] * rhs.data[3]; -+ accum.data[7] += data[7] * rhs.data[4]; -+ accum.data[8] += data[7] * rhs.data[5]; -+ accum.data[9] += data[10] * rhs.data[3]; -+ accum.data[10] += data[10] * rhs.data[4]; -+ accum.data[11] += data[10] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ accum.data[3] += data[5] * rhs.data[6]; -+ accum.data[4] += data[5] * rhs.data[7]; -+ accum.data[5] += data[5] * rhs.data[8]; -+ accum.data[6] += data[8] * rhs.data[6]; -+ accum.data[7] += data[8] * rhs.data[7]; -+ accum.data[8] += data[8] * rhs.data[8]; -+ accum.data[9] += data[11] * rhs.data[6]; -+ accum.data[10] += data[11] * rhs.data[7]; -+ accum.data[11] += data[11] * rhs.data[8]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 4-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[3] * rhs.data[0]; -+ accum.data[5] += data[3] * rhs.data[1]; -+ accum.data[6] += data[3] * rhs.data[2]; -+ accum.data[7] += data[3] * rhs.data[3]; -+ accum.data[8] += data[6] * rhs.data[0]; -+ accum.data[9] += data[6] * rhs.data[1]; -+ accum.data[10] += data[6] * rhs.data[2]; -+ accum.data[11] += data[6] * rhs.data[3]; -+ accum.data[12] += data[9] * rhs.data[0]; -+ accum.data[13] += data[9] * rhs.data[1]; -+ accum.data[14] += data[9] * rhs.data[2]; -+ accum.data[15] += data[9] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[4] * rhs.data[4]; -+ accum.data[5] += data[4] * rhs.data[5]; -+ accum.data[6] += data[4] * rhs.data[6]; -+ accum.data[7] += data[4] * rhs.data[7]; -+ accum.data[8] += data[7] * rhs.data[4]; -+ accum.data[9] += data[7] * rhs.data[5]; -+ accum.data[10] += data[7] * rhs.data[6]; -+ accum.data[11] += data[7] * rhs.data[7]; -+ accum.data[12] += data[10] * rhs.data[4]; -+ accum.data[13] += data[10] * rhs.data[5]; -+ accum.data[14] += data[10] * rhs.data[6]; -+ accum.data[15] += data[10] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ accum.data[4] += data[5] * rhs.data[8]; -+ accum.data[5] += data[5] * rhs.data[9]; -+ accum.data[6] += data[5] * rhs.data[10]; -+ accum.data[7] += data[5] * rhs.data[11]; -+ accum.data[8] += data[8] * rhs.data[8]; -+ accum.data[9] += data[8] * rhs.data[9]; -+ accum.data[10] += data[8] * rhs.data[10]; -+ accum.data[11] += data[8] * rhs.data[11]; -+ accum.data[12] += data[11] * rhs.data[8]; -+ accum.data[13] += data[11] * rhs.data[9]; -+ accum.data[14] += data[11] * rhs.data[10]; -+ accum.data[15] += data[11] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ accum += data[6]; -+ accum += data[7]; -+ accum += data[8]; -+ accum += data[9]; -+ accum += data[10]; -+ accum += data[11]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ accum += data[6] * data[6]; -+ accum += data[7] * data[7]; -+ accum += data[8] * data[8]; -+ accum += data[9] * data[9]; -+ accum += data[10] * data[10]; -+ accum += data[11] * data[11]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[4]; -+ accum += data[8]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 4-by-3 matrix -+template -+using Matrix4x3 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix4x3 make_Matrix4x3( -+ Element _0_0, Element _0_1, Element _0_2, -+ Element _1_0, Element _1_1, Element _1_2, -+ Element _2_0, Element _2_1, Element _2_2, -+ Element _3_0, Element _3_1, Element _3_2 -+) { -+ return Matrix4x3( -+ _0_0, _0_1, _0_2, -+ _1_0, _1_1, _1_2, -+ _2_0, _2_1, _2_2, -+ _3_0, _3_1, _3_2 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 4-by-4 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 4; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 4; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 16; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 4-by-4 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3, -+ Element _1_0, Element _1_1, Element _1_2, Element _1_3, -+ Element _2_0, Element _2_1, Element _2_2, Element _2_3, -+ Element _3_0, Element _3_1, Element _3_2, Element _3_3 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; -+ data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; -+ data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3; -+ data[12] = _3_0; data[13] = _3_1; data[14] = _3_2; data[15] = _3_3; -+ } -+ -+ /// Constucts a 4-by-4 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1, -+ Matrix const &row_2, -+ Matrix const &row_3 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_0.data[2]; -+ data[3] = row_0.data[3]; -+ data[4] = row_1.data[0]; -+ data[5] = row_1.data[1]; -+ data[6] = row_1.data[2]; -+ data[7] = row_1.data[3]; -+ data[8] = row_2.data[0]; -+ data[9] = row_2.data[1]; -+ data[10] = row_2.data[2]; -+ data[11] = row_2.data[3]; -+ data[12] = row_3.data[0]; -+ data[13] = row_3.data[1]; -+ data[14] = row_3.data[2]; -+ data[15] = row_3.data[3]; -+ } -+ -+ /// Static method to construct a 4-by-4 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1, -+ Matrix const &column_2, -+ Matrix const &column_3 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_2.data[0]; -+ result.data[3] = column_3.data[0]; -+ result.data[4] = column_0.data[1]; -+ result.data[5] = column_1.data[1]; -+ result.data[6] = column_2.data[1]; -+ result.data[7] = column_3.data[1]; -+ result.data[8] = column_0.data[2]; -+ result.data[9] = column_1.data[2]; -+ result.data[10] = column_2.data[2]; -+ result.data[11] = column_3.data[2]; -+ result.data[12] = column_0.data[3]; -+ result.data[13] = column_1.data[3]; -+ result.data[14] = column_2.data[3]; -+ result.data[15] = column_3.data[3]; -+ return result; -+ } -+ -+ /// Constructs an identity matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix identity() { -+ Matrix m; -+ -+ m.data[0] = Element(1); -+ m.data[5] = Element(1); -+ m.data[10] = Element(1); -+ m.data[15] = Element(1); -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ m.data[6] = s; -+ m.data[7] = s; -+ m.data[8] = s; -+ m.data[9] = s; -+ m.data[10] = s; -+ m.data[11] = s; -+ m.data[12] = s; -+ m.data[13] = s; -+ m.data[14] = s; -+ m.data[15] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[5] = diag.data[1]; -+ m.data[10] = diag.data[2]; -+ m.data[15] = diag.data[3]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[5] = diag.data[1]; -+ m.data[10] = diag.data[2]; -+ m.data[15] = diag.data[3]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[5]; -+ diag.data[2] = data[10]; -+ diag.data[3] = data[15]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[4] = data[1]; -+ mt.data[8] = data[2]; -+ mt.data[12] = data[3]; -+ mt.data[1] = data[4]; -+ mt.data[5] = data[5]; -+ mt.data[9] = data[6]; -+ mt.data[13] = data[7]; -+ mt.data[2] = data[8]; -+ mt.data[6] = data[9]; -+ mt.data[10] = data[10]; -+ mt.data[14] = data[11]; -+ mt.data[3] = data[12]; -+ mt.data[7] = data[13]; -+ mt.data[11] = data[14]; -+ mt.data[15] = data[15]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x4(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x4(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 4] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 4]; -+ m.data[3] = data[i * 4 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 4] = m.data[2]; -+ data[i * 4 + j + 5] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 4]; -+ m.data[4] = data[i * 4 + j + 5]; -+ m.data[5] = data[i * 4 + j + 6]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 4] = m.data[3]; -+ data[i * 4 + j + 5] = m.data[4]; -+ data[i * 4 + j + 6] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ m.data[4] = data[i * 4 + j + 4]; -+ m.data[5] = data[i * 4 + j + 5]; -+ m.data[6] = data[i * 4 + j + 6]; -+ m.data[7] = data[i * 4 + j + 7]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ data[i * 4 + j + 4] = m.data[4]; -+ data[i * 4 + j + 5] = m.data[5]; -+ data[i * 4 + j + 6] = m.data[6]; -+ data[i * 4 + j + 7] = m.data[7]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 4]; -+ m.data[2] = data[i * 4 + j + 8]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 4] = m.data[1]; -+ data[i * 4 + j + 8] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 4]; -+ m.data[3] = data[i * 4 + j + 5]; -+ m.data[4] = data[i * 4 + j + 8]; -+ m.data[5] = data[i * 4 + j + 9]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 4] = m.data[2]; -+ data[i * 4 + j + 5] = m.data[3]; -+ data[i * 4 + j + 8] = m.data[4]; -+ data[i * 4 + j + 9] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 4]; -+ m.data[4] = data[i * 4 + j + 5]; -+ m.data[5] = data[i * 4 + j + 6]; -+ m.data[6] = data[i * 4 + j + 8]; -+ m.data[7] = data[i * 4 + j + 9]; -+ m.data[8] = data[i * 4 + j + 10]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 4] = m.data[3]; -+ data[i * 4 + j + 5] = m.data[4]; -+ data[i * 4 + j + 6] = m.data[5]; -+ data[i * 4 + j + 8] = m.data[6]; -+ data[i * 4 + j + 9] = m.data[7]; -+ data[i * 4 + j + 10] = m.data[8]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ m.data[4] = data[i * 4 + j + 4]; -+ m.data[5] = data[i * 4 + j + 5]; -+ m.data[6] = data[i * 4 + j + 6]; -+ m.data[7] = data[i * 4 + j + 7]; -+ m.data[8] = data[i * 4 + j + 8]; -+ m.data[9] = data[i * 4 + j + 9]; -+ m.data[10] = data[i * 4 + j + 10]; -+ m.data[11] = data[i * 4 + j + 11]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ data[i * 4 + j + 4] = m.data[4]; -+ data[i * 4 + j + 5] = m.data[5]; -+ data[i * 4 + j + 6] = m.data[6]; -+ data[i * 4 + j + 7] = m.data[7]; -+ data[i * 4 + j + 8] = m.data[8]; -+ data[i * 4 + j + 9] = m.data[9]; -+ data[i * 4 + j + 10] = m.data[10]; -+ data[i * 4 + j + 11] = m.data[11]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 4]; -+ m.data[2] = data[i * 4 + j + 8]; -+ m.data[3] = data[i * 4 + j + 12]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 4] = m.data[1]; -+ data[i * 4 + j + 8] = m.data[2]; -+ data[i * 4 + j + 12] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_4x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_4x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 4]; -+ m.data[3] = data[i * 4 + j + 5]; -+ m.data[4] = data[i * 4 + j + 8]; -+ m.data[5] = data[i * 4 + j + 9]; -+ m.data[6] = data[i * 4 + j + 12]; -+ m.data[7] = data[i * 4 + j + 13]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 4] = m.data[2]; -+ data[i * 4 + j + 5] = m.data[3]; -+ data[i * 4 + j + 8] = m.data[4]; -+ data[i * 4 + j + 9] = m.data[5]; -+ data[i * 4 + j + 12] = m.data[6]; -+ data[i * 4 + j + 13] = m.data[7]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 4]; -+ m.data[4] = data[i * 4 + j + 5]; -+ m.data[5] = data[i * 4 + j + 6]; -+ m.data[6] = data[i * 4 + j + 8]; -+ m.data[7] = data[i * 4 + j + 9]; -+ m.data[8] = data[i * 4 + j + 10]; -+ m.data[9] = data[i * 4 + j + 12]; -+ m.data[10] = data[i * 4 + j + 13]; -+ m.data[11] = data[i * 4 + j + 14]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 4] = m.data[3]; -+ data[i * 4 + j + 5] = m.data[4]; -+ data[i * 4 + j + 6] = m.data[5]; -+ data[i * 4 + j + 8] = m.data[6]; -+ data[i * 4 + j + 9] = m.data[7]; -+ data[i * 4 + j + 10] = m.data[8]; -+ data[i * 4 + j + 12] = m.data[9]; -+ data[i * 4 + j + 13] = m.data[10]; -+ data[i * 4 + j + 14] = m.data[11]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ m.data[4] = data[i * 4 + j + 4]; -+ m.data[5] = data[i * 4 + j + 5]; -+ m.data[6] = data[i * 4 + j + 6]; -+ m.data[7] = data[i * 4 + j + 7]; -+ m.data[8] = data[i * 4 + j + 8]; -+ m.data[9] = data[i * 4 + j + 9]; -+ m.data[10] = data[i * 4 + j + 10]; -+ m.data[11] = data[i * 4 + j + 11]; -+ m.data[12] = data[i * 4 + j + 12]; -+ m.data[13] = data[i * 4 + j + 13]; -+ m.data[14] = data[i * 4 + j + 14]; -+ m.data[15] = data[i * 4 + j + 15]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ data[i * 4 + j + 4] = m.data[4]; -+ data[i * 4 + j + 5] = m.data[5]; -+ data[i * 4 + j + 6] = m.data[6]; -+ data[i * 4 + j + 7] = m.data[7]; -+ data[i * 4 + j + 8] = m.data[8]; -+ data[i * 4 + j + 9] = m.data[9]; -+ data[i * 4 + j + 10] = m.data[10]; -+ data[i * 4 + j + 11] = m.data[11]; -+ data[i * 4 + j + 12] = m.data[12]; -+ data[i * 4 + j + 13] = m.data[13]; -+ data[i * 4 + j + 14] = m.data[14]; -+ data[i * 4 + j + 15] = m.data[15]; -+ -+ return *this; -+ } -+ -+ /// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2) -+ , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2) -+ , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1), rhs.at(2, 2) -+ , lhs.at(3, 0), rhs.at(3, 0), rhs.at(3, 1), rhs.at(3, 2)); -+ } -+ -+ /// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-2 matrix with a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1) -+ , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1) -+ , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0), rhs.at(2, 1) -+ , lhs.at(3, 0), lhs.at(3, 1), rhs.at(3, 0), rhs.at(3, 1)); -+ } -+ -+ /// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-3 matrix with a 4-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0) -+ , lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0) -+ , lhs.at(2, 0), lhs.at(2, 1), lhs.at(2, 2), rhs.at(2, 0) -+ , lhs.at(3, 0), lhs.at(3, 1), lhs.at(3, 2), rhs.at(3, 0)); -+ } -+ -+ /// Forms a 4-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3) -+ , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3) -+ , lower.at(2, 0), lower.at(2, 1), lower.at(2, 2), lower.at(2, 3)); -+ } -+ -+ /// Forms a 4-by-4 matrix by vertically concatenating a 2-by-4 matrix with a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) -+ , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3) -+ , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3)); -+ } -+ -+ /// Forms a 4-by-4 matrix by vertically concatenating a 3-by-4 matrix with a 1-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) -+ , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3) -+ , upper.at(2, 0), upper.at(2, 1), upper.at(2, 2), upper.at(2, 3) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A, B.at(0, 0), B.at(0, 1), B.at(0, 2) -+ , C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2) -+ , C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2) -+ , C.at(2, 0), D.at(2, 0), D.at(2, 1), D.at(2, 2) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) -+ , C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1) -+ , C.at(2, 0), C.at(2, 1), D.at(2, 0), D.at(2, 1) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Element B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), A.at(0, 2), B -+ , C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0) -+ , C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0) -+ , C.at(2, 0), C.at(2, 1), C.at(2, 2), D.at(2, 0) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2) -+ , A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2) -+ , C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2) -+ , C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) -+ , A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) -+ , C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0) -+ , A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0) -+ , C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0) -+ , C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Element C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2) -+ , A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2) -+ , A.at(2, 0), B.at(2, 0), B.at(2, 1), B.at(2, 2) -+ , C, D.at(0, 0), D.at(0, 1), D.at(0, 2) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) -+ , A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1) -+ , A.at(2, 0), A.at(2, 1), B.at(2, 0), B.at(2, 1) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Element D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0) -+ , A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0) -+ , A.at(2, 0), A.at(2, 1), A.at(2, 2), B.at(2, 0) -+ , C.at(0, 0), C.at(0, 1), C.at(0, 2), D -+ ); -+ } -+ -+ /// Elementwise add operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ result.data[6] = data[6] + rhs.data[6]; -+ result.data[7] = data[7] + rhs.data[7]; -+ -+ result.data[8] = data[8] + rhs.data[8]; -+ result.data[9] = data[9] + rhs.data[9]; -+ result.data[10] = data[10] + rhs.data[10]; -+ result.data[11] = data[11] + rhs.data[11]; -+ -+ result.data[12] = data[12] + rhs.data[12]; -+ result.data[13] = data[13] + rhs.data[13]; -+ result.data[14] = data[14] + rhs.data[14]; -+ result.data[15] = data[15] + rhs.data[15]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ data[6] += rhs.data[6]; -+ data[7] += rhs.data[7]; -+ -+ data[8] += rhs.data[8]; -+ data[9] += rhs.data[9]; -+ data[10] += rhs.data[10]; -+ data[11] += rhs.data[11]; -+ -+ data[12] += rhs.data[12]; -+ data[13] += rhs.data[13]; -+ data[14] += rhs.data[14]; -+ data[15] += rhs.data[15]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ result.data[6] = data[6] - rhs.data[6]; -+ result.data[7] = data[7] - rhs.data[7]; -+ -+ result.data[8] = data[8] - rhs.data[8]; -+ result.data[9] = data[9] - rhs.data[9]; -+ result.data[10] = data[10] - rhs.data[10]; -+ result.data[11] = data[11] - rhs.data[11]; -+ -+ result.data[12] = data[12] - rhs.data[12]; -+ result.data[13] = data[13] - rhs.data[13]; -+ result.data[14] = data[14] - rhs.data[14]; -+ result.data[15] = data[15] - rhs.data[15]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ data[6] -= rhs.data[6]; -+ data[7] -= rhs.data[7]; -+ -+ data[8] -= rhs.data[8]; -+ data[9] -= rhs.data[9]; -+ data[10] -= rhs.data[10]; -+ data[11] -= rhs.data[11]; -+ -+ data[12] -= rhs.data[12]; -+ data[13] -= rhs.data[13]; -+ data[14] -= rhs.data[14]; -+ data[15] -= rhs.data[15]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ result.data[6] = data[6] * rhs.data[6]; -+ result.data[7] = data[7] * rhs.data[7]; -+ -+ result.data[8] = data[8] * rhs.data[8]; -+ result.data[9] = data[9] * rhs.data[9]; -+ result.data[10] = data[10] * rhs.data[10]; -+ result.data[11] = data[11] * rhs.data[11]; -+ -+ result.data[12] = data[12] * rhs.data[12]; -+ result.data[13] = data[13] * rhs.data[13]; -+ result.data[14] = data[14] * rhs.data[14]; -+ result.data[15] = data[15] * rhs.data[15]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ result.data[6] = data[6] * s; -+ result.data[7] = data[7] * s; -+ -+ result.data[8] = data[8] * s; -+ result.data[9] = data[9] * s; -+ result.data[10] = data[10] * s; -+ result.data[11] = data[11] * s; -+ -+ result.data[12] = data[12] * s; -+ result.data[13] = data[13] * s; -+ result.data[14] = data[14] * s; -+ result.data[15] = data[15] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ data[3] *= s; -+ -+ data[4] *= s; -+ data[5] *= s; -+ data[6] *= s; -+ data[7] *= s; -+ -+ data[8] *= s; -+ data[9] *= s; -+ data[10] *= s; -+ data[11] *= s; -+ -+ data[12] *= s; -+ data[13] *= s; -+ data[14] *= s; -+ data[15] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ result.data[6] = data[6] / rhs.data[6]; -+ result.data[7] = data[7] / rhs.data[7]; -+ -+ result.data[8] = data[8] / rhs.data[8]; -+ result.data[9] = data[9] / rhs.data[9]; -+ result.data[10] = data[10] / rhs.data[10]; -+ result.data[11] = data[11] / rhs.data[11]; -+ -+ result.data[12] = data[12] / rhs.data[12]; -+ result.data[13] = data[13] / rhs.data[13]; -+ result.data[14] = data[14] / rhs.data[14]; -+ result.data[15] = data[15] / rhs.data[15]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ result.data[6] = data[6] / s; -+ result.data[7] = data[7] / s; -+ -+ result.data[8] = data[8] / s; -+ result.data[9] = data[9] / s; -+ result.data[10] = data[10] / s; -+ result.data[11] = data[11] / s; -+ -+ result.data[12] = data[12] / s; -+ result.data[13] = data[13] / s; -+ result.data[14] = data[14] / s; -+ result.data[15] = data[15] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ data[3] /= s; -+ -+ data[4] /= s; -+ data[5] /= s; -+ data[6] /= s; -+ data[7] /= s; -+ -+ data[8] /= s; -+ data[9] /= s; -+ data[10] /= s; -+ data[11] /= s; -+ -+ data[12] /= s; -+ data[13] /= s; -+ data[14] /= s; -+ data[15] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ data[6] /= rhs.data[6]; -+ data[7] /= rhs.data[7]; -+ -+ data[8] /= rhs.data[8]; -+ data[9] /= rhs.data[9]; -+ data[10] /= rhs.data[10]; -+ data[11] /= rhs.data[11]; -+ -+ data[12] /= rhs.data[12]; -+ data[13] /= rhs.data[13]; -+ data[14] /= rhs.data[14]; -+ data[15] /= rhs.data[15]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ m.data[6] = -m.data[6]; -+ m.data[7] = -m.data[7]; -+ m.data[8] = -m.data[8]; -+ m.data[9] = -m.data[9]; -+ m.data[10] = -m.data[10]; -+ m.data[11] = -m.data[11]; -+ m.data[12] = -m.data[12]; -+ m.data[13] = -m.data[13]; -+ m.data[14] = -m.data[14]; -+ m.data[15] = -m.data[15]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 4-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[4] * rhs.data[0]; -+ accum.data[2] += data[8] * rhs.data[0]; -+ accum.data[3] += data[12] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[5] * rhs.data[1]; -+ accum.data[2] += data[9] * rhs.data[1]; -+ accum.data[3] += data[13] * rhs.data[1]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[2]; -+ accum.data[1] += data[6] * rhs.data[2]; -+ accum.data[2] += data[10] * rhs.data[2]; -+ accum.data[3] += data[14] * rhs.data[2]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[3]; -+ accum.data[1] += data[7] * rhs.data[3]; -+ accum.data[2] += data[11] * rhs.data[3]; -+ accum.data[3] += data[15] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[4] * rhs.data[0]; -+ accum.data[3] += data[4] * rhs.data[1]; -+ accum.data[4] += data[8] * rhs.data[0]; -+ accum.data[5] += data[8] * rhs.data[1]; -+ accum.data[6] += data[12] * rhs.data[0]; -+ accum.data[7] += data[12] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[5] * rhs.data[2]; -+ accum.data[3] += data[5] * rhs.data[3]; -+ accum.data[4] += data[9] * rhs.data[2]; -+ accum.data[5] += data[9] * rhs.data[3]; -+ accum.data[6] += data[13] * rhs.data[2]; -+ accum.data[7] += data[13] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ accum.data[2] += data[6] * rhs.data[4]; -+ accum.data[3] += data[6] * rhs.data[5]; -+ accum.data[4] += data[10] * rhs.data[4]; -+ accum.data[5] += data[10] * rhs.data[5]; -+ accum.data[6] += data[14] * rhs.data[4]; -+ accum.data[7] += data[14] * rhs.data[5]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[6]; -+ accum.data[1] += data[3] * rhs.data[7]; -+ accum.data[2] += data[7] * rhs.data[6]; -+ accum.data[3] += data[7] * rhs.data[7]; -+ accum.data[4] += data[11] * rhs.data[6]; -+ accum.data[5] += data[11] * rhs.data[7]; -+ accum.data[6] += data[15] * rhs.data[6]; -+ accum.data[7] += data[15] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[4] * rhs.data[0]; -+ accum.data[4] += data[4] * rhs.data[1]; -+ accum.data[5] += data[4] * rhs.data[2]; -+ accum.data[6] += data[8] * rhs.data[0]; -+ accum.data[7] += data[8] * rhs.data[1]; -+ accum.data[8] += data[8] * rhs.data[2]; -+ accum.data[9] += data[12] * rhs.data[0]; -+ accum.data[10] += data[12] * rhs.data[1]; -+ accum.data[11] += data[12] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[5] * rhs.data[3]; -+ accum.data[4] += data[5] * rhs.data[4]; -+ accum.data[5] += data[5] * rhs.data[5]; -+ accum.data[6] += data[9] * rhs.data[3]; -+ accum.data[7] += data[9] * rhs.data[4]; -+ accum.data[8] += data[9] * rhs.data[5]; -+ accum.data[9] += data[13] * rhs.data[3]; -+ accum.data[10] += data[13] * rhs.data[4]; -+ accum.data[11] += data[13] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ accum.data[3] += data[6] * rhs.data[6]; -+ accum.data[4] += data[6] * rhs.data[7]; -+ accum.data[5] += data[6] * rhs.data[8]; -+ accum.data[6] += data[10] * rhs.data[6]; -+ accum.data[7] += data[10] * rhs.data[7]; -+ accum.data[8] += data[10] * rhs.data[8]; -+ accum.data[9] += data[14] * rhs.data[6]; -+ accum.data[10] += data[14] * rhs.data[7]; -+ accum.data[11] += data[14] * rhs.data[8]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[9]; -+ accum.data[1] += data[3] * rhs.data[10]; -+ accum.data[2] += data[3] * rhs.data[11]; -+ accum.data[3] += data[7] * rhs.data[9]; -+ accum.data[4] += data[7] * rhs.data[10]; -+ accum.data[5] += data[7] * rhs.data[11]; -+ accum.data[6] += data[11] * rhs.data[9]; -+ accum.data[7] += data[11] * rhs.data[10]; -+ accum.data[8] += data[11] * rhs.data[11]; -+ accum.data[9] += data[15] * rhs.data[9]; -+ accum.data[10] += data[15] * rhs.data[10]; -+ accum.data[11] += data[15] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[4] * rhs.data[0]; -+ accum.data[5] += data[4] * rhs.data[1]; -+ accum.data[6] += data[4] * rhs.data[2]; -+ accum.data[7] += data[4] * rhs.data[3]; -+ accum.data[8] += data[8] * rhs.data[0]; -+ accum.data[9] += data[8] * rhs.data[1]; -+ accum.data[10] += data[8] * rhs.data[2]; -+ accum.data[11] += data[8] * rhs.data[3]; -+ accum.data[12] += data[12] * rhs.data[0]; -+ accum.data[13] += data[12] * rhs.data[1]; -+ accum.data[14] += data[12] * rhs.data[2]; -+ accum.data[15] += data[12] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[5] * rhs.data[4]; -+ accum.data[5] += data[5] * rhs.data[5]; -+ accum.data[6] += data[5] * rhs.data[6]; -+ accum.data[7] += data[5] * rhs.data[7]; -+ accum.data[8] += data[9] * rhs.data[4]; -+ accum.data[9] += data[9] * rhs.data[5]; -+ accum.data[10] += data[9] * rhs.data[6]; -+ accum.data[11] += data[9] * rhs.data[7]; -+ accum.data[12] += data[13] * rhs.data[4]; -+ accum.data[13] += data[13] * rhs.data[5]; -+ accum.data[14] += data[13] * rhs.data[6]; -+ accum.data[15] += data[13] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ accum.data[4] += data[6] * rhs.data[8]; -+ accum.data[5] += data[6] * rhs.data[9]; -+ accum.data[6] += data[6] * rhs.data[10]; -+ accum.data[7] += data[6] * rhs.data[11]; -+ accum.data[8] += data[10] * rhs.data[8]; -+ accum.data[9] += data[10] * rhs.data[9]; -+ accum.data[10] += data[10] * rhs.data[10]; -+ accum.data[11] += data[10] * rhs.data[11]; -+ accum.data[12] += data[14] * rhs.data[8]; -+ accum.data[13] += data[14] * rhs.data[9]; -+ accum.data[14] += data[14] * rhs.data[10]; -+ accum.data[15] += data[14] * rhs.data[11]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[12]; -+ accum.data[1] += data[3] * rhs.data[13]; -+ accum.data[2] += data[3] * rhs.data[14]; -+ accum.data[3] += data[3] * rhs.data[15]; -+ accum.data[4] += data[7] * rhs.data[12]; -+ accum.data[5] += data[7] * rhs.data[13]; -+ accum.data[6] += data[7] * rhs.data[14]; -+ accum.data[7] += data[7] * rhs.data[15]; -+ accum.data[8] += data[11] * rhs.data[12]; -+ accum.data[9] += data[11] * rhs.data[13]; -+ accum.data[10] += data[11] * rhs.data[14]; -+ accum.data[11] += data[11] * rhs.data[15]; -+ accum.data[12] += data[15] * rhs.data[12]; -+ accum.data[13] += data[15] * rhs.data[13]; -+ accum.data[14] += data[15] * rhs.data[14]; -+ accum.data[15] += data[15] * rhs.data[15]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ accum += data[6]; -+ accum += data[7]; -+ accum += data[8]; -+ accum += data[9]; -+ accum += data[10]; -+ accum += data[11]; -+ accum += data[12]; -+ accum += data[13]; -+ accum += data[14]; -+ accum += data[15]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ accum += data[6] * data[6]; -+ accum += data[7] * data[7]; -+ accum += data[8] * data[8]; -+ accum += data[9] * data[9]; -+ accum += data[10] * data[10]; -+ accum += data[11] * data[11]; -+ accum += data[12] * data[12]; -+ accum += data[13] * data[13]; -+ accum += data[14] * data[14]; -+ accum += data[15] * data[15]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[5]; -+ accum += data[10]; -+ accum += data[15]; -+ -+ return accum; -+ } -+ -+ /// Returns 4-by-4 rotation matrix around the X axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation_X(Element theta) { -+ Matrix m = identity(); -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ m.at(1, 1) = c; -+ m.at(1, 2) = -s; -+ m.at(2, 1) = s; -+ m.at(2, 2) = c; -+ -+ return m; -+ } -+ -+ /// Returns 4-by-4 rotation matrix around the Y axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation_Y(Element theta) { -+ Matrix m = identity(); -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ m.at(0, 0) = c; -+ m.at(2, 0) = -s; -+ m.at(0, 2) = s; -+ m.at(2, 2) = c; -+ -+ return m; -+ } -+ -+ /// Returns 4-by-4 rotation matrix around the Z axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation_Z(Element theta) { -+ Matrix m = Matrix::identity(); -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ m.at(0, 0) = c; -+ m.at(0, 1) = -s; -+ m.at(1, 0) = s; -+ m.at(1, 1) = c; -+ -+ return m; -+ } -+ -+ /// Returns a 4-by-4 rotation matrix around a unit-length axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation(Element theta, Matrix const &u) { -+ Element x = u.data[0]; -+ Element y = u.data[1]; -+ Element z = u.data[2]; -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ Element one_minus_cos = Element(1) - fast_cos(theta); -+ -+ Matrix m; -+ -+ m.set_slice3x3({ -+ c + x * x * one_minus_cos, x * y * one_minus_cos - z * s, x * z * one_minus_cos + y * s, -+ y * x * one_minus_cos * z * s, c + y * y * one_minus_cos, y * z * one_minus_cos - x * s, -+ z * x * one_minus_cos - y * s, z * y * one_minus_cos + x * s, c + z * z * one_minus_cos -+ }); -+ -+ return m; -+ } -+ -+ /// Returns a 4-by-4 reflection about the plane specified by the -+ /// unit-length normal vector n_unit -+ CUTLASS_HOST_DEVICE -+ static Matrix reflection(Matrix const &n_unit) { -+ -+ Element a = n_unit.data[0]; -+ Element b = n_unit.data[1]; -+ Element c = n_unit.data[2]; -+ -+ Matrix m = Matrix::identity(); -+ -+ m.set_slice3x3({ -+ Element(1) - Element(2) * a * a, Element(-2) * a * b, Element(-2) * a * c, -+ Element(-2) * a * b, Element(1) - Element(2) * b * b, Element(-2) * b * c, -+ Element(-2) * a * c, Element(-2) * b * c, Element(1) - Element(2) * c * c -+ }); -+ -+ return m; -+ } -+ -+ /// Returns a perspective projection matrix typical of OpenGL applications -+ CUTLASS_HOST_DEVICE -+ static Matrix perspective(Element near_plane, Element far_plane, Element fovH, Element fovV) { -+ Element aspect = fovH / fovV; -+ Element f = Element(cos(fovV)) / Element(fovH); -+ Element Q = near_plane - far_plane; -+ -+ return Matrix( -+ f / aspect, 0, 0, 0, -+ 0, f, 0, 0, -+ 0, 0, (near_plane + far_plane) / Q, Element(2) * far_plane * near_plane / Q, -+ 0, 0, -1, 0 -+ ); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Matrix translation(Matrix const &v) { -+ return Matrix( -+ 1, 0, 0, v.data[0], -+ 0, 1, 0, v.data[1], -+ 0, 0, 1, v.data[2], -+ 0, 0, 0, 1 -+ ); -+ } -+ -+ /// Computes the determinant of a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Element determinant(Element accum = Element()) const { -+ -+ accum += at(0, 0) * Matrix({ at(1, 1), at(1, 2), at(1, 3), at(2, 1), at(2, 2), at(2, 3), at(3, 1), at(3, 2), at(3, 3) }).determinant(); -+ accum -= at(0, 1) * Matrix({ at(1, 0), at(1, 2), at(1, 3), at(2, 0), at(2, 2), at(2, 3), at(3, 0), at(3, 2), at(3, 3) }).determinant(); -+ accum += at(0, 2) * Matrix({ at(1, 0), at(1, 1), at(1, 3), at(2, 0), at(2, 1), at(2, 3), at(3, 0), at(3, 1), at(3, 3) }).determinant(); -+ accum -= at(0, 3) * Matrix({ at(1, 0), at(1, 1), at(1, 2), at(2, 0), at(2, 1), at(2, 2), at(3, 0), at(3, 1), at(3, 2) }).determinant(); -+ -+ return accum; -+ } -+ -+ /// Computes the inverse of a 4-by-4 matrix (ignores the optional argument) -+ CUTLASS_HOST_DEVICE -+ Matrix inverse(Element ignore = 1) const { -+ Matrix B = slice_2x2(0, 2); -+ Matrix A = slice_2x2(0, 0); -+ Matrix C = slice_2x2(2, 0); -+ Matrix D = slice_2x2(2, 2); -+ -+ Matrix D_inv = D.inverse(); -+ -+ Matrix E = (A - B * D_inv * C).inverse(); -+ -+ return Matrix::block( -+ E, -E * B * D_inv, -+ -D_inv * C * E, D_inv + D_inv * C * E * B * D_inv -+ ); -+ } -+ -+}; -+ -+/// Template alias for 4-by-4 matrix -+template -+using Matrix4x4 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix4x4 make_Matrix4x4( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3, -+ Element _1_0, Element _1_1, Element _1_2, Element _1_3, -+ Element _2_0, Element _2_1, Element _2_2, Element _2_3, -+ Element _3_0, Element _3_1, Element _3_2, Element _3_3 -+) { -+ return Matrix4x4( -+ _0_0, _0_1, _0_2, _0_3, -+ _1_0, _1_1, _1_2, _1_3, -+ _2_0, _2_1, _2_2, _2_3, -+ _3_0, _3_1, _3_2, _3_3 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Elementwise scalar multiplication -+template -+CUTLASS_HOST_DEVICE -+Matrix operator*(Element s, Matrix const &rhs) { -+ return rhs.multiply(s); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/matrix_coord.h b/3rdparty/cutlass/include/cutlass/matrix_coord.h -new file mode 100644 -index 0000000..1563575 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/matrix_coord.h -@@ -0,0 +1,164 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a canonical coordinate for rank=2 matrices offering named indices. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes -+/// expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord. -+struct MatrixCoord : public Coord<2, int> { -+ -+public: -+ -+ /// Integer-valued index -+ using Index = int; -+ -+ /// Base type is a Coord of rank=2 -+ using Base = Coord<2, Index>; -+ -+ /// LongIndex type -+ using LongIndex = typename Base::LongIndex; -+ -+private: -+ -+ /// Rows dimension -+ static int const kRow = 0; -+ -+ /// Columns dimension -+ static int const kColumn = 1; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ MatrixCoord() { } -+ -+ /// Constructs from Coord<2> -+ CUTLASS_HOST_DEVICE -+ MatrixCoord(Coord<2, Index> const &coord): Base(coord) { } -+ -+ /// Helper to construct from a row and column -+ CUTLASS_HOST_DEVICE -+ MatrixCoord(Index row, Index column): Base(make_Coord(row, column)) { } -+ -+ /// Helper to construct from a row and column, which are LongIndex based -+ CUTLASS_HOST_DEVICE -+ MatrixCoord(LongIndex row, LongIndex column): Base(make_Coord(Index(row), Index(column))) { } -+ -+ /// Returns the row of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & row() const { return this->at(kRow); } -+ -+ /// Returns the row of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & row() { return this->at(kRow); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & column() const { return this->at(kColumn); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & column() { return this->at(kColumn); } -+ -+ // -+ // Coord operators -+ // -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ MatrixCoord operator+(Base const& b) const { -+ return MatrixCoord(Base::operator+(b)); -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ MatrixCoord operator-(Base const& b) const { -+ return MatrixCoord(Base::operator-(b)); -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ MatrixCoord operator*(Base const& b) const { -+ return MatrixCoord(Base::operator*(b)); -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ MatrixCoord operator/(Base const& b) const { -+ return MatrixCoord(Base::operator/(b)); -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ MatrixCoord& operator+=(Base const& b) { -+ Base::operator+=(b); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ MatrixCoord& operator-=(Base const& b) { -+ Base::operator-=(b); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ MatrixCoord& operator*=(Base const& b) { -+ Base::operator*=(b); -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ MatrixCoord& operator/=(Base const& b) { -+ Base::operator/=(b); -+ return *this; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/matrix_shape.h b/3rdparty/cutlass/include/cutlass/matrix_shape.h -new file mode 100644 -index 0000000..deae47c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/matrix_shape.h -@@ -0,0 +1,65 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a Shape template for matrix tiles -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Describes the size of a matrix tile -+template < -+ int Row_, ///< rows of a matrix -+ int Column_ ///< columns of a matrix -+> -+struct MatrixShape { -+ static int const kRow = Row_; ///< rows of a matrix -+ static int const kColumn = Column_; ///< columns of a matrix -+ static int const kCount = Row_ * Column_; ///< total number of elements in a matrix -+ -+ // -+ // Static member functions -+ // -+ -+ CUTLASS_HOST_DEVICE -+ static Coord<2> toCoord() { -+ return make_Coord(kRow, kColumn); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/numeric_conversion.h b/3rdparty/cutlass/include/cutlass/numeric_conversion.h -new file mode 100644 -index 0000000..3095cec ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/numeric_conversion.h -@@ -0,0 +1,2481 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Boost-like numeric conversion operator for CUTLASS numeric types -+*/ -+#pragma once -+ -+#if !defined(__CUDACC_RTC__) -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/thread/unary_op.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/half.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Floating-point rounding style similare to Standard Library's formats but supporting -+/// additional rounding options. -+enum class FloatRoundStyle { -+ round_indeterminate, ///< rounding mode unknown -+ round_toward_zero, ///< round toward zero -+ round_to_nearest, ///< round to nearest even -+ round_toward_infinity, ///< round toward infinity -+ round_toward_neg_infinity, ///< round toward negative infinity -+ round_half_ulp_truncate, ///< add 0.5ulp to integer representation then round toward zero -+ round_half_ulp_trunc_dntz ///< like round_half_ulp_truncate, except denorms are rounded *toward* zero -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename T, -+ typename S, -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+struct NumericConverter { -+ -+ using result_type = T; -+ using source_type = S; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ return static_cast(s); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for float => int32_t -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(__CUDA_ARCH__) -+template <> -+struct NumericConverter { -+ -+ using result_type = int32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ return __float2int_rn(s); -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ -+ using result_type = int32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ return __float2int_rz(s); -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+#elif !defined(__CUDACC_RTC__) -+ -+template <> -+struct NumericConverter { -+ -+ using result_type = int32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ static result_type convert(source_type const & s) { -+ std::fesetround(FE_TONEAREST); -+ return (result_type)std::nearbyint(s); -+ } -+ -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ -+ using result_type = int32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ static result_type convert(source_type const & s) { -+ std::fesetround(FE_TOWARDZERO); -+ return (result_type)std::nearbyint(s); -+ } -+ -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for float => int8_t -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(__CUDA_ARCH__) -+template <> -+struct NumericConverter { -+ -+ using result_type = int8_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ int32_t intermediate; -+ asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); -+ -+ return static_cast(intermediate); -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ -+ using result_type = int8_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ int32_t intermediate; -+ asm volatile("cvt.rzi.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); -+ -+ return static_cast(intermediate); -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+#elif !defined(__CUDACC_RTC__) -+ -+template <> -+struct NumericConverter { -+ -+ using result_type = int8_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ static result_type convert(source_type const & s) { -+ std::fesetround(FE_TONEAREST); -+ int32_t intermediate = (int32_t)std::nearbyint(s); -+ -+ // Low-end saturation -+ intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); -+ -+ // High-end saturation -+ intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); -+ -+ return static_cast(intermediate); -+ } -+ -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ -+ using result_type = int8_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ static result_type convert(source_type const & s) { -+ std::fesetround(FE_TOWARDZERO); -+ int32_t intermediate = (int32_t)std::nearbyint(s); -+ -+ // Low-end saturation -+ intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); -+ -+ // High-end saturation -+ intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); -+ -+ return static_cast(intermediate); -+ } -+ -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for float <= half_t -+template -+struct NumericConverter { -+ -+ using result_type = T; -+ using source_type = T; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ return s; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for float <=> half_t -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for float <= half_t -+template -+struct NumericConverter { -+ -+ using result_type = float; -+ using source_type = half_t; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ result_type result = static_cast(s); -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Specialization for round-to-nearest -+template <> -+struct NumericConverter { -+ -+ using result_type = half_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ result_type result = static_cast(s); -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Specialization for round-toward-zero -+template <> -+struct NumericConverter { -+ -+ using result_type = half_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ /// Round toward zero -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & flt) { -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__float2half_rz(flt)); -+ #else -+ // software implementation rounds toward nearest even -+ unsigned const& s = reinterpret_cast(flt); -+ uint16_t sign = uint16_t((s >> 16) & 0x8000); -+ int16_t exp = uint16_t(((s >> 23) & 0xff) - 127); -+ int mantissa = s & 0x7fffff; -+ uint16_t u = 0; -+ -+ if ((s & 0x7fffffff) == 0) { -+ // sign-preserving zero -+ return half_t::bitcast(sign); -+ } -+ -+ if (exp > 15) { -+ if (exp == 128 && mantissa) { -+ // not a number -+ u = 0x7fff; -+ } else { -+ // overflow to infinity -+ u = sign | 0x7c00; -+ } -+ return half_t::bitcast(u); -+ } -+ -+ if (exp >= -14) { -+ // normal fp32 to normal fp16 -+ exp = uint16_t(exp + uint16_t(15)); -+ u = uint16_t(((exp & 0x1f) << 10)); -+ u = uint16_t(u | (mantissa >> 13)); -+ } else { -+ // normal single-precision to subnormal half_t-precision representation -+ int rshift = (-14 - exp); -+ if (rshift < 32) { -+ mantissa |= (1 << 23); -+ mantissa = (mantissa >> rshift); -+ u = (uint16_t(mantissa >> 13) & 0x3ff); -+ } else { -+ mantissa = 0; -+ u = 0; -+ } -+ } -+ -+ u |= sign; -+ -+ return half_t::bitcast(u); -+ -+ #endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for float <=> bfloat16_t -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for float <= bfloat16_t -+template -+struct NumericConverter { -+ -+ using result_type = float; -+ using source_type = bfloat16_t; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ return static_cast(s); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ using result_type = bfloat16_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ return static_cast(s); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ using result_type = bfloat16_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_half_ulp_truncate; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ uint32_t x32 = reinterpret_cast(s); -+ -+ #if defined(__CUDA_ARCH__) -+ if (::isfinite(s)) { -+ x32 += 0x8000; -+ } -+ #else -+ if (std::isfinite(s)) { -+ x32 += 0x8000; -+ } -+ #endif -+ -+ uint16_t x16 = uint16_t((x32 >> 16) & 0xffff); -+ return bfloat16_t::bitcast(x16); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ using result_type = bfloat16_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ uint32_t x32 = reinterpret_cast(s); -+ uint16_t x16 = uint16_t(x32 >> 16); -+ -+ return bfloat16_t::bitcast(x16); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for float <=> tfloat32_t -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for float <= tfloat32_t -+template -+struct NumericConverter { -+ -+ using result_type = float; -+ using source_type = tfloat32_t; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ return static_cast(s); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ using result_type = tfloat32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ unsigned storage = reinterpret_cast(s); -+ -+ if ((storage & 0x7f800000) != 0x7f800000) { -+ -+ bool mantissa_bit = ((storage & (1 << 13)) != 0); -+ bool round_bit = ((storage & (1 << 12)) != 0); -+ bool sticky_bit = ((storage & ((1 << 12) - 1)) != 0); -+ -+ if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) { -+ storage += uint32_t(1 << 13); -+ } -+ -+ // Note, the following is intentionally commented out. TF32 -+ // does not define the low order bits, so they may be left in -+ // an undefined state. -+ // -+ // By not truncating these bit explicitly, we avoid an extra logical -+ // operation. -+ // -+ // TF32 may be implicitly converted to float by performing this -+ // operation as needed. -+ // -+ // storage = (storage & ~0x1fff); -+ } -+ else if (storage & ~0xff800000) { -+ storage = 0x7fffffff; -+ } -+ -+ return tfloat32_t::bitcast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ using result_type = tfloat32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_half_ulp_truncate; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ return tfloat32_t::round_half_ulp_truncate(s); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// This rounding operation is similar to half_ulp_truncate except it rounds denorms toward zero. -+/// It avoids predicated code, though it requires a temporary register. -+template <> -+struct NumericConverter { -+ using result_type = tfloat32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_half_ulp_trunc_dntz; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ unsigned y = reinterpret_cast(s); -+ y = y & 0xff800000; -+ float d = reinterpret_cast(y); -+ float z = d / float(1 << 11) + s; -+ -+ return reinterpret_cast(z); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ using result_type = tfloat32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ uint32_t x = reinterpret_cast(s); -+ return tfloat32_t::bitcast(x & 0xffffe000); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Conversion operator for float to tfloat32_t big and small values -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ FloatRoundStyle RoundBig = FloatRoundStyle::round_toward_zero, -+ FloatRoundStyle RoundSmall = FloatRoundStyle::round_half_ulp_truncate -+> -+struct NumericConverterFastF32 { -+ -+ // result_type holds big tfloat32_t at idx(0) and small tfloat32_t at idx(1) -+ using result_type = Array; -+ -+ // source data type -+ using source_type = float; -+ -+ // rounding styles for big and small part -+ static FloatRoundStyle const kRoundBig = RoundBig; -+ static FloatRoundStyle const kRoundSmall = RoundSmall; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ result_type result; -+ NumericConverter convert_big_; -+ NumericConverter convert_small_; -+ -+ // convert and fill tfloat32_t big at idx 0 -+ result[0] = convert_big_(source); -+ -+ // convert and fill tfloat32_t small at idx 1 -+ result[1] = convert_small_(source - static_cast(result[0])); -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Conversion and Clamp operator for Integers -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename T, -+ typename S -+> -+struct NumericConverterClamp { -+ -+ using result_type = T; -+ using source_type = S; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ NumericConverter convert_op; -+ result_type const kClamp_max = platform::numeric_limits::max(); -+ result_type const kClamp_min = platform::numeric_limits::lowest(); -+ if (s < (source_type)kClamp_min) -+ return kClamp_min; -+ if (s > (source_type)kClamp_max) -+ return kClamp_max; -+ return convert_op(s); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Conversion operator for Array -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Conversion operator for Array -+template < -+ typename T, -+ typename S, -+ int N, -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, -+ typename Transform = cutlass::transform::thread::UnaryTransform::Identity -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ static_assert(platform::is_same::value || -+ platform::is_same::value, -+ "Unary Operator not supported."); -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ result_type result; -+ NumericConverter convert_; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ if( platform::is_same::value ) -+ { -+ result[i] = convert_(s[i]); -+ } else { // conjugate -+ result[i] = conj(convert_(s[i])); -+ } -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template < -+ typename T, -+ int N, -+ FloatRoundStyle Round, -+ typename Transform -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ static_assert(platform::is_same::value || -+ platform::is_same::value, -+ "Unary Operator not supported."); -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ if( platform::is_same::value ) -+ { -+ return s; -+ } else { -+ result_type result; -+ for (int i = 0; i < N; ++i) { -+ result[i] = conj(s[i]); -+ } -+ return result; -+ } -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array, round to nearest -+template <> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ Array result; -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ reinterpret_cast<__half2 &>(result) = __float22half2_rn(reinterpret_cast(source)); -+ #else -+ NumericConverter convert_; -+ result[0] = convert_(source[0]); -+ result[1] = convert_(source[1]); -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array, round to nearest -+template -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ Array result; -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ reinterpret_cast(result) = __half22float2(reinterpret_cast<__half2 const &>(source)); -+ #else -+ NumericConverter convert_; -+ result[0] = convert_(source[0]); -+ result[1] = convert_(source[1]); -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ NumericConverter convert_element_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ if (N % 2) { -+ result[N - 1] = convert_element_(source[N - 1]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ NumericConverter convert_element_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ if (N % 2) { -+ result[N - 1] = convert_element_(source[N - 1]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array, round to nearest -+template <> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ unsigned d; -+ -+ asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(d) : "f"(source[1]), "f"(source[0]) ); -+ -+ return reinterpret_cast(d); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ NumericConverter convert_element_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ if (N % 2) { -+ result[N - 1] = convert_element_(source[N - 1]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+#endif // if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conditional guards to enable partial specialization for packed integers -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \ -+ ((__CUDACC_VER_MAJOR__ > 10) || \ -+ ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ NumericConverter convert_element_; -+ -+ result_type result; -+ -+ result[0] = convert_element_(source[0]); -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ uint32_t tmp; -+ -+ asm volatile( -+ "cvt.pack.sat.s8.s32.b32 %0, %2, %1, 0;\n" -+ : "=r"(tmp) : "r"(source[0]), "r"(source[1])); -+ -+ uint16_t out = (tmp & 0xffff); -+ return reinterpret_cast(out); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ unsigned out; -+ -+ asm volatile( -+ "{ .reg .u32 r4;" -+ "cvt.pack.sat.s8.s32.b32 r4, %4, %3, 0;" -+ "cvt.pack.sat.s8.s32.b32 %0, %2, %1, r4;" -+ "}" -+ : "=r"(out) : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3])); -+ -+ return reinterpret_cast(out); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ static_assert(!(N % 4), "N must be multiple of 4."); -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 4; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ NumericConverter convert_element_; -+ -+ result_type result; -+ -+ result[0] = convert_element_(source[0]); -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ uint32_t tmp; -+ -+ asm volatile( -+ "cvt.pack.sat.u8.s32.b32 %0, %2, %1, 0;\n" -+ : "=r"(tmp) : "r"(source[0]), "r"(source[1])); -+ -+ uint16_t out = (tmp & 0xffff); -+ return reinterpret_cast(out); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ unsigned out; -+ -+ asm volatile( -+ "{ .reg .u32 r4;" -+ "cvt.pack.sat.u8.s32.b32 r4, %4, %3, 0;" -+ "cvt.pack.sat.u8.s32.b32 %0, %2, %1, r4;" -+ "}" -+ : "=r"(out) : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3])); -+ -+ return reinterpret_cast(out); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ static_assert(!(N % 4), "N must be multiple of 4."); -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 4; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for Array <=> Array -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float; -+ using source_element = float_e4m3_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out_fp16[2]; -+ uint32_t const& src_packed = reinterpret_cast(source); -+ -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo, hi;\n" \ -+ "mov.b32 {lo, hi}, %2;\n" \ -+ "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ -+ "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ -+ "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); -+ -+ float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); -+ float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); -+ -+ result_type out; -+ out[0] = res0.x; -+ out[1] = res0.y; -+ out[2] = res1.x; -+ out[3] = res1.y; -+ return out; -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e4m3_t; -+ using source_element = float; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out; -+ -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo;\n" \ -+ ".reg .b16 hi;\n" \ -+ "cvt.rn.satfinite.e4m3x2.f32 lo, %2, %1;\n" \ -+ "cvt.rn.satfinite.e4m3x2.f32 hi, %4, %3;\n" \ -+ "mov.b32 %0, {lo, hi};\n" \ -+ "}" \ -+ : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); -+ -+ return reinterpret_cast(out); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float; -+ using source_element = float_e5m2_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out_fp16[2]; -+ uint32_t const& src_packed = reinterpret_cast(source); -+ -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo, hi;\n" \ -+ "mov.b32 {lo, hi}, %2;\n" \ -+ "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ -+ "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ -+ "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); -+ -+ float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); -+ float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); -+ -+ result_type out; -+ out[0] = res0.x; -+ out[1] = res0.y; -+ out[2] = res1.x; -+ out[3] = res1.y; -+ return out; -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e5m2_t; -+ using source_element = float; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out; -+ -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo;\n" \ -+ ".reg .b16 hi;\n" \ -+ "cvt.rn.satfinite.e5m2x2.f32 lo, %2, %1;\n" \ -+ "cvt.rn.satfinite.e5m2x2.f32 hi, %4, %3;\n" \ -+ "mov.b32 %0, {lo, hi};\n" \ -+ "}" \ -+ : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); -+ -+ return reinterpret_cast(out); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for Array <=> Array -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = half_t; -+ using source_element = float_e4m3_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out[2]; -+ uint32_t const& src_packed = reinterpret_cast(source); -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo, hi;\n" \ -+ "mov.b32 {lo, hi}, %2;\n" \ -+ "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ -+ "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ -+ "}\n" : "=r"(out[0]), "=r"(out[1]) : "r"(src_packed)); -+ return reinterpret_cast(out); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e4m3_t; -+ using source_element = half_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out; -+ uint32_t const* src_packed = reinterpret_cast(&source); -+ -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo;\n" \ -+ ".reg .b16 hi;\n" \ -+ "cvt.rn.satfinite.e4m3x2.f16x2 lo, %1;\n" \ -+ "cvt.rn.satfinite.e4m3x2.f16x2 hi, %2;\n" \ -+ "mov.b32 %0, {lo, hi};\n" \ -+ "}" \ -+ : "=r"(out) : "r"(src_packed[0]), "r"(src_packed[1])); -+ -+ return reinterpret_cast(out); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = half_t; -+ using source_element = float_e5m2_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out[2]; -+ uint32_t const& src_packed = reinterpret_cast(source); -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo, hi;\n" \ -+ "mov.b32 {lo, hi}, %2;\n" \ -+ "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ -+ "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ -+ "}\n" : "=r"(out[0]), "=r"(out[1]) : "r"(src_packed)); -+ return reinterpret_cast(out); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e5m2_t; -+ using source_element = half_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out; -+ uint32_t const* src_packed = reinterpret_cast(&source); -+ -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo;\n" \ -+ ".reg .b16 hi;\n" \ -+ "cvt.rn.satfinite.e5m2x2.f16x2 lo, %1;\n" \ -+ "cvt.rn.satfinite.e5m2x2.f16x2 hi, %2;\n" \ -+ "mov.b32 %0, {lo, hi};\n" \ -+ "}" \ -+ : "=r"(out) : "r"(src_packed[0]), "r"(src_packed[1])); -+ -+ return reinterpret_cast(out); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for Array <=> Array -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = bfloat16_t; -+ using source_element = float_e4m3_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ // Convert f8 to float -+ NumericArrayConverter src2float; -+ Array tmp_floats = src2float(source); -+ -+ // Convert float to bf16 -+ result_type out; -+ Array* packed_tmp = reinterpret_cast*>(&tmp_floats); -+ Array* packed_out = reinterpret_cast*>(&out); -+ NumericArrayConverter float2result; -+ packed_out[0] = float2result(packed_tmp[0]); -+ packed_out[1] = float2result(packed_tmp[1]); -+ -+ return out; -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e4m3_t; -+ using source_element = bfloat16_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ // Convert bf16 to float -+ Array tmp; -+ Array* packed_tmp = reinterpret_cast*>(&tmp); -+ Array const* packed_source = reinterpret_cast const*>(&source); -+ NumericArrayConverter src2float; -+ packed_tmp[0] = src2float(packed_source[0]); -+ packed_tmp[1] = src2float(packed_source[1]); -+ -+ // Convert float to f8 -+ NumericArrayConverter float2result; -+ return float2result(tmp); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = bfloat16_t; -+ using source_element = float_e5m2_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ // Convert f8 to float -+ NumericArrayConverter src2float; -+ Array tmp_floats = src2float(source); -+ -+ // Convert float to bf16 -+ result_type out; -+ Array* packed_tmp = reinterpret_cast*>(&tmp_floats); -+ Array* packed_out = reinterpret_cast*>(&out); -+ NumericArrayConverter float2result; -+ packed_out[0] = float2result(packed_tmp[0]); -+ packed_out[1] = float2result(packed_tmp[1]); -+ -+ return out; -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e5m2_t; -+ using source_element = bfloat16_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ // Convert bf16 to float -+ Array tmp; -+ Array* packed_tmp = reinterpret_cast*>(&tmp); -+ Array const* packed_source = reinterpret_cast const*>(&source); -+ NumericArrayConverter src2float; -+ packed_tmp[0] = src2float(packed_source[0]); -+ packed_tmp[1] = src2float(packed_source[1]); -+ -+ // Convert float to f8 -+ NumericArrayConverter float2result; -+ return float2result(tmp); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for Array <=> Array -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e4m3_t; -+ using source_element = float_e5m2_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e5m2_t; -+ using source_element = float_e4m3_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for: -+// Array <=> Array -+// Array <=> Array -+// -+// These are needed to avoid multiple-matching-template compilation errors (e.g., when -+// compiling float_e4m3_t <=> float_e4m3_t, which among T <= float_e4m3_t and float_e4m3_t <= T -+// should be used?) -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e4m3_t; -+ using source_element = float_e4m3_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return s; -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e5m2_t; -+ using source_element = float_e5m2_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specialziations for: -+// Array <=> Array -+// Array <=> Array -+// using packed converter under the hood -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename T, -+ typename S, -+ int N, -+ FloatRoundStyle Round -+> -+struct PackedNumericArrayConverter { -+ using result_element = T; -+ using source_element = S; -+ -+ using result_type = Array; -+ using source_type = Array; -+ -+ static FloatRoundStyle const round_style = Round; -+ -+private: -+ using packed_result_type = Array; -+ using packed_source_type = Array; -+ -+public: -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ result_type result; -+ packed_result_type* packed_result = reinterpret_cast(&result); -+ const packed_source_type* packed_source = reinterpret_cast(&source); -+ -+ NumericArrayConverter packed_converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 4; ++i) { -+ packed_result[i] = packed_converter(packed_source[i]); -+ } -+ -+ // Handle leftovers -+ NumericConverter converter; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N % 4; ++i) { -+ int idx = ((N / 4) * 4) + i; -+ result[idx] = converter(source[idx]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ typename T, -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ typename T, -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ typename S, -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ typename S, -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+/// Conversion is performed with saturation regardless of setting of -+/// the `Round` template parameter. -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ // Convert float to int -+ Array temporary; -+ -+ NumericArrayConverter compute_converter; -+ temporary = compute_converter(source); -+ -+ // Convert to int to int8_t -+ NumericArrayConverter destination_converter; -+ return destination_converter(temporary); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ -+ ((__CUDACC_VER_MAJOR__ > 10) || \ -+ ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ unsigned out; -+ -+ asm volatile( -+ "{ .reg .u32 r4;" -+ "cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;" -+ "cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;" -+ "cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;" -+ "cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;" -+ "}" -+ : "=r"(out) -+ : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]), -+ "r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7])); -+ -+ return reinterpret_cast(out); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ static_assert(!(N % 8), "N must be multiple of 8."); -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 8; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ unsigned out; -+ -+ asm volatile( -+ "{ .reg .u32 r4;" -+ "cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" -+ "cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;" -+ "cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;" -+ "cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" -+ "}" -+ : "=r"(out) -+ : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]), -+ "r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7])); -+ -+ return reinterpret_cast(out); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ static_assert(!(N % 8), "N must be multiple of 8."); -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 8; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+#endif // Conditional guards to enable partial specialization for packed integers -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// FastNumericArrayConverter only works when the source is within center range. -+/// Conversion operator for Array. See the comments before -+/// FastLinearCombinationClamp. -+template -+struct FastNumericArrayConverter { -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const &s) { -+ result_type result; -+ NumericArrayConverter convert_; -+ -+ return convert_(s); -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { return convert(s); } -+}; -+ -+/// Partial specialization for Array <= Array -+template -+struct FastNumericArrayConverter { -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const &source) { -+ result_type result; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ int tmp = source[i] + 1262485504 /*0x4B400000*/; -+ result[i] = reinterpret_cast(tmp) - 12582912.0f; -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { return convert(s); } -+}; -+ -+/// Partial specialization for Array <= Array -+template -+struct FastNumericArrayConverter { -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const &source) { -+ Array result; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ float tmp = source[i] + 12582912.0f; -+ result[i] = reinterpret_cast(tmp); -+ } -+ -+ result[0] = __byte_perm(result[0], result[1], 0x40); -+ result[2] = __byte_perm(result[2], result[3], 0x40); -+ result[0] = __byte_perm(result[0], result[2], 0x5410); -+ -+ return reinterpret_cast(result[0]); -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { return convert(s); } -+}; -+ -+/// Partial specialization for Array <= Array -+template -+struct FastNumericArrayConverter { -+ static_assert(!(N % 4), "N must be multiple of 4."); -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const &source) { -+ FastNumericArrayConverter convert_vector_; -+ -+ result_type result; -+ -+ Array *result_ptr = -+ reinterpret_cast *>(&result); -+ Array const *source_ptr = -+ reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 4; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { return convert(s); } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines preferred rounding mode for a pair of types -+template -+struct PreferredRoundingMode { -+ static FloatRoundStyle const kRound = FloatRoundStyle::round_to_nearest; -+}; -+ -+/// Defines preferred rounding mode for a pair of types -+template <> -+struct PreferredRoundingMode { -+ static FloatRoundStyle const kRound = FloatRoundStyle::round_half_ulp_truncate; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Packs predicates into an array. -+template -+struct PackPredicates { -+ using result_type = Array; -+ -+ static_assert(!(N % 4), "Must pack predicates in a count that is a multiple of 4"); -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(bool const predicates[]) { -+ -+ result_type packed; -+ packed.clear(); -+ -+ int const kWordSize = 8; -+ uint8_t *bytes = reinterpret_cast(packed.data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ int word_idx = (i / kWordSize); -+ int bit_idx = (i % kWordSize); -+ -+ uint8_t mask = ((predicates[i] ? 1u : 0u) << bit_idx); -+ bytes[word_idx] = (bytes[word_idx] | mask); -+ } -+ return packed; -+ } -+}; -+ -+/// Packs predicates into an array -+template -+struct UnpackPredicates { -+ using result_type = Array; -+ -+ static_assert(!(N % 4), "Must unpack predicates in a count that is a multiple of 4"); -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(bool predicates[], result_type const &packed) { -+ -+ int const kWordSize = 8; -+ uint8_t const *bytes = reinterpret_cast(packed.data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ int word_idx = (i / kWordSize); -+ int bit_idx = (i % kWordSize); -+ -+ predicates[i] = bool((bytes[word_idx] >> bit_idx) & 0x1); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/numeric_types.h b/3rdparty/cutlass/include/cutlass/numeric_types.h -new file mode 100644 -index 0000000..55555ec ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/numeric_types.h -@@ -0,0 +1,94 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Top-level include for all CUTLASS numeric types. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the size of an element in bits -+template -+struct sizeof_bits { -+ static int const value = int(sizeof(T) * 8); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Definitions for 1-bit binary and 4-bit integer types -+// -+ -+/// 1-bit binary type -+using bin1_t = bool; -+ -+/// Defines the size of an element in bits - specialized for bin1_t -+template <> -+struct sizeof_bits { -+ static int const value = 1; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct index_sequence; -+ -+template -+struct index_sequence_helper : index_sequence_helper {}; -+ -+template -+struct index_sequence_helper<0, 0, Next...> { -+ using type = index_sequence<0, Next...>; -+}; -+ -+template -+using make_index_sequence = typename index_sequence_helper::type; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/integer_subbyte.h" -+ -+#include "cutlass/half.h" -+#include "cutlass/bfloat16.h" -+#include "cutlass/tfloat32.h" -+#include "cutlass/float8.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/pipeline.hpp b/3rdparty/cutlass/include/cutlass/pipeline.hpp -new file mode 100644 -index 0000000..67538ae ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/pipeline.hpp -@@ -0,0 +1,529 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Redistribution and use in source and binary forms, with or without modification, are not permit- -+ * ted. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cute/numeric/integral_constant.hpp" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+using namespace arch; -+using namespace cute; -+ -+// Circular Buffer Index + Associated Phase -+// Assumes only one operation possible - i.e., ++ -+template -+struct PipelineState { -+ -+ static constexpr uint32_t Stages = Stages_; -+ -+private: -+ int index_ = 0; -+ uint32_t phase_ = 0; -+ -+public: -+ CUTLASS_DEVICE -+ PipelineState(): index_{}, phase_{} {} -+ -+ CUTLASS_DEVICE -+ PipelineState(int index, uint32_t phase) -+ : index_(index) -+ , phase_(phase){} -+ -+ CUTLASS_DEVICE -+ int index() const { -+ return index_; -+ } -+ -+ CUTLASS_DEVICE -+ uint32_t phase() const { -+ return phase_; -+ } -+ -+ CUTLASS_DEVICE -+ void operator++() { -+ ++index_; -+ if (index_ == Stages) { -+ index_ = 0; -+ phase_ ^= 1; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ PipelineState& operator=(const PipelineState& other) { -+ index_ = other.index(); -+ phase_ = other.phase(); -+ return *this; -+ } -+ -+ CUTLASS_DEVICE -+ PipelineState advance(uint32_t num_iterations) { -+ // Number of iterations cross over the stage boundary => flipped phase -+ if ((num_iterations < Stages) && (index_ + num_iterations) >= Stages ) { -+ phase_ ^= 1; -+ } -+ // How many times number of iterations cross over the stage boundary and -+ // end up on a odd number => flipped phase -+ if ((num_iterations >= Stages) && (((index_ + num_iterations) / Stages) % 2) == 1) { -+ phase_ ^= 1; -+ } -+ index_ = (index_ + num_iterations) % Stages; -+ return *this; -+ } -+ -+ CUTLASS_DEVICE -+ static PipelineState make_pipeline_state(PipelineState start_state, uint32_t num_iterations) { -+ return start_state.advance(num_iterations); -+ } -+}; -+ -+template -+CUTLASS_DEVICE -+PipelineState make_producer_start_state() -+{ -+ // Producer starts with an opposite phase as the buffer are initially empty -+ constexpr int InitialProducerStage = 0; -+ constexpr uint32_t InitialProducerPhase = 1; -+ return {InitialProducerStage, InitialProducerPhase}; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// TMA (producer) Async Pipeline class -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// Assumptions : Constructor is Visible Cluster-wide (as it needs a Cluster-Sync) -+// We have exactly one thread elected in the Producer as the "leader" -+// Currently, it is optional to elect a leader for the Consumers -+template -+class PipelineTmaAsync { -+public : -+ using ClusterShape = ClusterShape_; -+ using FullBarrier = ClusterTransactionBarrier; -+ using EmptyBarrier = ClusterBarrier; -+ using ValueType = FullBarrier::ValueType; -+ static constexpr uint32_t Stages = Stages_; -+ -+ struct SharedStorage { -+ FullBarrier full_barrier_[Stages]; -+ EmptyBarrier empty_barrier_[Stages]; -+ }; -+ -+ enum class ThreadCategory { -+ NonParticipant, -+ Producer, -+ Consumer, -+ ProducerConsumer -+ }; -+ -+ struct Params { -+ uint32_t transaction_bytes = 0; -+ ThreadCategory role = ThreadCategory::NonParticipant; -+ uint32_t is_leader = 0; -+ uint32_t num_consumers = 0; -+ }; -+ -+private : -+ // -+ // Data Members -+ // -+ uint32_t dst_blockid_ = 0; -+ uint32_t is_signalling_thread_ = 0; -+ FullBarrier *full_barrier_ptr_ = nullptr; -+ EmptyBarrier *empty_barrier_ptr_ = nullptr; -+ Params params_; -+ -+ // -+ // Methods -+ // -+ -+public: -+ // Constructor -+ CUTLASS_DEVICE -+ PipelineTmaAsync(SharedStorage& storage, Params params) -+ : params_(params) -+ , full_barrier_ptr_(&storage.full_barrier_[0]) -+ , empty_barrier_ptr_(&storage.empty_barrier_[0]) { -+ -+ int warp_idx = canonical_warp_idx(); -+ int lane_predicate = cute::elect_one_sync(); -+ auto cluster_shape = ClusterShape{}; -+ -+ if (warp_idx == 0 && lane_predicate == 1) { -+ // Barrier FULL init -+ for (int i = 0; i < Stages; ++i) { -+ full_barrier_ptr_[i].init(1); -+ } -+ -+ // Barrier EMPTY init -+ uint32_t const num_consumers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; -+ for (int i = 0; i < Stages; ++i) { -+ empty_barrier_ptr_[i].init(num_consumers); -+ } -+ } -+ -+ // Logic to optimally schedule Empty Arrives -+ // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) -+ dim3 block_id = block_id_in_cluster(); -+ auto cluster_size = cute::size(cluster_shape); -+ static constexpr int MaxClusterSize = 16; -+ static_assert(cluster_size <= MaxClusterSize, "ERROR : Cluster size too large !" ); -+ -+ // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) -+ if (params_.num_consumers == 128) { -+ int thread_idx = threadIdx.x % 128; -+ is_signalling_thread_ = (thread_idx % (128 / MaxClusterSize)) == 0; -+ auto layout = cute::composition(Swizzle<2,0,-2>{}, -+ Layout,Stride<_4, _1>>{}); -+ uint32_t thread_row = warp_idx % 4; -+ uint32_t thread_col = (thread_idx / 8) % 4; -+ dst_blockid_ = layout(thread_row, thread_col); -+ } -+ else if (params_.num_consumers == 32){ -+ int thread_idx = threadIdx.x % 32; -+ is_signalling_thread_ = (thread_idx % (32 / MaxClusterSize)) == 0; -+ auto layout = Layout,Stride<_4, _1>>{}; -+ uint32_t thread_row = thread_idx / 8; -+ uint32_t thread_col = (thread_idx % 8) / 2; -+ dst_blockid_ = layout(thread_row, thread_col); -+ } -+ else { -+ is_signalling_thread_ = 0; -+ } -+ -+ // STEP 2: Find if this dst block-id needs an arrival for this problem -+ is_signalling_thread_ &= dst_blockid_ < cluster_size; -+ is_signalling_thread_ &= is_same_row_or_col(dst_blockid_, block_id, cluster_shape); -+ -+ cutlass::arch::fence_barrier_init(); -+ } -+ -+ CUTLASS_DEVICE -+ void producer_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait = false) { -+ // 1. Wait for empty barrier to be ready -+ // 2. Set the transaction bytes set to occur on the Full barrier -+ uint32_t done = empty_barrier_ptr_[stage].test_wait(phase, (!skip_wait)); -+ if ((!done) && (!skip_wait)){ -+ empty_barrier_ptr_[stage].wait(phase); -+ } -+ -+ if (params_.is_leader) { -+ full_barrier_ptr_[stage].arrive_and_reset_bytes(params_.transaction_bytes); -+ } -+ -+ } -+ -+ CUTLASS_DEVICE -+ void producer_acquire(PipelineState state) { -+ producer_acquire(state.index(), state.phase()); -+ } -+ -+ // NOP for TMA based mainloop -+ CUTLASS_DEVICE -+ void producer_commit(uint32_t stage, uint32_t bytes) { -+ // Below code is used only for unit-testing (in the absennce of TMA commit) -+ #if CUTLASS_UNIT_TEST_PIPELINE -+ if (params_.is_leader) { -+ // STEP 1 : Commit to self -+ full_barrier_ptr_[stage].commit(bytes); -+ -+ // STEP 2 : Commit to other blocks in our cluster -+ auto cluster_shape = ClusterShape{}; -+ Layout block_layout_in_cluster = make_layout(cluster_shape); -+ dim3 local_block_id = cute::block_id_in_cluster(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int n = 0; n < size<1>(block_layout_in_cluster); ++n) { -+ uint32_t dst_block_id = block_layout_in_cluster(local_block_id.x,n,Int<0>{}); -+ full_barrier_ptr_[stage].commit(dst_block_id, bytes, n!=local_block_id.y); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int m = 0; m < size<0>(block_layout_in_cluster); ++m) { -+ uint32_t dst_block_id = block_layout_in_cluster(m,local_block_id.y,Int<0>{}); -+ full_barrier_ptr_[stage].commit(dst_block_id, bytes, m!=local_block_id.x); -+ } -+ } -+ #endif -+ } -+ -+ CUTLASS_DEVICE -+ void producer_commit(PipelineState state, uint32_t bytes) { -+ producer_commit(state.index(), bytes); -+ } -+ -+ -+ // Wait for producer to commit transactions (done by TMA) -+ CUTLASS_DEVICE -+ void consumer_wait(uint32_t stage, uint32_t phase) { -+ uint32_t done = full_barrier_ptr_[stage].test_wait(phase); -+ if (!done){ -+ full_barrier_ptr_[stage].wait(phase); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void consumer_wait(PipelineState state) { -+ consumer_wait(state.index(), state.phase()); -+ } -+ -+ // Consumer signalling Producer of completion -+ // Ensures all blocks in the Same Row and Column get notifed. -+ CUTLASS_DEVICE -+ void consumer_release(uint32_t stage, uint32_t skip = false) { -+ empty_barrier_ptr_[stage].arrive(dst_blockid_, is_signalling_thread_ & (!skip)); -+ } -+ -+ CUTLASS_DEVICE -+ void consumer_release(PipelineState state) { -+ consumer_release(state.index()); -+ } -+ -+ CUTLASS_DEVICE -+ ValueType* producer_get_barrier(uint32_t stage) { -+ return reinterpret_cast(&full_barrier_ptr_[stage]); -+ } -+ -+ CUTLASS_DEVICE -+ bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { -+ return ((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x || -+ (dst_block_id / cute::size<0>(cluster_shape)) == block_id.y); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Simple producer-consumer async Pipeline class -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// *Count Signifies the number of producers / consumers who will announce their completion -+ -+template -+class PipelineAsync { -+public : -+ using FullBarrier = ClusterBarrier; -+ using EmptyBarrier = ClusterBarrier; -+ using ProducerBarrierType = FullBarrier::ValueType; -+ static constexpr uint32_t Stages = Stages_; -+ -+ struct SharedStorage { -+ FullBarrier full_barrier_[Stages]; -+ EmptyBarrier empty_barrier_[Stages]; -+ }; -+ -+ enum class ThreadCategory { -+ NonParticipant, -+ Producer, -+ Consumer, -+ ProducerConsumer -+ }; -+ -+ struct Params { -+ ThreadCategory role = ThreadCategory::NonParticipant; -+ uint32_t producer_arv_count = 1; -+ uint32_t consumer_arv_count = 1; -+ uint32_t dst_blockid = cute::block_rank_in_cluster(); -+ }; -+ -+private: -+ // -+ // Data Members -+ // -+ Params params_; -+ FullBarrier *full_barrier_ptr_; -+ EmptyBarrier *empty_barrier_ptr_; -+ -+public: -+ -+ // Default assumption when only storage is passed is : -+ // => single producer, single consumer & they are in the same block (within the Cluster) -+ CUTLASS_DEVICE -+ PipelineAsync(SharedStorage& storage) -+ : PipelineAsync(storage, {}) {} -+ -+ CUTLASS_DEVICE -+ PipelineAsync( -+ SharedStorage& storage, -+ Params const& params) : -+ params_(params), -+ full_barrier_ptr_(&storage.full_barrier_[0]), -+ empty_barrier_ptr_(&storage.empty_barrier_[0]) { -+ -+ int warp_idx = canonical_warp_idx(); -+ int lane_predicate = cute::elect_one_sync(); -+ -+ // Barrier FULL, EMPTY init -+ // Init is done only by thread 0 of the block -+ if (warp_idx == 0 && lane_predicate == 1) { -+ for (int i = 0; i < Stages; ++i) { -+ full_barrier_ptr_[i].init(params.producer_arv_count); -+ empty_barrier_ptr_[i].init(params.consumer_arv_count); -+ } -+ } -+ -+ cutlass::arch::fence_barrier_init(); -+ } -+ -+ CUTLASS_DEVICE -+ void producer_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait = false) { -+ uint32_t done = empty_barrier_ptr_[stage].test_wait(phase, (!skip_wait)); -+ if ((!done) && (!skip_wait)){ -+ empty_barrier_ptr_[stage].wait(phase); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void producer_acquire(PipelineState state) { -+ producer_acquire(state.index(), state.phase()); -+ } -+ -+ CUTLASS_DEVICE -+ void producer_commit(uint32_t stage) { -+ full_barrier_ptr_[stage].arrive(); -+ } -+ -+ CUTLASS_DEVICE -+ void producer_commit(PipelineState state) { -+ producer_commit(state.index()); -+ } -+ -+ CUTLASS_DEVICE -+ void consumer_wait(uint32_t stage, uint32_t phase) { -+ uint32_t done = full_barrier_ptr_[stage].test_wait(phase); -+ if (!done){ -+ full_barrier_ptr_[stage].wait(phase); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void consumer_wait(PipelineState state) { -+ consumer_wait(state.index(), state.phase()); -+ } -+ -+ CUTLASS_DEVICE -+ void consumer_release(uint32_t stage, uint32_t skip = false) { -+ empty_barrier_ptr_[stage].arrive(params_.dst_blockid, (not skip)); -+ } -+ -+ CUTLASS_DEVICE -+ void consumer_release(PipelineState state) { -+ consumer_release(state.index()); -+ } -+ -+ CUTLASS_DEVICE -+ ProducerBarrierType* get_producer_barrier(uint32_t stage) { -+ return reinterpret_cast(&full_barrier_ptr_[stage]); -+ } -+}; -+ -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Barrier to ensure an Ordered Sequence between -+// SequenceLength number of groups (each with group_size participants) executing SequenceDepth Stages -+// i.e., for all i < j - only after id "i" arrives at a particular stage "m" -+// will the wait() for id "j" succeed for the same stage -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class OrderedSequenceBarrier { -+public : -+ using Barrier = ClusterBarrier; -+ -+ struct SharedStorage { -+ Barrier barrier_[SequenceDepth][SequenceLength]; -+ }; -+ -+ struct Params { -+ uint32_t group_id; -+ uint32_t group_size; -+ }; -+ -+private : -+ // -+ // Data Members -+ // -+ -+ // In future this Params object can be replaced easily with a CG object -+ Params params_; -+ Barrier *barrier_ptr_; -+ PipelineState stage_; -+ -+ static constexpr int Depth = SequenceDepth; -+ static constexpr int Length = SequenceLength; -+ -+public: -+ OrderedSequenceBarrier() = delete; -+ OrderedSequenceBarrier(const OrderedSequenceBarrier&) = delete; -+ OrderedSequenceBarrier(OrderedSequenceBarrier&&) = delete; -+ OrderedSequenceBarrier& operator=(const OrderedSequenceBarrier&) = delete; -+ OrderedSequenceBarrier& operator=(OrderedSequenceBarrier&&) = delete; -+ ~OrderedSequenceBarrier() = default; -+ -+ CUTLASS_DEVICE -+ OrderedSequenceBarrier(SharedStorage& storage, Params const& params) : -+ params_(params), -+ barrier_ptr_(&storage.barrier_[0][0]), -+ // Group 0 - starts with an opposite phase -+ stage_({0, params.group_id == 0}) { -+ -+ int warp_idx = canonical_warp_idx(); -+ int lane_predicate = cute::elect_one_sync(); -+ -+ // Barrier FULL, EMPTY init -+ // Init is done only by the one elected thread of the block -+ if (warp_idx == 0 && lane_predicate == 1) { -+ for (int d = 0; d < Depth; ++d) { -+ for (int l = 0; l < Length; ++l) { -+ barrier_ptr_[d * Length + l].init(params.group_size); -+ } -+ } -+ } -+ -+ cutlass::arch::fence_barrier_init(); -+ } -+ -+ // Wait on a stage to be unlocked -+ CUTLASS_DEVICE -+ void wait() { -+ get_barrier_for_current_stage(params_.group_id).wait(stage_.phase()); -+ } -+ -+ // Signal completion of Stage and move to the next stage -+ // (group_id) signals to (group_id+1) -+ CUTLASS_DEVICE -+ void arrive() { -+ int signalling_id = (params_.group_id + 1) % Length; -+ get_barrier_for_current_stage(signalling_id).arrive(); -+ ++stage_; -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ Barrier& get_barrier_for_current_stage(int group_id) { -+ return barrier_ptr_[stage_.index() * Length + group_id]; -+ } -+}; -+ -+} // end namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/pitch_linear_coord.h b/3rdparty/cutlass/include/cutlass/pitch_linear_coord.h -new file mode 100644 -index 0000000..2cd7bfe ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/pitch_linear_coord.h -@@ -0,0 +1,181 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used by TensorRef and derived classes for pitch-linear memory. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template defining a shape used by pitch-linear operators -+template < -+ int Contiguous, -+ int Strided -+> -+struct PitchLinearShape { -+ static int const kContiguous = Contiguous; -+ static int const kStrided = Strided; -+ static int const kCount = Contiguous * Strided; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Coordinate in pitch-linear space -+struct PitchLinearCoord : public Coord<2, int> { -+public: -+ -+ /// Integer-valued index -+ using Index = int; -+ -+ /// Base type is a Coord of rank=2 -+ using Base = Coord<2, Index>; -+ -+ /// Long integer type -+ using LongIndex = typename Base::LongIndex; -+ -+private: -+ -+ /// Rows dimension -+ static int const kContiguous = 0; -+ -+ /// Columns dimension -+ static int const kStrided = 1; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord() { } -+ -+ /// Constructs from Coord<2> -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord(Coord<2, Index> const &coord): Base(coord) { } -+ -+ /// Helper to construct from a row and column -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord(Index contiguous_, Index strided_): Base(make_Coord(contiguous_, strided_)) { } -+ -+ /// Helper to construct from a row and column based on LongIndex -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord(LongIndex contiguous_, LongIndex strided_) -+ : Base(make_Coord(Index(contiguous_), Index(strided_))) { } -+ -+ /// Returns the contiguous dimension -+ CUTLASS_HOST_DEVICE -+ Index const & contiguous() const { return this->at(kContiguous); } -+ -+ /// Returns the contiguous dimension -+ CUTLASS_HOST_DEVICE -+ Index & contiguous() { return this->at(kContiguous); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & strided() const { return this->at(kStrided); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & strided() { return this->at(kStrided); } -+ -+ // -+ // Coord operators -+ // -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord operator+(Base const& b) const { -+ return PitchLinearCoord(Base::operator+(b)); -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord operator-(Base const& b) const { -+ return PitchLinearCoord(Base::operator-(b)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord operator-() const { -+ return PitchLinearCoord(-at(0), -at(1)); -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord operator*(Base const& b) const { -+ return PitchLinearCoord(Base::operator*(b)); -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord operator/(Base const& b) const { -+ return PitchLinearCoord(Base::operator/(b)); -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord& operator+=(Base const& b) { -+ Base::operator+=(b); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord& operator-=(Base const& b) { -+ Base::operator-=(b); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord& operator*=(Base const& b) { -+ Base::operator*=(b); -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord& operator/=(Base const& b) { -+ Base::operator/=(b); -+ return *this; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/platform/platform.h b/3rdparty/cutlass/include/cutlass/platform/platform.h -new file mode 100644 -index 0000000..96bb8f6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/platform/platform.h -@@ -0,0 +1,891 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief C++ features that may be otherwise unimplemented for CUDA device functions. -+ * -+ * This file has three components: -+ * -+ * (1) Macros: -+ * - Empty macro defines for C++ keywords not supported by the current -+ * version of C++. These simply allow compilation to proceed (but do -+ * not provide the added semantics). -+ * - \p noexcept -+ * - \p constexpr -+ * - \p nullptr -+ * - \p static_assert -+ * -+ * - Macro functions that we need in constant expressions because the -+ * C++ equivalents require constexpr compiler support. These are -+ * prefixed with \p __NV_STD_* -+ * - \p __NV_STD_MAX -+ * - \p __NV_STD_MIN -+ * -+ * (2) Re-implementations of STL functions and types: -+ * - C++ features that need the \p __device__ annotation. These are -+ * placed into the \p platform namespace. -+ * - \p abs -+ * - \p plus -+ * - \p less -+ * - \p greater -+ * - \p min -+ * - \p max -+ * - \p methods on std::pair (==, !=, <, <=, >, >=, and make_pair()) -+ * -+ * (3) Stop-gap implementations of unsupported STL functions and types: -+ * - STL functions and types defined by C++ 11/14/17/etc. that are not -+ * provided by the current version of C++. These are placed into the -+ * \p platform namespace -+ * - \p integral_constant -+ * - \p nullptr_t -+ * - \p true_type -+ * - \p false_type -+ * - \p bool_constant -+ * - \p enable_if -+ * - \p conditional -+ * - \p is_same -+ * - \p is_base_of -+ * - \p remove_const -+ * - \p remove_volatile -+ * - \p remove_cv -+ * - \p is_volatile -+ * - \p is_pointer -+ * - \p is_void -+ * - \p is_integral -+ * - \p is_floating_point -+ * - \p is_arithmetic -+ * - \p is_fundamental -+ * - \p is_trivially_copyable -+ * - \p alignment_of -+ * - \p aligned_storage -+ * -+ * (4) Functions and types that are STL-like (but aren't in the STL): -+ * - \p TODO: min and max functors? -+ * -+ * The idea is that, as we drop support for older compilers, we can simply #define -+ * the \p __NV_STD_XYZ macros and \p platform namespace to alias their C++ -+ * counterparts (or trivially find-and-replace their occurrences in code text). -+ */ -+ -+//----------------------------------------------------------------------------- -+// Dependencies -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#if !defined(__CUDACC_RTC__) -+//----------------------------------------------------------------------------- -+// Include STL files that platform provides functionality for -+//----------------------------------------------------------------------------- -+ -+#include // Minimum/maximum operations -+#include // nullptr_t -+#include // Arithmetic operations -+#include // For methods on std::pair -+#if (!defined(_MSC_VER) && (__cplusplus >= 201103L)) || (defined(_MSC_VER) && (_MS_VER >= 1500)) -+#include // For integral constants, conditional metaprogramming, and type traits -+#endif -+ -+#include "cutlass/cutlass.h" -+ -+#endif -+ -+//----------------------------------------------------------------------------- -+// OS -+//----------------------------------------------------------------------------- -+#if defined(WIN32) || defined(_WIN32) || defined(__WIN32) && !defined(__CYGWIN__) -+#define CUTLASS_OS_WINDOWS -+#endif -+ -+/****************************************************************************** -+ * Macros -+ ******************************************************************************/ -+//----------------------------------------------------------------------------- -+// Keywords -+//----------------------------------------------------------------------------- -+ -+/// noexcept, constexpr -+#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1900)) -+#ifndef noexcept -+#define noexcept -+#endif -+#ifndef constexpr -+#define constexpr -+#endif -+#endif -+ -+/// nullptr -+#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1310)) -+#ifndef nullptr -+#define nullptr 0 -+#endif -+#endif -+ -+/// static_assert -+#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1600)) -+#ifndef static_assert -+#define __platform_cat_(a, b) a##b -+#define __platform_cat(a, b) __platform_cat_(a, b) -+#define static_assert(__e, __m) typedef int __platform_cat(AsSeRt, __LINE__)[(__e) ? 1 : -1] -+#endif -+#endif -+ -+//----------------------------------------------------------------------------- -+// Functions -+//----------------------------------------------------------------------------- -+ -+/// Select maximum(a, b) -+#ifndef __NV_STD_MAX -+#define __NV_STD_MAX(a, b) (((b) > (a)) ? (b) : (a)) -+#endif -+ -+/// Select minimum(a, b) -+#ifndef __NV_STD_MIN -+#define __NV_STD_MIN(a, b) (((b) < (a)) ? (b) : (a)) -+#endif -+ -+/****************************************************************************** -+ * Re-implementations -+ ******************************************************************************/ -+namespace cutlass { -+namespace platform { -+ -+//----------------------------------------------------------------------------- -+// Abs operations -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) -+/// std::abs -+CUTLASS_HOST_DEVICE constexpr int abs(int a) { -+ return (a < 0) ? -a : a; -+} -+CUTLASS_HOST_DEVICE constexpr long long abs(long long a) { -+ return (a < 0) ? -a : a; -+} -+#else -+using std::abs; -+#endif -+ -+//----------------------------------------------------------------------------- -+// Minimum/maximum operations -+//----------------------------------------------------------------------------- -+ -+/// std::min -+template -+CUTLASS_HOST_DEVICE constexpr const T& min(const T& a, const T& b) { -+ return (b < a) ? b : a; -+} -+ -+/// std::max -+template -+CUTLASS_HOST_DEVICE constexpr const T& max(const T& a, const T& b) { -+ return (a < b) ? b : a; -+} -+ -+#if !defined(__CUDACC_RTC__) -+//----------------------------------------------------------------------------- -+// Methods on std::pair -+//----------------------------------------------------------------------------- -+ -+using std::pair; -+ -+template -+CUTLASS_HOST_DEVICE constexpr bool operator==(const pair& lhs, const pair& rhs) { -+ return (lhs.first == rhs.first) && (lhs.second == rhs.second); -+} -+ -+template -+CUTLASS_HOST_DEVICE constexpr bool operator!=(const pair& lhs, const pair& rhs) { -+ return (lhs.first != rhs.first) && (lhs.second != rhs.second); -+} -+ -+template -+CUTLASS_HOST_DEVICE constexpr bool operator<(const pair& lhs, const pair& rhs) { -+ return (lhs.first < rhs.first) ? true : (rhs.first < lhs.first) ? false -+ : (lhs.second < rhs.second); -+} -+ -+template -+CUTLASS_HOST_DEVICE constexpr bool operator<=(const pair& lhs, const pair& rhs) { -+ return !(rhs < lhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE constexpr bool operator>(const pair& lhs, const pair& rhs) { -+ return (rhs < lhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE constexpr bool operator>=(const pair& lhs, const pair& rhs) { -+ return !(lhs < rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE std::pair make_pair(T1 t, T2 u) { -+ std::pair retval; -+ retval.first = t; -+ retval.second = u; -+ return retval; -+} -+#endif -+ -+} // namespace platform -+ -+/****************************************************************************** -+ * Implementations of C++ 11/14/17/... STL features -+ ******************************************************************************/ -+ -+namespace platform { -+ -+//----------------------------------------------------------------------------- -+// Integral constant helper types -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) -+ -+/// std::integral_constant -+template -+struct integral_constant; -+ -+/// std::integral_constant -+template -+struct integral_constant { -+ static const value_t value = V; -+ -+ typedef value_t value_type; -+ typedef integral_constant type; -+ -+ CUTLASS_HOST_DEVICE operator value_type() const { return value; } -+ -+ CUTLASS_HOST_DEVICE const value_type operator()() const { return value; } -+}; -+ -+#else -+ -+using std::integral_constant; -+using std::pair; -+ -+#endif -+ -+/// The type used as a compile-time boolean with true value. -+typedef integral_constant true_type; -+ -+/// The type used as a compile-time boolean with false value. -+typedef integral_constant false_type; -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus <= 201402L)) || (defined(_MSC_VER) && (_MSC_VER < 1900)) -+ -+/// std::bool_constant -+template -+struct bool_constant : platform::integral_constant {}; -+ -+#else -+ -+using std::bool_constant; -+ -+#endif -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1700)) -+ -+/// std::nullptr_t -+struct nullptr_t {}; -+ -+#else -+ -+using std::nullptr_t; -+ -+#endif -+ -+//----------------------------------------------------------------------------- -+// Conditional metaprogramming -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1600)) -+ -+/// std::enable_if (true specialization) -+template -+struct enable_if { -+ typedef T type; -+}; -+ -+/// std::enable_if (false specialization) -+template -+struct enable_if {}; -+ -+/// std::conditional (true specialization) -+template -+struct conditional { -+ typedef T type; -+}; -+ -+/// std::conditional (false specialization) -+template -+struct conditional { -+ typedef F type; -+}; -+ -+#else -+ -+using std::enable_if; -+using std::conditional; -+ -+#endif -+ -+//----------------------------------------------------------------------------- -+// Const/volatility specifiers -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) -+ -+/// std::remove_const (non-const specialization) -+template -+struct remove_const { -+ typedef T type; -+}; -+ -+/// std::remove_const (const specialization) -+template -+struct remove_const { -+ typedef T type; -+}; -+ -+/// std::remove_volatile (non-volatile specialization) -+template -+struct remove_volatile { -+ typedef T type; -+}; -+ -+/// std::remove_volatile (volatile specialization) -+template -+struct remove_volatile { -+ typedef T type; -+}; -+ -+/// std::remove_cv -+template -+struct remove_cv { -+ typedef typename remove_volatile::type>::type type; -+}; -+ -+#else -+ -+using std::remove_const; -+using std::remove_volatile; -+using std::remove_cv; -+ -+#endif -+ -+//----------------------------------------------------------------------------- -+// Type relationships -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) -+ -+/// std::is_same (false specialization) -+template -+struct is_same : false_type {}; -+ -+/// std::is_same (true specialization) -+template -+struct is_same : true_type {}; -+ -+/// Helper for std::is_base_of -+template -+struct is_base_of_helper { -+ typedef char (&yes)[1]; -+ typedef char (&no)[2]; -+ -+ template -+ struct dummy { -+ CUTLASS_HOST_DEVICE operator B*() const; -+ CUTLASS_HOST_DEVICE operator D*(); -+ }; -+ -+ template -+ CUTLASS_HOST_DEVICE static yes check(DerivedT*, T); -+ -+ CUTLASS_HOST_DEVICE static no check(BaseT*, int); -+ -+ static const bool value = sizeof(check(dummy(), int())) == sizeof(yes); -+}; -+ -+/// std::is_base_of -+template -+struct is_base_of -+ : integral_constant::type, -+ typename remove_cv::type>::value) || -+ (is_same::type, -+ typename remove_cv::type>::value)> {}; -+ -+#else -+ -+using std::is_same; -+using std::is_base_of; -+ -+#endif -+ -+//----------------------------------------------------------------------------- -+// Type properties -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) -+ -+/// std::is_volatile -+template -+struct is_volatile : false_type {}; -+template -+struct is_volatile : true_type {}; -+ -+/// Helper for std::is_pointer (false specialization) -+template -+struct is_pointer_helper : false_type {}; -+ -+/// Helper for std::is_pointer (true specialization) -+template -+struct is_pointer_helper : true_type {}; -+ -+/// std::is_pointer -+template -+struct is_pointer : is_pointer_helper::type> {}; -+ -+/// std::is_void -+template -+struct is_void : is_same::type> {}; -+ -+/// std::is_integral -+template -+struct is_integral : false_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template -+struct is_integral : is_integral {}; -+template -+struct is_integral : is_integral {}; -+template -+struct is_integral : is_integral {}; -+ -+/// std::is_floating_point -+template -+struct is_floating_point -+ : integral_constant::type>::value || -+ is_same::type>::value)> {}; -+ -+/// std::is_arithmetic -+template -+struct is_arithmetic -+ : integral_constant::value || is_floating_point::value)> {}; -+ -+/// std::is_fundamental -+template -+struct is_fundamental -+ : integral_constant::value || is_void::value || -+ is_same::type>::value)> {}; -+ -+#else -+ -+using std::is_volatile; -+using std::is_pointer; -+using std::is_void; -+using std::is_integral; -+using std::is_floating_point; -+using std::is_arithmetic; -+using std::is_fundamental; -+ -+#endif -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1800)) || \ -+ (defined(__GNUG__) && (__GNUC__ < 5)) -+ -+/** -+ * std::is_trivially_copyable -+ * -+ * This implementation only evaluates true if T is fundamental or pointer -+ * -+ * Without help from partial template specializations provided by the user for -+ * a specific class or struct, this trait will never report that the specified -+ * class or struct is trivially-copyable ; this is always safe, -+ * if possibly sub-optimal. -+ */ -+template -+struct is_trivially_copyable -+ : integral_constant::value || is_pointer::value)> {}; -+ -+#else -+ -+using std::is_trivially_copyable; -+ -+#endif -+ -+//----------------------------------------------------------------------------- -+// bit_cast -+//----------------------------------------------------------------------------- -+ -+template< class To, class From > -+constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& from ) noexcept; -+ -+template -+constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& src) noexcept -+{ -+ static_assert(sizeof(To) == sizeof(From), "sizes must match"); -+ return reinterpret_cast(src); -+} -+ -+//----------------------------------------------------------------------------- -+// Alignment and layout utilities -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) -+ -+/// std::alignment_of -+template -+struct alignment_of { -+ struct pad { -+ value_t val; -+ char byte; -+ }; -+ -+ enum { value = sizeof(pad) - sizeof(value_t) }; -+}; -+ -+#else -+ -+template -+struct alignment_of : std::alignment_of {}; -+ -+#endif -+ -+/* 16B specializations where 32-bit Win32 host compiler disagrees with device compiler */ -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+ -+// Specializations for volatile/const qualified types -+template -+struct alignment_of : alignment_of {}; -+template -+struct alignment_of : alignment_of {}; -+template -+struct alignment_of : alignment_of {}; -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1800)) -+ -+template -+struct aligned_chunk; -+template <> -+struct __align__(1) aligned_chunk<1> { -+ uint8_t buff; -+}; -+template <> -+struct __align__(2) aligned_chunk<2> { -+ uint16_t buff; -+}; -+template <> -+struct __align__(4) aligned_chunk<4> { -+ uint32_t buff; -+}; -+template <> -+struct __align__(8) aligned_chunk<8> { -+ uint32_t buff[2]; -+}; -+template <> -+struct __align__(16) aligned_chunk<16> { -+ uint32_t buff[4]; -+}; -+template <> -+struct __align__(32) aligned_chunk<32> { -+ uint32_t buff[8]; -+}; -+template <> -+struct __align__(64) aligned_chunk<64> { -+ uint32_t buff[16]; -+}; -+template <> -+struct __align__(128) aligned_chunk<128> { -+ uint32_t buff[32]; -+}; -+template <> -+struct __align__(256) aligned_chunk<256> { -+ uint32_t buff[64]; -+}; -+template <> -+struct __align__(512) aligned_chunk<512> { -+ uint32_t buff[128]; -+}; -+template <> -+struct __align__(1024) aligned_chunk<1024> { -+ uint32_t buff[256]; -+}; -+template <> -+struct __align__(2048) aligned_chunk<2048> { -+ uint32_t buff[512]; -+}; -+template <> -+struct __align__(4096) aligned_chunk<4096> { -+ uint32_t buff[1024]; -+}; -+ -+/// std::aligned_storage -+template -+struct aligned_storage { -+ typedef aligned_chunk type[Len / sizeof(aligned_chunk)]; -+}; -+ -+#else -+ -+using std::aligned_storage; -+ -+#endif -+ -+#if !defined(__CUDACC_RTC__) -+/// Default deleter -+template -+struct default_delete { -+ void operator()(T* ptr) const { delete ptr; } -+}; -+ -+/// Partial specialization for deleting array types -+template -+struct default_delete { -+ void operator()(T* ptr) const { delete[] ptr; } -+}; -+ -+/// std::unique_ptr -+template > -+class unique_ptr { -+ public: -+ typedef T* pointer; -+ typedef T element_type; -+ typedef Deleter deleter_type; -+ -+ private: -+ /// Pointer to memory -+ pointer _ptr; -+ -+ /// Deleter -+ deleter_type _deleter; -+ -+ public: -+ unique_ptr() : _ptr(nullptr) {} -+ unique_ptr(pointer p) : _ptr(p) {} -+ -+ ~unique_ptr() { -+ if (_ptr) { -+ _deleter(_ptr); -+ } -+ } -+ /// Returns a pointer to the managed object or nullptr if no object is owned. -+ pointer get() const noexcept { return _ptr; } -+ -+ /// Releases ownership of the managed object, if any -+ pointer release() noexcept { -+ pointer p(_ptr); -+ _ptr = nullptr; -+ return p; -+ } -+ -+ /// Replaces the managed object, deleting the old object. -+ void reset(pointer p = pointer()) noexcept { -+ pointer old_ptr = _ptr; -+ _ptr = p; -+ if (old_ptr != nullptr) { -+ get_deleter()(old_ptr); -+ } -+ } -+ -+ /// Swaps the managed objects with *this and another unique_ptr -+ void swap(unique_ptr& other) noexcept { std::swap(_ptr, other._ptr); } -+ -+ /// Returns the deleter object -+ Deleter& get_deleter() noexcept { return _deleter; } -+ -+ /// Returns the deleter object -+ Deleter const& get_deleter() const noexcept { return _deleter; } -+ -+ /// Checks whether an object is owned -+ operator bool() const noexcept { return _ptr != nullptr; } -+ -+ /// Dereferences the unique_ptr -+ T& operator*() const { return *_ptr; } -+ -+ /// Returns a pointer to the managed object -+ pointer operator->() const noexcept { return _ptr; } -+ -+ /// Array access to managed object -+ T& operator[](size_t i) const { return _ptr[i]; } -+}; -+ -+/// Specializes the swap algorithm -+template -+void swap(unique_ptr& lhs, unique_ptr& rhs) noexcept { -+ lhs.swap(rhs); -+} -+#endif -+ -+/// std::numeric_limits -+template -+struct numeric_limits; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr int32_t lowest() noexcept { return -2147483647 - 1;} -+ CUTLASS_HOST_DEVICE -+ static constexpr int32_t max() noexcept { return 2147483647;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr int16_t lowest() noexcept { return -32768;} -+ CUTLASS_HOST_DEVICE -+ static constexpr int16_t max() noexcept { return 32767;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr int8_t lowest() noexcept { return -128;} -+ CUTLASS_HOST_DEVICE -+ static constexpr int8_t max() noexcept { return 127;} -+ static constexpr bool is_integer = true; -+}; -+ -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr uint32_t lowest() noexcept { return 0;} -+ CUTLASS_HOST_DEVICE -+ static constexpr uint32_t max() noexcept { return 4294967295U;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr uint16_t lowest() noexcept { return 0;} -+ CUTLASS_HOST_DEVICE -+ static constexpr uint16_t max() noexcept { return 65535U;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr uint8_t lowest() noexcept { return 0;} -+ CUTLASS_HOST_DEVICE -+ static constexpr uint8_t max() noexcept { return 255U;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr float infinity() noexcept { return bit_cast(0x7f800000);} -+ static constexpr bool is_integer = false; -+ static constexpr bool has_infinity = true; -+}; -+ -+} // namespace platform -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/predicate_vector.h b/3rdparty/cutlass/include/cutlass/predicate_vector.h -new file mode 100644 -index 0000000..d158225 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/predicate_vector.h -@@ -0,0 +1,524 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines container classes and iterators for managing a statically sized vector -+ of boolean predicates. -+*/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#include -+#else -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/platform/platform.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*!@defgroup predicate_vector_concept Predicate Vector Concept -+@{ -+ -+Implementations of \ref predicate_vector_concept contain an ordered set of boolean predicates which -+may be used as conditionals in other device-side operations. Both random access and iterators -+offering sequential access are provided. -+ -+@par Predicate Vector -+ A \ref predicate_vector_concept satisfies the following expressions -+ - at(int idx) - returns the value of the indexed predicate -+ - set(int idx, bool value) - sets the value of the indexed predicate -+ - begin() - returns a \ref predicate_iterator_concept pointing to the first predicate -+ -+@} -+*/ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*!@defgroup predicate_iterator_concept Predicate Iterator Concept -+@{ -+ -+Implementations of \ref predicate_iterator_concept enables accessing and traversing elements of a -+bit vector. -+ -+@par Const Predicate Iterator -+ A const \ref predicate_iterator_concept satisfies the following expressions -+ - ++it increments the iterator to the next predicate -+ - *it returns the value of the currently pointed-to predicate -+ -+@par Mutable Predicate Iterator -+ A \ref predicate_iterator_concept that is non-const also satisfies the following expressions -+ - it.set(bool value) sets the value of the currently pointed-to predicate -+ -+@} -+*/ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*!@defgroup predicate_tile_adapter Predicate Tile Adapter Concept -+@{ -+ -+Implementations of \ref predicate_tile_adapter provide a mapping between a the elements of a \ref -+tile_traits_concept and a \ref predicate_vector_concept. -+ -+@par Predicate Tile Adapter -+ A \ref predicate_tile_adapter satisfies the following expressions -+ - at(int d, int h, int w, int c) - returns the value of a predicate corresponding to the -+ access (d, h, w, c) within the tile. -+ -+@} -+*/ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Statically sized array of bits implementing @concept{predicate_vector_concept}. -+template < -+ /// Number of predicates conatined in predicate vector -+ int kPredicates_, -+ /// Number of predicates contained in each byte of internal storage -+ int kPredicatesPerByte_ = 4, -+ /// Location of first predicate within byte of internal storage -+ int kPredicateStart_ = 0> -+struct PredicateVector { -+ /// Number of bits stored by the PredicateVector -+ static int const kPredicates = kPredicates_; -+ -+ /// Number of bits stored within each byte of the predicate bit vector -+ static int const kPredicatesPerByte = kPredicatesPerByte_; -+ -+ /// First bit withing each byte containing predicates -+ static int const kPredicateStart = kPredicateStart_; -+ -+ // Make sure no one tries to put more than 8 bits in a byte :) -+ static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte"); -+ // Make sure the "offsetted" bits fit in one byte. -+ static_assert(kPredicateStart + kPredicatesPerByte <= 8, -+ "The offsetted predicates must fit within an actual byte."); -+ -+ /// Storage type of individual elements -+ typedef uint32_t Storage; -+ -+ /// Number of bytes needed -+ static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte; -+ -+ /// Number of storage elements needed -+ static int const kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage)); -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Words of bit vector -+ Storage storageData[kWordCount]; -+ -+ // -+ // Methods -+ // -+ -+ /// Computes the word and bit corresponding to a logical predicate index -+ CUTLASS_HOST_DEVICE void computeStorageOffset(int &word, int &bit, int idx) const { -+ CUTLASS_ASSERT(idx < kPredicates); -+ -+ int byte = (idx / kPredicatesPerByte); -+ int bit_offset = (idx % kPredicatesPerByte); -+ -+ word = byte / sizeof(Storage); -+ int byte_offset = (byte % sizeof(Storage)); -+ -+ bit = byte_offset * 8 + bit_offset + kPredicateStart; -+ } -+ -+ /// Accesses a given word with optional assertions -+ CUTLASS_HOST_DEVICE Storage &storage(int word) { -+ CUTLASS_ASSERT(word < kWordCount); -+ return storageData[word]; -+ } -+ -+ /// Accesses a given word with optional assertions -+ CUTLASS_HOST_DEVICE Storage const &storage(int word) const { -+ CUTLASS_ASSERT(word < kWordCount); -+ return storageData[word]; -+ } -+ -+ public: -+ // -+ // Iterator -+ // -+ -+ /** -+ * @brief An iterator implementing \ref predicate_iterator_concept enabling sequential -+ * read and write access to predicates. -+ * @concept{predicate_iterator_concept} -+ */ -+ class Iterator { -+ /// Reference to PredicateVector instance -+ PredicateVector &vec_; -+ -+ /// Index into PredicateVector -+ int bit_; -+ -+ public: -+ /// Copy constructor -+ CUTLASS_HOST_DEVICE -+ Iterator(Iterator const &it) : vec_(it.vec_), bit_(it.bit_) {} -+ -+ /// Constructs an iterator from a PredicateVector -+ CUTLASS_HOST_DEVICE -+ Iterator(PredicateVector &vec, int _start = 0) : vec_(vec), bit_(_start) {} -+ -+ /// Pre-increment -+ CUTLASS_HOST_DEVICE -+ Iterator &operator++() { -+ ++bit_; -+ return *this; -+ } -+ -+ /// Increment -+ CUTLASS_HOST_DEVICE -+ Iterator &operator+=(int offset) { -+ bit_ += offset; -+ return *this; -+ } -+ -+ /// Pre-decrement -+ CUTLASS_HOST_DEVICE -+ Iterator &operator--() { -+ --bit_; -+ return *this; -+ } -+ -+ /// Decrement -+ CUTLASS_HOST_DEVICE -+ Iterator &operator-=(int offset) { -+ bit_ -= offset; -+ return *this; -+ } -+ -+ /// Post-increment -+ CUTLASS_HOST_DEVICE -+ Iterator operator++(int) { -+ Iterator ret(*this); -+ ret.bit_++; -+ return ret; -+ } -+ -+ /// Post-decrement -+ CUTLASS_HOST_DEVICE -+ Iterator operator--(int) { -+ Iterator ret(*this); -+ ret.bit_--; -+ return ret; -+ } -+ -+ /// Iterator advances by some amount -+ CUTLASS_HOST_DEVICE -+ Iterator operator+(int offset) { -+ Iterator ret(*this); -+ ret.bit_ += offset; -+ return ret; -+ } -+ -+ /// Iterator recedes by some amount -+ CUTLASS_HOST_DEVICE -+ Iterator operator-(int offset) { -+ ConstIterator ret(*this); -+ ret.bit_ -= offset; -+ return ret; -+ } -+ -+ /// Returns true if iterators point to the same bit -+ CUTLASS_HOST_DEVICE -+ bool operator==(Iterator const &it) const { return bit_ == it.bit_; } -+ -+ /// Returns false if iterators point to the same bit -+ CUTLASS_HOST_DEVICE -+ bool operator!=(Iterator const &it) const { return bit_ != it.bit_; } -+ -+ /// Gets the bit at the pointed to location -+ CUTLASS_HOST_DEVICE -+ bool get() { return vec_.at(bit_); } -+ -+ /// Gets the bit at the pointed to location -+ CUTLASS_HOST_DEVICE -+ bool at() const { return vec_.at(bit_); } -+ -+ /// Dereferences iterator -+ CUTLASS_HOST_DEVICE -+ bool operator*() const { return at(); } -+ -+ /// Sets the bit at the pointed to location -+ CUTLASS_HOST_DEVICE -+ void set(bool value = true) { vec_.set(bit_, value); } -+ }; -+ -+ /** -+ * @brief An iterator implementing \ref predicate_iterator_concept enabling sequential -+ * read and write access to predicates. -+ * @concept{predicate_iterator_concept} -+ */ -+ class ConstIterator { -+ /// Reference to PredicateVector instance -+ PredicateVector const &vec_; -+ -+ /// Index into PredicateVector -+ int bit_; -+ -+ public: -+ /// Copy constructor -+ CUTLASS_HOST_DEVICE -+ ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {} -+ -+ /// Constructs an iterator from a PredicateVector -+ CUTLASS_HOST_DEVICE -+ ConstIterator(PredicateVector const &vec, int _start = 0) : vec_(vec), bit_(_start) {} -+ -+ /// Pre-increment -+ CUTLASS_HOST_DEVICE -+ ConstIterator &operator++() { -+ ++bit_; -+ return *this; -+ } -+ -+ /// Increment -+ CUTLASS_HOST_DEVICE -+ ConstIterator &operator+=(int offset) { -+ bit_ += offset; -+ return *this; -+ } -+ -+ /// Pre-decrement -+ CUTLASS_HOST_DEVICE -+ ConstIterator &operator--() { -+ --bit_; -+ return *this; -+ } -+ -+ /// Decrement -+ CUTLASS_HOST_DEVICE -+ ConstIterator &operator-=(int offset) { -+ bit_ -= offset; -+ return *this; -+ } -+ -+ /// Post-increment -+ CUTLASS_HOST_DEVICE -+ ConstIterator operator++(int) { -+ ConstIterator ret(*this); -+ ret.bit_++; -+ return ret; -+ } -+ -+ /// Post-decrement -+ CUTLASS_HOST_DEVICE -+ ConstIterator operator--(int) { -+ ConstIterator ret(*this); -+ ret.bit_--; -+ return ret; -+ } -+ -+ /// Iterator advances by some amount -+ CUTLASS_HOST_DEVICE -+ ConstIterator operator+(int offset) { -+ ConstIterator ret(*this); -+ ret.bit_ += offset; -+ return ret; -+ } -+ -+ /// Iterator recedes by some amount -+ CUTLASS_HOST_DEVICE -+ ConstIterator operator-(int offset) { -+ ConstIterator ret(*this); -+ ret.bit_ -= offset; -+ return ret; -+ } -+ -+ /// Returns true if iterators point to the same bit -+ CUTLASS_HOST_DEVICE -+ bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; } -+ -+ /// Returns false if iterators point to the same bit -+ CUTLASS_HOST_DEVICE -+ bool operator!=(ConstIterator const &it) const { return bit_ != it.bit_; } -+ -+ /// Gets the bit at the pointed to location -+ CUTLASS_HOST_DEVICE -+ bool get() { return vec_.at(bit_); } -+ -+ /// Gets the bit at the pointed to location -+ CUTLASS_HOST_DEVICE -+ bool at() const { return vec_.at(bit_); } -+ -+ /// Dereferences iterator -+ CUTLASS_HOST_DEVICE -+ bool operator*() const { return at(); } -+ }; -+ -+ /// Iterator that always returns true -+ struct TrivialIterator { -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TrivialIterator() {} -+ -+ /// Copy constructor -+ CUTLASS_HOST_DEVICE -+ TrivialIterator(Iterator const &it) {} -+ -+ /// Constructs an iterator from a PredicateVector -+ CUTLASS_HOST_DEVICE -+ TrivialIterator(PredicateVector const &_vec) {} -+ -+ /// Pre-increment -+ CUTLASS_HOST_DEVICE -+ TrivialIterator &operator++() { return *this; } -+ -+ /// Post-increment -+ CUTLASS_HOST_DEVICE -+ TrivialIterator operator++(int) { return *this; } -+ -+ /// Dereferences iterator -+ CUTLASS_HOST_DEVICE -+ bool operator*() const { return true; } -+ }; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Initialize the predicate vector -+ CUTLASS_HOST_DEVICE PredicateVector(bool value = true) { fill(value); } -+ -+ /// Fills all predicates with a given value -+ CUTLASS_HOST_DEVICE void fill(bool value = true) { -+ Storage item = (value ? ~Storage(0) : Storage(0)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kWordCount; ++i) { -+ storage(i) = item; -+ } -+ } -+ -+ /// Clears all predicates -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kWordCount; ++i) { -+ storage(i) = 0; -+ } -+ } -+ -+ /// Sets all predicates to true -+ CUTLASS_HOST_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kWordCount; ++i) { -+ storage(i) = ~Storage(0); -+ } -+ } -+ -+ /// Accesses a bit within the predicate vector. -+ CUTLASS_HOST_DEVICE bool operator[](int idx) const { return at(idx); } -+ -+ /// Accesses a bit within the predicate vector. -+ CUTLASS_HOST_DEVICE bool at(int idx) const { -+ int bit, word; -+ computeStorageOffset(word, bit, idx); -+ -+ return ((storage(word) >> bit) & 1); -+ } -+ -+ /// Set a bit within the predicate vector. -+ CUTLASS_HOST_DEVICE void set(int idx, bool value = true) { -+ int bit, word; -+ computeStorageOffset(word, bit, idx); -+ -+ Storage disable_mask = (~(Storage(1) << bit)); -+ Storage enable_mask = (Storage(value) << bit); -+ -+ storage(word) = ((storage(word) & disable_mask) | enable_mask); -+ } -+ -+ /// Computes the intersection of two identical predicate vectors. -+ CUTLASS_HOST_DEVICE PredicateVector &operator&=(PredicateVector const &predicates) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kWordCount; ++i) { -+ storage(i) = (storage(i) & predicates.storage(i)); -+ } -+ return *this; -+ } -+ -+ /// Computes the union of two identical predicate vectors. -+ CUTLASS_HOST_DEVICE PredicateVector &operator|=(PredicateVector const &predicates) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kWordCount; ++i) { -+ storage(i) = (storage(i) | predicates.storage(i)); -+ } -+ return *this; -+ } -+ -+ /// Returns true if entire predicate array is zero. -+ CUTLASS_HOST_DEVICE bool is_zero() const { -+ Storage mask(0); -+ for (int byte = 0; byte < sizeof(Storage); ++byte) { -+ Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart); -+ mask |= (byte_mask << (byte * 8)); -+ } -+ uint32_t result = 0; -+ for (int word = 0; word < kWordCount; ++word) { -+ result |= storage(word); -+ } -+ return result == 0; -+ } -+ -+ /// Returns an iterator to the start of the bit vector -+ CUTLASS_DEVICE -+ Iterator begin() { return Iterator(*this); } -+ -+ /// Returns an iterator -+ CUTLASS_DEVICE -+ Iterator end() { return Iterator(*this, kPredicates); } -+ -+ /// Returns a ConstIterator -+ CUTLASS_DEVICE -+ ConstIterator const_begin() const { return ConstIterator(*this); } -+ -+ /// Returns a ConstIterator -+ CUTLASS_DEVICE -+ ConstIterator const_end() const { return ConstIterator(*this, kPredicates); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/quaternion.h b/3rdparty/cutlass/include/cutlass/quaternion.h -new file mode 100644 -index 0000000..1015be4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/quaternion.h -@@ -0,0 +1,753 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a densely packed quaternion object intended for storing data in registers and -+ executing quaternion operations within a CUDA or host thread. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/array.h" -+#include "cutlass/real.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/vector.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Quaternion: xi + yj + zk + w -+template < -+ typename Element_ = float ///< element type -+> -+class Quaternion : public Array { -+public: -+ -+ /// Logical rank of tensor index space -+ static int const kRank = 1; -+ -+ /// Number of elements -+ static int const kExtent = 4; -+ -+ /// Base class is a four-element array -+ using Base = Array; -+ -+ /// Element type -+ using Element = typename Base::Element; -+ -+ /// Reference type to an element -+ using Reference = typename Base::reference; -+ -+ /// Index type -+ using Index = int; -+ -+ /// Quaternion storage - imaginary part -+ static int const kX = 0; -+ -+ /// Quaternion storage - imaginary part -+ static int const kY = 1; -+ -+ /// Quaternion storage - imaginary part -+ static int const kZ = 2; -+ -+ /// Quaternion storage - real part -+ static int const kW = 3; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a quaternion q = 0 -+ CUTLASS_HOST_DEVICE -+ Quaternion() { -+ Base::at(kX) = Element(); -+ Base::at(kY) = Element(); -+ Base::at(kZ) = Element(); -+ Base::at(kW) = Element(); -+ } -+ -+ /// Constructs a quaternion q = w + 0*i + 0*j + 0*k -+ CUTLASS_HOST_DEVICE -+ Quaternion( -+ Element w_ -+ ) { -+ Base::at(kX) = Element(); -+ Base::at(kY) = Element(); -+ Base::at(kZ) = Element(); -+ Base::at(kW) = w_; -+ } -+ -+ /// Constructs a quaternion q = w + x*i + y*j + z*k -+ CUTLASS_HOST_DEVICE -+ Quaternion( -+ Element x_, -+ Element y_, -+ Element z_, -+ Element w_ -+ ) { -+ Base::at(kX) = x_; -+ Base::at(kY) = y_; -+ Base::at(kZ) = z_; -+ Base::at(kW) = w_; -+ } -+ -+ /// Constructs a quaternion from a vector representing the imaginary part and a real number -+ CUTLASS_HOST_DEVICE -+ Quaternion( -+ Matrix3x1 const &imag_, -+ Element w_ = Element() -+ ) { -+ Base::at(kX) = imag_[0]; -+ Base::at(kY) = imag_[1]; -+ Base::at(kZ) = imag_[2]; -+ Base::at(kW) = w_; -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference at(Index idx) const { -+ return Base::at(idx); -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference at(Index idx) { -+ return Base::at(idx); -+ } -+ -+ /// Accesses the x element of the imaginary part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Element x() const { -+ return Base::at(kX); -+ } -+ -+ /// Accesses the x element of the imaginary part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Reference x() { -+ return Base::at(kX); -+ } -+ -+ /// Accesses the y element of the imaginary part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Element y() const { -+ return Base::at(kY); -+ } -+ -+ /// Accesses the y element of the imaginary part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Reference y() { -+ return Base::at(kY); -+ } -+ -+ /// Accesses the z element of the imaginary part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Element z() const { -+ return Base::at(kZ); -+ } -+ -+ /// Accesses the z element of the imaginary part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Reference z() { -+ return Base::at(kZ); -+ } -+ -+ /// Accesses the real part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Element w() const { -+ return Base::at(kW); -+ } -+ -+ /// Accesses the real part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Reference w() { -+ return Base::at(kW); -+ } -+ -+ /// Returns the pure imaginary part of the quaternion as a 3-vector -+ CUTLASS_HOST_DEVICE -+ Matrix3x1 pure() const { -+ return Matrix3x1(x(), y(), z()); -+ } -+ -+ /// Returns a quaternion representation of a spatial rotation given a unit-length axis and -+ /// a rotation in radians. -+ CUTLASS_HOST_DEVICE -+ static Quaternion rotation( -+ Matrix3x1 const &axis_unit, ///< axis of rotation (assumed to be unit length) -+ Element theta) { ///< angular rotation in radians -+ -+ Element s = fast_sin(theta / Element(2)); -+ -+ return Quaternion( -+ s * axis_unit[0], -+ s * axis_unit[1], -+ s * axis_unit[2], -+ fast_cos(theta / Element(2)) -+ ); -+ } -+ -+ /// Returns a quaternion representation of a spatial rotation represented as a -+ /// unit-length rotation axis (r_x, r_y, r_z) and an angular rotation in radians -+ CUTLASS_HOST_DEVICE -+ static Quaternion rotation( -+ Element r_x, -+ Element r_y, -+ Element r_z, -+ Element theta) { ///< angular rotation in radians -+ -+ return rotation({r_x, r_y, r_z}, theta); -+ } -+ -+ /// Geometric rotation of a 3-element vector -+ CUTLASS_HOST_DEVICE -+ Matrix3x1 rotate(Matrix3x1 const &rhs) const { -+ return (*this * Quaternion(rhs, 0) * reciprocal(*this)).pure(); -+ } -+ -+ /// Inverse rotation operation -+ CUTLASS_HOST_DEVICE -+ Matrix3x1 rotate_inv(Matrix3x1 const &rhs) const { -+ return (reciprocal(*this) * Quaternion(rhs, 0) * *this).pure(); -+ } -+ -+ /// Rotates a 3-vector assuming this is a unit quaternion (a spinor) -+ CUTLASS_HOST_DEVICE -+ Matrix3x1 spinor(Matrix3x1 const &rhs) const { -+ return (*this * Quaternion(rhs, 0) * conj(*this)).pure(); -+ } -+ -+ /// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor) -+ CUTLASS_HOST_DEVICE -+ Matrix3x1 spinor_inv(Matrix3x1 const &rhs) const { -+ return (conj(*this) * Quaternion(rhs, 0) * *this).pure(); -+ } -+ -+ /// In-place addition -+ template -+ CUTLASS_HOST_DEVICE -+ Quaternion &operator+=(Quaternion const &rhs) { -+ *this = (*this + rhs); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ template -+ CUTLASS_HOST_DEVICE -+ Quaternion &operator-=(Quaternion const &rhs) { -+ *this = (*this - rhs); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ template -+ CUTLASS_HOST_DEVICE -+ Quaternion &operator*=(Quaternion const &rhs) { -+ *this = (*this * rhs); -+ return *this; -+ } -+ -+ /// Scalar multiplication -+ template -+ CUTLASS_HOST_DEVICE -+ Quaternion &operator*=(Element s) { -+ *this = (*this * s); -+ return *this; -+ } -+ -+ /// In-place Division -+ template -+ CUTLASS_HOST_DEVICE -+ Quaternion &operator/=(Quaternion const &rhs) { -+ *this = (*this / rhs); -+ return *this; -+ } -+ -+ /// In-place Division -+ template -+ CUTLASS_HOST_DEVICE -+ Quaternion &operator/=(Element s) { -+ *this = (*this / s); -+ return *this; -+ } -+ -+ /// Computes a 3x3 rotation matrix (row-major representation) -+ CUTLASS_HOST_DEVICE -+ Matrix3x3 as_rotation_matrix_3x3() const { -+ Matrix3x3 m( -+ w() * w() + x() * x() - y() * y() - z() * z(), -+ 2 * x() * y() - 2 * w() * z(), -+ 2 * x() * z() + 2 * w() * y(), -+ -+ 2 * x() * y() + 2 * w() * z(), -+ w() * w() - x() * x() + y() * y() - z() * z(), -+ 2 * y() * z() - 2 * w() * x(), -+ -+ 2 * x() * z() - 2 * w() * y(), -+ 2 * y() * z() + 2 * w() * x(), -+ w() * w() - x() * x() - y() * y() + z() * z() -+ ); -+ return m; -+ } -+ -+ /// Computes a 4x4 rotation matrix (row-major representation) -+ CUTLASS_HOST_DEVICE -+ Matrix4x4 as_rotation_matrix_4x4() const { -+ Matrix4x4 m = Matrix4x4::identity(); -+ m.set_slice_3x3(as_rotation_matrix_3x3()); -+ return m; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Constructs a quaternion that is non-zero only in its real element. -+template -+CUTLASS_HOST_DEVICE -+Quaternion make_Quaternion( -+ Element w) { ///< real part -+ -+ return Quaternion(w); -+} -+ -+/// Constructs a quaternion from a vector and real -+template -+CUTLASS_HOST_DEVICE -+Quaternion make_Quaternion( -+ Matrix3x1 const &imag, ///< imaginary party as a vector -+ Element w) { ///< real part -+ -+ return Quaternion(imag, w); -+} -+ -+/// Constructs a quaternion from a unit-length rotation axis and a rotation -+/// angle in radians -+template -+CUTLASS_HOST_DEVICE -+Quaternion make_QuaternionRotation( -+ Matrix3x1 const &axis_unit, ///< rotation axis (unit-length) -+ Element w) { ///< rotation angle in radians -+ -+ return Quaternion::rotation(axis_unit, w); -+} -+ -+/// Constructs a quaternion q = xi + yj + zk + w -+template -+CUTLASS_HOST_DEVICE -+Quaternion make_Quaternion(Element x, Element y, Element z, Element w) { -+ return Quaternion(x, y, z, w); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns the real part of the quaternion number -+template -+CUTLASS_HOST_DEVICE -+Element const &real(Quaternion const &q) { -+ return q.w(); -+} -+ -+/// Returns the real part of the quaternion number -+template -+CUTLASS_HOST_DEVICE -+Element &real(Quaternion &q) { -+ return q.w(); -+} -+ -+/// Returns the magnitude of the quaternion number -+template -+CUTLASS_HOST_DEVICE -+Element abs(Quaternion const &q) { -+ return fast_sqrt(norm(q)); -+} -+ -+/// Quaternion conjugate -+template -+CUTLASS_HOST_DEVICE -+Quaternion conj(Quaternion const &q) { -+ return make_Quaternion( -+ -q.x(), -+ -q.y(), -+ -q.z(), -+ q.w() -+ ); -+} -+ -+/// Computes the squared magnitude of the quaternion -+template -+CUTLASS_HOST_DEVICE -+Element norm(Quaternion const &q) { -+ return q.x() * q.x() + q.y() * q.y() + q.z() * q.z() + q.w() * q.w(); -+} -+ -+/// Quaternion reciprocal -+template -+CUTLASS_HOST_DEVICE -+Quaternion reciprocal(Quaternion const &q) { -+ -+ Element nsq = norm(q); -+ -+ return make_Quaternion( -+ -q.x() / nsq, -+ -q.y() / nsq, -+ -q.z() / nsq, -+ q.w() / nsq -+ ); -+} -+ -+/// Returns a unit-length quaternion -+template -+CUTLASS_HOST_DEVICE -+Quaternion unit(Quaternion const &q) { -+ -+ Element rcp_mag = Element(1) / abs(q); -+ -+ return make_Quaternion( -+ q.x() * rcp_mag, -+ q.y() * rcp_mag, -+ q.z() * rcp_mag, -+ q.w() * rcp_mag -+ ); -+} -+ -+/// Quaternion exponential -+template -+CUTLASS_HOST_DEVICE -+Quaternion exp(Quaternion const &q) { -+ -+ Element exp_ = fast_exp(q.w()); -+ Element imag_norm = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z()); -+ Element sin_norm = fast_sin(imag_norm); -+ -+ return make_Quaternion( -+ exp_ * q.x() * sin_norm / imag_norm, -+ exp_ * q.y() * sin_norm / imag_norm, -+ exp_ * q.z() * sin_norm / imag_norm, -+ exp_ * fast_cos(imag_norm) -+ ); -+} -+ -+/// Quaternion natural logarithm -+template -+CUTLASS_HOST_DEVICE -+Quaternion log(Quaternion const &q) { -+ -+ Element v = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z()); -+ Element s = fast_acos(q.w() / abs(q)) / v; -+ -+ return make_Quaternion( -+ q.x() * s, -+ q.y() * s, -+ q.z() * s, -+ fast_log(q.w()) -+ ); -+} -+ -+/// Gets the rotation angle from a unit-length quaternion -+template -+CUTLASS_HOST_DEVICE -+Element get_rotation_angle(Quaternion const &q_unit) { -+ return fast_acos(q_unit.w()) * Element(2); -+} -+ -+/// Gets the rotation axis from a unit-length quaternion -+template -+CUTLASS_HOST_DEVICE -+Matrix3x1 get_rotation_axis(Quaternion const &q_unit) { -+ return q_unit.pure().unit(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Equality operator -+template -+CUTLASS_HOST_DEVICE -+bool operator==(Quaternion const &lhs, Quaternion const &rhs) { -+ return lhs.x() == rhs.x() && -+ lhs.y() == rhs.y() && -+ lhs.z() == rhs.z() && -+ lhs.w() == rhs.w(); -+} -+ -+/// Inequality operator -+template -+CUTLASS_HOST_DEVICE -+bool operator!=(Quaternion const &lhs, Quaternion const &rhs) { -+ return !(lhs == rhs); -+} -+ -+/// Quaternion scalar multiplication -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator*(Quaternion q, Element s) { -+ return make_Quaternion( -+ q.x() * s, -+ q.y() * s, -+ q.z() * s, -+ q.w() * s -+ ); -+} -+ -+/// Quaternion scalar multiplication -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator*(Element s, Quaternion const &q) { -+ return make_Quaternion( -+ s * q.x(), -+ s * q.y(), -+ s * q.z(), -+ s * q.w() -+ ); -+} -+ -+/// Quaternion scalar division -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator/(Quaternion const &q, Element s) { -+ return make_Quaternion( -+ q.x() / s, -+ q.y() / s, -+ q.z() / s, -+ q.w() / s -+ ); -+} -+ -+/// Quaternion unary negation -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator-(Quaternion const &q) { -+ return make_Quaternion( -+ -q.x(), -+ -q.y(), -+ -q.z(), -+ -q.w() -+ ); -+} -+ -+/// Quaternion addition -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator+(Quaternion const &lhs, Quaternion const &rhs) { -+ return make_Quaternion( -+ lhs.x() + rhs.x(), -+ lhs.y() + rhs.y(), -+ lhs.z() + rhs.z(), -+ lhs.w() + rhs.w() -+ ); -+} -+ -+/// Quaternion subtraction -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator-(Quaternion const &lhs, Quaternion const &rhs) { -+ return make_Quaternion( -+ lhs.x() - rhs.x(), -+ lhs.y() - rhs.y(), -+ lhs.z() - rhs.z(), -+ lhs.w() - rhs.w() -+ ); -+} -+ -+/// Quaternion product -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator*(Quaternion const &lhs, Quaternion const &rhs) { -+ return make_Quaternion( -+ lhs.w() * rhs.x() + rhs.w() * lhs.x() + lhs.y() * rhs.z() - lhs.z() * rhs.y(), -+ lhs.w() * rhs.y() + rhs.w() * lhs.y() + lhs.z() * rhs.x() - lhs.x() * rhs.z(), -+ lhs.w() * rhs.z() + rhs.w() * lhs.z() + lhs.x() * rhs.y() - lhs.y() * rhs.x(), -+ lhs.w() * rhs.w() - lhs.x() * rhs.x() - lhs.y() * rhs.y() - lhs.z() * rhs.z() -+ ); -+} -+ -+/// Quaternion division -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator/(Quaternion const &lhs, Quaternion const &rhs) { -+ return lhs * reciprocal(rhs); -+} -+ -+/// Quaternion scalar division -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator/(Element s, Quaternion const &q) { -+ return s * reciprocal(q); -+} -+ -+/// Comparison -+template -+CUTLASS_HOST_DEVICE -+bool operator<(Quaternion const &lhs, Quaternion const &rhs) { -+ //TODO -+ return true; -+} -+ -+/// Rotates a 3-vector assuming this is a unit quaternion (a spinor). This avoids computing -+/// a reciprocal. -+template -+CUTLASS_HOST_DEVICE -+Matrix3x1 spinor_rotation( -+ Quaternion const &spinor, /// unit-length quaternion -+ Matrix3x1 const &rhs) { /// arbitrary 3-vector -+ -+ return (spinor * Quaternion(rhs, 0) * conj(spinor)).pure(); -+} -+ -+/// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor). This avoids computing -+/// a reciprocal. -+template -+CUTLASS_HOST_DEVICE -+Matrix3x1 spinor_rotation_inv( -+ Quaternion const &spinor, /// unit-length quaternion -+ Matrix3x1 const &rhs) { /// arbitrary 3-vector -+ -+ return (conj(spinor) * Quaternion(rhs, 0) * spinor).pure(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Quaternion-valued type. -+template -+struct RealType< Quaternion > { -+ using Type = T; -+ -+ /// Number of elements -+ static int const kExtent = Quaternion::kExtent; -+ -+CUTLASS_HOST_DEVICE -+ static Quaternion from_real(double x) { -+ return Quaternion(static_cast(x)); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// Factories -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+CUTLASS_HOST_DEVICE -+cutlass::Quaternion from_real >(double r) { -+ return cutlass::Quaternion(half_t(r)); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+cutlass::Quaternion from_real >(double r) { -+ return cutlass::Quaternion(float(r)); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+cutlass::Quaternion from_real >(double r) { -+ return cutlass::Quaternion(r); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// functional.h numeric specializations -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct multiplies> { -+ CUTLASS_HOST_DEVICE -+ Quaternion operator()(Quaternion lhs, Quaternion const &rhs) const { -+ lhs = lhs * rhs; -+ return lhs; -+ } -+}; -+ -+/// Squares with optional conversion -+template -+struct magnitude_squared, Output> { -+ CUTLASS_HOST_DEVICE -+ Output operator()(Quaternion lhs) const { -+ multiplies mul_op; -+ -+ Output y_w = Output(lhs.w()); -+ Output y_x = Output(lhs.x()); -+ Output y_y = Output(lhs.y()); -+ Output y_z = Output(lhs.z()); -+ -+ return mul_op(y_w, y_w) + mul_op(y_x, y_x) + mul_op(y_y, y_y) + \ -+ mul_op(y_z, y_z); -+ } -+}; -+ -+template -+struct multiply_add, Quaternion, Quaternion> { -+ CUTLASS_HOST_DEVICE -+ Quaternion operator()( -+ Quaternion const &a, -+ Quaternion const &b, -+ Quaternion const &c) const { -+ -+ T x = c.x(); -+ T y = c.y(); -+ T z = c.z(); -+ T w = c.w(); -+ -+ x += a.w() * b.x(); -+ x += b.w() * a.x(); -+ x += a.y() * b.z(); -+ x += -a.z() * b.y(), -+ -+ y += a.w() * b.y(); -+ y += b.w() * a.y(); -+ y += a.z() * b.x(); -+ y += -a.x() * b.z(); -+ -+ z += a.w() * b.z(); -+ z += b.w() * a.z(); -+ z += a.x() * b.y(); -+ z += -a.y() * b.x(); -+ -+ w += a.w() * b.w(); -+ w += -a.x() * b.x(); -+ w += -a.y() * b.y(); -+ w += -a.z() * b.z(); -+ -+ return cutlass::make_Quaternion(x, y, z, w); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/real.h b/3rdparty/cutlass/include/cutlass/real.h -new file mode 100644 -index 0000000..ed9018a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/real.h -@@ -0,0 +1,61 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/** -+ \file -+ \brief This class provides helpers to support real<> and complex<> types in generic code. -+*/ -+ -+#pragma once -+ -+namespace cutlass { -+ -+/// Used to determine the real-valued underlying type of a numeric type T. -+template -+struct RealType { -+ using Type = T; -+ -+ /// Number of elements -+ static int const kExtent = 1; -+ -+CUTLASS_HOST_DEVICE -+ static T from_real(double x) { -+ return static_cast(x); -+ } -+}; -+ -+template -+CUTLASS_HOST_DEVICE -+static T from_real(double r) { -+ return T(r); -+} -+ -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/reduction/device/reduce_split_k.h b/3rdparty/cutlass/include/cutlass/reduction/device/reduce_split_k.h -new file mode 100644 -index 0000000..92e1f61 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/device/reduce_split_k.h -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over densely packed tensors in global memory -+*/ -+ -+#pragma once -+ -+#include "cutlass/device_kernel.h" -+#include "cutlass/reduction/kernel/reduce_split_k.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ReductionKernel_ -+> -+class ReduceSplitK { -+public: -+ using ReductionKernel = ReductionKernel_; -+ -+ using Shape = typename ReductionKernel::Shape; -+ using ReductionOp = typename ReductionKernel::ReductionOp; -+ using OutputOp = typename ReductionKernel::OutputOp; -+ -+ using ElementWorkspace = typename ReductionKernel::ElementWorkspace; -+ using ElementAccumulator = typename ReductionKernel::ElementAccumulator; -+ using ElementOutput = typename ReductionKernel::ElementOutput; -+ -+ using WorkspaceTensorRef = typename ReductionKernel::WorkspaceTensorRef; -+ using OutputTensorRef = typename ReductionKernel::OutputTensorRef; -+ -+ using StrideIndex = typename ReductionKernel::StrideIndex; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ MatrixCoord problem_size; -+ int partitions; -+ size_t partition_stride; -+ WorkspaceTensorRef workspace; -+ OutputTensorRef destination; -+ OutputTensorRef source; -+ typename OutputOp::Params output; -+ typename ReductionOp::Params reduction; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() : -+ problem_size(0, 0), -+ partitions(1), -+ partition_stride(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ MatrixCoord const & problem_size -+ ): -+ problem_size(problem_size) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ MatrixCoord problem_size_, -+ int partitions_, -+ size_t partition_stride_, -+ WorkspaceTensorRef workspace_, -+ OutputTensorRef destination_, -+ OutputTensorRef source_, -+ typename OutputOp::Params output_ = typename OutputOp::Params(), -+ typename ReductionOp::Params reduction_ = typename ReductionOp::Params() -+ ): -+ problem_size(problem_size_), -+ partitions(partitions_), -+ partition_stride(partition_stride_), -+ workspace(workspace_), -+ destination(destination_), -+ source(source_), -+ output(output_), -+ reduction(reduction_) -+ { -+ -+ } -+ -+ }; -+ -+private: -+ /// Kernel parameters object -+ typename ReductionKernel::Params params_; -+ -+public: -+ /// Constructs Reduction SplitK -+ ReduceSplitK() { } -+ -+ /// Determines whether the ReduceSplitK can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ // needs no additional workspace -+ return 0; -+ } -+ -+ /// Initializes Reduction state from arguments. -+ Status initialize( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ // initialize the params structure from the arguments -+ params_ = typename ReductionKernel::Params( -+ args.problem_size, -+ args.partitions, -+ args.partition_stride, -+ args.workspace, -+ args.destination, -+ args.source, -+ args.output, -+ args.reduction -+ ); -+ -+ return Status::kSuccess; -+ -+ } -+ -+ /// Initializes Reduction kernel state from arguments. -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ // update the params structure from the arguments -+ params_.workspace.reset(args.workspace.non_const_ref().data()); -+ params_.destination.reset(args.destination.non_const_ref().data()); -+ params_.source.reset(args.source.non_const_ref().data()); -+ params_.output = args.output; -+ params_.reduction = args.reduction; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ // -+ // Launch reduction kernel -+ // -+ dim3 block = ReductionKernel::block_shape(); -+ dim3 grid = ReductionKernel::grid_shape(params_.problem_size); -+ -+ Kernel<<< grid, block, 0, stream >>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace reduction -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce.h b/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce.h -new file mode 100644 -index 0000000..31d50f6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce.h -@@ -0,0 +1,264 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over one or more ranks of an affine tensor -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/reduction/device/tensor_reduce_affine_strided.h" -+#include "cutlass/reduction/device/tensor_reduce_affine_contiguous.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tensor reduction operator on specific CUTLASS layouts over exactly one index -+template < -+ typename ElementOutput_, -+ typename ElementSource_, -+ typename Layout_, -+ typename ReductionOp_, -+ int VectorLength_ = 1, -+ typename ElementCompute_ = ElementOutput_ -+> -+struct TensorReduction { -+ -+ using ElementOutput = ElementOutput_; -+ using ElementSource = ElementSource_; -+ using Layout = Layout_; -+ using ReductionOp = ReductionOp_; -+ static int const kVectorLength = VectorLength_; -+ using ElementCompute = ElementCompute_; -+ -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Reduction operator -+ using ReductionDeviceStridedOperator = TensorReductionAffineStrided< -+ 4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute -+ >; -+ -+ using ReductionDeviceContiguousOperator = TensorReductionAffineContiguous< -+ 4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute -+ >; -+ -+ // -+ // Data members -+ // -+ -+ ReductionDeviceStridedOperator reduction_strided; -+ ReductionDeviceContiguousOperator reduction_contiguous; -+ int reduction_index; -+ -+ // -+ // Methods -+ // -+ -+ /// -+ TensorReduction( -+ TensorCoord extent, -+ int reduction_index_ -+ ): -+ reduction_index(reduction_index_) { -+ -+ Coord<4> extent_affine; -+ -+ switch (reduction_index) { -+ case 0: -+ extent_affine[0] = extent[1]; -+ extent_affine[1] = extent[2]; -+ extent_affine[2] = extent[0]; -+ extent_affine[3] = extent[3]; -+ break; -+ case 1: -+ extent_affine[0] = extent[0]; -+ extent_affine[1] = extent[2]; -+ extent_affine[2] = extent[1]; -+ extent_affine[3] = extent[3]; -+ break; -+ case 2: -+ extent_affine[0] = extent[0]; -+ extent_affine[1] = extent[1]; -+ extent_affine[2] = extent[2]; -+ extent_affine[3] = extent[3]; -+ break; -+ case 3: -+ extent_affine[0] = extent[0]; -+ extent_affine[1] = extent[1]; -+ extent_affine[2] = extent[2]; -+ extent_affine[3] = extent[3]; -+ break; -+ default: break; -+ } -+ -+ if (reduction_index == 3) { -+ reduction_contiguous = ReductionDeviceContiguousOperator(extent_affine); -+ } -+ else { -+ reduction_strided = ReductionDeviceStridedOperator(extent_affine); -+ } -+ } -+ -+ /// Simple check to verify the object is initialized correctly -+ bool good() const { -+ if (reduction_index == 3) { -+ return reduction_contiguous.good(); -+ } -+ return reduction_strided.good(); -+ } -+ -+ /// Size of one workspace -+ int64_t workspace_stride() const { -+ if (reduction_index == 3) { -+ return reduction_contiguous.workspace_stride(); -+ } -+ else { -+ return reduction_strided.workspace_stride(); -+ } -+ } -+ -+ /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs -+ int64_t workspace_size() const { -+ if (reduction_index == 3) { -+ return reduction_contiguous.workspace_size(); -+ } -+ else { -+ return reduction_strided.workspace_size(); -+ } -+ } -+ -+ /// Helper to use overloaded function call operator -+ Status reduce( -+ TensorRef dst_ref, -+ TensorRef src_ref, -+ void *device_workspace_ptr = nullptr, -+ ElementCompute reduction_identity = ElementCompute(), -+ ReductionOp reduction_op = ReductionOp(), -+ cudaStream_t stream = nullptr) { -+ -+ int64_t src_stride[3]; -+ int64_t dst_stride[3]; -+ -+ switch (reduction_index) { -+ case 0: -+ src_stride[0] = src_ref.stride()[1]; -+ src_stride[1] = src_ref.stride()[0]; -+ src_stride[2] = src_ref.stride()[2]; -+ dst_stride[0] = dst_ref.stride()[1]; -+ dst_stride[1] = dst_ref.stride()[0]; -+ break; -+ case 1: -+ src_stride[0] = src_ref.stride()[2]; -+ src_stride[1] = src_ref.stride()[0]; -+ src_stride[2] = src_ref.stride()[1]; -+ dst_stride[0] = dst_ref.stride()[2]; -+ dst_stride[1] = dst_ref.stride()[0]; -+ break; -+ case 2: -+ src_stride[0] = src_ref.stride()[2]; -+ src_stride[1] = src_ref.stride()[1]; -+ src_stride[2] = src_ref.stride()[0]; -+ dst_stride[0] = dst_ref.stride()[2]; -+ dst_stride[1] = dst_ref.stride()[1]; -+ break; -+ case 3: -+ src_stride[0] = src_ref.stride()[2]; -+ src_stride[1] = src_ref.stride()[1]; -+ src_stride[2] = src_ref.stride()[0]; -+ -+ dst_stride[0] = dst_ref.stride()[2]; -+ dst_stride[1] = dst_ref.stride()[1]; -+ dst_stride[2] = dst_ref.stride()[0]; -+ -+ default: break; -+ } -+ -+ if (reduction_index == 3) { -+ return reduction_contiguous( -+ dst_ref.data(), -+ dst_stride, -+ src_ref.data(), -+ src_stride, -+ device_workspace_ptr, -+ reduction_identity, -+ reduction_op, -+ stream); -+ } -+ else { -+ return reduction_strided( -+ dst_ref.data(), -+ dst_stride, -+ src_ref.data(), -+ src_stride, -+ device_workspace_ptr, -+ reduction_identity, -+ reduction_op, -+ stream); -+ } -+ } -+ -+ Status operator()( -+ TensorRef dst_ref, -+ TensorRef src_ref, -+ void *device_workspace_ptr = nullptr, -+ ElementCompute reduction_identity = ElementCompute(), -+ ReductionOp reduction_op = ReductionOp(), -+ cudaStream_t stream = nullptr) { -+ -+ return reduce( -+ dst_ref, -+ src_ref, -+ device_workspace_ptr, -+ reduction_identity, -+ reduction_op, -+ stream); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reduction -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h b/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h -new file mode 100644 -index 0000000..234a1c4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h -@@ -0,0 +1,373 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over one or more ranks of an affine tensor -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tensor reduction operator on layouts which are affine -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (e.g. ND => 2) -+ typename ElementOutput_, -+ typename ElementSource_, -+ typename ReductionOp_, -+ int VectorLength = 1, -+ typename ElementCompute_ = ElementOutput_, -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+struct TensorReductionAffineContiguous { -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ -+ using ElementOutput = ElementOutput_; -+ using ElementSource = ElementSource_; -+ using ReductionOp = ReductionOp_; -+ using ElementCompute = ElementCompute_; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal status field -+ Status status; -+ -+ /// Extent of tensor in source layout -+ Coord extent; -+ -+ /// Number of points in the outer index space -+ int64_t outer_count; -+ -+ /// Number of elements in the inner index space -+ int64_t inner_count; -+ -+ /// Number of workspaces needed -+ int workspace_count; -+ -+ /// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner) -+ dim3 grid_shape; -+ -+ /// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner) -+ dim3 threadblock_shape; -+ -+ /// CUDA grid shape for the final reduction step if needed -+ dim3 grid_final; -+ -+ /// CUDA threadblock shape for the final reduction step if needed -+ dim3 threadblock_final; -+ -+private: -+ // -+ // Methods -+ // -+ -+ /// Helper to reshape 'count' such that it is less than 2 x 'ext' -+ static int reshape_pow2(int ext, int count) { -+ if (ext > count) { -+ return 1; -+ } -+ int x = 1; -+ for (; count >= ext * 2; ) { -+ count >>= 1; -+ x <<= 1; -+ } -+ return x; -+ } -+ -+public: -+ -+ /// Default ctor -+ TensorReductionAffineContiguous(): -+ status(Status::kErrorInvalidProblem), -+ extent(), -+ outer_count(0), -+ inner_count(0), -+ workspace_count(0), -+ grid_shape(0, 0, 0), -+ threadblock_shape(0, 0, 0) { } -+ -+ /// Constructor -+ TensorReductionAffineContiguous( -+ Coord extent_, -+ int target_threadblock_count = 128 -+ ): -+ status(Status::kSuccess), -+ extent(extent_), -+ outer_count(0), -+ inner_count(0), -+ workspace_count(0) { -+ -+ // -+ // Plan the parallel mapping strategy. -+ // -+ -+ outer_count = 1; -+ inner_count = 1; -+ -+ // Compute number of elements in strided ranks -+ for (int p = 0; p < kReducedRank; ++p) { -+ outer_count *= extent[p]; -+ } -+ -+ for (int p = 0; p < kInnerRank; ++p) { -+ inner_count *= extent[kReducedRank + p]; -+ } -+ -+ int cta_count_x = 1; -+ int cta_count_y = 1; -+ int cta_count_z = 1; -+ -+ int cta_threads_x = kThreads; -+ int cta_threads_y = 1; -+ int cta_threads_z = 1; -+ -+ // Determine CTA shape -+ int64_t inner_vector_count = inner_count / kVectorLength; -+ -+ // Priority 1. Assign threadblocks to outer indices if possible -+ if (outer_count > target_threadblock_count) { -+ cta_count_x = 1; -+ cta_count_y = target_threadblock_count; -+ cta_count_z = 1; -+ } -+ else { -+ -+ cta_count_y = int(outer_count); -+ int remaining_ctas = target_threadblock_count / cta_count_y; -+ -+ // Priority 2. Assign inner dimensions to one CTA -+ if (inner_vector_count > cta_threads_x) { -+ int64_t cta_z_bound = inner_vector_count / cta_threads_x; -+ if (cta_z_bound > remaining_ctas) { -+ cta_count_z = remaining_ctas; -+ } -+ else { -+ cta_count_z = int(cta_z_bound); -+ } -+ } -+ else { -+ cta_threads_x = reshape_pow2(int(inner_vector_count), cta_threads_x); -+ cta_count_z = 1; -+ } -+ } -+ -+ grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z); -+ threadblock_shape = dim3(cta_threads_x, cta_threads_y, cta_threads_z); -+ -+ workspace_count = (cta_count_z > 1 ? cta_count_z : 0); -+ -+ // Determine shape of final reduction kernel if needed -+ if (workspace_count) { -+ -+ int final_threads = kThreads; -+ int final_ctas = 1; -+ -+ if (outer_count > kThreads) { -+ final_ctas = int(outer_count + kThreads - 1) / kThreads; -+ } -+ else { -+ final_threads = int(outer_count); -+ } -+ -+ grid_final = dim3(final_ctas, 1, 1); -+ threadblock_final = dim3(final_threads, 1, 1); -+ } -+ else { -+ grid_final = dim3(0, 0, 0); -+ threadblock_final = dim3(0, 0, 0); -+ } -+ } -+ -+ /// Simple check to verify the object is initialized correctly -+ bool good() const { -+ return status == Status::kSuccess; -+ } -+ -+ /// Size (in bytes) of workspace elements which are densely packed together -+ int64_t workspace_stride() const { -+ -+ // Error condition -+ if (!good()) { -+ return 0; -+ } -+ -+ return outer_count * sizeof_bits::value / 8; -+ } -+ -+ /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs -+ int64_t workspace_size() const { -+ -+ // Error condition -+ if (!good()) { -+ return 0; -+ } -+ -+ // No reduction across CTAs -+ if (grid_shape.z == 1) { -+ return 0; -+ } -+ -+ return workspace_stride() * grid_shape.z; -+ } -+ -+ /// Performs a reduction -+ Status reduce( -+ ElementOutput *dst_ptr, ///< Pointer to destination tensor -+ int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) -+ ElementSource const *src_ptr, ///< Pointer to source tensor -+ int64_t src_stride[], ///< Stride vector (of length kRank - 1) -+ void *device_workspace_ptr = nullptr, ///< Device workspace -+ ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element -+ ReductionOp reduction_op = ReductionOp(), ///< Reduction operator -+ cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched -+ -+ // Initial status check -+ if (!good()) { -+ return status; -+ } -+ -+ // Guard against null workspace -+ if (workspace_count > 1 && device_workspace_ptr == nullptr) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ // Define reduction kernel -+ using ReductionKernel = kernel::TensorReductionAffineContiguous< -+ kRank, -+ kReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ kVectorLength, -+ ElementCompute, -+ kThreads>; -+ -+ using FinalReductionKernel = kernel::TensorReductionAffineContiguousFinal< -+ kRank, -+ kReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ kVectorLength, -+ ElementCompute, -+ kThreads>; -+ -+ using Params = typename ReductionKernel::Params; -+ -+ // Construct the parameters -+ Params params( -+ extent, -+ dst_ptr, -+ dst_stride, -+ src_ptr, -+ src_stride, -+ static_cast(device_workspace_ptr), -+ workspace_stride(), -+ workspace_count, -+ reduction_op, -+ reduction_identity); -+ -+ // Shared memory size -+ int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage); -+ -+ // Launch the kernel -+ Kernel<<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params); -+ -+ // Check error condition -+ if (cudaPeekAtLastError() == cudaSuccess) { -+ status = Status::kSuccess; -+ } -+ else { -+ status = Status::kErrorInternal; -+ } -+ -+ // Final reduction kernel -+ if (workspace_count) { -+ Kernel<<< grid_final, threadblock_final, 0, stream >>>(params); -+ } -+ -+ // Check error condition -+ if (cudaPeekAtLastError() == cudaSuccess) { -+ status = Status::kSuccess; -+ } -+ else { -+ status = Status::kErrorInternal; -+ } -+ -+ return status; -+ } -+ -+ /// Helper to use overloaded function call operator -+ Status operator()( -+ ElementOutput *dst_ptr, ///< Pointer to destination tensor -+ int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) -+ ElementSource const *src_ptr, ///< Pointer to source tensor -+ int64_t src_stride[], ///< Stride vector (of length kRank - 1) -+ void *device_workspace_ptr = nullptr, ///< Pointer to device workspace -+ ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element -+ ReductionOp reduction_op = ReductionOp(), ///< Reduction operator -+ cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched -+ -+ return reduce(dst_ptr, dst_stride, src_ptr, src_stride, device_workspace_ptr, reduction_identity, reduction_op, stream); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reduction -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h b/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h -new file mode 100644 -index 0000000..e613934 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h -@@ -0,0 +1,361 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over one or more ranks of an affine tensor -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/reduction/kernel/tensor_reduce_affine_strided.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tensor reduction operator on layouts which are affine -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) -+ typename ElementOutput_, -+ typename ElementSource_, -+ typename ReductionOp_, -+ int VectorLength = 1, -+ typename ElementCompute_ = ElementOutput_, -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+struct TensorReductionAffineStrided { -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ -+ using ElementOutput = ElementOutput_; -+ using ElementSource = ElementSource_; -+ using ReductionOp = ReductionOp_; -+ using ElementCompute = ElementCompute_; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal status field -+ Status status; -+ -+ /// Extent of tensor in source layout -+ Coord extent; -+ -+ /// Number of points in the outer index space -+ int64_t outer_count; -+ -+ /// Number of elements in the inner index space -+ int64_t inner_count; -+ -+ /// Number of workspaces needed -+ int workspace_count; -+ -+ /// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner) -+ dim3 grid_shape; -+ -+ /// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner) -+ dim3 threadblock_shape; -+ -+ /// CUDA grid shape for the final reduction step if needed -+ dim3 grid_final; -+ -+ /// CUDA threadblock shape for the final reduction step if needed -+ dim3 threadblock_final; -+ -+private: -+ // -+ // Methods -+ // -+ -+ /// Helper to reshape 'count' such that it is less than 2 x 'ext' -+ static int reshape_pow2(int ext, int count) { -+ if (ext > count) { -+ return 1; -+ } -+ int x = 1; -+ for (; count >= ext * 2; ) { -+ count >>= 1; -+ x <<= 1; -+ } -+ return x; -+ } -+ -+public: -+ -+ /// Default ctor -+ TensorReductionAffineStrided(): -+ status(Status::kErrorInvalidProblem), -+ extent(), -+ outer_count(0), -+ inner_count(0), -+ workspace_count(0), -+ grid_shape(0, 0, 0), -+ threadblock_shape(0, 0, 0) { } -+ -+ /// Constructor -+ TensorReductionAffineStrided( -+ Coord extent_, -+ int target_threadblock_count = 128 -+ ): -+ status(Status::kSuccess), -+ extent(extent_), -+ outer_count(0), -+ inner_count(0), -+ workspace_count(0) { -+ -+ // -+ // Plan the parallel mapping strategy. -+ // -+ -+ outer_count = 1; -+ inner_count = 1; -+ -+ // Compute number of elements in strided ranks -+ for (int p = 0; p < kReducedRank - 1; ++p) { -+ outer_count *= extent[p]; -+ } -+ -+ for (int p = 0; p < kInnerRank; ++p) { -+ inner_count *= extent[kReducedRank + p - 1]; -+ } -+ -+ // Compute plan for the reduction -+ int extent_c = extent[kRank - 1]; -+ int vectors_c = (extent_c -1 + kVectorLength) / kVectorLength; -+ -+ // Determine CTA shape -+ int cta_width = kThreads * kVectorLength; -+ int cta_ways = reshape_pow2(extent_c, cta_width); -+ int cta_threads_x = kThreads / cta_ways; -+ -+ threadblock_shape = dim3(cta_threads_x, 1, std::min(cta_ways, 64)); -+ -+ // This leads to an error. -+ if (threadblock_shape.z > 1) { -+ if (threadblock_shape.y != 1) { -+ status = Status::kErrorInternal; -+ return; -+ } -+ } -+ -+ // Determine grid shape -+ int cta_count_x = (vectors_c + cta_threads_x - 1) / cta_threads_x; -+ int cta_count_y = std::max(1, target_threadblock_count / cta_count_x); -+ -+ // Limit the number of CTAs assigned to outer dimension -+ if (int64_t(cta_count_y * threadblock_shape.y) > outer_count) { -+ cta_count_y = int(outer_count + threadblock_shape.y - 1) / threadblock_shape.y; -+ } -+ -+ // Limit the number of CTAs assigned to inner dimension -+ int cta_count_z = std::max(1, target_threadblock_count / cta_count_y); -+ if (int64_t(cta_count_z * threadblock_shape.z) > inner_count) { -+ cta_count_z = int(inner_count + threadblock_shape.z - 1) / threadblock_shape.z; -+ } -+ -+ grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z); -+ workspace_count = (cta_count_z > 1 ? cta_count_z : 0); -+ -+ // Determine shape of final reduction kernel if needed -+ grid_final = dim3(cta_count_x, int(outer_count)); -+ threadblock_final = dim3(cta_threads_x, 1, 1); -+ } -+ -+ /// Simple check to verify the object is initialized correctly -+ bool good() const { -+ return status == Status::kSuccess; -+ } -+ -+ /// Size of one CTA's workspace -+ int64_t workspace_stride() const { -+ -+ // Error condition -+ if (!good()) { -+ return 0; -+ } -+ -+ int vector_size_bytes = kVectorLength * sizeof_bits::value / 8; -+ -+ return extent[kRank - 1] * vector_size_bytes; -+ } -+ -+ /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs -+ int64_t workspace_size() const { -+ -+ // Error condition -+ if (!good()) { -+ return 0; -+ } -+ -+ // No reduction across CTAs -+ if (grid_shape.z == 1) { -+ return 0; -+ } -+ -+ return workspace_stride() * outer_count * grid_shape.z; -+ } -+ -+ /// Performs a reduction -+ Status reduce( -+ ElementOutput *dst_ptr, ///< Pointer to destination tensor -+ int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) -+ ElementSource const *src_ptr, ///< Pointer to source tensor -+ int64_t src_stride[], ///< Stride vector (of length kRank - 1) -+ void *device_workspace_ptr = nullptr, ///< Device workspace -+ ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity -+ ReductionOp reduction_op = ReductionOp(), ///< Reduction operator -+ cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched -+ -+ // Initial status check -+ if (!good()) { -+ return status; -+ } -+ -+ // Guard against null workspace -+ if (workspace_count > 1 && device_workspace_ptr == nullptr) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ // Define reduction kernel -+ using ReductionKernel = kernel::TensorReductionAffineStrided< -+ kRank, -+ kReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ kVectorLength, -+ ElementCompute, -+ kThreads>; -+ -+ using FinalReductionKernel = kernel::TensorReductionAffineStridedFinal< -+ kRank, -+ kReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ kVectorLength, -+ ElementCompute, -+ kThreads>; -+ -+ using Params = typename ReductionKernel::Params; -+ -+ // Construct the parameters -+ Params params( -+ extent, -+ dst_ptr, -+ dst_stride, -+ src_ptr, -+ src_stride, -+ static_cast(device_workspace_ptr), -+ workspace_stride(), -+ workspace_count, -+ reduction_op, -+ reduction_identity); -+ -+ // Shared memory size -+ int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage); -+ -+ // Launch the kernel -+ Kernel<<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params); -+ -+ // Check error condition -+ if (cudaPeekAtLastError() == cudaSuccess) { -+ status = Status::kSuccess; -+ } -+ else { -+ status = Status::kErrorInternal; -+ } -+ -+ // Final reduction kernel -+ if (workspace_count) { -+ -+ Kernel<<< grid_final, threadblock_final, 0, stream >>>(params); -+ -+ // Check error condition -+ if (cudaPeekAtLastError() == cudaSuccess) { -+ status = Status::kSuccess; -+ } -+ else { -+ status = Status::kErrorInternal; -+ } -+ } -+ -+ return status; -+ } -+ -+ /// Helper to use overloaded function call operator -+ Status operator()( -+ ElementOutput *dst_ptr, ///< Pointer to destination tensor -+ int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) -+ ElementSource const *src_ptr, ///< Pointer to source tensor -+ int64_t src_stride[], ///< Stride vector (of length kRank - 1) -+ void *device_workspace_ptr = nullptr, ///< Pointer to device workspace -+ ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity -+ ReductionOp reduction_op = ReductionOp(), ///< Reduction operator -+ cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched -+ -+ return reduce( -+ dst_ptr, -+ dst_stride, -+ src_ptr, -+ src_stride, -+ device_workspace_ptr, -+ reduction_identity, -+ reduction_op, -+ stream); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reduction -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h b/3rdparty/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h -new file mode 100644 -index 0000000..99e8aed ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h -@@ -0,0 +1,267 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a final reduction for softmax -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace kernel { -+ -+template < -+ typename ElementNorm_, -+ typename ElementSum_, -+ typename ElementSoftmaxCompute_, -+ typename ThreadblockShape_, -+ bool GroupedProblem = false -+> -+class ApplySoftmaxFinalReduction { -+public: -+ -+ using ElementNorm = ElementNorm_; -+ using ElementSum = ElementSum_; -+ using ElementSoftmaxCompute = ElementSoftmaxCompute_; -+ using ThreadblockShape = ThreadblockShape_; -+ static const bool isGroupedProblem = GroupedProblem; -+ -+ // -+ // Arguments -+ // -+ -+ struct Arguments { -+ -+ cutlass::gemm::GemmCoord* problem_sizes; -+ cutlass::gemm::GemmCoord problem_size; -+ ElementNorm* block_Norm; -+ ElementSum* block_Sum; -+ int64_t* offset_Norm_Device; -+ int64_t* offset_Sum_Device; -+ int64_t batch_stride_Max; -+ int64_t batch_stride_Sum; -+ -+ // -+ // Methods -+ // -+ Arguments() { } -+ -+ // Non-grouped constructor without batching -+ Arguments( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementNorm* block_Norm, -+ ElementSum* block_Sum -+ ): -+ problem_size(problem_size), -+ block_Norm(block_Norm), -+ block_Sum(block_Sum), -+ problem_sizes(nullptr), -+ offset_Norm_Device(nullptr), -+ offset_Sum_Device(nullptr), -+ batch_stride_Max(0), -+ batch_stride_Sum(0) -+ { -+ -+ } -+ -+ // Non-grouped constructor with batching -+ Arguments( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementNorm* block_Norm, -+ ElementSum* block_Sum, -+ int64_t batch_stride_Max, -+ int64_t batch_stride_Sum -+ ): -+ problem_size(problem_size), -+ block_Norm(block_Norm), -+ block_Sum(block_Sum), -+ batch_stride_Max(batch_stride_Max), -+ batch_stride_Sum(batch_stride_Sum), -+ problem_sizes(nullptr), -+ offset_Norm_Device(nullptr), -+ offset_Sum_Device(nullptr) -+ { -+ -+ } -+ -+ -+ // Grouped constructor -+ Arguments( -+ cutlass::gemm::GemmCoord *problem_sizes, -+ ElementNorm* block_Norm, -+ ElementSum* block_Sum, -+ int64_t* offset_Norm_Device, -+ int64_t* offset_Sum_Device -+ ): -+ problem_sizes(problem_sizes), -+ problem_size(cutlass::gemm::GemmCoord(0, 0, 0)), -+ block_Norm(block_Norm), -+ block_Sum(block_Sum), -+ offset_Norm_Device(offset_Norm_Device), -+ offset_Sum_Device(offset_Sum_Device) -+ { -+ -+ } -+ }; -+ -+ struct SharedStorage { -+ -+ -+ }; -+ -+ // -+ // Params struct -+ // -+ -+ struct Params { -+ Arguments args; -+ -+ // -+ // Methods -+ // -+ Params() { } -+ -+ Params(Arguments const &args_): args(args_) { } -+ }; -+ -+private: -+ -+public: -+ -+ CUTLASS_DEVICE -+ ApplySoftmaxFinalReduction() { } -+ -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ apply(params, shared_storage); -+ } -+ -+private: -+ -+ /// Full reduction -+ CUTLASS_DEVICE -+ void apply(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ int tid = threadIdx.x; -+ int bid = blockIdx.x; -+ int bdim = blockDim.x; -+ -+ int block_batch = blockIdx.z; -+ -+ // defining three vars for a general reduction module -+ cutlass::gemm::GemmCoord problem_size = isGroupedProblem ? params.args.problem_sizes[bid] : params.args.problem_size; -+ int m_dim_in_loop = isGroupedProblem ? problem_size.m() : tid + bdim; -+ int access_offset = isGroupedProblem ? 0 : bid * bdim; -+ -+ if (!isGroupedProblem && access_offset + tid >= problem_size.m()) return; -+ -+ ElementNorm *curr_ptr_Max = isGroupedProblem ? \ -+ params.args.block_Norm + params.args.offset_Norm_Device[bid] : \ -+ params.args.block_Norm + block_batch * params.args.batch_stride_Max; -+ ElementSum *curr_ptr_Sum = isGroupedProblem ? \ -+ params.args.block_Sum + params.args.offset_Sum_Device[bid] : \ -+ params.args.block_Sum + block_batch * params.args.batch_stride_Sum; -+ -+ int threadblock_num = (problem_size.n() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; -+ -+ using ConvertSumOutput = cutlass::NumericConverter; -+ using ConvertNormOutput = cutlass::NumericConverter; -+ -+ using ConvertSum = cutlass::NumericConverter; -+ using ConvertNorm = cutlass::NumericConverter; -+ -+ ConvertSum convert_sum; -+ ConvertNorm convert_norm; -+ -+ ConvertSumOutput convert_sum_output; -+ ConvertNormOutput convert_norm_output; -+ -+ uint32_t float_max_bits = 0xff7fffff; -+ float min_float = reinterpret_cast(float_max_bits); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx_m = tid; idx_m < m_dim_in_loop; idx_m += bdim) { -+ ElementNorm *access_n = curr_ptr_Max + idx_m + access_offset; -+ ElementSum *access_s = curr_ptr_Sum + idx_m + access_offset; -+ ElementNorm *access_n_bak = access_n; -+ ElementSum *access_s_bak = access_s; -+ ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float); -+ ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0); -+ ElementNorm fetch_n; -+ ElementSum fetch_s; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { -+ cutlass::arch::global_load(fetch_n, access_n, true); -+ max_val = cutlass::fast_max(max_val, convert_norm(fetch_n)); -+ access_n += problem_size.m(); -+ } -+ -+ access_n = access_n_bak; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { -+ cutlass::arch::global_load(fetch_n, access_n, true); -+ cutlass::arch::global_load(fetch_s, access_s, true); -+ sum_val += convert_sum(fetch_s) * cutlass::fast_exp(convert_norm(fetch_n) - max_val); -+ access_n += problem_size.m(); -+ access_s += problem_size.m(); -+ } -+ -+ ElementSoftmaxCompute inv_sum = cutlass::constants::one() / sum_val; -+ -+ access_n = access_n_bak; -+ access_s = access_s_bak; -+ -+ access_n[0] = convert_norm_output(max_val); -+ access_s[0] = convert_sum_output(inv_sum); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace reduction -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h b/3rdparty/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h -new file mode 100644 -index 0000000..96847e7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h -@@ -0,0 +1,248 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over densely packed tensors in global memory -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, ///< shape of CTA (concept: MatrixShape) -+ typename OutputOp_ , ///< output operator (concept: epilogue::thread operator) -+ typename ReductionOp_, ///< reduction operator (concept: ReductionOperator) -+ int PartitionsPerStage = 4 ///< number of partitions to issue -+> -+class ReduceSplitK { -+public: -+ -+ using Shape = Shape_; -+ using ReductionOp = ReductionOp_; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = OutputOp::kCount; -+ static int const kPartitionsPerStage = PartitionsPerStage; -+ -+ using ElementWorkspace = typename ReductionOp::Element; -+ using ElementAccumulator = typename ReductionOp::ElementAccumulator; -+ using ElementOutput = typename OutputOp::ElementOutput; -+ -+ using WorkspaceTensorRef = TensorRef; -+ using OutputTensorRef = TensorRef; -+ using StrideIndex = typename WorkspaceTensorRef::Layout::Stride::Index; -+ -+ using FragmentWorkspace = AlignedArray; -+ using FragmentAccumulator = Array; -+ using FragmentOutput = AlignedArray; -+ -+ // -+ // Types -+ // -+ -+ /// Params structure -+ struct Params { -+ -+ MatrixCoord problem_size; -+ int partitions; -+ size_t partition_stride; -+ WorkspaceTensorRef workspace; -+ OutputTensorRef destination; -+ OutputTensorRef source; -+ typename OutputOp::Params output; -+ typename ReductionOp::Params reduction; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ MatrixCoord problem_size_, -+ int partitions_, -+ size_t partition_stride_, -+ WorkspaceTensorRef workspace_, -+ OutputTensorRef destination_, -+ OutputTensorRef source_, -+ typename OutputOp::Params output_ = typename OutputOp::Params(), -+ typename ReductionOp::Params reduction_ = typename ReductionOp::Params() -+ ): -+ problem_size(problem_size_), -+ partitions(partitions_), -+ partition_stride(sizeof(FragmentWorkspace) * partition_stride_ / kElementsPerAccess), -+ workspace(workspace_), -+ destination(destination_), -+ source(source_), -+ output(output_), -+ reduction(reduction_) { -+ -+ } -+ }; -+ -+ struct SharedStorage { }; -+ -+ -+public: -+ -+ /// Computes the grid size given a chosen threadblock shape -+ CUTLASS_HOST_DEVICE -+ static dim3 grid_shape( -+ cutlass::MatrixCoord problem_size) { -+ -+ return dim3( -+ (problem_size.row() + Shape::kRow - 1) / Shape::kRow, -+ (problem_size.column() + Shape::kColumn - 1) / Shape::kColumn); -+ } -+ -+ /// Determines the threadblock shape -+ CUTLASS_HOST_DEVICE -+ static dim3 block_shape() { -+ return dim3(Shape::kColumn / kElementsPerAccess, Shape::kRow); -+ } -+ -+ /// Perform a reduction -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &storage) { -+ -+ // Determine CTA position -+ MatrixCoord thread_offset( -+ MatrixCoord::Index(int(blockIdx.x) * Shape::kRow + threadIdx.y), -+ MatrixCoord::Index(int(blockIdx.y) * Shape::kColumn + threadIdx.x * kElementsPerAccess) -+ ); -+ -+ // One guard conditional -+ if (!(thread_offset.row() < params.problem_size.row() && -+ thread_offset.column() < params.problem_size.column())) { -+ -+ return; -+ } -+ -+ -+ ReductionOp reduction_op(params.reduction); -+ -+ FragmentAccumulator accumulator; -+ -+ accumulator.clear(); -+ -+ // -+ // Load the first slice -+ // -+ -+ char const *workspace_ptr = -+ reinterpret_cast( -+ params.workspace.data() + params.workspace.offset(thread_offset)); -+ -+ FragmentWorkspace workspace_frag[kPartitionsPerStage]; -+ -+ // -+ // Construct the output operator -+ // -+ -+ OutputOp output_op(params.output); -+ -+ // -+ // Load and accumulate with a simple batched loading sequence. -+ // -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int k = 0; k < params.partitions; k += kPartitionsPerStage) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPartitionsPerStage; ++i) { -+ if (k + i < params.partitions) { -+ workspace_frag[i] = *reinterpret_cast(workspace_ptr); -+ workspace_ptr += params.partition_stride; -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPartitionsPerStage; ++i) { -+ if (k + i < params.partitions) { -+ accumulator = reduction_op(accumulator, workspace_frag[i]); -+ } -+ } -+ } -+ -+ // -+ // Conditionally load the source -+ // -+ -+ FragmentOutput source_frag; -+ -+ source_frag.clear(); -+ -+ FragmentOutput const *source_ptr = reinterpret_cast( -+ params.source.data() + params.source.offset(thread_offset)); -+ -+ if (output_op.is_source_needed()) { -+ reinterpret_cast(source_frag) = *source_ptr; -+ } -+ -+ // -+ // Compute the output -+ // -+ -+ typename OutputOp::FragmentOutput output_frag = output_op(accumulator, source_frag); -+ -+ // -+ // Store -+ // -+ -+ FragmentOutput *dest_ptr = reinterpret_cast( -+ params.destination.data() + params.destination.offset(thread_offset)); -+ -+ *dest_ptr = reinterpret_cast(output_frag); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace reduction -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h b/3rdparty/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h -new file mode 100644 -index 0000000..d139ed4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h -@@ -0,0 +1,606 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over one or more ranks of an affine tensor -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (i.e. number of outer ranks) -+ typename ElementOutput, ///< Data type of output tensor -+ typename ElementSource, ///< Data type of source tensor -+ typename ReductionOp, ///< Reduction operator -+ int VectorLength = 1, ///< Vector length for memory -+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+struct TensorReductionAffineContiguousParams { -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ -+ Coord extent; /// Extent of source tensor -+ FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank -+ int64_t dst_stride[kReducedRank]; /// stride (units of bytes) - I, J -+ int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K -+ int64_t workspace_stride; /// stride (units of bytes) between workspace -+ int workspace_count; /// number of workspaces -+ -+ uint64_t inner_count; /// Number of elements in reduced index space -+ uint64_t outer_count; /// Number of elements in outer index space -+ -+ ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank -+ ElementSource const * source; /// Poitner to source pointer of rank kRank -+ ReductionOp reduction_op; /// Reduction operator -+ ElementCompute reduction_identity; /// Identity element used by reduction operator -+ ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorReductionAffineContiguousParams() { -+ -+ } -+ -+ /// Ctor -+ TensorReductionAffineContiguousParams( -+ Coord extent_, ///< Extent of source tensor -+ ElementOutput * dst_ptr_, ///< Output tensor data -+ int64_t dst_stride_[], ///< Stride (units of elements) -+ ElementSource const * src_ptr_, ///< Source tensor data -+ int64_t src_stride_[], ///< Stride (units of elements) -+ ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions -+ int64_t workspace_stride_, ///< Stride between workspaces -+ int workspace_count_, ///< Number of workspaces -+ ReductionOp reduction_op_, ///< Reduction operator -+ ElementCompute reduction_identity_ = ElementCompute() ///< Identity element used by reduction operator -+ ): -+ extent(extent_), -+ inner_count(1), -+ outer_count(1), -+ destination(dst_ptr_), -+ source(src_ptr_), -+ device_workspace(device_workspace_), -+ workspace_stride(workspace_stride_), -+ workspace_count(workspace_count_), -+ reduction_op(reduction_op_), -+ reduction_identity(reduction_identity_) { -+ -+ // Initialize divisors for fast div-mod -+ for (int p = 1; p < kRank; ++p) { -+ divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); -+ } -+ -+ int input_size_bits = sizeof_bits::value; -+ int output_size_bits = sizeof_bits::value; -+ -+ // Compute strides in units of bytes -+ for (int p = 0; p < kReducedRank; ++p) { -+ dst_stride[p] = dst_stride_[p] * output_size_bits / 8; -+ } -+ -+ for (int p = 0; p < kRank - 1; ++p) { -+ src_stride[p] = src_stride_[p] * input_size_bits / 8; -+ } -+ -+ // Compute number of elements in strided ranks -+ for (int p = 0; p < kReducedRank; ++p) { -+ outer_count *= uint64_t(extent[p]); -+ } -+ -+ for (int p = 0; p < kInnerRank; ++p) { -+ inner_count *= uint64_t(extent[kRank - 1 - p]); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to reduce a tensor with affine layout over a set of ranks *INCLUDING* the contiguous -+/// rank. This leads to favorable vectorized memory accesses over the contiguous rank. -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) -+ typename ElementOutput, ///< Data type of output tensor -+ typename ElementSource, ///< Data type of source tensor -+ typename ReductionOp, ///< Reduction operator -+ int VectorLength = 1, ///< Vector length for memory -+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+class TensorReductionAffineContiguous { -+public: -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ using ComputeFragment = Array; -+ using SourceFragment = AlignedArray; -+ using OutputFragment = AlignedArray; -+ -+ /// Shared memory allocation used for reduction within the CTA -+ struct SharedStorage { -+ Array workspace; -+ }; -+ -+ /// Parameters structure -+ using Params = TensorReductionAffineContiguousParams< -+ Rank, -+ ReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ VectorLength, -+ ElementCompute, -+ Threads, -+ BatchSize -+ >; -+ -+private: -+ -+ /// Computes the coordinate and offset of a given linear index -+ CUTLASS_DEVICE -+ void compute_inner_coord_and_offset_( -+ Params const ¶ms, -+ Coord & coord, -+ int64_t &src_offset, -+ uint64_t linear_idx) const { -+ -+ // Decompose into a coordinate of rank -+ coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kRank - kInnerRank]); -+ -+ // Compute an offset using the souce stride -+ src_offset = 0; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kInnerRank - 1; ++i) { -+ src_offset += coord[i] * params.src_stride[kReducedRank + i]; -+ } -+ src_offset += coord[kInnerRank - 1] * sizeof_bits::value / 8; -+ } -+ -+ /// Computes the coordinate and offset of a given linear index -+ CUTLASS_DEVICE -+ void compute_outer_coord_and_offset_( -+ Params const ¶ms, -+ Coord & coord, -+ int64_t &dst_offset, -+ int64_t &src_offset, -+ uint64_t linear_idx) const { -+ -+ // Decompose into coordinate of rank -+ coord = CoordinateDecomposition(linear_idx, params.divmod); -+ -+ // Compute offsets using destination and source strides -+ dst_offset = 0; -+ src_offset = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kReducedRank; ++i) { -+ dst_offset += params.dst_stride[i] * coord[i]; -+ src_offset += params.src_stride[i] * coord[i]; -+ } -+ } -+ -+ /// Reduces over the reduction indices yielding a single element -+ CUTLASS_DEVICE -+ ElementCompute reduce_indices_( -+ Params const ¶ms, -+ ElementCompute *threadblock_workspace, -+ char const *src_byte_ptr, -+ int coord_c) { -+ -+ NumericArrayConverter convert_source; -+ ReductionOp reduction_op(params.reduction_op); -+ -+ // -+ // Early exit or initialize to identity element -+ // -+ if (!params.inner_count) { -+ return params.reduction_identity; -+ } -+ -+ ComputeFragment accumulator; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulator.size(); ++i) { -+ accumulator[i] = params.reduction_identity; -+ } -+ -+ // Compute the coordinate of the first access -+ int64_t src_byte_offset = 0; -+ Coord coord; -+ -+ uint64_t linear_idx = (threadIdx.x + blockDim.x * threadIdx.z + blockDim.x * blockIdx.z * blockDim.z) * kVectorLength; -+ compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); -+ -+ // Load the first vector -+ SourceFragment source_fragment[kBatchSize]; -+ -+ bool not_done = true; -+ -+ // Iterate over vectors in a linearized reduction index space -+ while (not_done) { -+ -+ bool guards[kBatchSize]; -+ -+ // Issue a batch of loads -+ CUTLASS_PRAGMA_UNROLL -+ for (int b = 0; b < kBatchSize; ++b) { -+ -+ if (linear_idx < params.inner_count) { -+ source_fragment[b] = *reinterpret_cast(src_byte_ptr + src_byte_offset); -+ guards[b] = true; -+ } -+ else { -+ guards[b] = false; -+ not_done = false; -+ } -+ -+ linear_idx += (blockDim.z * gridDim.z * blockDim.x) * kVectorLength; -+ compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); -+ } -+ -+ // Perform a batch of reduction operations -+ CUTLASS_PRAGMA_UNROLL -+ for (int b = 0; b < kBatchSize; ++b) { -+ if (guards[b]) { -+ auto cvt = convert_source(source_fragment[b]); -+ -+ accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( -+ reduction_op, -+ accumulator, -+ cvt); -+ } -+ } -+ }; -+ -+ // -+ // Reduction of vectors to scalar -+ // -+ -+ ElementCompute reduced_accumulator = accumulator[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < kVectorLength; ++i) { -+ reduced_accumulator = reduction_op(reduced_accumulator, accumulator[i]); -+ } -+ -+ // -+ // Reduction within CTA across threadIdx.xz => threadIdx{.x = 0, .z = 0} -+ // -+ // This re-arranges data so threadIdx.y is effectively a row index and threadIdx.xz is a column -+ // -+ -+ int thread_count = blockDim.x * blockDim.z; -+ int thread_j = threadIdx.x + blockDim.x * threadIdx.z; -+ int thread_i = threadIdx.y; -+ -+ ElementCompute *frag_ptr = reinterpret_cast(threadblock_workspace) + thread_i * thread_count; -+ -+ frag_ptr[thread_j] = reduced_accumulator; -+ -+ // -+ // Reduce -+ // -+ CUTLASS_PRAGMA_NO_UNROLL -+ while (thread_count > 1) { -+ thread_count /= 2; -+ -+ __syncthreads(); -+ -+ if (thread_j < thread_count) { -+ ElementCompute other = frag_ptr[thread_j + thread_count]; -+ -+ reduced_accumulator = reduction_op(reduced_accumulator, other); -+ -+ frag_ptr[thread_j] = reduced_accumulator; -+ } -+ -+ __syncthreads(); -+ } -+ -+ -+ return reduced_accumulator; -+ } -+ -+public: -+ -+ /// Perform a reduction -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; -+ -+ char const * src_byte_ptr = reinterpret_cast(params.source); -+ char * dst_byte_ptr = nullptr; -+ -+ // If performing a reduction across CTAs, redirect output to device workspace -+ if (gridDim.z == 1) { -+ dst_byte_ptr = reinterpret_cast(params.destination); -+ } -+ else { -+ dst_byte_ptr = reinterpret_cast(params.device_workspace); -+ } -+ -+ uint64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; -+ -+ // Use modulo division to compute location -+ Coord outer_coord; -+ int64_t dst_byte_offset; -+ int64_t src_byte_offset; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ src_byte_offset, -+ idx_linear); -+ -+ if (gridDim.z == 1) { -+ -+ /// Complete the reduction with no workspace -+ while (idx_linear < params.outer_count) { -+ -+ ElementCompute result = reduce_indices_( -+ params, -+ shared_storage.workspace.data(), -+ src_byte_ptr + src_byte_offset, -+ coord_c); -+ -+ // Store the result after possible final reduction within the CTA -+ if (threadIdx.z == 0 && threadIdx.x == 0) { -+ -+ // Convert to output type and store -+ NumericConverter convert_output; -+ ElementOutput cvt = convert_output(result); -+ -+ *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = cvt; -+ } -+ -+ __syncthreads(); -+ -+ // Update indices and pointers -+ idx_linear += gridDim.y * blockDim.y; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ src_byte_offset, -+ idx_linear); -+ -+ } // while -+ } -+ else { -+ -+ /// Complete the reduction with workspace -+ while (idx_linear < params.outer_count) { -+ -+ ElementCompute result = reduce_indices_( -+ params, -+ shared_storage.workspace.data(), -+ src_byte_ptr + src_byte_offset, -+ coord_c); -+ -+ int64_t byte_offset = -+ blockIdx.z * params.workspace_stride + idx_linear * sizeof_bits::value / 8; -+ -+ // Store the result for final reduction -+ if (threadIdx.z == 0 && threadIdx.x == 0) { -+ *reinterpret_cast(dst_byte_ptr + byte_offset) = result; -+ } -+ -+ __syncthreads(); -+ -+ // Update indices and pointers -+ idx_linear += gridDim.y * blockDim.y; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ src_byte_offset, -+ idx_linear); -+ } // while -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to perform final reduction -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) -+ typename ElementOutput, ///< Data type of output tensor -+ typename ElementSource, ///< Data type of source tensor -+ typename ReductionOp, ///< Reduction operator -+ int VectorLength = 1, ///< Vector length for memory -+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+class TensorReductionAffineContiguousFinal { -+public: -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ -+ /// Shared memory -+ struct SharedStorage { }; -+ -+ /// Parameters structure -+ using Params = TensorReductionAffineContiguousParams< -+ Rank, -+ ReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ VectorLength, -+ ElementCompute, -+ Threads, -+ BatchSize -+ >; -+ -+private: -+ -+ /// Computes the coordinate and offset of a given linear index -+ CUTLASS_DEVICE -+ void compute_outer_coord_and_offset_( -+ Params const ¶ms, -+ Coord & coord, -+ int64_t &dst_offset, -+ uint64_t linear_idx) const { -+ -+ // Decompose into coordinate of rank -+ coord = CoordinateDecomposition(linear_idx, params.divmod); -+ -+ // Compute offsets using destination and source strides -+ dst_offset = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kReducedRank; ++i) { -+ dst_offset += params.dst_stride[i] * coord[i]; -+ } -+ } -+ -+ /// Reduces over the reduction indices -+ CUTLASS_DEVICE -+ ElementCompute reduce_indices_( -+ Params const ¶ms, -+ ElementCompute const *device_workspace) { -+ -+ ReductionOp reduction_op(params.reduction_op); -+ char const *src_byte_ptr = reinterpret_cast(device_workspace); -+ -+ // Accumulated output -+ ElementCompute accumulator = params.reduction_identity; -+ -+ for (int iter = 0; iter < params.workspace_count; ++iter) { -+ ElementCompute workspace_item = *reinterpret_cast(src_byte_ptr); -+ -+ accumulator = reduction_op(accumulator, workspace_item); -+ -+ src_byte_ptr += params.workspace_stride; -+ } -+ -+ return accumulator; -+ } -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Perform a reduction -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ uint64_t idx_linear = blockIdx.x * blockDim.x + threadIdx.x; -+ -+ char * dst_byte_ptr = reinterpret_cast(params.destination); -+ -+ // Use modulo division to compute location -+ Coord outer_coord; -+ int64_t dst_byte_offset; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ idx_linear); -+ -+ /// Complete the reduction -+ while (idx_linear < params.outer_count) { -+ -+ ElementCompute result = reduce_indices_(params, params.device_workspace + idx_linear); -+ -+ // Convert to output type and store -+ NumericConverter convert_output; -+ -+ *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = convert_output(result); -+ -+ // Update indices and pointers -+ idx_linear += gridDim.x * blockDim.x; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ idx_linear); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace reduction -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h b/3rdparty/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h -new file mode 100644 -index 0000000..9d5b045 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h -@@ -0,0 +1,641 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over one or more ranks of an affine tensor -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+/// Parameters structure -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) -+ typename ElementOutput, ///< Data type of output tensor -+ typename ElementSource, ///< Data type of source tensor -+ typename ReductionOp, ///< Reduction operator -+ int VectorLength = 1, ///< Vector length for memory -+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+struct TensorReductionAffineStridedParams { -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ -+ Coord extent; /// Extent of source tensor -+ FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank -+ int64_t dst_stride[kReducedRank - 1]; /// stride (units of bytes) - I, J -+ int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K -+ int64_t workspace_stride; /// stride (units of bytes) between workspace -+ int64_t workspace_outer_stride; /// stride (units of bytes) between 'rows' of the workspace -+ int workspace_count; /// number of workspaces -+ -+ uint64_t inner_count; /// Number of elements in reduced index space -+ uint64_t outer_count; /// Number of elements in outer index space -+ -+ ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank -+ ElementSource const * source; /// Poitner to source pointer of rank kRank -+ ReductionOp reduction_op; /// Reduction operator -+ ElementCompute reduction_identity; /// Identity element for reduction operator -+ ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorReductionAffineStridedParams() { -+ -+ } -+ -+ /// Ctor -+ TensorReductionAffineStridedParams( -+ Coord extent_, ///< Extent of source tensor -+ ElementOutput * dst_ptr_, ///< Output tensor data -+ int64_t dst_stride_[], ///< Stride (units of elements) -+ ElementSource const * src_ptr_, ///< Source tensor data -+ int64_t src_stride_[], ///< Stride (units of elements) -+ ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions -+ int64_t workspace_stride_, ///< Stride between workspaces -+ int workspace_count_, ///< Number of workspaces -+ ReductionOp reduction_op_, ///< Reduction operator -+ ElementCompute reduction_identity_ = ElementCompute() ///< Identity element for reduction operator -+ ): -+ extent(extent_), -+ inner_count(1), -+ outer_count(1), -+ destination(dst_ptr_), -+ source(src_ptr_), -+ device_workspace(device_workspace_), -+ workspace_outer_stride(0), -+ workspace_stride(workspace_stride_), -+ workspace_count(workspace_count_), -+ reduction_op(reduction_op_), -+ reduction_identity(reduction_identity_) { -+ -+ // Initialize divisors for fast div-mod -+ for (int p = 1; p < kRank; ++p) { -+ divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); -+ } -+ -+ int input_size_bits = sizeof_bits::value; -+ int output_size_bits = sizeof_bits::value; -+ -+ workspace_outer_stride = workspace_stride * workspace_count; -+ -+ // Compute strides in units of bytes -+ for (int p = 0; p < kReducedRank - 1; ++p) { -+ dst_stride[p] = dst_stride_[p] * output_size_bits / 8; -+ } -+ -+ for (int p = 0; p < kRank - 1; ++p) { -+ src_stride[p] = src_stride_[p] * input_size_bits / 8; -+ } -+ -+ // Compute number of elements in strided ranks -+ for (int p = 0; p < kReducedRank - 1; ++p) { -+ outer_count *= uint64_t(extent[p]); -+ } -+ -+ for (int p = 0; p < kInnerRank; ++p) { -+ inner_count *= uint64_t(extent[kReducedRank + p - 1]); -+ } -+ } -+}; -+ -+/// Kernel to reduce a tensor with affine layout over a set of ranks *EXCLUDING* the contiguous -+/// rank. This leads to favorable vectorized memory accesses over the contiguous rank. -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) -+ typename ElementOutput, ///< Data type of output tensor -+ typename ElementSource, ///< Data type of source tensor -+ typename ReductionOp, ///< Reduction operator -+ int VectorLength = 1, ///< Vector length for memory -+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+class TensorReductionAffineStrided { -+public: -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ using ComputeFragment = Array; -+ using SourceFragment = AlignedArray; -+ using OutputFragment = AlignedArray; -+ -+ /// Shared memory allocation used for reduction within the CTA -+ struct SharedStorage { -+ Array workspace; -+ }; -+ -+ /// Parameters structure -+ using Params = TensorReductionAffineStridedParams< -+ Rank, -+ ReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ VectorLength, -+ ElementCompute, -+ Threads, -+ BatchSize -+ >; -+ -+private: -+ -+ /// Computes the coordinate and offset of a given linear index -+ CUTLASS_DEVICE -+ void compute_inner_coord_and_offset_( -+ Params const ¶ms, -+ Coord & coord, -+ int64_t &src_offset, -+ uint64_t linear_idx) const { -+ -+ // Decompose into coordinate -+ coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kReducedRank - 1]); -+ -+ // Compute linear offset -+ src_offset = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kInnerRank; ++i) { -+ src_offset += params.src_stride[kReducedRank + i - 1] * coord[i]; -+ } -+ } -+ -+ /// Computes the coordinate and offset of a given linear index -+ CUTLASS_DEVICE -+ void compute_outer_coord_and_offset_( -+ Params const ¶ms, -+ Coord & coord, -+ int64_t &dst_offset, -+ int64_t &src_offset, -+ uint64_t linear_idx) const { -+ -+ // Decompose linear coordinate -+ coord = CoordinateDecomposition(linear_idx, params.divmod); -+ -+ // Compute offset into tensors -+ dst_offset = 0; -+ src_offset = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kReducedRank - 1; ++i) { -+ dst_offset += params.dst_stride[i] * coord[i]; -+ src_offset += params.src_stride[i] * coord[i]; -+ } -+ } -+ -+ /// Reduces over the reduction indices -+ CUTLASS_DEVICE -+ ComputeFragment reduce_indices_( -+ Params const ¶ms, -+ ElementCompute *threadblock_workspace, -+ char const *src_byte_ptr) { -+ -+ NumericArrayConverter convert_source; -+ ReductionOp reduction_op(params.reduction_op); -+ -+ // Accumulated output -+ ComputeFragment identity_frag; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < identity_frag.size(); ++i) { -+ identity_frag[i] = params.reduction_identity; -+ } -+ -+ if (!params.inner_count) { -+ return identity_frag; -+ } -+ -+ ComputeFragment accumulator = identity_frag; -+ -+ // Compute the coordinate of the first access -+ int64_t src_byte_offset = 0; -+ Coord coord; -+ -+ uint64_t linear_idx = threadIdx.z + blockIdx.z * blockDim.z; -+ compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); -+ -+ // Load the first vector -+ SourceFragment source_fragment[kBatchSize]; -+ -+ bool not_done = true; -+ -+ // Iterate over vectors in a linearized reduction index space -+ while (not_done) { -+ -+ bool guards[kBatchSize]; -+ -+ // Issue a batch of loads -+ CUTLASS_PRAGMA_UNROLL -+ for (int b = 0; b < kBatchSize; ++b) { -+ -+ if (linear_idx < params.inner_count) { -+ source_fragment[b] = *reinterpret_cast(src_byte_ptr + src_byte_offset); -+ guards[b] = true; -+ } -+ else { -+ guards[b] = false; -+ not_done = false; -+ } -+ -+ linear_idx += blockDim.z * gridDim.z; -+ compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); -+ } -+ -+ // Perform a batch of reduction operations -+ CUTLASS_PRAGMA_UNROLL -+ for (int b = 0; b < kBatchSize; ++b) { -+ if (guards[b]) { -+ -+ auto cvt = convert_source(source_fragment[b]); -+ -+ accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( -+ reduction_op, -+ accumulator, -+ cvt); -+ } -+ } -+ }; -+ -+ // Optional reduction within a CTA -+ if (blockDim.z > 1) { -+ -+ // Linearized thread ID -+ int thread_idx = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); -+ -+ // all threads store to workspace -+ ComputeFragment *frag_ptr = reinterpret_cast(threadblock_workspace); -+ -+ frag_ptr[thread_idx] = accumulator; -+ -+ __syncthreads(); -+ -+ if (threadIdx.z == 0) { -+ // Load all additional block indices -+ for (int z = 1; z < blockDim.z; ++z) { -+ ComputeFragment frag = frag_ptr[thread_idx + z * blockDim.x * blockDim.y]; -+ -+ accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( -+ reduction_op, -+ accumulator, -+ frag); -+ } -+ } -+ -+ __syncthreads(); -+ } -+ -+ return accumulator; -+ } -+ -+public: -+ -+ /// Perform a reduction -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; -+ -+ char const * src_byte_ptr = reinterpret_cast(params.source + coord_c); -+ char * dst_byte_ptr = nullptr; -+ -+ // If performing a reduction across CTAs, redirect output to device workspace -+ if (gridDim.z == 1) { -+ dst_byte_ptr = reinterpret_cast(params.destination + coord_c); -+ } -+ else { -+ dst_byte_ptr = reinterpret_cast(params.device_workspace + coord_c); -+ } -+ -+ // If the C index is out of bounds, exit -+ if (coord_c >= params.extent[kRank - 1]) { -+ return; -+ } -+ -+ int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; -+ -+ // Use modulo division to compute location -+ Coord outer_coord; -+ int64_t dst_byte_offset; -+ int64_t src_byte_offset; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ src_byte_offset, -+ idx_linear); -+ -+ if (gridDim.z == 1) { -+ -+ /// Complete the reduction with no workspace -+ while (idx_linear < params.outer_count) { -+ -+ ComputeFragment result; -+ -+ result = reduce_indices_( -+ params, -+ shared_storage.workspace.data(), -+ src_byte_ptr + src_byte_offset); -+ -+ // Store the result after possible final reduction within the CTA -+ if (threadIdx.z == 0) { -+ -+ // Convert to output type and store -+ NumericArrayConverter convert_output; -+ auto cvt = convert_output(result); -+ -+ *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = -+ reinterpret_cast(cvt); -+ } -+ -+ // Update indices and pointers -+ idx_linear += gridDim.y * blockDim.y; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ src_byte_offset, -+ idx_linear); -+ -+ } // while -+ } -+ else { -+ -+ /// Complete the reduction with a device workspace -+ while (idx_linear < params.outer_count) { -+ -+ ComputeFragment result; -+ -+ result = reduce_indices_( -+ params, -+ shared_storage.workspace.data(), -+ src_byte_ptr + src_byte_offset); -+ -+ // Store the result after possible final reduction within the CTA -+ if (threadIdx.z == 0) { -+ -+ int64_t byte_offset = -+ blockIdx.z * params.workspace_stride + idx_linear * params.workspace_outer_stride; -+ -+ // No conversion - store in compute type -+ *reinterpret_cast(dst_byte_ptr + byte_offset) = -+ reinterpret_cast(result); -+ } -+ -+ // Update indices and pointers -+ idx_linear += gridDim.y * blockDim.y; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ src_byte_offset, -+ idx_linear); -+ -+ } // while (outer index) -+ } // if () -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to perform final reduction -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) -+ typename ElementOutput, ///< Data type of output tensor -+ typename ElementSource, ///< Data type of source tensor -+ typename ReductionOp, ///< Reduction operator -+ int VectorLength = 1, ///< Vector length for memory -+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+class TensorReductionAffineStridedFinal { -+public: -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ using ComputeFragment = Array; -+ using SourceFragment = AlignedArray; -+ using OutputFragment = AlignedArray; -+ -+ /// Shared memory -+ struct SharedStorage { }; -+ -+ /// Parameters structure -+ using Params = TensorReductionAffineStridedParams< -+ Rank, -+ ReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ VectorLength, -+ ElementCompute, -+ Threads, -+ BatchSize -+ >; -+ -+private: -+ -+ /// Computes the coordinate and offset of a given linear index -+ CUTLASS_DEVICE -+ void compute_outer_coord_and_offset_( -+ Params const ¶ms, -+ Coord & coord, -+ int64_t &dst_offset, -+ uint64_t linear_idx) const { -+ -+ // Decompose linear index -+ coord = CoordinateDecomposition(linear_idx, params.divmod); -+ -+ // Compute tensor offset -+ dst_offset = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kReducedRank - 1; ++i) { -+ dst_offset += params.dst_stride[i] * coord[i]; -+ } -+ } -+ -+ /// Reduces over the reduction indices -+ CUTLASS_DEVICE -+ ComputeFragment reduce_indices_( -+ Params const ¶ms, -+ char *src_byte_ptr) { -+ -+ ReductionOp reduction_op(params.reduction_op); -+ -+ // Accumulated output -+ ComputeFragment identity_frag; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < identity_frag.size(); ++i) { -+ identity_frag[i] = params.reduction_identity; -+ } -+ -+ ComputeFragment accumulator = identity_frag; -+ ComputeFragment workspace_fragments[kBatchSize]; -+ -+ // Partially unrolled loop -+ for (int idx = 0; idx < params.workspace_count; idx += kBatchSize) { -+ -+ // Issue a batch of loads -+ CUTLASS_PRAGMA_UNROLL -+ for (int b = 0; b < kBatchSize; ++b) { -+ if (idx + b < params.workspace_count) { -+ workspace_fragments[b] = -+ *reinterpret_cast(src_byte_ptr); -+ } -+ else { -+ workspace_fragments[b] = identity_frag; -+ } -+ src_byte_ptr += + params.workspace_stride; -+ } -+ -+ // Perform a reduction -+ CUTLASS_PRAGMA_UNROLL -+ for (int b = 0; b < kBatchSize; ++b) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kVectorLength; ++i) { -+ accumulator[i] = reduction_op(accumulator[i], workspace_fragments[b][i]); -+ } -+ } -+ } -+ -+ return accumulator; -+ } -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Perform a reduction -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; -+ -+ char * src_byte_ptr = reinterpret_cast(params.device_workspace + coord_c); -+ char * dst_byte_ptr = reinterpret_cast(params.destination + coord_c); -+ -+ // If the C index is out of bounds, exit -+ if (coord_c >= params.extent[kRank - 1]) { -+ return; -+ } -+ -+ int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; -+ -+ // Use modulo division to compute location -+ Coord outer_coord; -+ int64_t dst_byte_offset; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ idx_linear); -+ -+ /// Complete the reduction -+ while (idx_linear < params.outer_count) { -+ -+ int64_t src_byte_offset = idx_linear * params.workspace_outer_stride; -+ -+ ComputeFragment result = reduce_indices_( -+ params, -+ src_byte_ptr + src_byte_offset); -+ -+ // Convert to output type and store -+ NumericArrayConverter convert_output; -+ auto cvt = convert_output(result); -+ -+ *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = -+ reinterpret_cast(cvt); -+ -+ // Update indices and pointers -+ idx_linear += gridDim.y * blockDim.y; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ idx_linear); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace reduction -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/reduction/thread/reduce.h b/3rdparty/cutlass/include/cutlass/reduction/thread/reduce.h -new file mode 100644 -index 0000000..4f6e180 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/thread/reduce.h -@@ -0,0 +1,234 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic thread level reduction with specializations for Array. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/half.h" -+#include "cutlass/functional.h" -+ -+namespace cutlass { -+namespace reduction { -+namespace thread { -+ -+/// Structure to compute the thread level reduction -+template -+struct Reduce; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial Specialization of Reduce for "plus" (a functional operator) -+template -+struct Reduce< plus, T > { -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs, T const &rhs) const { -+ plus _op; -+ return _op(lhs, rhs); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization of Reduce for Array -+template -+struct Reduce < plus, Array> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &in) const { -+ -+ Array result; -+ Reduce< plus, T > scalar_reduce; -+ result.clear(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (auto i = 0; i < N; ++i) { -+ result[0] = scalar_reduce(result[0], in[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specializations of Reduce for Array -+template -+struct Reduce < plus, Array > { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &input) { -+ -+ Array result; -+ -+ // If there is only 1 element - there is nothing to reduce -+ if( N ==1 ){ -+ -+ result[0] = input.front(); -+ -+ } else { -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600) -+ -+ __half result_d; -+ Array const *in_ptr_half = reinterpret_cast const *>(&input); -+ Array const *in_ptr_half2 = reinterpret_cast const *>(&input); -+ __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2); -+ -+ // Set initial result = first half2, in case N==2 -+ __half2 tmp_result = x_in_half2[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < N/2; ++i) { -+ -+ tmp_result = __hadd2(x_in_half2[i], tmp_result); -+ -+ } -+ -+ result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result)); -+ -+ // One final step is needed for odd "N" (to add the (N-1)th element) -+ if( N%2 ){ -+ -+ __half last_element; -+ Array tmp_last; -+ Array *tmp_last_ptr = &tmp_last; -+ tmp_last_ptr[0] = in_ptr_half[N-1]; -+ last_element = reinterpret_cast<__half const &>(tmp_last); -+ -+ result_d = __hadd(result_d, last_element); -+ -+ } -+ -+ Array *result_ptr = &result; -+ *result_ptr = reinterpret_cast &>(result_d); -+ -+ #else -+ -+ Reduce< plus, half_t > scalar_reduce; -+ result.clear(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (auto i = 0; i < N; ++i) { -+ -+ result[0] = scalar_reduce(result[0], input[i]); -+ -+ } -+ -+ #endif -+ } -+ -+ return result; -+ -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specializations of Reduce for AlignedArray -+template -+struct Reduce < plus, AlignedArray > { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(AlignedArray const &input) { -+ -+ Array result; -+ -+ // If there is only 1 element - there is nothing to reduce -+ if( N ==1 ){ -+ -+ result[0] = input.front(); -+ -+ } else { -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600) -+ -+ __half result_d; -+ AlignedArray const *in_ptr_half = reinterpret_cast const *>(&input); -+ AlignedArray const *in_ptr_half2 = reinterpret_cast const *>(&input); -+ __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2); -+ -+ // Set initial result = first half2, in case N==2 -+ __half2 tmp_result = x_in_half2[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < N/2; ++i) { -+ -+ tmp_result = __hadd2(x_in_half2[i], tmp_result); -+ -+ } -+ -+ result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result)); -+ -+ // One final step is needed for odd "N" (to add the (N-1)th element) -+ if( N%2 ){ -+ -+ __half last_element; -+ AlignedArray tmp_last; -+ AlignedArray *tmp_last_ptr = &tmp_last; -+ tmp_last_ptr[0] = in_ptr_half[N-1]; -+ last_element = reinterpret_cast<__half const &>(tmp_last); -+ -+ result_d = __hadd(result_d, last_element); -+ -+ } -+ -+ Array *result_ptr = &result; -+ *result_ptr = reinterpret_cast &>(result_d); -+ -+ #else -+ -+ Reduce< plus, half_t > scalar_reduce; -+ result.clear(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (auto i = 0; i < N; ++i) { -+ -+ result[0] = scalar_reduce(result[0], input[i]); -+ -+ } -+ -+ #endif -+ } -+ -+ return result; -+ -+ } -+}; -+} -+} -+} -diff --git a/3rdparty/cutlass/include/cutlass/reduction/thread/reduction_operators.h b/3rdparty/cutlass/include/cutlass/reduction/thread/reduction_operators.h -new file mode 100644 -index 0000000..d54bcc0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/thread/reduction_operators.h -@@ -0,0 +1,235 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over densely packed tensors in global memory -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mixed-precision reduction -+template < -+ typename ElementAccumulator_, -+ typename Element_, -+ int Count = 1 -+> -+struct ReduceAdd { -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementAccumulator = ElementAccumulator_; -+ using Element = Element_; -+ static int const kCount = Count; -+ -+ using FragmentAccumulator = cutlass::Array; -+ using FragmentElement = cutlass::Array; -+ -+ struct Params { }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ ReduceAdd(Params params_ = Params()): params(params_) { } -+ -+ /// Operator -+ CUTLASS_HOST_DEVICE -+ FragmentAccumulator operator()( -+ FragmentAccumulator accumulator, -+ FragmentElement element) const { -+ -+ plus op; -+ -+ NumericArrayConverter< -+ ElementAccumulator, -+ Element, -+ kCount, -+ PreferredRoundingMode::kRound> converter; -+ -+ return op(accumulator, converter(element)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Special handling for binary operators -+template -+struct VectorizeArrayOperation { -+ -+ using ValueType = Array; -+ -+ CUTLASS_HOST_DEVICE -+ ValueType operator()( -+ ReductionOp const &reduction_op, -+ ValueType const &lhs, -+ ValueType const &rhs) const { -+ -+ ValueType result; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = reduction_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct ReduceArrayOperation { -+ -+ using ArrayType = Array; -+ -+ CUTLASS_HOST_DEVICE -+ Element operator()( -+ ReductionOp const &reduction_op, -+ ArrayType const &array) const { -+ -+ Element item = reduction_op(array[0], array[1]); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 2; i < N; ++i) { -+ item = reduction_op(item, array[i]); -+ } -+ -+ return item; -+ } -+}; -+ -+template -+struct ReduceArrayOperation, uint1b_t, N> { -+ -+ using ArrayType = Array; -+ -+ CUTLASS_HOST_DEVICE -+ uint1b_t operator()( -+ logical_and const &reduction_op, -+ ArrayType const &array) const { -+ -+ uint8_t const *ptr = reinterpret_cast(&array); -+ bool item = false; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int byte = 0; byte < (N + 7) / 8; ++byte) { -+ uint8_t bits = ptr[byte]; -+ item = (item || !bits); -+ } -+ -+ return uint1b_t(!item); -+ } -+}; -+ -+template -+struct ReduceArrayOperation, uint1b_t, N> { -+ -+ using ArrayType = Array; -+ -+ CUTLASS_HOST_DEVICE -+ uint1b_t operator()( -+ logical_and const &reduction_op, -+ ArrayType const &array) const { -+ -+ uint8_t const *ptr = reinterpret_cast(&array); -+ bool item = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int byte = 0; byte < (N + 7) / 8; ++byte) { -+ uint8_t bits = ptr[byte]; -+ item = (item || bits); -+ } -+ -+ return uint1b_t(item); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper function to infer template argument types -+template -+CUTLASS_HOST_DEVICE -+Array ApplyArrayOperator( -+ ReductionOp const &reduction_op, -+ Array const &lhs, -+ Array const &rhs) { -+ -+ VectorizeArrayOperation vectorize_op; -+ -+ return vectorize_op(reduction_op, lhs, rhs); -+} -+ -+/// Helper to reduce an array -+template -+Element ReduceArray(ReductionOp const &reduction_op, Array const &array) { -+ ReduceArrayOperation reduce_array_op; -+ -+ return reduce_array_op(reduction_op, array); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace reduction -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/reduction/threadblock_swizzle.h b/3rdparty/cutlass/include/cutlass/reduction/threadblock_swizzle.h -new file mode 100644 -index 0000000..5dd6e44 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/threadblock_swizzle.h -@@ -0,0 +1,67 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+* -+**************************************************************************************************/ -+/*! \file -+\brief Defies functors for mapping blockIdx to partitions of the batched reduction computation. -+*/ -+#pragma once -+#include "cutlass/coord.h" -+ -+namespace cutlass { -+namespace reduction { -+struct DefaultBlockSwizzle { -+ /// Ctor -+ CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {} -+ -+ /// Swizzle the block index. -+ CUTLASS_DEVICE dim3 swizzle() { return blockIdx; } -+ -+ /// -+ CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size, -+ Coord<3> const &OutputTile) { -+ assert(OutputTile[0] == 1 && OutputTile[1] == 1); -+ assert((problem_size[0] * problem_size[1] * problem_size[2]) % OutputTile[2] == 0); -+ dim3 grid; -+ grid.x = problem_size[0] * problem_size[1] * problem_size[2] -+ / OutputTile[2] ; -+ return grid; -+ } -+ -+ /// -+ CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &SubTile) { -+ assert(SubTile[0] == 1 && SubTile[1] == 1); -+ dim3 block = swizzle(); -+ Coord<3> threadblock_offset = -+ make_Coord(0, 0, block.x * SubTile[2]); -+ return threadblock_offset; -+ } -+}; -+} // namespace reduction -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/relatively_equal.h b/3rdparty/cutlass/include/cutlass/relatively_equal.h -new file mode 100644 -index 0000000..4736e28 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/relatively_equal.h -@@ -0,0 +1,203 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Performs comparison between two elements with support for floating-point comparisons. -+*/ -+ -+#pragma once -+ -+#include "numeric_types.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+CUTLASS_HOST_DEVICE -+bool relatively_equal(T a, T b, T epsilon, T nonzero_floor); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+// This floating-point comparison function implements the method described in -+// -+// https://floating-point-gui.de/errors/comparison/ -+// -+template -+CUTLASS_HOST_DEVICE -+bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) { -+ -+ using std::abs; -+ -+ T abs_A = abs(a); -+ T abs_B = abs(b); -+ T diff = abs(a - b); -+ T zero = T(0); -+ -+ if (a == b) { -+ return true; -+ } -+ else if (a == zero || b == zero || diff < nonzero_floor) { -+ return diff < epsilon * nonzero_floor; -+ } -+ -+ return diff < epsilon * (abs_A + abs_B); -+} -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint1b_t a, uint1b_t b, uint1b_t, uint1b_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(int2b_t a, int2b_t b, int2b_t, int2b_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint2b_t a, uint2b_t b, uint2b_t, uint2b_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(int4b_t a, int4b_t b, int4b_t, int4b_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint4b_t a, uint4b_t b, uint4b_t, uint4b_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(int8_t a, int8_t b, int8_t, int8_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint8_t a, uint8_t b, uint8_t, uint8_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(int16_t a, int16_t b, int16_t, int16_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint16_t a, uint16_t b, uint16_t, uint16_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(int32_t a, int32_t b, int32_t, int32_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint32_t a, uint32_t b, uint32_t, uint32_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(int64_t a, int64_t b, int64_t, int64_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint64_t a, uint64_t b, uint64_t, uint64_t) { -+ return (a == b); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(half_t a, half_t b, half_t epsilon, half_t nonzero_floor) { -+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal( -+ bfloat16_t a, -+ bfloat16_t b, -+ bfloat16_t epsilon, -+ bfloat16_t nonzero_floor) { -+ -+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal( -+ tfloat32_t a, -+ tfloat32_t b, -+ tfloat32_t epsilon, -+ tfloat32_t nonzero_floor) { -+ -+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(float a, float b, float epsilon, float nonzero_floor) { -+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -+} -+ -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(double a, double b, double epsilon, double nonzero_floor) { -+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/semaphore.h b/3rdparty/cutlass/include/cutlass/semaphore.h -new file mode 100644 -index 0000000..ed8a179 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/semaphore.h -@@ -0,0 +1,122 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implementation of a CTA-wide semaphore for inter-CTA synchronization. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// CTA-wide semaphore for inter-CTA synchronization. -+class Semaphore { -+public: -+ -+ int *lock; -+ bool wait_thread; -+ int state; -+ -+public: -+ -+ /// Implements a semaphore to wait for a flag to reach a given value -+ CUTLASS_HOST_DEVICE -+ Semaphore(int *lock_, int thread_id): -+ lock(lock_), -+ wait_thread(thread_id < 0 || thread_id == 0), -+ state(-1) { -+ -+ } -+ -+ /// Permit fetching the synchronization mechanism early -+ CUTLASS_DEVICE -+ void fetch() { -+ if (wait_thread) { -+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 -+ asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); -+ #else -+ asm volatile ("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); -+ #endif -+ } -+ } -+ -+ /// Gets the internal state -+ CUTLASS_DEVICE -+ int get_state() const { -+ return state; -+ } -+ -+ /// Waits until the semaphore is equal to the given value -+ CUTLASS_DEVICE -+ void wait(int status = 0) { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ while( __syncthreads_and(state != status) ) { -+ fetch(); -+ } -+ -+ __syncthreads(); -+#endif -+ } -+ -+ /// Updates the lock with the given result -+ CUTLASS_DEVICE -+ void release(int status = 0) { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ __syncthreads(); -+ -+ if (wait_thread) { -+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 -+ asm volatile ("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); -+ #else -+ asm volatile ("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); -+ #endif -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/subbyte_reference.h b/3rdparty/cutlass/include/cutlass/subbyte_reference.h -new file mode 100644 -index 0000000..58c460a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/subbyte_reference.h -@@ -0,0 +1,637 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Provides a mechanism for packing and unpacking elements smaller than one byte -+*/ -+#pragma once -+ -+#include "cutlass/numeric_types.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This class provides a mechanism for packing and unpacking elements smaller than one byte. It -+/// assumes these sub-byte elements are packed in a traditional C++ numeric type. -+/// -+/// The intended application is to provide a mechanism to indirectly reference elements in -+/// memory or Array<> objects whose addresses cannot otherwise be taken since they are smaller -+/// than one byte. -+/// -+/// Supports basic pointer arithmetic: -+/// -+/// Example: -+/// -+/// int4b_t *ptr = ...; -+/// -+/// SubbyteReference ref = ptr; -+/// ref += 15; -+/// -+/// int4b_t x = ref; // load an int4b_t -+/// ref = x + 2_s4; // perform arithmetic on int4b_t and then store -+/// -+template < -+ typename Element_, /// CUTLASS numeric element type. -+ typename Storage_ = uint8_t /// Underlying storage type. Must be able to hold an integer -+ /// number of objects of type Element. -+> -+class ConstSubbyteReference { -+public: -+ -+ using Element = Element_; -+ using Storage = Storage_; -+ using StoragePointer = Storage const *; -+ -+ static_assert(sizeof_bits::value <= sizeof_bits::value, -+ "Size of Element must not be greater than Storage."); -+ -+ static_assert(!(sizeof_bits::value % sizeof_bits::value), -+ "Storage must be divisible by Element"); -+ -+private: -+ -+ ///! Number of elements per storage vector -+ int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; -+ -+ ///! Bit mask -+ Storage const kMask = -+ ((sizeof_bits::value < sizeof_bits::value) ? -+ (Storage(1) << sizeof_bits::value) - Storage(1) : -+ ~Storage(0)); -+ -+private: -+ -+ /// Pointer to array containing element -+ StoragePointer ptr_; -+ -+ /// Offset (in units of elements) from pointer. -+ /// -+ /// Invariant: must always be in range [0, kElementsPerVector) -+ int offset_; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference(): ptr_(nullptr), offset_(0) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference( -+ Element const *ptr, /// pointer to memory -+ int64_t offset /// logical offset in units of Element -+ ): -+ ptr_(reinterpret_cast(ptr)), -+ offset_(0) { -+ -+ int64_t offset_in_vectors = offset / kElementsPerVector; -+ int64_t offset_in_elements = offset % kElementsPerVector; -+ -+ ptr_ += offset_in_vectors; -+ offset_ = int(offset_in_elements); -+ } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference( -+ Element *ptr = nullptr -+ ): ConstSubbyteReference(ptr, 0) { } -+ -+ /// Gets storage pointer -+ CUTLASS_HOST_DEVICE -+ StoragePointer storage_pointer() const { -+ return ptr_; -+ } -+ -+ /// Gets element offset within storage vector -+ CUTLASS_HOST_DEVICE -+ int element_offset() const { -+ return offset_; -+ } -+ -+ /// Unpacks an element from memory -+ CUTLASS_HOST_DEVICE -+ Element get() const { -+ Storage item = Storage((*ptr_ >> (offset_ * sizeof_bits::value)) & kMask); -+ return reinterpret_cast(item); -+ } -+ -+ /// Unpacks an element from memory -+ CUTLASS_HOST_DEVICE -+ operator Element() const { -+ return get(); -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference &operator+=(int offset) { -+ -+ offset += offset_; -+ -+ int offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = offset % kElementsPerVector; -+ -+ ptr_ += offset_in_vectors; -+ offset_ = offset_in_elements; -+ -+ return *this; -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference &operator+=(long long offset) { -+ -+ offset += offset_; -+ -+ long long offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = int(offset % kElementsPerVector); -+ -+ ptr_ += offset_in_vectors; -+ offset_ = offset_in_elements; -+ -+ return *this; -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference &operator-=(int offset) { -+ -+ int offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = offset % kElementsPerVector; -+ -+ ptr_ -= offset_in_vectors; -+ offset_ -= offset_in_elements; -+ -+ if (offset_ < 0) { -+ offset_ += kElementsPerVector; -+ --ptr_; -+ } -+ -+ return *this; -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference &operator-=(long long offset) { -+ -+ long long offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = int(offset % kElementsPerVector); -+ -+ ptr_ -= offset_in_vectors; -+ offset_ -= offset_in_elements; -+ -+ if (offset_ < 0) { -+ offset_ += kElementsPerVector; -+ --ptr_; -+ } -+ -+ return *this; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference operator+(int offset) const { -+ -+ ConstSubbyteReference ref(ptr_, offset_); -+ ref += offset; -+ -+ return ref; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference operator+(long long offset) const { -+ -+ ConstSubbyteReference ref(ptr_, offset_); -+ ref += offset; -+ -+ return ref; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference operator-(int offset) const { -+ -+ ConstSubbyteReference ref(ptr_, offset_); -+ ref -= offset; -+ -+ return ref; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference operator-=(long long offset) const { -+ -+ ConstSubbyteReference ref(ptr_, offset_); -+ ref -= offset; -+ -+ return ref; -+ } -+ -+ /// Computes the difference in elements between references -+ CUTLASS_HOST_DEVICE -+ ptrdiff_t operator-(ConstSubbyteReference ref) const { -+ return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); -+ } -+ -+ /// Explicit cast to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(get()); -+ } -+ -+ /// Explicit cast to signed 64-bit integer -+ CUTLASS_HOST_DEVICE -+ explicit operator int64_t() const { -+ return int64_t(get()); -+ } -+ -+ /// Explicit cast to unsigned 64-bit integer -+ CUTLASS_HOST_DEVICE -+ explicit operator uint64_t() const { -+ return uint64_t(get()); -+ } -+ -+ /// Explicit cast to float -+ CUTLASS_HOST_DEVICE -+ explicit operator float() const { -+ return float(get()); -+ } -+ -+ /// Explicit cast to double -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(get()); -+ } -+}; -+ -+template < -+ typename Element_, /// CUTLASS numeric element type. -+ typename Storage_ = /// Underlying storage type. Must be able to hold an integer -+ /// number of objects of type Element. -+ -+#if defined(__CUDA_ARCH__) /// Default size depends on width of atomicCas() overloads. -+ #if (__CUDA_ARCH__ >= 700) /// -+ uint16_t -+ #else -+ uint32_t -+ #endif -+#else -+ uint8_t -+#endif -+> -+class SubbyteReference { -+public: -+ -+ using Element = Element_; -+ using Storage = Storage_; -+ using StoragePointer = Storage *; -+ -+ static_assert(sizeof_bits::value <= sizeof_bits::value, -+ "Size of Element must not be greater than Storage."); -+ -+ static_assert(!(sizeof_bits::value % sizeof_bits::value), -+ "Storage must be divisible by Element"); -+ -+private: -+ -+ ///! Number of elements per storage vector -+ int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; -+ -+ ///! Bit mask -+ Storage const kMask = -+ ((sizeof_bits::value < sizeof_bits::value) ? -+ (Storage(1) << sizeof_bits::value) - Storage(1) : -+ ~Storage(0)); -+ -+private: -+ -+ /// Pointer to array containing element -+ StoragePointer ptr_; -+ -+ /// Offset (in units of elements) from pointer. -+ /// -+ /// Invariant: must always be in range [0, kElementsPerVector) -+ int offset_; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ SubbyteReference(): ptr_(nullptr), offset_(0) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ SubbyteReference( -+ Element *ptr, /// pointer to memory -+ int64_t offset /// logical offset in units of Element -+ ): -+ ptr_(reinterpret_cast(ptr)), -+ offset_(0) { -+ -+ int64_t offset_in_vectors = offset / kElementsPerVector; -+ int64_t offset_in_elements = offset % kElementsPerVector; -+ -+ ptr_ += offset_in_vectors; -+ offset_ = int(offset_in_elements); -+ } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ SubbyteReference( -+ Element *ptr = nullptr -+ ): SubbyteReference(ptr, 0) { } -+ -+ /// Gets storage pointer -+ CUTLASS_HOST_DEVICE -+ StoragePointer storage_pointer() const { -+ return ptr_; -+ } -+ -+ /// Gets storage pointer -+ CUTLASS_HOST_DEVICE -+ Element * operator&() const { -+ return reinterpret_cast(ptr_); -+ } -+ -+ /// Gets element offset within storage vector -+ CUTLASS_HOST_DEVICE -+ int element_offset() const { -+ return offset_; -+ } -+ -+ /// Unpacks an element from memory -+ CUTLASS_HOST_DEVICE -+ Element get() const { -+ Storage item = Storage((*ptr_ >> (offset_ * sizeof_bits::value)) & kMask); -+ return reinterpret_cast(item); -+ } -+ -+ /// Stores an element to memory -+ CUTLASS_HOST_DEVICE -+ SubbyteReference & set(Element const &x) { -+ -+ Storage item = (reinterpret_cast(x) & kMask); -+ Storage kUpdateMask = Storage(~(kMask << (offset_ * cutlass::sizeof_bits::value))); -+ Storage new_bits = Storage(item << (offset_ * cutlass::sizeof_bits::value)); -+ -+#if defined(__CUDA_ARCH__) -+ -+ // -+ // Homebrew read-modify-write -+ // -+ Storage original; -+ Storage updated; -+ -+ do { -+ -+ original = (*ptr_); -+ -+ updated = Storage((original & kUpdateMask) | new_bits); -+ -+ original = atomicCAS(ptr_, original, updated); -+ -+ } while (updated != original); -+ -+#else -+ -+ Storage original = (*ptr_); -+ Storage updated = Storage((original & kUpdateMask) | new_bits); -+ *ptr_ = updated; -+ -+#endif -+ -+ return *this; -+ } -+ -+ //// -+ -+ /// Unpacks an element from memory -+ CUTLASS_HOST_DEVICE -+ operator Element() const { -+ return get(); -+ } -+ -+ /// Stores an element to memory -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator=(Element const & x) { -+ return set(x); -+ } -+ -+ /// Stores an element to memory -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator=(SubbyteReference const & x) { -+ return set(x.get()); -+ } -+ -+ /// Stores an element to memory -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator=( -+ ConstSubbyteReference const &x) { -+ return set(x.get()); -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator+=(int offset) { -+ -+ offset += offset_; -+ -+ int offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = offset % kElementsPerVector; -+ -+ ptr_ += offset_in_vectors; -+ offset_ = offset_in_elements; -+ -+ return *this; -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator+=(long long offset) { -+ -+ offset += offset_; -+ -+ long long offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = int(offset % kElementsPerVector); -+ -+ ptr_ += offset_in_vectors; -+ offset_ = offset_in_elements; -+ -+ return *this; -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator-=(int offset) { -+ -+ int offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = offset % kElementsPerVector; -+ -+ ptr_ -= offset_in_vectors; -+ offset_ -= offset_in_elements; -+ -+ if (offset_ < 0) { -+ offset_ += kElementsPerVector; -+ --ptr_; -+ } -+ -+ return *this; -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator-=(long long offset) { -+ -+ long long offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = int(offset % kElementsPerVector); -+ -+ ptr_ -= offset_in_vectors; -+ offset_ -= offset_in_elements; -+ -+ if (offset_ < 0) { -+ offset_ += kElementsPerVector; -+ --ptr_; -+ } -+ -+ return *this; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference operator+(int offset) const { -+ -+ SubbyteReference ref(ptr_, offset_); -+ ref += offset; -+ -+ return ref; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference operator+(long long offset) const { -+ -+ SubbyteReference ref(ptr_, offset_); -+ ref += offset; -+ -+ return ref; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference operator-(int offset) const { -+ -+ SubbyteReference ref(ptr_, offset_); -+ ref -= offset; -+ -+ return ref; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference operator-=(long long offset) const { -+ -+ SubbyteReference ref(ptr_, offset_); -+ ref -= offset; -+ -+ return ref; -+ } -+ -+ /// Computes the difference in elements between references -+ CUTLASS_HOST_DEVICE -+ ptrdiff_t operator-(SubbyteReference ref) const { -+ return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); -+ } -+ -+ /// Explicit cast to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(get()); -+ } -+ -+ /// Explicit cast to signed 64-bit integer -+ CUTLASS_HOST_DEVICE -+ explicit operator int64_t() const { -+ return int64_t(get()); -+ } -+ -+ /// Explicit cast to unsigned 64-bit integer -+ CUTLASS_HOST_DEVICE -+ explicit operator uint64_t() const { -+ return uint64_t(get()); -+ } -+ -+ /// Explicit cast to float -+ CUTLASS_HOST_DEVICE -+ explicit operator float() const { -+ return float(get()); -+ } -+ -+ /// Explicit cast to double -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(get()); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template ::value < 8)> -+struct ReferenceFactory; -+ -+template -+struct ReferenceFactory { -+ CUTLASS_HOST_DEVICE -+ static Element &get(Element *ptr, int64_t offset) { -+ return ptr[offset]; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Element const &get(Element const *ptr, int64_t offset) { -+ return ptr[offset]; -+ } -+}; -+ -+template -+struct ReferenceFactory { -+ CUTLASS_HOST_DEVICE -+ static SubbyteReference get(Element *ptr, int64_t offset) { -+ return SubbyteReference(ptr, offset); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static ConstSubbyteReference get(Element const *ptr, -+ int64_t offset) { -+ return ConstSubbyteReference(ptr, offset); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/tensor_coord.h b/3rdparty/cutlass/include/cutlass/tensor_coord.h -new file mode 100644 -index 0000000..d3a7b32 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/tensor_coord.h -@@ -0,0 +1,326 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a canonical coordinate for rank=4 tensors offering named indices. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a canonical 4D coordinate used by tensor operations. -+struct Tensor4DCoord : public Coord<4> { -+ -+ /// Base class -+ using Base = Coord<4>; -+ -+ /// Index type -+ using Index = typename Base::Index; -+ -+ /// LongIndex type -+ using LongIndex = typename Base::LongIndex; -+ -+ /// Batch dimension -+ static int const kN = 0; -+ -+ /// Height dimension -+ static int const kH = 1; -+ -+ /// Width dimension -+ static int const kW = 2; -+ -+ /// Channels dimension -+ static int const kC = 3; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord() { } -+ -+ /// Constructs from Coord<4> -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord(Coord<4> const &coord): Base(coord) { } -+ -+ /// Helper to construct from N, H, W, and C. -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord(Index n, Index h, Index w, Index c): Base(make_Coord(n, h, w, c)) { } -+ -+ /// Helper to construct from N, H, W, and C, which are LongIndex type -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord(LongIndex n, LongIndex h, LongIndex w, LongIndex c) -+ : Base(make_Coord(Index(n), Index(h), Index(w), Index(c))) { } -+ -+ /// Returns the batch of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & n() const { return this->at(kN); } -+ -+ /// Returns the batch of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & n() { return this->at(kN); } -+ -+ /// Returns the row of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & h() const { return this->at(kH); } -+ -+ /// Returns the row of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & h() { return this->at(kH); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & w() const { return this->at(kW); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & w() { return this->at(kW); } -+ -+ /// Returns the channel of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & c() const { return this->at(kC); } -+ -+ /// Returns the channel of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & c() { return this->at(kC); } -+ -+ // -+ // Coord operators -+ // -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord operator+(Base const& b) const { -+ return Tensor4DCoord(Base::operator+(b)); -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord operator-(Base const& b) const { -+ return Tensor4DCoord(Base::operator-(b)); -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord operator*(Base const& b) const { -+ return Tensor4DCoord(Base::operator*(b)); -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord operator/(Base const& b) const { -+ return Tensor4DCoord(Base::operator/(b)); -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord& operator+=(Base const& b) { -+ Base::operator+=(b); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord& operator-=(Base const& b) { -+ Base::operator-=(b); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord& operator*=(Base const& b) { -+ Base::operator*=(b); -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord& operator/=(Base const& b) { -+ Base::operator/=(b); -+ return *this; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a canonical 5D coordinate used by tensor operations. -+struct Tensor5DCoord : public Coord<5> { -+ -+ /// Base class -+ using Base = Coord<5>; -+ -+ /// Index type -+ using Index = typename Base::Index; -+ -+ /// LongIndex type -+ using LongIndex = typename Base::LongIndex; -+ -+ /// Batch dimension -+ static int const kN = 0; -+ -+ /// Depth dimension -+ static int const kD = 1; -+ -+ /// Height dimension -+ static int const kH = 2; -+ -+ /// Width dimension -+ static int const kW = 3; -+ -+ /// Channels dimension -+ static int const kC = 4; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord() { } -+ -+ /// Constructs from Coord<5> -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord(Coord<5> const &coord): Base(coord) { } -+ -+ /// Helper to construct from N, D, H, W, and C. -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord(Index n, Index d, Index h, Index w, Index c): Base(make_Coord(n, d, h, w, c)) { } -+ -+ /// Helper to construct from N, D, H, W, and C, which are LongIndex type -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord(LongIndex n, LongIndex d, LongIndex h, LongIndex w, LongIndex c) -+ : Base(make_Coord(Index(n), Index(d), Index(h), Index(w), Index(c))) { } -+ -+ /// Returns the batch of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & n() const { return this->at(kN); } -+ -+ /// Returns the batch of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & n() { return this->at(kN); } -+ -+ /// Returns the batch of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & d() const { return this->at(kD); } -+ -+ /// Returns the batch of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & d() { return this->at(kD); } -+ -+ /// Returns the row of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & h() const { return this->at(kH); } -+ -+ /// Returns the row of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & h() { return this->at(kH); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & w() const { return this->at(kW); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & w() { return this->at(kW); } -+ -+ /// Returns the channel of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & c() const { return this->at(kC); } -+ -+ /// Returns the channel of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & c() { return this->at(kC); } -+ -+ // -+ // Coord operators -+ // -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord operator+(Base const& b) const { -+ return Tensor5DCoord(Base::operator+(b)); -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord operator-(Base const& b) const { -+ return Tensor5DCoord(Base::operator-(b)); -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord operator*(Base const& b) const { -+ return Tensor5DCoord(Base::operator*(b)); -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord operator/(Base const& b) const { -+ return Tensor5DCoord(Base::operator/(b)); -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord& operator+=(Base const& b) { -+ Base::operator+=(b); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord& operator-=(Base const& b) { -+ Base::operator-=(b); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord& operator*=(Base const& b) { -+ Base::operator*=(b); -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord& operator/=(Base const& b) { -+ Base::operator/=(b); -+ return *this; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/tensor_ref.h b/3rdparty/cutlass/include/cutlass/tensor_ref.h -new file mode 100644 -index 0000000..ce2505e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/tensor_ref.h -@@ -0,0 +1,418 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a structure containing strides, bounds, and a pointer to tensor data. -+*/ -+#pragma once -+ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+#include "cutlass/platform/platform.h" -+#include "cutlass/subbyte_reference.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Default layout function from coordinates in a tensor's index space into the n-D array held -+/// in memory. -+/// -+/// All layout functions must define at least the members shown in IdentityTensorLayout<>. -+template -+class IdentityTensorLayout { -+public: -+ /// Logical rank of tensor -+ static int const kRank = Rank; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = Rank; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = Coord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ IdentityTensorLayout(Stride const &stride = Stride()): stride_(stride) { } -+ -+ /// Returns the offset of a coordinate in linear memory -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(Coord const &coord) const { -+ return coord.dot(stride_); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &size) const { -+ int idx = stride_.max_dim_index(); -+ return stride_[idx] * size[idx]; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank -+ and layout within memory. A TensorRef combines a pointer and a Layout concept -+ -+ Examples: -+ -+ (These examples use helpers for matrix layouts defined in cutlass/layout/matrix.h) -+ -+ 1. Column-major matrix may be represented as a rank=2 tensor: -+ -+ TensorRef A(ptr_A, ldm); -+ -+ 2. Row-major matrix may be represented as a rank=2 tensor: -+ -+ TensorRef B(ptr_A, ldm); -+ -+ 3. An interleaved matrix may be represented as a rank=2 tensor: -+ -+ TensorRef > C; -+ -+ 4. A helper exists to define a TensorRef for a contiguous matrix whose layout -+ is not known at compile time. -+ -+ int ldm; // leading dimension -+ layout::Matrix kind; // Could be layout::Matrix::kRowMajor or layout::Matrix::kColumnMajor -+ -+ -+ TensorRef E(ptr_E, {ldm, kind}); -+ -+*/ -+template < -+ /// Data type of element stored within tensor (concept: NumericType) -+ typename Element_, -+ /// Defines a mapping from logical coordinate to linear memory (concept: Layout) -+ typename Layout_ -+> -+class TensorRef { -+ public: -+ /// Data type of individual access -+ using Element = Element_; -+ -+ /// Mapping function from logical coordinate to linear memory -+ using Layout = Layout_; -+ -+ /// Reference type to an element -+ using Reference = typename platform::conditional< -+ sizeof_bits::value >= 8, -+ Element &, -+ SubbyteReference -+ >::type; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = Layout::kRank; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Layout's stride vector -+ using Stride = typename Layout::Stride; -+ -+ /// TensorRef to constant data -+ using ConstTensorRef = TensorRef< -+ typename platform::remove_const::type const, -+ Layout>; -+ -+ /// TensorRef to non-constant data -+ using NonConstTensorRef = TensorRef< -+ typename platform::remove_const::type, -+ Layout>; -+ -+ /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a -+ /// scalar, but degenerate cases such as these are difficult to accommodate without -+ /// extensive C++ metaprogramming or support for zero-length arrays. -+ static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); -+ -+ private: -+ -+ /// Pointer -+ Element* ptr_; -+ -+ /// Layout object maps logical coordinates to linear offsets -+ Layout layout_; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a TensorRef with a pointer and layout object. -+ CUTLASS_HOST_DEVICE -+ TensorRef(): ptr_(nullptr) { -+ -+ } -+ -+ /// Constructs a TensorRef with a pointer and layout object. -+ CUTLASS_HOST_DEVICE -+ TensorRef( -+ Element *ptr, ///< pointer to start of tensor -+ Layout const &layout ///< layout object containing stride and mapping function -+ ): -+ ptr_(ptr), layout_(layout) { -+ -+ } -+ -+ /// Converting constructor from TensorRef to non-constant data. -+ template -+ CUTLASS_HOST_DEVICE -+ TensorRef( -+ NonConstTensorRef const &ref, ///< TensorRef to non-const data -+ ///SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const -+ _Magic magic = (typename platform::enable_if< ! platform::is_same >::value, _Magic>::type)0 -+ ): -+ ptr_(ref.data()), layout_(ref.layout()) { } -+ -+ /// Returns a reference to constant-valued tensor. -+ CUTLASS_HOST_DEVICE -+ ConstTensorRef const_ref() const { -+ return ConstTensorRef(ptr_, layout_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ NonConstTensorRef non_const_ref() const { -+ return NonConstTensorRef(const_cast::type *>(ptr_), layout_); -+ } -+ -+ /// Updates only the pointer -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr = nullptr) { -+ ptr_ = ptr; -+ } -+ -+ /// Updates the pointer and layout object -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr, Layout const &layout) { -+ ptr_ = ptr; -+ layout_ = layout; -+ } -+ -+ /// Returns true if the TensorRef is non-null -+ CUTLASS_HOST_DEVICE -+ bool good() const { -+ return ptr_ != nullptr; -+ } -+ -+ /// Returns the pointer to referenced data -+ CUTLASS_HOST_DEVICE -+ Element * data() const { return ptr_; } -+ -+ /// Returns a reference to the element at a given linear index -+ CUTLASS_HOST_DEVICE -+ Reference data(LongIndex idx) const { -+ return ReferenceFactory::type, -+ (sizeof_bits::value < 8)>::get(ptr_, idx); -+ } -+ -+ /// Returns the layout object -+ CUTLASS_HOST_DEVICE -+ Layout & layout() { -+ return layout_; -+ } -+ -+ /// Returns the layout object -+ CUTLASS_HOST_DEVICE -+ Layout layout() const { -+ return layout_; -+ } -+ -+ /// Returns the layout object's stride vector -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride vector -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ CUTLASS_HOST_DEVICE -+ typename Layout::Stride::Index stride(int dim) const { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ CUTLASS_HOST_DEVICE -+ typename Layout::Stride::Index & stride(int dim) { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Computes the offset of an index from the origin of the tensor -+ CUTLASS_HOST_DEVICE -+ LongIndex offset(TensorCoord const& coord) const { -+ return layout_(coord); -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference at(TensorCoord const& coord) const { -+ return data(offset(coord)); -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference operator[](TensorCoord const& coord) const { -+ return data(offset(coord)); -+ } -+ -+ /// Adds an offset to each pointer -+ CUTLASS_HOST_DEVICE -+ TensorRef & add_pointer_offset(LongIndex offset_) { -+ ptr_ += offset_; -+ return *this; -+ } -+ -+ /// Adds an offset to each pointer -+ CUTLASS_HOST_DEVICE -+ TensorRef & add_coord_offset(TensorCoord const &coord) { -+ add_pointer_offset(offset(coord)); -+ return *this; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRef operator+(TensorCoord const& b) const { -+ TensorRef result(*this); -+ result.add_coord_offset(b); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRef & operator+=(TensorCoord const& b) { -+ add_coord_offset(b); -+ return *this; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRef operator-(TensorCoord const& b) const { -+ TensorRef result(*this); -+ result.add_pointer_offset(-offset(b)); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRef & operator-=(TensorCoord const& b) { -+ add_pointer_offset(-offset(b)); -+ return *this; -+ } -+}; -+ -+/// Constructs a TensorRef, deducing types from arguments. -+template < -+ typename Element, -+ typename Layout -+> -+CUTLASS_HOST_DEVICE -+TensorRef make_TensorRef(Element *ptr, Layout const &layout) { -+ return TensorRef(ptr, layout); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations to handle degenerate and sub-byte cases. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Element, -+ typename Layout -+> -+CUTLASS_HOST_DEVICE -+bool TensorRef_aligned(TensorRef const &ref, int alignment) { -+ -+ int const kStrideRank = Layout::kStrideRank; -+ -+ if (reinterpret_cast(ref.data()) % alignment) { -+ return false; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kStrideRank; ++i) { -+ if (ref.stride(i) % alignment) { -+ return false; -+ } -+ } -+ -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/tensor_ref_planar_complex.h b/3rdparty/cutlass/include/cutlass/tensor_ref_planar_complex.h -new file mode 100644 -index 0000000..a0131fd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/tensor_ref_planar_complex.h -@@ -0,0 +1,374 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a structure containing strides, bounds, and a pointer to tensor data. -+*/ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/tensor_ref.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct PlanarComplexReference { -+ -+ // -+ // Type definitions -+ // -+ -+ using Element = Element_; -+ using ComplexElement = complex; -+ -+ // -+ // Data members -+ // -+ -+ Element *real; -+ Element *imag; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ PlanarComplexReference( -+ Element *real_ = nullptr, -+ Element *imag_ = nullptr -+ ): -+ real(real_), imag(imag_) { } -+ -+ /// Loads the complex element -+ CUTLASS_HOST_DEVICE -+ operator complex() const { -+ return complex{*real, *imag}; -+ } -+ -+ /// Stores a complex element to the location pointed to by the reference -+ CUTLASS_HOST_DEVICE -+ PlanarComplexReference &operator=(complex const &rhs) { -+ *real = rhs.real(); -+ *imag = rhs.imag(); -+ return *this; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank -+ and layout within memory. A TensorRef combines a pointer and a Layout concept -+ -+*/ -+template < -+ /// Data type of element stored within tensor (concept: NumericType) -+ typename Element_, -+ /// Defines a mapping from logical coordinate to linear memory (concept: Layout) -+ typename Layout_ -+> -+class TensorRefPlanarComplex { -+ public: -+ /// Data type of individual access -+ using Element = Element_; -+ -+ /// Complex element type -+ using ComplexElement = complex; -+ -+ /// Mapping function from logical coordinate to linear memory -+ using Layout = Layout_; -+ -+ static_assert(sizeof_bits::value >= 8, -+ "Planar complex not suitable for subbyte elements at this time"); -+ -+ /// Reference type to an element -+ using Reference = PlanarComplexReference; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = Layout::kRank; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Layout's stride vector -+ using Stride = typename Layout::Stride; -+ -+ /// TensorRef to constant data -+ using ConstTensorRef = TensorRefPlanarComplex< -+ typename platform::remove_const::type const, -+ Layout>; -+ -+ /// TensorRef to non-constant data -+ using NonConstTensorRef = TensorRefPlanarComplex< -+ typename platform::remove_const::type, -+ Layout>; -+ -+ /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a -+ /// scalar, but degenerate cases such as these are difficult to accommodate without -+ /// extensive C++ metaprogramming or support for zero-length arrays. -+ static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); -+ -+ private: -+ -+ /// Pointer -+ Element* ptr_; -+ -+ /// Layout object maps logical coordinates to linear offsets -+ Layout layout_; -+ -+ /// Offset to imaginary part -+ LongIndex imaginary_stride_; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a TensorRef with a pointer and layout object. -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex( -+ Element *ptr = nullptr, ///< pointer to start of tensor -+ Layout const &layout = Layout(), ///< layout object containing stride and mapping function -+ LongIndex imaginary_stride = 0 -+ ): -+ ptr_(ptr), layout_(layout), imaginary_stride_(imaginary_stride) { -+ -+ } -+ -+ /// Converting constructor from TensorRef to non-constant data. -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex( -+ NonConstTensorRef const &ref ///< TensorRef to non-const data -+ ): -+ ptr_(ref.data()), layout_(ref.layout()), imaginary_stride_(ref.imaginary_stride_) { } -+ -+ /// Returns a reference to constant-valued tensor. -+ CUTLASS_HOST_DEVICE -+ ConstTensorRef const_ref() const { -+ return ConstTensorRef(ptr_, layout_, imaginary_stride_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ NonConstTensorRef non_const_ref() const { -+ return NonConstTensorRef( -+ const_cast::type *>(ptr_), -+ layout_, -+ imaginary_stride_); -+ } -+ -+ /// Updates only the pointer -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr = nullptr, LongIndex imaginary_stride = 0) { -+ ptr_ = ptr; -+ imaginary_stride_ = imaginary_stride; -+ } -+ -+ /// Updates the pointer and layout object -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride) { -+ ptr_ = ptr; -+ layout_ = layout; -+ imaginary_stride_ = imaginary_stride; -+ } -+ -+ /// Returns true if the TensorRef is non-null -+ CUTLASS_HOST_DEVICE -+ bool good() const { -+ return ptr_ != nullptr; -+ } -+ -+ /// Returns the pointer to referenced data -+ CUTLASS_HOST_DEVICE -+ Element * data() const { return ptr_; } -+ -+ /// Returns the pointer to referenced data -+ CUTLASS_HOST_DEVICE -+ Element * imaginary_data() const { return ptr_ + imaginary_stride_; } -+ -+ /// Returns a reference to the element at a given linear index -+ CUTLASS_HOST_DEVICE -+ Reference data(LongIndex idx) const { -+ return Reference(ptr_ + idx, ptr_ + idx + imaginary_stride_); -+ } -+ -+ /// Returns the layout object -+ CUTLASS_HOST_DEVICE -+ Layout & layout() { -+ return layout_; -+ } -+ -+ /// Returns the layout object -+ CUTLASS_HOST_DEVICE -+ Layout layout() const { -+ return layout_; -+ } -+ -+ /// Gets the stride to an imaginary element -+ LongIndex imaginary_stride() const { -+ return imaginary_stride_; -+ } -+ -+ /// Gets the stride to an imaginary element -+ LongIndex &imaginary_stride() { -+ return imaginary_stride_; -+ } -+ -+ /// Returns the layout object's stride vector -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride vector -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ CUTLASS_HOST_DEVICE -+ Index stride(int dim) const { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ CUTLASS_HOST_DEVICE -+ Index & stride(int dim) { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Computes the offset of an index from the origin of the tensor -+ CUTLASS_HOST_DEVICE -+ LongIndex offset(TensorCoord const& coord) const { -+ return layout_(coord); -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference at(TensorCoord const& coord) const { -+ return data(offset(coord)); -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference operator[](TensorCoord const& coord) const { -+ return data(offset(coord)); -+ } -+ -+ /// Adds an offset to each pointer -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex & add_pointer_offset(LongIndex offset_) { -+ ptr_ += offset_; -+ return *this; -+ } -+ -+ /// Adds an offset to each pointer -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex & add_coord_offset(TensorCoord const &coord) { -+ add_pointer_offset(offset(coord)); -+ return *this; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex operator+(TensorCoord const& b) const { -+ TensorRefPlanarComplex result(*this); -+ result.add_coord_offset(b); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex & operator+=(TensorCoord const& b) { -+ add_coord_offset(b); -+ return *this; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex operator-(TensorCoord const& b) const { -+ TensorRefPlanarComplex result(*this); -+ result.add_pointer_offset(-offset(b)); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex & operator-=(TensorCoord const& b) { -+ add_pointer_offset(-offset(b)); -+ return *this; -+ } -+ -+ /// TensorRef to real-valued tensor -+ CUTLASS_HOST_DEVICE -+ cutlass::TensorRef ref_real() const { -+ return cutlass::TensorRef(data(), layout()); -+ } -+ -+ /// TensorRef to real-valued tensor -+ CUTLASS_HOST_DEVICE -+ cutlass::TensorRef ref_imag() const { -+ return cutlass::TensorRef(imaginary_data(), layout()); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Constructs a TensorRef, deducing types from arguments. -+template < -+ typename Element, -+ typename Layout -+> -+CUTLASS_HOST_DEVICE -+TensorRefPlanarComplex make_TensorRefPlanarComplex( -+ Element *ptr, -+ Layout const &layout, -+ int64_t imaginary_stride) { -+ -+ return TensorRefPlanarComplex(ptr, layout, imaginary_stride); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/tensor_view.h b/3rdparty/cutlass/include/cutlass/tensor_view.h -new file mode 100644 -index 0000000..9a4d238 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/tensor_view.h -@@ -0,0 +1,297 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a structure containing strides and a pointer to tensor data. -+ -+ TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, -+ it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from -+ data storage and is therefore lightweight and may be embedded in larger tensor objects or -+ memory structures. -+ -+ See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to -+ linear memory. -+*/ -+ -+#pragma once -+ -+#if !defined(__CUDACC_RTC__) -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Data type of element stored within tensor -+ typename Element_, -+ /// Maps a Coord in the logical tensor index space to the internal n-D array -+ typename Layout_ -+> -+class TensorView : public TensorRef { -+ public: -+ -+ /// Base tensor reference -+ using Base = cutlass::TensorRef; -+ -+ /// Mapping function from logical coordinate to internal n-D array -+ using Layout = Layout_; -+ -+ /// TensorRef pointing to constant memory -+ using ConstTensorRef = typename Base::ConstTensorRef; -+ -+ /// Underlying TensorRef type -+ using TensorRef = Base; -+ -+ /// Data type of individual access -+ using Element = Element_; -+ -+ /// Reference type to an element -+ using Reference = Element &; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = Layout::kRank; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Coordinate in storage n-D array -+ using Stride = typename Layout::Stride; -+ -+ /// TensorView pointing to constant memory -+ using ConstTensorView = TensorView< -+ typename platform::remove_const::type const, -+ Layout>; -+ -+ /// TensorView pointing to non-constant memory -+ using NonConstTensorView = TensorView< -+ typename platform::remove_const::type, -+ Layout>; -+ -+ /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a -+ /// scalar, but degenerate cases such as these are difficult to accommodate without -+ /// extensive C++ metaprogramming or support for zero-length arrays. -+ static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); -+ -+ private: -+ -+ /// View extent -+ TensorCoord extent_; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a TensorView object -+ CUTLASS_HOST_DEVICE -+ TensorView() { } -+ -+ /// Constructs a TensorView object -+ CUTLASS_HOST_DEVICE -+ TensorView( -+ Element *ptr, ///< pointer to start of tensor -+ Layout const &layout, ///< layout object containing stride and mapping function -+ TensorCoord const &extent ///< size of the view in logical coordinates -+ ): -+ Base(ptr, layout), extent_(extent) { -+ -+ } -+ -+ /// Constructs a TensorView object -+ CUTLASS_HOST_DEVICE -+ TensorView( -+ TensorRef const &ref, ///< pointer and layout object referencing a tensor -+ TensorCoord const &extent ///< logical size of tensor -+ ): -+ Base(ref), extent_(extent) { -+ -+ } -+ -+ /// Converting constructor from TensorRef to non-constant data. -+ CUTLASS_HOST_DEVICE -+ TensorView( -+ NonConstTensorView const &view ///< TensorView to non-const data -+ ): -+ Base(view), extent_(view.extent_) { } -+ -+ /// Updates the pointer and layout object -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr, Layout const &layout, TensorCoord const &extent) { -+ Base::reset(ptr, layout); -+ this->resize(extent); -+ } -+ -+ /// Updates the pointer -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr) { -+ Base::reset(ptr); -+ } -+ -+ /// Changes the size of the view without affecting pointer or layout -+ CUTLASS_HOST_DEVICE -+ void resize(TensorCoord const &extent) { -+ this->extent_ = extent; -+ } -+ -+ /// Returns the extent of the view (the size along each logical dimension). -+ CUTLASS_HOST_DEVICE -+ TensorCoord const& extent() const { return extent_; } -+ -+ /// Returns the extent along a particular logical dimension. -+ CUTLASS_HOST_DEVICE -+ Index extent(int dim) const { return extent_.at(dim); } -+ -+ /// Returns the number of logical elements -+ CUTLASS_HOST_DEVICE -+ LongIndex size() const { -+ return extent_.product(); -+ } -+ -+ /// Determines whether a location is within a tensor -+ CUTLASS_HOST_DEVICE -+ bool contains(TensorCoord const& coord) const { -+ CUTLASS_PRAGMA_UNROLL -+ for (int dim = 0; dim < kRank; ++dim) { -+ if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) { -+ return false; -+ } -+ } -+ return true; -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ TensorRef ref() const { -+ return TensorRef(this->data(), this->layout()); -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ ConstTensorRef const_ref() const { -+ return ConstTensorRef(this->data(), this->layout()); -+ } -+ -+ /// Returns a TensorView to const data -+ CUTLASS_HOST_DEVICE -+ ConstTensorView const_view() const { -+ return ConstTensorView(const_ref(), extent_); -+ } -+ -+ /// Returns a Tensor_view given location and size quantities -+ CUTLASS_HOST_DEVICE -+ TensorView subview( -+ TensorCoord extent, ///< extent of the resulting view -+ TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view -+ ) const { -+ -+ TensorView result(this->ref(), extent.clamp(extent_ - location)); -+ result.add_coord_offset(location); -+ return result; -+ } -+ -+ /// Returns the number of scalar elements needed to store tensor. -+ CUTLASS_HOST_DEVICE -+ size_t capacity() const { -+ return Base::layout().capacity(extent_); -+ } -+ -+ /// Returns a TensorView offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorView operator+( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) const { -+ -+ TensorView result(*this); -+ result.add_pointer_offset(this->offset(b)); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorView& operator+=( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) { -+ -+ this->add_pointer_offset(this->offset(b)); -+ return *this; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorView operator-( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) const { -+ -+ TensorRef result(*this); -+ result.add_pointer_offset(-this->offset(b)); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorView& operator-=( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) { -+ -+ this->add_pointer_offset(-this->offset(b)); -+ return *this; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Constructs a TensorRef, deducing types from arguments. -+template < -+ typename Element, -+ typename Layout -+> -+CUTLASS_HOST_DEVICE TensorView make_TensorView( -+ Element *ptr, -+ Layout const &layout, -+ typename Layout::TensorCoord const &extent) { -+ -+ return TensorView(ptr, layout, extent); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/tensor_view_planar_complex.h b/3rdparty/cutlass/include/cutlass/tensor_view_planar_complex.h -new file mode 100644 -index 0000000..6a66c6a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/tensor_view_planar_complex.h -@@ -0,0 +1,301 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a structure containing strides and a pointer to tensor data. -+ -+ TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, -+ it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from -+ data storage and is therefore lightweight and may be embedded in larger tensor objects or -+ memory structures. -+ -+ See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to -+ linear memory. -+*/ -+ -+#pragma once -+ -+#if !defined(__CUDACC_RTC__) -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref_planar_complex.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Data type of element stored within tensor -+ typename Element_, -+ /// Maps a Coord in the logical tensor index space to the internal n-D array -+ typename Layout_ -+> -+class TensorViewPlanarComplex : public TensorRefPlanarComplex { -+ public: -+ -+ /// Base tensor reference -+ using Base = cutlass::TensorRefPlanarComplex; -+ -+ /// Mapping function from logical coordinate to internal n-D array -+ using Layout = Layout_; -+ -+ /// TensorRef pointing to constant memory -+ using ConstTensorRef = typename Base::ConstTensorRef; -+ -+ /// Underlying TensorRef type -+ using TensorRef = Base; -+ -+ /// Data type of individual access -+ using Element = Element_; -+ -+ /// Reference type to an element -+ using Reference = Element &; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = Layout::kRank; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Coordinate in storage n-D array -+ using Stride = typename Layout::Stride; -+ -+ /// TensorView pointing to constant memory -+ using ConstTensorView = TensorViewPlanarComplex< -+ typename platform::remove_const::type const, -+ Layout>; -+ -+ /// TensorView pointing to non-constant memory -+ using NonConstTensorView = TensorViewPlanarComplex< -+ typename platform::remove_const::type, -+ Layout>; -+ -+ /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a -+ /// scalar, but degenerate cases such as these are difficult to accommodate without -+ /// extensive C++ metaprogramming or support for zero-length arrays. -+ static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); -+ -+ private: -+ -+ /// View extent -+ TensorCoord extent_; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a TensorView object -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex(TensorCoord const &extent = TensorCoord()): extent_(extent) { -+ -+ } -+ -+ /// Constructs a TensorView object -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex( -+ Element *ptr, ///< pointer to start of tensor -+ Layout const &layout, ///< layout object containing stride and mapping function -+ LongIndex imaginary_stride, ///< stride between real and imaginary part -+ TensorCoord const &extent ///< size of the view in logical coordinates -+ ): -+ Base(ptr, layout, imaginary_stride), extent_(extent) { -+ -+ } -+ -+ /// Constructs a TensorView object -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex( -+ TensorRef const &ref, ///< pointer and layout object referencing a tensor -+ TensorCoord const &extent ///< logical size of tensor -+ ): -+ Base(ref), extent_(extent) { -+ -+ } -+ -+ /// Converting constructor from TensorRef to non-constant data. -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex( -+ NonConstTensorView const &view ///< TensorView to non-const data -+ ): -+ Base(view), extent_(view.extent_) { } -+ -+ /// Updates the pointer and layout object -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride, TensorCoord size) { -+ Base::reset(ptr, layout, imaginary_stride); -+ this->resize(extent_); -+ } -+ -+ /// Changes the size of the view without affecting pointer or layout -+ CUTLASS_HOST_DEVICE -+ void resize(TensorCoord extent) { -+ this->extent_ = extent; -+ } -+ -+ /// Returns the extent of the view (the size along each logical dimension). -+ CUTLASS_HOST_DEVICE -+ TensorCoord const& extent() const { return extent_; } -+ -+ /// Returns the extent along a particular logical dimension. -+ CUTLASS_HOST_DEVICE -+ Index extent(int dim) const { return extent_.at(dim); } -+ -+ /// Determines whether a location is within a tensor -+ CUTLASS_HOST_DEVICE -+ bool contains(TensorCoord const& coord) const { -+ CUTLASS_PRAGMA_UNROLL -+ for (int dim = 0; dim < kRank; ++dim) { -+ if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) { -+ return false; -+ } -+ } -+ return true; -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ Base ref() const { -+ return Base(this->data(), this->layout(), this->imaginary_stride()); -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ ConstTensorRef const_ref() const { -+ return ConstTensorRef(this->data(), this->layout()); -+ } -+ -+ /// Returns a TensorView to const data -+ CUTLASS_HOST_DEVICE -+ ConstTensorView const_view() const { -+ return ConstTensorView(const_ref(), extent_); -+ } -+ -+ /// Returns a Tensor_view given location and size quantities -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex subview( -+ TensorCoord extent, ///< extent of the resulting view -+ TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view -+ ) const { -+ -+ TensorViewPlanarComplex result(this->ref(), extent.clamp(extent_ - location)); -+ result.add_coord_offset(location); -+ return result; -+ } -+ -+ /// Returns the number of scalar elements needed to store tensor. -+ CUTLASS_HOST_DEVICE -+ size_t capacity() const { -+ return Base::layout().capacity(extent_); -+ } -+ -+ /// Returns a TensorView offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex operator+( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) const { -+ -+ TensorViewPlanarComplex result(*this); -+ result.add_pointer_offset(this->offset(b)); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex& operator+=( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) { -+ -+ this->add_pointer_offset(this->offset(b)); -+ return *this; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex operator-( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) const { -+ -+ TensorRef result(*this); -+ result.add_pointer_offset(-this->offset(b)); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex& operator-=( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) { -+ -+ this->add_pointer_offset(-this->offset(b)); -+ return *this; -+ } -+ -+ /// TensorRef to real-valued tensor -+ CUTLASS_HOST_DEVICE -+ cutlass::TensorView view_real() const { -+ return cutlass::TensorView(this->data(), this->layout(), extent_); -+ } -+ -+ /// TensorRef to real-valued tensor -+ CUTLASS_HOST_DEVICE -+ cutlass::TensorView view_imag() const { -+ return cutlass::TensorView(this->imaginary_data(), this->layout(), extent_); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Constructs a TensorRef, deducing types from arguments. -+template < -+ typename Element, -+ typename Layout -+> -+CUTLASS_HOST_DEVICE TensorViewPlanarComplex make_TensorViewPlanarComplex( -+ Element *ptr, -+ Layout const &layout, -+ typename Layout::LongIndex imaginary_stride, -+ typename Layout::TensorCoord const &extent) { -+ -+ return TensorViewPlanarComplex(ptr, layout, imaginary_stride, extent); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/tfloat32.h b/3rdparty/cutlass/include/cutlass/tfloat32.h -new file mode 100644 -index 0000000..76e2bf9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/tfloat32.h -@@ -0,0 +1,477 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines a proxy class for storing Tensor Float 32 data type. -+*/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include "cutlass/floating_point_nvrtc.h" -+#else -+#include -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tensor Float 32 data type -+struct alignas(4) tfloat32_t { -+ -+ // -+ // Data members -+ // -+ -+ /// Storage type -+ uint32_t storage; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs from an unsigned int -+ CUTLASS_HOST_DEVICE -+ static tfloat32_t bitcast(uint32_t x) { -+ tfloat32_t h; -+ h.storage = x; -+ return h; -+ } -+ -+ /// Emulated rounding is fast in device code -+ CUTLASS_HOST_DEVICE -+ static tfloat32_t round_half_ulp_truncate(float const &s) { -+ uint32_t x = reinterpret_cast(s); -+ -+ #if defined(__CUDA_ARCH__) -+ if (::isfinite(s)) { -+ x += 0x1000u; -+ } -+ #else -+ if (std::isfinite(s)) { -+ x += 0x1000u; -+ } -+ #endif -+ -+ return tfloat32_t::bitcast(x); -+ } -+ -+ /// Default constructor -+ tfloat32_t() = default; -+ -+ /// Floating-point conversion - round toward nearest even -+ CUTLASS_HOST_DEVICE -+// explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).storage) { } -+ tfloat32_t(float x): storage(round_half_ulp_truncate(x).storage) { } -+ -+ /// Floating-point conversion - round toward nearest even -+ CUTLASS_HOST_DEVICE -+// explicit tfloat32_t(double x): tfloat32_t(float(x)) { -+ tfloat32_t(double x): tfloat32_t(float(x)) { -+ } -+ -+ /// Integer conversion - round toward zero -+ CUTLASS_HOST_DEVICE -+// explicit tfloat32_t(int x) { -+ tfloat32_t(int x) { -+ float flt = static_cast(x); -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(flt); -+ #else -+ std::memcpy(&storage, &flt, sizeof(storage)); -+ #endif -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ operator float() const { -+ -+ // Conversions to IEEE single-precision requires clearing dont-care bits -+ // of the mantissa. -+ unsigned bits = (storage & ~0x1fffu); -+ -+ #if defined(__CUDA_ARCH__) -+ return reinterpret_cast(bits); -+ #else -+ float flt; -+ std::memcpy(&flt, &bits, sizeof(flt)); -+ return flt; -+ #endif -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(float(*this)); -+ } -+ -+ /// Converts to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(float(*this)); -+ } -+ -+ /// Casts to bool -+ CUTLASS_HOST_DEVICE -+ explicit operator bool() const { -+ return (float(*this) != 0.0f); -+ } -+ -+ /// Obtains raw bits -+ CUTLASS_HOST_DEVICE -+ uint32_t raw() const { -+ return storage; -+ } -+ -+ /// Returns the sign bit -+ CUTLASS_HOST_DEVICE -+ bool signbit() const { -+ return ((raw() & 0x80000000) != 0); -+ } -+ -+ /// Returns the biased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent_biased() const { -+ return int((raw() >> 23) & 0x0ff); -+ } -+ -+ /// Returns the unbiased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent() const { -+ return exponent_biased() - 127; -+ } -+ -+ /// Returns the mantissa -+ CUTLASS_HOST_DEVICE -+ int mantissa() const { -+ return int(raw() & 0x7fffff); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool signbit(cutlass::tfloat32_t const& h) { -+ return h.signbit(); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::tfloat32_t abs(cutlass::tfloat32_t const& h) { -+ return cutlass::tfloat32_t::bitcast(h.raw() & 0x7fffffff); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isnan(cutlass::tfloat32_t const& h) { -+ return (h.exponent_biased() == 0x0ff) && h.mantissa(); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isfinite(cutlass::tfloat32_t const& h) { -+ return (h.exponent_biased() != 0x0ff); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::tfloat32_t nan_tf32(const char*) { -+ // NVIDIA canonical NaN -+ return cutlass::tfloat32_t::bitcast(0x7fffffff); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isinf(cutlass::tfloat32_t const& h) { -+ return (h.exponent_biased() == 0x0ff) && !h.mantissa(); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isnormal(cutlass::tfloat32_t const& h) { -+ return h.exponent_biased() && h.exponent_biased() != 0x0ff; -+} -+ -+CUTLASS_HOST_DEVICE -+int fpclassify(cutlass::tfloat32_t const& h) { -+ int exp = h.exponent_biased(); -+ int mantissa = h.mantissa(); -+ if (exp == 0x0ff) { -+ if (mantissa) { -+ return FP_NAN; -+ } -+ else { -+ return FP_INFINITE; -+ } -+ } -+ else if (!exp) { -+ if (mantissa) { -+ return FP_SUBNORMAL; -+ } -+ else { -+ return FP_ZERO; -+ } -+ } -+ return FP_NORMAL; -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::tfloat32_t sqrt(cutlass::tfloat32_t const& h) { -+#if defined(__CUDACC_RTC__) -+ return cutlass::tfloat32_t(sqrtf(float(h))); -+#else -+ return cutlass::tfloat32_t(std::sqrt(float(h))); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t copysign(tfloat32_t const& a, tfloat32_t const& b) { -+ -+ uint32_t a_mag = (reinterpret_cast(a) & 0x7fffffff); -+ uint32_t b_sign = (reinterpret_cast(b) & 0x80000000); -+ uint32_t result = (a_mag | b_sign); -+ -+ return reinterpret_cast(result); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Standard Library operations and definitions -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace std { -+ -+#if !defined(__CUDACC_RTC__) -+/// Numeric limits -+template <> -+struct numeric_limits { -+ static bool const is_specialized = true; -+ static bool const is_signed = true; -+ static bool const is_integer = false; -+ static bool const is_exact = false; -+ static bool const has_infinity = true; -+ static bool const has_quiet_NaN = true; -+ static bool const has_signaling_NaN = false; -+ static std::float_denorm_style const has_denorm = std::denorm_present; -+ static bool const has_denorm_loss = true; -+ static std::float_round_style const round_style = std::round_to_nearest; -+ static bool const is_iec559 = false; -+ static bool const is_bounded = true; -+ static bool const is_modulo = false; -+ static int const digits = 19; -+ -+ /// Least positive value -+ static cutlass::tfloat32_t min() { return cutlass::tfloat32_t::bitcast(0x01); } -+ -+ /// Minimum finite value -+ static cutlass::tfloat32_t lowest() { return cutlass::tfloat32_t::bitcast(0xff7fffff); } -+ -+ /// Maximum finite value -+ static cutlass::tfloat32_t max() { return cutlass::tfloat32_t::bitcast(0x7f7fffff); } -+ -+ /// Returns smallest finite value -+ static cutlass::tfloat32_t epsilon() { return cutlass::tfloat32_t::bitcast(0x1000); } -+ -+ /// Returns smallest finite value -+ static cutlass::tfloat32_t round_error() { return cutlass::tfloat32_t(0.5f); } -+ -+ /// Returns smallest finite value -+ static cutlass::tfloat32_t infinity() { return cutlass::tfloat32_t::bitcast(0x7f800000); } -+ -+ /// Returns smallest finite value -+ static cutlass::tfloat32_t quiet_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } -+ -+ /// Returns smallest finite value -+ static cutlass::tfloat32_t signaling_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } -+ -+ /// Returns smallest finite value -+ static cutlass::tfloat32_t denorm_min() { return cutlass::tfloat32_t::bitcast(0x1); } -+}; -+#endif -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace std -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Arithmetic operators -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool operator==(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return float(lhs) == float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator!=(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return float(lhs) != float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return float(lhs) < float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<=(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return float(lhs) <= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return float(lhs) > float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>=(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return float(lhs) >= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return tfloat32_t(float(lhs) + float(rhs)); -+} -+ -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator-(tfloat32_t const& lhs) { -+ union u_tff32 { -+ float val_f32; -+ tfloat32_t val_tf; -+ CUTLASS_HOST_DEVICE u_tff32() : val_f32(0) { } -+ }; -+ union u_tff32 x; x.val_f32 = -reinterpret_cast(lhs); -+ return x.val_tf; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator-(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return tfloat32_t(float(lhs) - float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator*(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return tfloat32_t(float(lhs) * float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator/(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return tfloat32_t(float(lhs) / float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t& operator+=(tfloat32_t & lhs, tfloat32_t const& rhs) { -+ lhs = tfloat32_t(float(lhs) + float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t& operator-=(tfloat32_t & lhs, tfloat32_t const& rhs) { -+ lhs = tfloat32_t(float(lhs) - float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t& operator*=(tfloat32_t & lhs, tfloat32_t const& rhs) { -+ lhs = tfloat32_t(float(lhs) * float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t& operator/=(tfloat32_t & lhs, tfloat32_t const& rhs) { -+ lhs = tfloat32_t(float(lhs) / float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t& operator++(tfloat32_t & lhs) { -+ float tmp(lhs); -+ ++tmp; -+ lhs = tfloat32_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t& operator--(tfloat32_t & lhs) { -+ float tmp(lhs); -+ --tmp; -+ lhs = tfloat32_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator++(tfloat32_t & lhs, int) { -+ tfloat32_t ret(lhs); -+ float tmp(lhs); -+ tmp++; -+ lhs = tfloat32_t(tmp); -+ return ret; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator--(tfloat32_t & lhs, int) { -+ tfloat32_t ret(lhs); -+ float tmp(lhs); -+ tmp--; -+ lhs = tfloat32_t(tmp); -+ return ret; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// User-defined literals -+// -+ -+CUTLASS_HOST_DEVICE -+cutlass::tfloat32_t operator "" _tf32(long double x) { -+ return cutlass::tfloat32_t(float(x)); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::tfloat32_t operator "" _tf32(unsigned long long int x) { -+ return cutlass::tfloat32_t(int(x)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/thread/matrix.h b/3rdparty/cutlass/include/cutlass/thread/matrix.h -new file mode 100644 -index 0000000..bc78cf8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/thread/matrix.h -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a matrix object intended for storing data in registers and operations within -+ a CUDA thread. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/matrix_coord.h" -+ -+namespace cutlass { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Per-thread matrix object storing a packed matrix -+template < -+ typename Element, -+ int Rows, -+ int Columns, -+ typename Layout = layout::RowMajor -+> -+class Matrix : public Array { -+public: -+ -+ // Verify layout refers to a rank=2 matrix. -+ static_assert( -+ Layout::kRank == 2, -+ "Layout type must refer to a rank=2 matrix"); -+ -+ /// Base type -+ using Base = Array; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Number of rows -+ static int const kRows = Rows; -+ -+ /// Number of columns -+ static int const kColumns = Columns; -+ -+ /// Layout within the array -+ using Layout = Layout_; -+ -+ /// Reference type to an element -+ using Reference = Element &; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = 2; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Stride type -+ using Stride = typename Layout::Stride; -+ -+ /// TensorRef to matrix object -+ using TensorRef = TensorRef; -+ -+ /// TensorRef to constant matrix object -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ /// TensorRef to matrix object -+ using TensorView = TensorView; -+ -+ /// TensorRef to constant matrix object -+ using ConstTensorView = typename TensorView::ConstTensorView; -+ -+ /// Diagonal vector -+ using Diagonal = Vector; -+ -+private: -+ -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns the size of the object -+ CUTLASS_HOST_DEVICE -+ static MatrixCoord extent() { -+ return make_Coord(kRows, kColumns); -+ } -+ -+ /// Returns the layout object -+ CUTLASS_HOST_DEVICE -+ static Layout layout() { -+ return Layout::packed(extent()); -+ } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Matrix() { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Matrix(Diagonal const &diag) { -+ // Todo - construct from diagonal -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ TensorRef ref() { -+ return TensorRef(this->data(), layout()); -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ ConstTensorRef const_ref() const { -+ return ConstTensorRef(this->data(), layout()); -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ TensorView view() { -+ return TensorView(ref(), extent()); -+ } -+ -+ /// Returns a TensorView to const data -+ CUTLASS_HOST_DEVICE -+ ConstTensorView const_view() const { -+ return ConstTensorView(const_ref(), extent()); -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference at(MatrixCoord const& coord) const { -+ typename Base::size_type offset_(layout().offset(coord)); -+ return Base::at(offset_); -+ } -+ -+ /// Returns the number of scalar elements needed to store tensor. -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity() const { -+ return LongIndex(Base::size()); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Column vector defined as a matrix with exactly one column -+template < -+ typename Element, -+ int Rows, -+ typename Layout = layout::ColumnMajor -+> -+using ColumnVector = Matrix; -+ -+/// Row vector defined as a matrix with exactly one row -+template < -+ typename Element, -+ int Columns, -+ typename Layout = layout::RowMajor -+> -+using RowVector = Matrix; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/trace.h b/3rdparty/cutlass/include/cutlass/trace.h -new file mode 100644 -index 0000000..c77e7f4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/trace.h -@@ -0,0 +1,59 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Helpers for optionally tracing through code when debugging. -+ -+ This file is to be included after all other headers. -+*/ -+ -+#pragma once -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Tracing options -+#ifndef CUTLASS_DEBUG_TRACE_LEVEL -+#define CUTLASS_DEBUG_TRACE_LEVEL 0 -+#endif -+ -+#if CUTLASS_DEBUG_TRACE_LEVEL -+#include -+#include "cutlass/core_io.h" -+#if defined(__CUDA_ARCH__) -+#define CUTLASS_TRACE_HOST(x) -+#else -+#define CUTLASS_TRACE_HOST(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } -+#endif -+#else -+#define CUTLASS_TRACE_HOST(x) -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/transform/pitch_linear_thread_map.h b/3rdparty/cutlass/include/cutlass/transform/pitch_linear_thread_map.h -new file mode 100644 -index 0000000..c084dd4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/pitch_linear_thread_map.h -@@ -0,0 +1,926 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing how threads are mapped to a given tile. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Strip-mines a pitch-linear tile among a given number of threads, first along -+/// the contiguous dimension then along the strided dimension. -+/// -+/// The tile must be divisible by the thread count such that all threads may -+/// execute the same number of iterations with the same delta to exhaustively -+/// cover the tile. -+/// -+/// This class satisfies the "RegularThreadMapping" concept. -+/// -+/// This ThreadMap is used by SIMT kernels and operand E of the sparse tensor -+/// kernels. -+template < -+ typename Shape_, -+ int Threads, -+ int ElementsPerAccess = 1 -+> -+struct PitchLinearStripminedThreadMap { -+ -+ /// Tensor coordinate -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ /// Tile shape -+ using Shape = Shape_; -+ -+ /// Number of threads total -+ static int const kThreads = Threads; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = layout::PitchLinearShape; -+ -+ /// Internal implementation details -+ struct Detail { -+ -+ static_assert(!(Shape::kContiguous % kElementsPerAccess), ""); -+ -+ /// Shape of the tile in units of vectors -+ using ShapeVec = layout::PitchLinearShape< -+ Shape::kContiguous / kElementsPerAccess, -+ Shape::kStrided -+ >; -+ -+ static_assert((Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) || -+ (!(kThreads % ShapeVec::kContiguous)), -+ "Shape must be divisible by number of iterations of each thread."); -+ }; -+ -+ /// Number of iterations by each thread -+ using Iterations = typename platform::conditional< -+ Threads >= Detail::ShapeVec::kContiguous, -+ layout::PitchLinearShape< -+ 1, -+ // Redo the comparison here to work around divide by zero compiler -+ // error. The compiler evaluates both path of platform::conditional. -+ (Threads >= Detail::ShapeVec::kContiguous -+ ? (Detail::ShapeVec::kStrided + (kThreads / Detail::ShapeVec::kContiguous - 1)) / -+ (kThreads / Detail::ShapeVec::kContiguous) -+ : 0)>, -+ layout::PitchLinearShape>::type; -+ -+ -+ /// Interval between accesses along each dimension of the tensor's logical coordinate space -+ /// (in units of Elements) -+ using Delta = typename platform::conditional< -+ Threads >= Detail::ShapeVec::kContiguous, -+ layout::PitchLinearShape< -+ 1, -+ kThreads / Detail::ShapeVec::kContiguous -+ >, -+ layout::PitchLinearShape< -+ kThreads * kElementsPerAccess, -+ 1 -+ > -+ >::type; -+ -+ /// Shape of the tile in units of vectors -+ using StorageShape = typename platform::conditional< -+ Threads >= Detail::ShapeVec::kContiguous, -+ layout::PitchLinearShape, -+ layout::PitchLinearShape>::type; -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space -+ /// (in units of Elements) -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ return TensorCoord( -+ (thread_id % Detail::ShapeVec::kContiguous) * kElementsPerAccess, -+ thread_id / Detail::ShapeVec::kContiguous); -+ } -+}; -+ -+/// This ThreadMap is used by GEMV -+template < -+ typename Shape, -+ int Threads, -+ int ElementsPerAccess = 1 -+> -+struct PitchLinearTilePolicyStripminedThreadContiguous -+{ -+ static_assert((Shape::kContiguous % (Threads * ElementsPerAccess)) == 0, -+ "Contiguous shape must divide number of threads"); -+ -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ static int const kThreads = Threads; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using Iterations = layout::PitchLinearShape< -+ Shape::kContiguous / (kThreads * kElementsPerAccess), -+ Shape::kStrided>; -+ -+ using Delta = layout::PitchLinearShape<1, 1>; -+ -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) -+ { -+ return TensorCoord(thread_id * Iterations::kContiguous * kElementsPerAccess, 0); -+ } -+}; -+ -+template < -+ typename Shape, -+ int Threads, -+ int ElementsPerAccess = 1 -+> -+struct PitchLinearTilePolicyStripminedThreadStrided -+{ -+ static_assert((Shape::kStrided % Threads == 0), -+ "Strided shape must divide number of threads"); -+ -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ static int const kThreads = Threads; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using Iterations = layout::PitchLinearShape< -+ Shape::kContiguous / kElementsPerAccess, -+ Shape::kStrided / kThreads>; -+ -+ using Delta = layout::PitchLinearShape<1, 1>; -+ -+ using ShapeVec = Shape; -+ -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) -+ { -+ -+ return TensorCoord(0, thread_id * Iterations::kStrided); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous -+/// elements. -+/// -+/// This ThreadMap is used by tensor core kernels. -+template < -+ typename Shape_, -+ int Threads, -+ typename WarpThreadArrangement_, -+ int ElementsPerAccess = 1 -+> -+struct PitchLinearWarpRakedThreadMap { -+ -+ /// Tensor coordinate -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ /// Tile shape -+ using Shape = Shape_; -+ -+ /// Number of threads total -+ static int const kThreads = Threads; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = layout::PitchLinearShape; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ -+ /// Fixed arrangement of threads within a warp (units of threads). -+ using WarpThreadArrangement = WarpThreadArrangement_; -+ -+ /// Number of threads per warp -+ static int const kWarpSize = WarpThreadArrangement::kCount; -+ -+ /// Number of participating warps -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static_assert( -+ !(Shape::kContiguous % kElementsPerAccess), -+ "Shape must be divisible by vector length."); -+ -+ /// Compute the 'shape' of the overall tile in units of vectors -+ using ShapeInAccesses = layout::PitchLinearShape< -+ Shape::kContiguous / kElementsPerAccess, -+ Shape::kStrided -+ >; -+ -+ static_assert( -+ !(ShapeInAccesses::kContiguous % WarpThreadArrangement::kContiguous), -+ "ShapeInAccesses must be divisible by WarpThreadArrangement."); -+ -+ static_assert( -+ !(ShapeInAccesses::kStrided % WarpThreadArrangement::kStrided), -+ "ShapeInAccesses must be divisible by WarpThreadArrangement."); -+ -+ // compute number of warp-level accesses total -+ using WarpAccessIterations = layout::PitchLinearShape< -+ ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous, -+ ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided -+ >; -+ -+ // Divide it into the number of warps, first partitioning the strided dimension then the -+ // contiguous. -+ static int const kWarpsStrided = -+ (WarpAccessIterations::kStrided >= kWarpCount -+ ? kWarpCount -+ : WarpAccessIterations::kStrided); -+ -+ static int const kWarpsContiguous = -+ (kWarpCount > WarpAccessIterations::kStrided -+ ? kWarpCount / kWarpsStrided -+ : 1); -+ -+ /// Arrangement of warps within a threadblock-scoped tile -+ using WarpArrangement = layout::PitchLinearShape< -+ kWarpsContiguous, kWarpsStrided -+ >; -+ }; -+ -+ ///< Iterations along each dimension (concept: PitchLinearShape) -+ using Iterations = layout::PitchLinearShape< -+ Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, -+ Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided -+ >; -+ -+ static_assert(Iterations::kCount, -+ "Number of iterations must be non-zero"); -+ -+ ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) -+ using Delta = layout::PitchLinearShape< -+ Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, -+ Detail::WarpThreadArrangement::kStrided -+ >; -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ int warp_id = (thread_id / Detail::kWarpSize); -+ int lane_id = (thread_id % Detail::kWarpSize); -+ -+ // -+ // compute warp-level offset -+ // -+ -+ // This is the shape of the entire area covered by a warp's memory access (in units of vectors) -+ layout::PitchLinearCoord warp_footprint{ -+ Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, -+ Detail::WarpThreadArrangement::kStrided * Iterations::kStrided -+ }; -+ -+ // This is the offset of a specific warp (in units of vectors) -+ layout::PitchLinearCoord warp_offset{ -+ (warp_id % Detail::kWarpsContiguous), -+ (warp_id / Detail::kWarpsContiguous) -+ }; -+ -+ // This is the offset of a specific thread within a warp (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_warp{ -+ lane_id % Detail::WarpThreadArrangement::kContiguous, -+ lane_id / Detail::WarpThreadArrangement::kContiguous -+ }; -+ -+ // This is the offset of a thread within a threadblock tile (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = -+ warp_footprint * warp_offset + thread_offset_in_warp; -+ -+ // This is the offset of a thread within a threadblock tile (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ -+ thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, -+ thread_offset_in_threadblock_tile_vec.strided() -+ }; -+ -+ return thread_offset_in_threadblock_tile_base; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous -+/// elements. Warps are arranged based on a stride. -+/// -+/// This ThreadMap is used by tensor core kernels for NCxHWx layout. -+template < -+ typename Shape_, -+ int Threads, -+ typename WarpThreadArrangement_, -+ int ElementsPerAccess = 1 -+> -+struct PitchLinearStridedWarpRakedThreadMap { -+ -+ /// Tensor coordinate -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ /// Tile shape -+ using Shape = Shape_; -+ -+ /// Number of threads total -+ static int const kThreads = Threads; -+ -+ using WarpThreadArrangement = WarpThreadArrangement_; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ /// Base ThreadMap -+ using BaseThreadMap = PitchLinearWarpRakedThreadMap< -+ Shape, -+ kThreads, -+ WarpThreadArrangement, -+ kElementsPerAccess -+ >; -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = typename BaseThreadMap::ThreadAccessShape; -+ -+ -+ struct Detail { -+ -+ using WarpThreadArrangement = WarpThreadArrangement_; -+ -+ using WarpAccessIterations = typename BaseThreadMap::Detail::WarpAccessIterations; -+ -+ static int const kWarpSize = BaseThreadMap::Detail::kWarpSize; -+ -+ static int const kWarpCount = BaseThreadMap::Detail::kWarpCount; -+ -+ using ShapeInAccesses = typename BaseThreadMap::Detail::ShapeInAccesses; -+ -+ // Divide it into the number of warps, first partitioning the contiguous dimension then the -+ // stride. -+ static int const kWarpsContiguous = -+ (WarpAccessIterations::kContiguous >= kWarpCount -+ ? kWarpCount -+ : WarpAccessIterations::kContiguous); -+ -+ static int const kWarpsStrided = -+ (kWarpCount > WarpAccessIterations::kContiguous -+ ? kWarpCount / kWarpsContiguous -+ : 1); -+ -+ /// Arrangement of warps within a threadblock-scoped tile -+ using WarpArrangement = layout::PitchLinearShape< -+ kWarpsContiguous, kWarpsStrided -+ >; -+ -+ }; -+ -+ ///< Iterations along each dimension (concept: PitchLinearShape) -+ using Iterations = layout::PitchLinearShape< -+ Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, -+ Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided -+ >; -+ -+ static_assert(Iterations::kCount, -+ "Number of iterations must be non-zero"); -+ -+ ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) -+ using Delta = typename BaseThreadMap::Delta; -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ int warp_id = (thread_id / Detail::kWarpSize); -+ int lane_id = (thread_id % Detail::kWarpSize); -+ -+ // -+ // compute warp-level offset -+ // -+ -+ // This is the shape of the entire area covered by a warp's memory access (in units of vectors) -+ layout::PitchLinearCoord warp_footprint{ -+ Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, -+ Detail::WarpThreadArrangement::kStrided * Iterations::kStrided -+ }; -+ -+ // This is the offset of a specific warp (in units of vectors) -+ layout::PitchLinearCoord warp_offset{ -+ (warp_id % Detail::kWarpsContiguous), -+ (warp_id / Detail::kWarpsContiguous) -+ }; -+ -+ // This is the offset of a specific thread within a warp (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_warp{ -+ lane_id % Detail::WarpThreadArrangement::kContiguous, -+ lane_id / Detail::WarpThreadArrangement::kContiguous -+ }; -+ -+ // This is the offset of a thread within a threadblock tile (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = -+ warp_footprint * warp_offset + thread_offset_in_warp; -+ -+ // This is the offset of a thread within a threadblock tile (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ -+ thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, -+ thread_offset_in_threadblock_tile_vec.strided() -+ }; -+ -+ return thread_offset_in_threadblock_tile_base; -+ } -+ -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Transpose the existing ThreadMap. For example, interleaved layout is like -+/// congruous in the global memory and crosswise in the shared memory. We need -+/// to transpose the coordinates between two. -+ -+template -+struct TransposePitchLinearThreadMap { -+ /// Underlying ThreadMap -+ using ThreadMap = ThreadMap_; -+ -+ /// Tensor coordinate -+ using TensorCoord = typename ThreadMap::TensorCoord; -+ -+ /// Tile shape -+ using Shape = typename ThreadMap::Shape; -+ -+ /// Number of threads total -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = layout::PitchLinearShape; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// Fixed arrangement of threads within a warp (units of threads). -+ using WarpThreadArrangement = WarpThreadArrangement_; -+ -+ /// Number of threads per warp -+ static int const kWarpSize = WarpThreadArrangement::kCount; -+ -+ /// Number of participating warps -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static_assert(!(Shape::kContiguous % kElementsPerAccess), -+ "Shape must be divisible by vector length."); -+ -+ /// Arrangement of warps within a threadblock-scoped tile -+ using WarpArrangement = -+ layout::PitchLinearShape; -+ }; -+ -+ ///< Iterations along each dimension (concept: PitchLinearShape) -+ using Iterations = -+ layout::PitchLinearShape; -+ -+ static_assert(Iterations::kContiguous == 1, -+ "Contiguous iteration has to be one to reuse the same shared store function with those that don't need transpose"); -+ -+ static_assert(Iterations::kCount, "Number of iterations must be non-zero"); -+ -+ ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) -+ using Delta = -+ layout::PitchLinearShape; -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical -+ /// coordinate space Note this is slightly different from the one of -+ /// PitchLinearWarpRakedThreadMap. -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ int warp_id = (thread_id / Detail::kWarpSize); -+ int lane_id = (thread_id % Detail::kWarpSize); -+ -+ // -+ // compute warp-level offset -+ // -+ -+ // This is the shape of the entire area covered by a warp's memory access -+ // (in units of vectors) -+ layout::PitchLinearCoord warp_footprint{ -+ Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, -+ Detail::WarpThreadArrangement::kStrided * Iterations::kStrided}; -+ -+ // This is the offset of a specific warp (in units of vectors) -+ // Note the order of / and %. Also the 2nd operand is kStrided. -+ layout::PitchLinearCoord warp_offset{ -+ (warp_id / Detail::WarpArrangement::kStrided), -+ (warp_id % Detail::WarpArrangement::kStrided)}; -+ -+ // This is the offset of a specific thread within a warp (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_warp{ -+ lane_id % Detail::WarpThreadArrangement::kContiguous, -+ lane_id / Detail::WarpThreadArrangement::kContiguous}; -+ -+ // This is the offset of a thread within a threadblock tile (units of -+ // vectors) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = -+ warp_footprint * warp_offset + thread_offset_in_warp; -+ -+ // This is the offset of a thread within a threadblock tile (units of -+ // elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ -+ thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, -+ thread_offset_in_threadblock_tile_vec.strided()}; -+ -+ return thread_offset_in_threadblock_tile_base; -+ } -+}; -+ -+template -+struct TransposePitchLinearThreadMapSimt { -+ /// Underlying ThreadMap -+ using ThreadMap = ThreadMap_; -+ -+ /// Tensor coordinate -+ using TensorCoord = typename ThreadMap::TensorCoord; -+ -+ /// Tile shape -+ using Shape = typename ThreadMap::Shape; -+ -+ /// Number of threads total -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ static_assert(kElementsPerAccess == 1 , "Simt transpose requires elements per access to be 1"); -+ ///< Iterations along each dimension (concept: PitchLinearShape) -+ using Iterations = -+ layout::PitchLinearShape; -+ -+ static_assert(Iterations::kCount, "Number of iterations must be non-zero"); -+ -+ static_assert(Iterations::kStrided == 1, -+ "Strided iteration has to be one to reuse the same shared store function with those that don't need transpose"); -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; -+ -+ ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) -+ using Delta = -+ layout::PitchLinearShape; -+ -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical -+ /// coordinate space Note this is slightly different from the one of -+ /// PitchLinearWarpRakedThreadMap. -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ TensorCoord coord = ThreadMap::initial_offset(thread_id); -+ -+ return TensorCoord( -+ coord.strided(), -+ coord.contiguous() -+ ); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Policy defining a warp-striped arrangement. This partitions a tile into vectorized memory -+/// accesses performed by each warp then distributes warps across them. Warps are striped in the -+/// strided dimension and raked across the contiguous dimension. -+template < -+ typename Shape_, /// Overall shape to partition in units of elements -+ int Threads, /// Number of partiticipation threads -+ typename WarpThreadArrangement_, /// Describes the shape of one memory access per warp -+ int ElementsPerAccess = 1 /// Number of elements accessed by each thread per memory operation (i.e. vector size) -+> -+struct PitchLinearWarpStripedThreadMap { -+ -+ /// Tensor coordinate -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ /// Tile shape -+ using Shape = Shape_; -+ -+ /// Number of threads total -+ static int const kThreads = Threads; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = layout::PitchLinearShape; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ -+ /// Fixed arrangement of threads within a warp (units of threads). -+ using WarpThreadArrangement = WarpThreadArrangement_; -+ -+ /// Number of threads per warp -+ static int const kWarpSize = WarpThreadArrangement::kCount; -+ -+ /// Number of participating warps -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static_assert( -+ !(Shape::kContiguous % kElementsPerAccess), -+ "Shape must be divisible by vector length."); -+ -+ /// Compute the 'shape' of the overall tile in units of vectors -+ using ShapeInAccesses = layout::PitchLinearShape< -+ Shape::kContiguous / kElementsPerAccess, -+ Shape::kStrided -+ >; -+ -+ // compute number of warp-level accesses total -+ using WarpAccessIterations = layout::PitchLinearShape< -+ ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous, -+ ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided -+ >; -+ -+ // Divide it into the number of warps, first partitioning the strided dimension then the -+ // contiguous. -+ static int const kWarpsStrided = -+ (WarpAccessIterations::kStrided >= kWarpCount -+ ? kWarpCount : (kWarpCount / WarpAccessIterations::kStrided)); -+ -+ static int const kWarpsContiguous = -+ (kWarpCount > WarpAccessIterations::kStrided ? -+ WarpAccessIterations::kContiguous / kWarpsStrided : 1); -+ -+ /// Arrangement of warps within a threadblock-scoped tile -+ using WarpArrangement = layout::PitchLinearShape< -+ kWarpsContiguous, kWarpsStrided -+ >; -+ }; -+ -+ ///< Iterations along each dimension (concept: PitchLinearShape) -+ using Iterations = layout::PitchLinearShape< -+ Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, -+ Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided -+ >; -+ -+ static_assert(Iterations::kCount, -+ "Number of iterations must be non-zero"); -+ -+ ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) -+ using Delta = layout::PitchLinearShape< -+ Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, -+ Detail::WarpThreadArrangement::kStrided * Detail::WarpArrangement::kStrided -+ >; -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ int warp_id = (thread_id / Detail::kWarpSize); -+ int lane_id = (thread_id % Detail::kWarpSize); -+ -+ // -+ // compute warp-level offset -+ // -+ -+ // This is the shape of the entire area covered by a warp's memory access (in units of vectors) -+ layout::PitchLinearCoord warp_footprint{ -+ Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, -+ Detail::WarpThreadArrangement::kStrided -+ }; -+ -+ // This is the offset of a specific warp (in units of vectors) -+ layout::PitchLinearCoord warp_offset{ -+ (warp_id % Detail::kWarpsContiguous), -+ (warp_id / Detail::kWarpsContiguous) -+ }; -+ -+ // This is the offset of a specific thread within a warp (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_warp{ -+ lane_id % Detail::WarpThreadArrangement::kContiguous, -+ lane_id / Detail::WarpThreadArrangement::kContiguous -+ }; -+ -+ // This is the offset of a thread within a threadblock tile (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = -+ warp_footprint * warp_offset + thread_offset_in_warp; -+ -+ // This is the offset of a thread within a threadblock tile (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ -+ thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, -+ thread_offset_in_threadblock_tile_vec.strided() -+ }; -+ -+ return thread_offset_in_threadblock_tile_base; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Strip-mines a pitch-linear tile among a given number of threads, first along the contiguous -+/// dimension then along the strided dimension, while each thread access a 2D thread-tile. -+/// -+/// The tile must be divisible by the thread count such that all threads may execute the same -+/// number of iterations with the same delta to exhaustively cover the tile. -+/// -+/// This class satisfies the "RegularThreadMapping" concept. -+template < -+ typename Shape_, -+ int Threads, -+ typename ThreadTileShape -+> -+struct PitchLinear2DThreadTileStripminedThreadMap; -+ -+ -+template < -+ typename Shape_, -+ int Threads -+> -+struct PitchLinear2DThreadTileStripminedThreadMap >{ -+ -+ /// Tensor coordinate -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ /// Tile shape -+ using Shape = Shape_; -+ -+ /// Access Shape of each thread -+ using ThreadAccessShape = cutlass::layout::PitchLinearShape<4, 4>; -+ //using ThreadAccessShape = ThreadTileShape; -+ -+ /// Number of threads total -+ static int const kThreads = Threads; -+ -+ /// Extract length of each access from Layout -+ static int const kElementsPerAccess = ThreadAccessShape::kContiguous; -+ -+ static_assert(!(kElementsPerAccess % 4) , "kElementsPerAccess, needs to be multiple of 4 (32bits)"); -+ -+ /// Internal implementation details -+ struct Detail { -+ -+ static_assert(!(ThreadAccessShape::kContiguous % 4), "ThreadAccessShape, needs to be multiple of 4"); -+ -+ static_assert(!(Shape::kContiguous % ThreadAccessShape::kContiguous), ""); -+ -+ static_assert(!((Shape::kContiguous * Shape::kStrided) % (kThreads * ThreadAccessShape::kCount)), -+ "Shape must be divisible thread count * accesses per thread."); -+ -+ /// Shape of the tile in units of vectors -+ using ShapeVec = layout::PitchLinearShape< -+ Shape::kContiguous / ThreadAccessShape::kContiguous, -+ Shape::kStrided / ThreadAccessShape::kStrided -+ >; -+ -+ static_assert( -+ (Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) || -+ (!(kThreads % ShapeVec::kContiguous) && !(ShapeVec::kStrided % (kThreads / ShapeVec::kContiguous))), -+ "Shape must be divisible by number of iterations of each thread." -+ ); -+ }; -+ -+ /// Number of iterations by each thread -+ using Iterations = typename platform::conditional< -+ Threads >= Detail::ShapeVec::kContiguous, -+ layout::PitchLinearShape< -+ 1, -+ // Redo the comparison here to work around divide by zero compiler -+ // error. The compiler evaluates both path of platform::conditional. -+ (Threads >= Detail::ShapeVec::kContiguous -+ ? Detail::ShapeVec::kStrided / -+ (kThreads / Detail::ShapeVec::kContiguous) -+ : 0)>, -+ layout::PitchLinearShape>::type; -+ -+ /// Interval between accesses along each dimension of the tensor's logical coordinate space -+ /// (in units of Elements) -+ using Delta = typename platform::conditional< -+ Threads >= Detail::ShapeVec::kContiguous, -+ layout::PitchLinearShape< -+ Shape::kContiguous, -+ kThreads * ThreadAccessShape::kStrided / Detail::ShapeVec::kContiguous -+ >, -+ layout::PitchLinearShape< -+ kThreads * ThreadAccessShape::kContiguous, -+ 1 -+ > -+ >::type; -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space -+ /// (in units of Elements) -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ return TensorCoord( -+ (thread_id % Detail::ShapeVec::kContiguous) * ThreadAccessShape::kContiguous, -+ (thread_id / Detail::ShapeVec::kContiguous) * ThreadAccessShape::kStrided); -+ } -+}; -+ -+/// Thread Mapping a 2D threadtiled mapping as a tranposed Pitchlinear2DThreadTile mapping -+template -+struct TransposePitchLinearThreadMap2DThreadTile { -+ /// Underlying ThreadMap -+ using ThreadMap = ThreadMap_; -+ -+ /// Tensor coordinate -+ using TensorCoord = typename ThreadMap::TensorCoord; -+ -+ /// Tile shape -+ using Shape = typename ThreadMap::Shape; -+ -+ /// Number of threads total -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ -+ static_assert(kElementsPerAccess > 1 , "Simt transpose requires elements per access to be 1"); -+ ///< Iterations along each dimension (concept: PitchLinearShape) -+ using Iterations = -+ layout::PitchLinearShape; -+ -+ static_assert(Iterations::kCount, "Number of iterations must be non-zero"); -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; -+ -+ ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) -+ using Delta = -+ layout::PitchLinearShape; -+ -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical -+ /// coordinate space Note this is slightly different from the one of -+ /// PitchLinearWarpRakedThreadMap. -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ TensorCoord coord = ThreadMap::initial_offset(thread_id); -+ return TensorCoord( -+ coord.strided(), -+ coord.contiguous() -+ ); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/thread/transpose.h b/3rdparty/cutlass/include/cutlass/transform/thread/transpose.h -new file mode 100644 -index 0000000..b62b6bf ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/thread/transpose.h -@@ -0,0 +1,107 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Basic copy routines for tensor views -+*/ -+ -+#pragma once -+ -+namespace cutlass { -+namespace transform { -+namespace thread { -+ -+/// Transforms a fragment by doing a transpose -+template < -+ int ElementCount, -+ typename TransposeShape, -+ typename Element -+> struct Transpose; -+ -+/// Specialization for int8_t 4x4 transpose -+template -+struct Transpose , int8_t> { -+ -+ static const int kElementCount = ElementCount_; -+ using TransposeShape = layout::PitchLinearShape<4,4>; -+ using Element = int8_t; -+ using Fragment = cutlass::Array; -+ -+ static_assert(!(kElementCount % TransposeShape::kCount), "Shape needs to be multiple of 16 elements to do a 4x4 transpose"); -+ -+ CUTLASS_DEVICE -+ void transform(Fragment& dst, Fragment& src) { -+ -+ // Expose src/dst as int arrays. -+ int* src_int = reinterpret_cast(&src); -+ int* dst_int = reinterpret_cast(&dst); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElementCount / TransposeShape::kCount; i++){ -+ -+ int const i0 = 4 * i + 0; -+ int const i1 = 4 * i + 1; -+ int const i2 = 4 * i + 2; -+ int const i3 = 4 * i + 3; -+ -+ int a0 = src_int[i0]; -+ int a1 = src_int[i1]; -+ int a2 = src_int[i2]; -+ int a3 = src_int[i3]; -+ -+ int b0, b1, b2, b3, c0; -+ b0 = __byte_perm(a0, a1, 0x0040); -+ c0 = __byte_perm(a2, a3, 0x0040); -+ b0 = __byte_perm(b0, c0, 0x5410); -+ -+ b1 = __byte_perm(a0, a1, 0x0051); -+ c0 = __byte_perm(a2, a3, 0x0051); -+ b1 = __byte_perm(b1, c0, 0x5410); -+ -+ b2 = __byte_perm(a0, a1, 0x0062); -+ c0 = __byte_perm(a2, a3, 0x0062); -+ b2 = __byte_perm(b2, c0, 0x5410); -+ -+ b3 = __byte_perm(a0, a1, 0x0073); -+ c0 = __byte_perm(a2, a3, 0x0073); -+ b3 = __byte_perm(b3, c0, 0x5410); -+ -+ dst_int[i0] = b0; -+ dst_int[i1] = b1; -+ dst_int[i2] = b2; -+ dst_int[i3] = b3; -+ } -+ } -+}; -+ -+} // namespace thread -+} // namespace layout -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/transform/thread/unary_op.h b/3rdparty/cutlass/include/cutlass/transform/thread/unary_op.h -new file mode 100644 -index 0000000..c50e75b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/thread/unary_op.h -@@ -0,0 +1,105 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+ -+namespace cutlass { -+namespace transform { -+namespace thread { -+ -+namespace UnaryTransform { -+ struct Identity; ///< None (i.e., identity) -+ struct Conjugate; ///< Complex conjugate -+} -+ -+/// Element-wise unary operator that transforms one element of a fragment at a time -+template< -+ typename FragmentIn, ///< Input Fragment -+ typename FragmentOut,///< Output Fragment -+ typename Transform> ///< Unary transform operator -+class UnaryOp -+{ -+ public: -+ CUTLASS_DEVICE -+ static FragmentOut execute(FragmentIn &in) -+ { -+ static_assert(FragmentIn::kElements == FragmentOut::kElements, "Number of elements must match."); -+ static_assert(platform::is_same::value || -+ platform::is_same::value, -+ "Unary Operator not supported."); -+ -+ FragmentOut out; -+ if (platform::is_same::value ) -+ { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i=0; i < FragmentIn::kElements; ++i){ -+ out[i] = static_cast(in[i]); -+ } -+ } -+ else if (platform::is_same::value ) -+ { -+ for (int i=0; i < FragmentIn::kElements; ++i){ -+ out[i] = conj(static_cast(in[i])); -+ } -+ } -+ return out; -+ } -+}; -+ -+template -+class UnaryOp -+{ -+ public: -+ CUTLASS_DEVICE -+ static FragmentIn execute(FragmentIn &in) -+ { -+ static_assert(platform::is_same::value || -+ platform::is_same::value, -+ "Unary Operator not supported."); -+ -+ if (platform::is_same::value ) -+ { -+ return in; -+ } -+ else if (platform::is_same::value ) -+ { -+ for(int i=0; i < FragmentIn::kElements; ++i){ -+ in[i] = conj(in[i]); -+ } -+ } -+ return in; -+ } -+ }; -+ } -+ } -+} -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_iterator.h -new file mode 100644 -index 0000000..0578123 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_iterator.h -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Ell iterator for matrix of indices (ellColInd matrix) -+*/ -+ -+#pragma once -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+namespace ell{ -+ -+constexpr unsigned int SmemPow = 8; -+constexpr unsigned int SmemStages = 2; -+constexpr unsigned int SmemSize = 1 << SmemPow; -+constexpr unsigned int SmemMask = (SmemSize*SmemStages-1); -+ -+class SharedStorage{ -+ public: -+ Array array; -+}; -+ -+class Iterator{ -+ public: -+ using Layout = layout::PitchLinear; -+ using LongIndex = typename Layout::LongIndex; -+ -+ private: -+ const int *gmem_col_idx_; -+ int *smem_col_idx_; -+ const int block_size_; -+ const int base_idx_; -+ const int k_shape_; -+ const int ell_increment_; -+ const int array_length_; -+ int col_idx_base_; -+ int residue_; -+ int counter_; -+ -+ int pow2_; -+ int residue_shape_; -+ -+ int smem_offset_; -+ int smem_stage_; -+ int gmem_offset_; -+ -+ int lane_; -+ -+ bool is_pow2_; -+ bool is_residue_tile_; -+ -+ public: -+ CUTLASS_DEVICE -+ void load_ell_indices(){ -+ for(int i=threadIdx.x; i= 0) ? gmem_col_idx : -1; -+ } -+ gmem_offset_ += SmemSize; -+ smem_stage_ ^= 1; -+ } -+ -+ CUTLASS_DEVICE -+ Iterator( -+ SharedStorage& shared_storage_base, -+ const int* col_idx, -+ const int& block_size, -+ const int& base_idx, -+ const int k_shape, -+ const int& problem_size_k, -+ const int& ell_stride, -+ const int& thread_idx) -+ : residue_(0), -+ counter_(0), -+ smem_offset_(0), -+ smem_stage_(0), -+ gmem_offset_(0), -+ block_size_(block_size), -+ base_idx_(base_idx), -+ k_shape_(k_shape), -+ ell_increment_(ell_stride * block_size), -+ array_length_((problem_size_k + block_size_ - 1) / block_size_), -+ residue_shape_(problem_size_k % k_shape_), -+ is_residue_tile_(residue_shape_ != 0), -+ smem_col_idx_(reinterpret_cast(&shared_storage_base.array)), -+ gmem_col_idx_(const_cast(col_idx)), -+ lane_(thread_idx % 32) { -+ -+ load_ell_indices(); -+ __syncthreads(); -+ -+ is_pow2_ = ((block_size_ & (block_size_ - 1)) == 0); -+ if( is_pow2_ && k_shape <= block_size_ ) lane_ = 0; -+ -+ col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_; -+ -+ pow2_ = 0; -+ while(block_size_ >> (pow2_ + 1)) ++pow2_; -+ } -+ -+ CUTLASS_DEVICE -+ int get_blocksize(){ -+ return block_size_; -+ } -+ -+ CUTLASS_DEVICE -+ Iterator &operator++(){ -+ if(is_residue_tile_){ -+ residue_ += residue_shape_; -+ is_residue_tile_ = false; -+ } else { -+ residue_ += k_shape_; -+ } -+ -+ if(residue_ < block_size_){ -+ return *this; -+ } -+ -+ if((array_length_ > SmemSize) && (((smem_offset_ >> SmemPow) & 1) != smem_stage_)) -+ load_ell_indices(); -+ -+ if(residue_ == block_size_){ -+ ++smem_offset_; -+ counter_ += ell_increment_; -+ residue_ = 0; -+ col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_; -+ return *this; -+ } -+ -+ if(is_pow2_){ -+ smem_offset_ += residue_ >> pow2_; -+ counter_ += (residue_ >> pow2_) * ell_increment_; -+ residue_ = residue_ & ((1 << pow2_) - 1); -+ } -+ else { -+ smem_offset_ += residue_ / block_size_; -+ counter_ += (residue_ / block_size_) * ell_increment_; -+ residue_ %= block_size_; -+ } -+ -+ col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_; -+ -+ return *this; -+ } -+ -+ CUTLASS_DEVICE -+ LongIndex get_offset(const int& idx) { -+ int num_jump_tiles; -+ if(is_pow2_) -+ num_jump_tiles = (idx + residue_) >> pow2_; -+ else -+ num_jump_tiles = (idx + residue_) / block_size_; -+ -+ int tmp = __shfl_sync(0xffffffff, col_idx_base_, num_jump_tiles); -+ return tmp - num_jump_tiles * ell_increment_; -+ } -+ -+ CUTLASS_DEVICE -+ LongIndex get_offset_fast() { -+ return col_idx_base_; -+ } -+}; -+ -+} -+} -+} -+} -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h -new file mode 100644 -index 0000000..9eec17e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h -@@ -0,0 +1,1350 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaMultistage -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// EllPredicatedTileAccessIterator -+/// -+template -+class EllPredicatedTileAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. -+/// -+template -+class EllPredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static int const kPredicatesPerByte = 4; -+ static int const kPredicatesPerWord = 4 * kPredicatesPerByte; -+ -+ static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; -+ -+ /// Number of 32b words containing predicates -+ static int const kPredicateByteCount = -+ (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; -+ static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; -+ -+ static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; -+ -+ static_assert(kPredicateWordCount <= 4, "Too many predicates."); -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = Array; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend EllPredicatedTileAccessIterator; -+ -+ private: -+ /// stride of pitch-linear layout (units of Element) -+ LongIndex stride_; -+ /// amount (in byte) to increment pointer to move to next access along -+ /// strided dimension -+ LongIndex inc_strided_; -+ /// amount (in byte) to increment pointer from last access to first access -+ /// of next tile -+ LongIndex inc_next_; -+ /// amount (in byte) to increment pointer from first access of current tile -+ /// to first access of next tile -+ LongIndex inc_advance_; -+ -+ public: -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : stride_(layout.stride(0)) { -+ inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) * -+ sizeof_bits::value / 8; -+ -+ if (kAdvanceRank) { -+ // advance along strided dimension -+ inc_advance_ = -+ Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; -+ } else { -+ // advance along contiguous dimension -+ inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; -+ } -+ -+ inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * -+ ThreadMap::Delta::kStrided * LongIndex(stride_) * -+ sizeof_bits::value / 8; -+ }; -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params const ¶ms_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ /// Guard predicates -+ uint32_t predicates_[kPredicateWordCount]; -+ -+ /// Size of tensor -+ TensorCoord extent_; -+ -+ /// Initial offset for each thread -+ TensorCoord thread_offset_; -+ -+ /// Offset to the first steady-state tile -+ TensorCoord residue_offset_; -+ -+ /// Initial offset to define ELL block -+ TensorCoord ell_offset_; -+ -+ /// Used for out-of-order visitation -+ bool is_residue_tile_; -+ -+ /// Iteration along vectors implied by the thread map -+ int iteration_vector_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent, -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0u; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { -+ -+ int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int c = access_residual / kAccessesPerVector; -+ int v = access_residual % kAccessesPerVector; -+ -+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, -+ s * ThreadMap::Delta::kStrided); -+ -+ TensorCoord coord = thread_offset_ + iteration_coord; -+ -+ bool guard; -+ -+ if (is_steady_state) { -+ if (kAdvanceRank == 0) { -+ guard = (coord.strided() < extent.strided()); -+ } else { -+ guard = (coord.contiguous() < extent.contiguous()); -+ } -+ } else { -+ guard = (coord.strided() < extent.strided() && -+ coord.contiguous() < extent.contiguous()); -+ } -+ -+ int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); -+ -+ } -+ -+ } -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : params_(params), -+ pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ extent_(extent), -+ is_residue_tile_(true) { -+ -+ TensorCoord residue_extent; -+ if (kAdvanceRank) { -+ -+ typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; -+ if (!residue_size) { -+ residue_size = Shape::kStrided; -+ } -+ -+ residue_offset_ = make_Coord(0, residue_size); -+ residue_extent = make_Coord( -+ extent_.contiguous(), -+ min(threadblock_offset.strided() + residue_size, extent_.strided()) -+ ); -+ } else { -+ -+ typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous; -+ if (!residue_size) { -+ residue_size = Shape::kContiguous; -+ } -+ -+ residue_offset_ = make_Coord(residue_size, 0); -+ -+ residue_extent = make_Coord( -+ min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size), -+ extent_.strided() -+ ); -+ } -+ -+ // Per-thread offset in logical coordinates of tensor -+ ell_offset_ = ThreadMap::initial_offset(thread_id); -+ thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(thread_offset_)); -+ -+ compute_predicates_(residue_extent, false); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id) -+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += sizeof_bits::value * pointer_offset / 8; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ if (is_residue_tile_) { -+ -+ thread_offset_ += residue_offset_; -+ -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(residue_offset_)); -+ -+ compute_predicates_(extent_, true); -+ -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } else { -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } -+ is_residue_tile_ = false; -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast( -+ pointer_ + -+ iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + iteration_vector_; -+ } -+ -+ /// Returns a k_location -+ CUTLASS_HOST_DEVICE -+ int get_k() const { -+ if(kAdvanceRank){ //strided -+ return ell_offset_.strided() + iteration_strided_ * ThreadMap::Delta::kStrided; -+ }else{ -+ return ell_offset_.contiguous() + iteration_contiguous_ * ThreadMap::Delta::kContiguous + iteration_vector_ * AccessType::kElements; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { -+ if(kAdvanceRank) -+ return params_.stride_; -+ else -+ return 1; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator &operator++() { -+ -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ -+ iteration_vector_ = 0; -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ pointer_ += params_.inc_strided_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, this -+ // subtraction as well as the subsequent integer addition are both elided by -+ // the compiler. -+ pointer_ -= params_.inc_advance_; -+ -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator operator++(int) { -+ EllPredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = enable ? 0u : predicates_[i]; -+ } -+ -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0xffffffff; -+ } -+ -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = mask[i]; -+ } -+ -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ mask[i] = predicates_[i]; -+ } -+ } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_DEVICE -+ void ell_add_mask(int blocksize) { -+ -+ Mask mask; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ mask[i] = 0u; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { -+ -+ int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int c = access_residual / kAccessesPerVector; -+ int v = access_residual % kAccessesPerVector; -+ -+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, -+ s * ThreadMap::Delta::kStrided); -+ -+ TensorCoord coord = ell_offset_ + iteration_coord; -+ -+ bool guard; -+ -+ if (kAdvanceRank == 0) { -+ guard = (coord.strided() < blocksize); -+ } else { -+ guard = (coord.contiguous() < blocksize); -+ } -+ -+ int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ mask[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); -+ -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ mask[i] &= predicates_[i]; -+ } -+ set_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ int pred_idx = -+ iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; -+ return pred; -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class EllPredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend EllPredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), -+ threadblock_offset.column())) {} -+ -+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_k() const { -+ return iterator_.get_k(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { -+ return iterator_.get_stride(); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator operator++(int) { -+ EllPredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_DEVICE -+ void ell_add_mask(int blocksize) { -+ iterator_.ell_add_mask(blocksize); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class EllPredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend EllPredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_k() const { -+ return iterator_.get_k(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { -+ return iterator_.get_stride(); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator operator++(int) { -+ EllPredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_DEVICE -+ void ell_add_mask(int blocksize) { -+ iterator_.ell_add_mask(blocksize); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileAccessIterator for column-major interleaved data. -+/// It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+ -+template -+class EllPredicatedTileAccessIterator, -+ AdvanceRank, ThreadMap_, AccessType_> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileAccessIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend EllPredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row() * kInterleavedK, -+ extent.column() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row() * kInterleavedK, -+ threadblock_offset.column() / kInterleavedK)) {} -+ -+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_k() const { -+ return iterator_.get_k(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { -+ return iterator_.get_stride(); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator operator++(int) { -+ EllPredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_DEVICE -+ void ell_add_mask(int blocksize) { -+ iterator_.ell_add_mask(blocksize); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { return iterator_.valid(); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileAccessIterator for row-major interleaved data. -+/// It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class EllPredicatedTileAccessIterator, -+ AdvanceRank, ThreadMap_, AccessType_> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::RowMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileAccessIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, -+ AccessType>; -+ -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend EllPredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column() * kInterleavedK, -+ extent.row() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column() * kInterleavedK, -+ threadblock_offset.row() / kInterleavedK)) {} -+ -+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_k() const { -+ return iterator_.get_k(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { -+ return iterator_.get_stride(); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator operator++(int) { -+ EllPredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_DEVICE -+ void ell_add_mask(int blocksize) { -+ iterator_.ell_add_mask(blocksize); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { return iterator_.valid(); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h -new file mode 100644 -index 0000000..f984733 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h -@@ -0,0 +1,1315 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaPipelined -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/memory.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -+ -+#include "cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h" -+#include "cutlass/transform/threadblock/ell_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// EllPredicatedTileIterator -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+/// Regular tile iterator using a precomputed control structure to minimize register liveness -+/// and integer arithmetic. -+/// -+/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. -+/// -+/// Base pointer and tensor extents may be specified at the time the iterator is constructed. -+/// Subsequently, they are assumed to be immutable. -+/// -+/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. -+/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. -+/// -+/// Visitation order is intended to first visit a "residual" tile that may be partially full in -+/// both the advance dimension and the steady-state dimension. This is assumed to be the last -+/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to -+/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent -+/// accesses may be performed without updating internal predicates and are efficient in terms of -+/// live register state and pointer arithmetic instructions. -+/// -+/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once -+/// outside any looping structure to minimize integer arithmetic. -+/// -+/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing -+/// the iterator. -+/// -+/// -+/// Example: -+/// -+/// An efficient pipeline structure may be constructed as follows: -+/// -+// template -+// __global__ void kernel( -+// typename Iterator::Params params, -+// typename Iterator::Element *ptr, -+// TensorCoord extent) { -+// -+// typename Iterator::Fragment fragment; -+// -+// TensorCoord threadblock_offset(0, 0); -+// -+// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -+// -+// -+// fragment = *iter; // load "residue" tile first -+// ++iter; // advance to first "steady state" tile and update internal masks -+// -+// -+// #pragma unroll -+// for (int i = Remaining - 1; i >= 0; --i) { -+// -+// f(fragment); -+// -+// if (!i) { -+// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. -+// } -+// -+// fragment = *iter; // load tile during "steady state" phase -+// ++iter; // advance to next tile - lightweight due to steady-state masks -+// } -+// } -+// -+// void host(TensorView view) { -+// -+// using Iterator = transform::threadblock::EllPredicatedTileIterator; -+// -+// typename Iterator::Params params(view.layout()); -+// -+// kernel(params, view.data()); -+// } -+/// -+/// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ int AccessSize = ThreadMap::kElementsPerAccess -+> -+class EllPredicatedTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class EllPredicatedTileIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ /// Type used for internal memory accesses -+ using AccessType = AlignedArray::value / 8)>; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = -+ EllPredicatedTileAccessIterator; -+ -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Iterator for ELL storage -+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend EllPredicatedTileIterator; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : params_(layout) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : address_iterator_(params.params_, pointer, extent, thread_id, -+ threadblock_offset) {} -+ -+ /// Construct a EllPredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator &operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset({0, 1}); -+ else -+ address_iterator_.add_tile_offset({1, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator operator++(int) { -+ EllPredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns a stride -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { return address_iterator_.get_stride(); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { address_iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_HOST_DEVICE -+ void ell_add_mask(int blocksize) { address_iterator_.ell_add_mask(blocksize); } -+ -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, address_iterator_.valid()); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_byte_offset(frag, 0); } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index(Fragment &frag, EllIterator &ell_iter) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ address_iterator_.set_iteration_index(idx); -+ LongIndex ell_offset = 0; -+ -+ int k_offset = address_iterator_.get_k(); -+ ell_offset = ell_iter.get_offset(k_offset) * sizeof(Element); -+ -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + ell_offset; -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ bool is_valid = address_iterator_.valid(); -+ is_valid = is_valid && (ell_offset >= 0); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, is_valid); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index_fast(Fragment &frag, EllIterator &ell_iter) { -+ -+ LongIndex ell_offset = ell_iter.get_offset_fast() * sizeof(Element); -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + ell_offset; -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ bool is_valid = address_iterator_.valid(); -+ is_valid = is_valid && (ell_offset >= 0); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, is_valid); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType *access_ptr = reinterpret_cast(byte_ptr); -+ -+ if (address_iterator_.valid()) { -+ *access_ptr = frag_ptr[idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize -+> -+class EllPredicatedTileIterator { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessSize -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Iterator for ELL storage -+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend EllPredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { -+ -+ } -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset ///< Initial offset of threadblock -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) -+ ) { } -+ -+ /// Construct a EllPredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator operator++(int) { -+ EllPredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns a stride -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { return iterator_.get_stride(); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_HOST_DEVICE -+ void ell_add_mask(int blocksize) { -+ iterator_.ell_add_mask(blocksize); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { -+ iterator_.load_with_ell_index(frag, ell_iter); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { -+ iterator_.load_with_ell_index_fast(frag, ell_iter); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize -+> -+class EllPredicatedTileIterator { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessSize -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Iterator for ELL storage -+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend EllPredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { -+ -+ }; -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset ///< Initial offset of threadblock -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) -+ ) { } -+ -+ /// Construct a EllPredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator operator++(int) { -+ EllPredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns a stride -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { return iterator_.get_stride(); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_HOST_DEVICE -+ void ell_add_mask(int blocksize) { -+ iterator_.ell_add_mask(blocksize); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { -+ iterator_.load_with_ell_index(frag, ell_iter); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { -+ iterator_.load_with_ell_index_fast(frag, ell_iter); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileIterator for interleaved data. It is mapped -+/// to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+ -+template -+class EllPredicatedTileIterator, -+ AdvanceRank, ThreadMap_, AccessSize> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>; -+ -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Iterator for ELL storage -+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend EllPredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row() * kInterleavedK, -+ extent.column() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row() * kInterleavedK, -+ threadblock_offset.column() / kInterleavedK)) {} -+ -+ /// Construct a EllPredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator operator++(int) { -+ EllPredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns a stride -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { return iterator_.get_stride(); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_HOST_DEVICE -+ void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { -+ iterator_.load_with_ell_index(frag, ell_iter); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { -+ iterator_.load_with_ell_index_fast(frag, ell_iter); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileIterator for interleaved-32 data. It is -+/// mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class EllPredicatedTileIterator, -+ AdvanceRank, ThreadMap_, AccessSize> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::RowMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>; -+ -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend EllPredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column() * kInterleavedK, -+ extent.row() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column() * kInterleavedK, -+ threadblock_offset.row() / kInterleavedK)) {} -+ -+ /// Construct a EllPredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator operator++(int) { -+ EllPredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns a stride -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { return iterator_.get_stride(); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_HOST_DEVICE -+ void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h -new file mode 100644 -index 0000000..61bed18 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h -@@ -0,0 +1,375 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of scale and bias vectors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. -+ -+ It can be used to load the gamma and beta vectors of layernorm which is loop variant. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedScaleBiasVectorAccessIterator -+/// -+template -+class PredicatedScaleBiasVectorAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data. -+/// -+template -+class PredicatedScaleBiasVectorAccessIterator { -+ public: -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kElementsPerAccess = 128 / sizeof_bits::value; -+ static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess; -+ -+ using AccessType = AlignedArray; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ TensorCoord thread_offset_; -+ -+ int problem_size_k_; -+ -+ /// Used for out-of-order visitation -+ bool is_residue_tile_; -+ -+ bool guard_; -+ -+ TensorCoord::Index residue_size_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ /// Extent of tensor -+ int problem_size_k, -+ /// Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) { -+ pointer_ = (thread_id < kThreads) -+ ? reinterpret_cast( -+ const_cast(scale_pointer)) -+ : reinterpret_cast( -+ const_cast(bias_pointer)); -+ -+ // Per-thread offset in logical coordinates of tensor -+ int thread_base = (thread_id < kThreads) ? 0 : kThreads; -+ -+ problem_size_k_ = problem_size_k; -+ -+ is_residue_tile_ = true; -+ -+ residue_size_ = (problem_size_k_ - threadblock_offset.contiguous()) % ThreadblockShape::kContiguous; -+ -+ if (residue_size_ == 0) { -+ residue_size_ = ThreadblockShape::kContiguous; -+ } -+ -+ guard_ = ((thread_id - thread_base) * kElementsPerAccess) < residue_size_; -+ -+ thread_offset_ = -+ threadblock_offset + -+ TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ /// Extent of tensor -+ int problem_size_k, -+ /// Pointer to start of scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to start of scale vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedScaleBiasVectorAccessIterator(problem_size_k, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) {} -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ -+ guard_ = threadIdx.x < kThreads * 2; -+ -+ TensorCoord offset = is_residue_tile_ ? -+ TensorCoord(residue_size_ + ThreadblockShape::kContiguous * (tile_offset.contiguous() - 1), 0) -+ : TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0); -+ -+ thread_offset_ = -+ thread_offset_ + -+ offset; -+ -+ is_residue_tile_ = false; -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ return reinterpret_cast( -+ pointer_ + -+ (thread_offset_.contiguous() * sizeof_bits::value / 8)); -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator &operator++() { -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_DEVICE -+ PredicatedScaleBiasVectorAccessIterator operator++(int) { -+ PredicatedScaleBiasVectorAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ guard_ &= (!enable); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return guard_; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedScaleBiasVectorAccessIterator { -+ public: -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ ///< Extent of tensor -+ int problem_size_k, -+ ///< Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ ///< Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(problem_size_k, scale_pointer, bias_pointer, -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ int problem_size_k, ///< Extent of tensor -+ ConstPointer scale_pointer, ///< Pointer to the start of the scale vector -+ ConstPointer bias_pointer, ///< Pointer to the start of the bias vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedScaleBiasVectorAccessIterator(problem_size_k, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// threadblock tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator operator++(int) { -+ PredicatedScaleBiasVectorAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h -new file mode 100644 -index 0000000..fb08930 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h -@@ -0,0 +1,328 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of scale and bias vectors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. -+ -+ This can be used to load var and mean vectors in layernorm which is loop invariant. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedScaleBiasVectorIterator -+/// -+template -+class PredicatedScaleBiasVectorIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for wgrad pitch-linear data. -+/// -+template -+class PredicatedScaleBiasVectorIterator { -+ public: -+ -+ using WarpShape = WarpShape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kElementsPerAccess = 1; -+ -+ using AccessType = AlignedArray; -+ -+ static int const kIterations = WarpShape::kContiguous / 8; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to first access of tile -+ ConstPointer scale_pointer_; -+ ConstPointer bias_pointer_; -+ -+ /// Size of tensor -+ int problem_size_; -+ -+ int32_t thread_offset_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ /// Extent of tensor -+ int problem_size, -+ /// Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : problem_size_(problem_size), -+ scale_pointer_(scale_pointer), -+ bias_pointer_(bias_pointer) { -+ -+ thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4; -+ } -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ /// Extent of tensor -+ int problem_size, -+ /// Pointer to start of scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to start of scale vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedScaleBiasVectorIterator(problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole warp tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ -+ thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous()); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ frag.fill(__float2half2_rn(0.0f)); -+ __half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag); -+ -+ // load scale -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ -+ cutlass::arch::global_load< -+ __half, -+ sizeof(AccessType) -+ >( -+ frag_ptr[c * 2].x, -+ scale_pointer_ + thread_offset_ + c * 8, -+ (thread_offset_ + c * 8) < problem_size_ -+ ); -+ } -+ -+ // load bias -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ -+ cutlass::arch::global_load< -+ __half, -+ sizeof(AccessType) -+ >( -+ frag_ptr[c * 2 + 1].x, -+ bias_pointer_ + thread_offset_ + c * 8, -+ (thread_offset_ + c * 8) < problem_size_ -+ ); -+ } -+ -+ // duplicate scale -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ frag_ptr[c * 2].y = frag_ptr[c * 2].x; -+ } -+ -+ // duplicate bias -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x; -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedScaleBiasVectorIterator { -+ public: -+ -+ using WarpShape = WarpShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedScaleBiasVectorIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; -+ using Fragment = typename UnderlyingIterator::Fragment; -+ -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ ///< Extent of tensor -+ int problem_size, -+ ///< Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ ///< Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(problem_size, scale_pointer, bias_pointer, -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ int problem_size, ///< Extent of tensor -+ ConstPointer scale_pointer, ///< Pointer to the start of the scale vector -+ ConstPointer bias_pointer, ///< Pointer to the start of the bias vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedScaleBiasVectorIterator(problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// threadblock tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ iterator_.load(frag); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h -new file mode 100644 -index 0000000..29fa8af ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h -@@ -0,0 +1,2085 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of tiles -+ from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. The first tile this -+ iterator visits maybe partial, then the remaining tiles are complete. So, we -+ only need to compute the predicates twice, once before the first tile and -+ once for the remaining full tiles which can share the same predicates. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileAccessIteratorPredicates -+/// -+template -+class PredicatedTileAccessIteratorPredicates { -+ public: -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static int const kPredicatesPerByte = 4; -+ static int const kPredicatesPerWord = 4 * kPredicatesPerByte; -+ -+ static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; -+ -+ /// Number of 32b words containing predicates -+ static int const kPredicateByteCount = -+ (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; -+ static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; -+ -+ static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; -+ -+ static_assert(kPredicateWordCount <= 4, "Too many predicates."); -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = Array; -+ -+// private: -+ /// Guard predicates -+ uint32_t predicates_[kPredicateWordCount]; -+ -+ /// Size of tensor -+ TensorCoord extent_; -+ -+ /// Initial offset for each thread -+ TensorCoord thread_offset_; -+ -+ /// Offset to the first steady-state tile -+ TensorCoord residue_offset_; -+ -+ /// Iteration along vectors implied by the thread map -+ int iteration_vector_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent, -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0u; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { -+ -+ int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int c = access_residual / kAccessesPerVector; -+ int v = access_residual % kAccessesPerVector; -+ -+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, -+ s * ThreadMap::Delta::kStrided); -+ -+ TensorCoord coord = thread_offset_ + iteration_coord; -+ -+ bool guard; -+ -+ if (is_steady_state) { -+ if (kAdvanceRank == 0) { -+ guard = (coord.strided() < extent.strided()); -+ } else { -+ guard = (coord.contiguous() < extent.contiguous()); -+ } -+ } else { -+ guard = (coord.strided() < extent.strided() && -+ coord.contiguous() < extent.contiguous()); -+ } -+ -+ int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); -+ -+ } -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_predicates(int thread_id, TensorCoord const &threadblock_offset) { -+ -+ TensorCoord residue_extent; -+ if (kAdvanceRank) { -+ -+ typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; -+ if (!residue_size) { -+ residue_size = Shape::kStrided; -+ } -+ -+ residue_offset_ = make_Coord(0, residue_size); -+ residue_extent = make_Coord( -+ extent_.contiguous(), -+ min(threadblock_offset.strided() + residue_size, extent_.strided()) -+ ); -+ } else { -+ -+ typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous; -+ if (!residue_size) { -+ residue_size = Shape::kContiguous; -+ } -+ -+ residue_offset_ = make_Coord(residue_size, 0); -+ -+ residue_extent = make_Coord( -+ min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size), -+ extent_.strided() -+ ); -+ } -+ -+ // Per-thread offset in logical coordinates of tensor -+ thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); -+ -+ compute_predicates_(residue_extent, false); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Default constructor -+ PredicatedTileAccessIteratorPredicates() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorPredicates( -+ /// Extent of tensor -+ TensorCoord extent) -+ : extent_(extent) { -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorPredicates &operator++() { -+ -+ return *this; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = enable ? 0u : predicates_[i]; -+ } -+ -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0xffffffff; -+ } -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = mask[i]; -+ } -+ -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ mask[i] = predicates_[i]; -+ } -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ -+ int pred_idx = -+ iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; -+ return pred; -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileAccessIterator -+/// -+template -+class PredicatedTileAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for pitch-linear data. -+/// -+template -+class PredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< -+ Shape, Element, Layout, AdvanceRank, ThreadMap, AccessType>; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ using Mask = typename UnderlyingPredicates::Mask; -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileAccessIteratorParams { -+ -+ using Base = PredicatedTileAccessIteratorParams; -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : -+ Base(layout.stride(0), -+ MakePredicatedTileAccessIteratorDesc()() -+ ) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) : -+ Base(base) { } -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ UnderlyingPredicates the_predicates; -+ -+ /// Parameters object with precomputed internal state -+ Params params_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ /// Used for out-of-order visitation -+ bool is_residue_tile_; -+ -+ /// Below is used when Gather is turned on. We need to record strided_offset -+ /// and contiguous_offset seperated to compute the offset by using -+ /// -+ /// offset = contiguous_offset + indices[strided_offset] -+ /// -+ -+ /// Gather indices -+ int const *indices_; -+ -+ Index gather_offset_strided; -+ -+ private: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent, -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ the_predicates.compute_predicates_(extent, is_steady_state); -+ } -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ /// Gather indices -+ int const *indices = nullptr) -+ : params_(params), -+ pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ the_predicates(extent), -+ is_residue_tile_(true), -+ indices_(indices) { -+ -+ the_predicates.set_predicates(thread_id, threadblock_offset); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ -+ if (!Gather) { -+ add_pointer_offset(layout(the_predicates.thread_offset_)); -+ } else { -+ gather_offset_strided = the_predicates.thread_offset_.strided(); -+ add_pointer_offset(layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); -+ } -+ } -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ the_predicates.set_iteration_index(index); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += sizeof_bits::value * pointer_offset / 8; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ if (is_residue_tile_) { -+ -+ the_predicates.thread_offset_ += the_predicates.residue_offset_; -+ -+ the_predicates.compute_predicates_(the_predicates.extent_, true); -+ -+ Layout layout(params_.stride_); -+ -+ if (!Gather) { -+ add_pointer_offset(layout(the_predicates.residue_offset_)); -+ -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } else { -+ gather_offset_strided = the_predicates.thread_offset_.strided(); -+ add_pointer_offset(layout(make_Coord(the_predicates.residue_offset_.contiguous(), 0))); -+ -+ if (kAdvanceRank) { -+ gather_offset_strided += Shape::kStrided * (tile_offset.strided() - 1); -+ add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); -+ } else { -+ add_pointer_offset(Shape::kContiguous * (tile_offset.contiguous() - 1)); -+ gather_offset_strided += Shape::kStrided * tile_offset.strided(); -+ } -+ } -+ } else { -+ if (!Gather) { -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } else { -+ add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); -+ gather_offset_strided += Shape::kStrided * tile_offset.strided(); -+ } -+ } -+ -+ is_residue_tile_ = false; -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ if (Gather) { -+ assert(indices_); -+ -+ if (!valid()) { -+ return nullptr; -+ } -+ -+ LongIndex contiguous_offset = the_predicates.iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + the_predicates.iteration_vector_; -+ int strided_index = gather_offset_strided + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; -+ -+ LongIndex strided_offset = indices_[strided_index] * LongIndex(params_.stride_) * sizeof_bits::value / 8; -+ -+ return reinterpret_cast(pointer_ + contiguous_offset + strided_offset); -+ } -+ -+ return reinterpret_cast( -+ pointer_ + -+ the_predicates.iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + the_predicates.iteration_vector_; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ -+ the_predicates.operator++(); -+ -+ ++the_predicates.iteration_vector_; -+ if (the_predicates.iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ -+ the_predicates.iteration_vector_ = 0; -+ ++the_predicates.iteration_contiguous_; -+ -+ if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ the_predicates.iteration_contiguous_ = 0; -+ ++the_predicates.iteration_strided_; -+ -+ if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ if (!Gather) { -+ pointer_ += params_.inc_strided_; -+ } -+ -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ the_predicates.iteration_strided_ = 0; -+ -+ if (!Gather) { -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, this -+ // subtraction as well as the subsequent integer addition are both elided by -+ // the compiler. -+ pointer_ -= params_.inc_advance_; -+ } -+ -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ the_predicates.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ the_predicates.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ the_predicates.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ the_predicates.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ return the_predicates.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType, Gather>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))){}; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), -+ threadblock_offset.column()), -+ indices) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType, Gather>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))){}; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ /// Gather indices -+ int const *indices = nullptr) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row()), -+ indices) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for affine rank 2 data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator, -+ AdvanceRank, ThreadMap_, AccessType_, false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRankN<2>; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< -+ Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap, AccessType>; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingPredicates::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend PredicatedTileAccessIterator; -+ -+ private: -+ /// stride of pitch-linear layout (units of Element) -+ Coord stride_; -+ /// amount (in byte) to increment pointer to move to next access along -+ /// contiguous dimension -+ LongIndex inc_contiguous_; -+ /// amount (in byte) to increment pointer from first access of current -+ /// contiguous dimension to first access of next one. -+ LongIndex inc_strided_; -+ /// amount (in byte) to increment pointer from last access of current -+ /// contiguous dimension to first access of next one. -+ LongIndex inc_next_strided_; -+ /// amount (in byte) to increment pointer from last access to first access -+ /// of next tile -+ LongIndex inc_next_; -+ /// amount (in byte) to increment pointer from first access of current tile -+ /// to first access of next tile -+ LongIndex inc_advance_; -+ -+ public: -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ Params(): stride_(0), inc_contiguous_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : stride_({layout.stride(0), layout.stride(1)}) { -+ inc_contiguous_ = (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * -+ sizeof_bits::value / 8; -+ -+ inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * -+ sizeof_bits::value / 8; -+ -+ inc_next_strided_ = inc_strided_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; -+ -+ if (kAdvanceRank) { -+ // advance along strided dimension -+ inc_advance_ = -+ Shape::kStrided * LongIndex(stride_[1]) * sizeof_bits::value / 8; -+ } else { -+ // advance along contiguous dimension -+ inc_advance_ = Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; -+ } -+ -+ inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; -+ }; -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params params_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ UnderlyingPredicates the_predicates; -+ -+ /// Used for out-of-order visitation -+ bool is_residue_tile_; -+ -+ private: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent, -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ the_predicates.compute_predicates_(extent, is_steady_state); -+ } -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : params_(params), -+ pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ the_predicates(extent), -+ is_residue_tile_(true) { -+ -+ the_predicates.set_predicates(thread_id, threadblock_offset); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(the_predicates.thread_offset_)); -+ } -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { the_predicates.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += sizeof_bits::value * pointer_offset / 8; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ if (is_residue_tile_) { -+ -+ the_predicates.thread_offset_ += the_predicates.residue_offset_; -+ -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(the_predicates.residue_offset_)); -+ -+ the_predicates.compute_predicates_(the_predicates.extent_, true); -+ -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1] - 1); -+ pointer_ += Shape::kContiguous * tile_offset[0]; -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0] - 1); -+ pointer_ += Shape::kStrided * tile_offset[1]; -+ } -+ } else { -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); -+ pointer_ += Shape::kContiguous * tile_offset[0]; -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); -+ pointer_ += Shape::kStrided * tile_offset[1]; -+ } -+ } -+ is_residue_tile_ = false; -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(pointer_) + the_predicates.iteration_vector_; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ the_predicates.operator++(); -+ ++the_predicates.iteration_vector_; -+ if (the_predicates.iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ -+ the_predicates.iteration_vector_ = 0; -+ ++the_predicates.iteration_contiguous_; -+ -+ if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ pointer_ += params_.inc_contiguous_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ the_predicates.iteration_contiguous_ = 0; -+ ++the_predicates.iteration_strided_; -+ -+ if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ pointer_ += params_.inc_next_strided_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ the_predicates.iteration_strided_ = 0; -+ -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, this -+ // subtraction as well as the subsequent integer addition are both elided by -+ // the compiler. -+ pointer_ -= params_.inc_advance_; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { the_predicates.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { the_predicates.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { the_predicates.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return the_predicates.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for affine rank 2 column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::AffineRankN<2>, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), -+ threadblock_offset.column())) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset(make_Coord(tile_offset.row(), tile_offset.column())); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for affine rank-2 row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::AffineRankN<2>, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset(make_Coord(tile_offset.column(), tile_offset.row())); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for column-major interleaved data. -+/// It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+ -+template -+class PredicatedTileAccessIterator, -+ AdvanceRank, ThreadMap_, AccessType_, false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row() * kInterleavedK, -+ extent.column() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row() * kInterleavedK, -+ threadblock_offset.column() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { return iterator_.valid(); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for row-major interleaved data. -+// It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator, -+ AdvanceRank, ThreadMap_, AccessType_, false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::RowMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, -+ AccessType>; -+ -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column() * kInterleavedK, -+ extent.row() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column() * kInterleavedK, -+ threadblock_offset.row() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { return iterator_.valid(); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h -new file mode 100644 -index 0000000..1ce5e39 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h -@@ -0,0 +1,834 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of tiles -+ from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last -+ "residue" tile first, with the objective of minimizing predicate mask updates -+ during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileAccessIterator2dThreadTile -+/// -+template -+class PredicatedTileAccessIterator2dThreadTile; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. -+/// -+template -+class PredicatedTileAccessIterator2dThreadTile { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kPredicatesPerByte = 4; -+ static int const kPredicatesPerWord = 4 * kPredicatesPerByte; -+ -+ /// Number of 32b words containing predicates -+ static int const kPredicateByteCount = (ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kStrided + kPredicatesPerByte - 1) / kPredicatesPerByte; -+ static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; -+ -+ static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; -+ -+ static_assert(kPredicateWordCount <= 4, "Too many predicates."); -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = Array; -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileAccessIteratorParams { -+ -+ public: -+ friend PredicatedTileAccessIterator2dThreadTile; -+ -+ using Base = PredicatedTileAccessIteratorParams; -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : -+ Base(layout.stride(0), -+ MakePredicatedTileAccessIteratorDesc()() -+ ) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) : -+ Base(base) { } -+ }; -+ -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params const ¶ms_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ /// Guard predicates -+ uint32_t predicates_[kPredicateWordCount]; -+ -+ /// Size of tensor -+ TensorCoord extent_; -+ -+ /// Initial offset for each thread -+ TensorCoord thread_offset_; -+ -+ /// Index of residue tile -+ int residue_tile_idx_; -+ -+ /// Used for out-of-order visitation -+ bool is_residue_tile_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ /// Tracks iterations within the thread loop -+ int iteration_thread_; -+ -+ private: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_HOST_DEVICE -+ void compute_predicates_( -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0u; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++) { -+ -+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous, -+ ts + s * ThreadMap::Delta::kStrided); -+ -+ TensorCoord coord = thread_offset_ + iteration_coord; -+ -+ bool guard; -+ -+ if (is_steady_state) { -+ if (kAdvanceRank == 0) { -+ guard = (coord.strided() < extent_.strided()); -+ } else { -+ guard = (coord.contiguous() < extent_.contiguous()); -+ } -+ } else { -+ guard = (coord.strided() < extent_.strided() && -+ coord.contiguous() < extent_.contiguous()); -+ } -+ -+ int pred_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); -+ -+ } -+ } -+ } -+ -+ } -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : params_(params), -+ pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ extent_(extent), -+ is_residue_tile_(true) { -+ -+ -+ TensorCoord residue_offset; -+ if (kAdvanceRank) { -+ residue_tile_idx_ = -+ (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) / -+ Shape::kStrided; -+ residue_offset = make_Coord(0, residue_tile_idx_ * Shape::kStrided); -+ } else { -+ residue_tile_idx_ = -+ (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) / -+ Shape::kContiguous; -+ residue_offset = make_Coord(residue_tile_idx_ * Shape::kContiguous, 0); -+ } -+ -+ // Per-thread offset in logical coordinates of tensor -+ thread_offset_ = threadblock_offset + residue_offset + -+ ThreadMap::initial_offset(thread_id); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(thread_offset_)); -+ -+ compute_predicates_(false); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ int residual = index % (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided); -+ iteration_strided_ = index / (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided); -+ -+ iteration_contiguous_ = residual / ThreadMap::ThreadAccessShape::kStrided; -+ iteration_thread_ = residual % ThreadMap::ThreadAccessShape::kStrided; -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += int(sizeof(Element)) * pointer_offset; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ if (is_residue_tile_) { -+ TensorCoord residue_offset; -+ if (kAdvanceRank) { -+ residue_offset = TensorCoord(0, residue_tile_idx_ * Shape::kStrided); -+ } else { -+ residue_offset = TensorCoord(residue_tile_idx_ * Shape::kContiguous, 0); -+ } -+ -+ thread_offset_ -= residue_offset; -+ -+ Layout layout(params_.stride_); -+ add_pointer_offset(-layout(residue_offset)); -+ -+ compute_predicates_(true); -+ -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } else { -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * tile_offset.strided(); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * tile_offset.contiguous(); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } -+ is_residue_tile_ = false; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ AccessType *ret_val = reinterpret_cast( -+ pointer_ + (iteration_thread_ * params_.stride_ + iteration_contiguous_ * ThreadMap::Delta::kContiguous) * int(sizeof(Element))); -+ -+ return ret_val; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile &operator++() { -+ -+ iteration_thread_++; -+ -+ if (iteration_thread_ < ThreadMap::ThreadAccessShape::kStrided) -+ return *this; -+ -+ iteration_thread_ = 0; -+ -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ pointer_ += params_.inc_strided_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, this -+ // subtraction as well as the subsequent integer addition are both elided by -+ // the compiler. -+ pointer_ -= params_.inc_advance_; -+ -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile operator++(int) { -+ PredicatedTileAccessIterator2dThreadTile self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = enable ? 0u : predicates_[i]; -+ } -+ -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0xffffffff; -+ } -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = mask[i]; -+ } -+ -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ mask[i] = predicates_[i]; -+ } -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ int pred_idx = -+ iteration_thread_ + -+ iteration_contiguous_ * ThreadMap::ThreadAccessShape::kStrided + -+ iteration_strided_ * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; -+ -+ return pred; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator2dThreadTile { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator2dThreadTile; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))){} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), -+ threadblock_offset.column())) {} -+ -+ /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile operator++(int) { -+ PredicatedTileAccessIterator2dThreadTile self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator2dThreadTile { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator2dThreadTile; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))){} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile operator++(int) { -+ PredicatedTileAccessIterator2dThreadTile self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h -new file mode 100755 -index 0000000..cbabc4e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Predicated tile access iterator descriptor object containing template dependent state -+struct PredicatedTileAccessIteratorDesc { -+ -+ int element_size_bits; -+ int advance_rank; -+ layout::PitchLinearCoord threadblock_shape; -+ layout::PitchLinearCoord threadmap_iterations; -+ layout::PitchLinearCoord threadmap_delta; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc() { } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc( -+ int element_size_bits_, -+ int advance_rank_, -+ layout::PitchLinearCoord threadblock_shape_, -+ layout::PitchLinearCoord threadmap_iterations_, -+ layout::PitchLinearCoord threadmap_delta_ -+ ): -+ element_size_bits(element_size_bits_), -+ advance_rank(advance_rank_), -+ threadblock_shape(threadblock_shape_), -+ threadmap_iterations(threadmap_iterations_), -+ threadmap_delta(threadmap_delta_) -+ { -+ #if 0 -+ printf("PredicatedTileAccessIteratorDesc(%d, %d, {%d, %d}, {%d, %d}, {%d, %d}})\n", -+ element_size_bits, -+ advance_rank, -+ threadblock_shape.contiguous(), threadblock_shape.strided(), -+ threadmap_iterations.contiguous(), threadmap_iterations.strided(), -+ threadmap_delta.contiguous(), threadmap_delta.strided()); -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Helper template to construct an PredicatedTileAccessIteratorDesc from a template -+// dependent state -+template < -+ typename Shape, typename Element, typename Layout, -+ int AdvanceRank, typename ThreadMap> -+ struct MakePredicatedTileAccessIteratorDesc; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for pitch-linear data. -+template < -+ typename Shape, typename Element, int AdvanceRank, -+ typename ThreadMap> -+struct MakePredicatedTileAccessIteratorDesc < -+ Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap> { -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc operator()() { -+ -+ return PredicatedTileAccessIteratorDesc( -+ sizeof_bits::value, -+ AdvanceRank, -+ {Shape::kContiguous, Shape::kStrided}, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ); -+} -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for column-major data. -+template < -+ typename Shape, typename Element, int AdvanceRank, -+ typename ThreadMap> -+struct MakePredicatedTileAccessIteratorDesc < -+ Shape, Element, layout::ColumnMajor, AdvanceRank, ThreadMap> { -+ -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>; -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc operator()() { -+ -+ return UnderlyingMakeOperator()(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for row-major data. -+template < -+ typename Shape, typename Element, int AdvanceRank, -+ typename ThreadMap> -+struct MakePredicatedTileAccessIteratorDesc < -+ Shape, Element, layout::RowMajor, AdvanceRank, ThreadMap> { -+ -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>; -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc operator()() { -+ -+ return UnderlyingMakeOperator()(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for column-major interleaved data. -+template < -+ typename Shape, typename Element, int AdvanceRank, -+ typename ThreadMap, int InterleavedK> -+struct MakePredicatedTileAccessIteratorDesc < -+ Shape, Element, layout::ColumnMajorInterleaved, AdvanceRank, ThreadMap> { -+ -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kInterleavedK = InterleavedK; -+ -+ using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>; -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc operator()() { -+ -+ return UnderlyingMakeOperator()(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for roww-major interleaved data. -+template < -+ typename Shape, typename Element, int AdvanceRank, -+ typename ThreadMap, int InterleavedK> -+struct MakePredicatedTileAccessIteratorDesc < -+ Shape, Element, layout::RowMajorInterleaved, AdvanceRank, ThreadMap> { -+ -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kInterleavedK = InterleavedK; -+ -+ using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>; -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc operator()() { -+ -+ return UnderlyingMakeOperator()(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Parameters struct -+// -+ -+struct PredicatedTileAccessIteratorParams { -+ -+ using Index = int32_t; -+ using LongIndex = int64_t; -+ -+ // -+ // Data members -+ // -+ /// stride of pitch-linear layout (units of Element) -+ LongIndex stride_; -+ /// amount (in byte) to increment pointer to move to next access along -+ /// strided dimension -+ LongIndex inc_strided_; -+ /// amount (in byte) to increment pointer from last access to first access -+ /// of next tile -+ LongIndex inc_next_; -+ /// amount (in byte) to increment pointer from first access of current tile -+ /// to first access of next tile -+ LongIndex inc_advance_; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { -+ -+ stride_ = stride; -+ -+ inc_strided_ = (LongIndex(stride_) * desc.threadmap_delta.strided()) * -+ desc.element_size_bits / 8; -+ -+ if (desc.advance_rank) { -+ // advance along strided dimension -+ inc_advance_ = -+ desc.threadblock_shape.strided() * LongIndex(stride_) * desc.element_size_bits / 8; -+ } else { -+ // advance along contiguous dimension -+ inc_advance_ = desc.threadblock_shape.contiguous() * desc.element_size_bits / 8; -+ } -+ -+ inc_next_ = inc_advance_ - LongIndex(desc.threadmap_iterations.strided() - 1) * -+ desc.threadmap_delta.strided() * LongIndex(stride_) * -+ desc.element_size_bits / 8; -+ -+ return Status::kSuccess; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(Index stride, PredicatedTileAccessIteratorDesc desc) { -+ return initialize(LongIndex(stride), desc); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorParams() { -+ initialize(LongIndex(0), PredicatedTileAccessIteratorDesc()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorParams(Index stride, PredicatedTileAccessIteratorDesc desc) { -+ initialize(stride, desc); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorParams(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { -+ initialize(stride, desc); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h -new file mode 100644 -index 0000000..d304b99 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h -@@ -0,0 +1,892 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of tiles -+ from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last -+ "residue" tile first, with the objective of minimizing predicate mask updates -+ during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileAccessIteratorTriangularMatrix -+/// -+template -+class PredicatedTileAccessIteratorTriangularMatrix; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for pitch-linear data. -+/// -+template -+class PredicatedTileAccessIteratorTriangularMatrix { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ using CompareOp = typename TrMatrixCompareOp::Type; -+ -+ static_assert( kFillMode == FillMode::kFull || -+ ((kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) && AccessType::kElements == 1), -+ "BLAS3 iterator for the triangular/symmetric matrix must use AccessType::kElements as 1"); -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static int const kPredicatesPerByte = 4; -+ static int const kPredicatesPerWord = 4 * kPredicatesPerByte; -+ -+ static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; -+ -+ /// Number of 32b words containing predicates -+ static int const kPredicateByteCount = -+ (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; -+ static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; -+ -+ static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; -+ -+ static_assert(kPredicateWordCount <= 4, "Too many predicates."); -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = Array; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend PredicatedTileAccessIteratorTriangularMatrix; -+ -+ private: -+ /// stride of pitch-linear layout (units of Element) -+ StrideIndex stride_; -+ /// (true) pitch-linear layout is mapped to row-major matrix -+ /// (false) pitch-linear layout is mapped to column-major matrix -+ bool is_row_major_; -+ /// for vectorized access across the diagonal boundary guard condition is -+ /// checked for the element on the boundary -+ int access_diagonal_boundary_; -+ /// amount (in byte) to increment pointer to move to next access along -+ /// strided dimension -+ LongIndex inc_strided_; -+ /// amount (in byte) to increment pointer from last access to first access -+ /// of next tile -+ LongIndex inc_next_; -+ /// amount (in byte) to increment pointer from first access of current tile -+ /// to first access of next tile -+ LongIndex inc_advance_; -+ -+ public: -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0), is_row_major_(false), access_diagonal_boundary_(0) { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout, bool is_row_major, int access_diagonal_boundary) : -+ stride_(layout.stride(0)), is_row_major_(is_row_major), access_diagonal_boundary_(access_diagonal_boundary) { -+ -+ inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) * -+ sizeof_bits::value / 8; -+ -+ if (kAdvanceRank) { -+ // advance along strided dimension -+ inc_advance_ = -+ Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; -+ } else { -+ // advance along contiguous dimension -+ inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; -+ } -+ -+ inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * -+ ThreadMap::Delta::kStrided * LongIndex(stride_) * -+ sizeof_bits::value / 8; -+ -+ }; -+ -+ -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params const ¶ms_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ /// Guard predicates -+ uint32_t predicates_[kPredicateWordCount]; -+ -+ /// Track global memory addresses on the diagonal -+ /// To ignore imag part for diagonal elements of hermitian matrices -+ uint32_t predicates_onDiag_[kPredicateWordCount]; -+ -+ /// Size of tensor -+ TensorCoord extent_; -+ -+ /// Initial offset for each thread -+ TensorCoord thread_offset_; -+ -+ /// Iteration along vectors implied by the thread map -+ int iteration_vector_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ private: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0u; -+ predicates_onDiag_[i] = 0u; -+ } -+ -+ CompareOp compare_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { -+ -+ int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int c = access_residual / kAccessesPerVector; -+ int v = access_residual % kAccessesPerVector; -+ -+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, -+ s * ThreadMap::Delta::kStrided); -+ -+ TensorCoord coord = thread_offset_ + iteration_coord; -+ -+ bool guard; -+ bool onDiag = false; -+ -+ guard = ((coord.strided() < extent.strided()) && -+ (coord.contiguous() < extent.contiguous())); -+ -+ -+ // guard access on the wrong side of the triagular matrix diagonal -+ if (kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) { -+ coord += TensorCoord{params_.access_diagonal_boundary_, 0}; -+ -+ bool triagular_guard_row_major = compare_op(coord.strided(), coord.contiguous()) | !params_.is_row_major_; -+ bool triagular_guard_col_major = compare_op(coord.contiguous(), coord.strided()) | params_.is_row_major_; -+ -+ guard = guard && triagular_guard_row_major && triagular_guard_col_major; -+ -+ if (kDiagType == DiagType::kUnit) { -+ onDiag = (guard && coord.strided() == coord.contiguous()) ? true : false; -+ } -+ } -+ -+ int pred_idx_onDiag = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); -+ int word_idx_onDiag = pred_idx_onDiag / kPredicatesPerWord; -+ int residual_onDiag = pred_idx_onDiag % kPredicatesPerWord; -+ int byte_idx_onDiag = residual_onDiag / kPredicatesPerByte; -+ int bit_idx_onDiag = residual_onDiag % kPredicatesPerByte; -+ -+ predicates_onDiag_[word_idx_onDiag] |= (unsigned(onDiag) << (byte_idx_onDiag * 8 + bit_idx_onDiag)); -+ -+ int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); -+ -+ } -+ -+ } -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : params_(params), -+ pointer_(reinterpret_cast(const_cast(pointer))), -+ extent_(extent) { -+ -+ -+ // Per-thread offset in logical coordinates of tensor -+ thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(thread_offset_)); -+ -+ compute_predicates_(extent_); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += sizeof_bits::value * pointer_offset / 8; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ thread_offset_ += TensorCoord{0, Shape::kStrided * tile_offset.strided()}; -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ thread_offset_ += TensorCoord{Shape::kContiguous * tile_offset.contiguous(), 0}; -+ } -+ -+ compute_predicates_(extent_); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast( -+ pointer_ + -+ iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + iteration_vector_; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix &operator++() { -+ -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ -+ iteration_vector_ = 0; -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ pointer_ += params_.inc_strided_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, this -+ // subtraction as well as the subsequent integer addition are both elided by -+ // the compiler. -+ pointer_ -= params_.inc_advance_; -+ -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix operator++(int) { -+ PredicatedTileAccessIteratorTriangularMatrix self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = enable ? 0u : predicates_[i]; -+ } -+ -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0xffffffff; -+ } -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = mask[i]; -+ } -+ -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ mask[i] = predicates_[i]; -+ } -+ } -+ -+ /// Return if the address in on the diagonal -+ CUTLASS_HOST_DEVICE -+ bool getOnDiag() { -+ int pred_idx = -+ iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ bool pred = (predicates_onDiag_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; -+ return pred; -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ -+ int pred_idx = -+ iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; -+ return pred; -+ -+ -+ //return true; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIteratorTriangularMatrix { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, -+ kSideMode, kFillMode, kDiagType, AccessType>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ static int const kAccessDiagonalBoundary = -+ (kFillMode == FillMode::kLower) ? (AccessType::kElements - 1) : 0; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorTriangularMatrix; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0)), false, kAccessDiagonalBoundary){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), -+ threadblock_offset.column())) {} -+ -+ /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix operator++(int) { -+ PredicatedTileAccessIteratorTriangularMatrix self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Return if the address in on the diagonal -+ CUTLASS_HOST_DEVICE -+ bool getOnDiag() { -+ return iterator_.getOnDiag(); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIteratorTriangularMatrix { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, -+ kSideMode, kFillMode, kDiagType, AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ static int const kAccessDiagonalBoundary = -+ (kFillMode == FillMode::kUpper) ? (AccessType::kElements - 1) : 0; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorTriangularMatrix; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0)), true, kAccessDiagonalBoundary){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix operator++(int) { -+ PredicatedTileAccessIteratorTriangularMatrix self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Return if the address in on the diagonal -+ CUTLASS_HOST_DEVICE -+ bool getOnDiag() { -+ return iterator_.getOnDiag(); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h -new file mode 100644 -index 0000000..839d8f5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h -@@ -0,0 +1,1880 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. The first tile this -+ iterator visits maybe partial, then the remaining tiles are complete. So, we -+ only need to compute the predicates twice, once before the first tile and -+ once for the remaining full tiles which can share the same predicates. -+ -+ A precomputed "Params" object minimizes the amount of state that must be stored in registers, -+ and integer addition is used to advance the pointer through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/memory.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileIterator -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+/// Regular tile iterator using a precomputed control structure to minimize register liveness -+/// and integer arithmetic. -+/// -+/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. -+/// -+/// Base pointer and tensor extents may be specified at the time the iterator is constructed. -+/// Subsequently, they are assumed to be immutable. -+/// -+/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. -+/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. -+/// -+/// Visitation order is intended to first visit a "residual" tile that may be partially full in -+/// both the advance dimension and the steady-state dimension. This is assumed to be the last -+/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to -+/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent -+/// accesses may be performed without updating internal predicates and are efficient in terms of -+/// live register state and pointer arithmetic instructions. -+/// -+/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once -+/// outside any looping structure to minimize integer arithmetic. -+/// -+/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing -+/// the iterator. -+/// -+/// -+/// Example: -+/// -+/// An efficient pipeline structure may be constructed as follows: -+/// -+// template -+// __global__ void kernel( -+// typename Iterator::Params params, -+// typename Iterator::Element *ptr, -+// TensorCoord extent) { -+// -+// typename Iterator::Fragment fragment; -+// -+// TensorCoord threadblock_offset(0, 0); -+// -+// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -+// -+// -+// fragment = *iter; // load "residue" tile first -+// ++iter; // advance to first "steady state" tile and update internal masks -+// -+// -+// #pragma unroll -+// for (int i = Remaining - 1; i >= 0; --i) { -+// -+// f(fragment); -+// -+// if (!i) { -+// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. -+// } -+// -+// fragment = *iter; // load tile during "steady state" phase -+// ++iter; // advance to next tile - lightweight due to steady-state masks -+// } -+// } -+// -+// void host(TensorView view) { -+// -+// using Iterator = transform::threadblock::PredicatedTileIterator; -+// -+// typename Iterator::Params params(view.layout()); -+// -+// kernel(params, view.data()); -+// } -+/// -+/// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ int AccessSize = ThreadMap::kElementsPerAccess, -+ bool Gather = false -+> -+class PredicatedTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ /// Type used for internal memory accesses -+ using AccessType = AlignedArray::value / 8)>; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = -+ PredicatedTileAccessIterator; -+ -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ using Base = typename TileAccessIterator::Params::Base; -+ -+ friend PredicatedTileIterator; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : params_(layout) {} -+ -+ /// Default constructor -+ Params() = default; -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ /// Gather indices -+ int const *indices = nullptr) -+ : address_iterator_(params.params_, pointer, extent, thread_id, -+ threadblock_offset, indices) {} -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset({0, 1}); -+ else -+ address_iterator_.add_tile_offset({1, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { address_iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } -+ -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, address_iterator_.valid()); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_byte_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType *access_ptr = reinterpret_cast(byte_ptr); -+ -+ if (address_iterator_.valid()) { -+ *access_ptr = frag_ptr[idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ bool Gather -+> -+class PredicatedTileIterator { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessSize, -+ Gather -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) -+ {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()), -+ indices) -+ { } -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ bool Gather -+> -+class PredicatedTileIterator { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessSize, -+ Gather -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock -+ int const *indices = nullptr ///< Gather indices -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()), -+ indices -+ ) { } -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for affine rank-2 data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileIterator, AdvanceRank, -+ ThreadMap_, AccessSize, false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRankN<2>; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ /// Type used for internal memory accesses -+ using AccessType = AlignedArray::value / 8)>; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = -+ PredicatedTileAccessIterator; -+ -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ -+ friend PredicatedTileIterator; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : params_(layout) {} -+ -+ /// Default constructor -+ Params() = default; -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : address_iterator_(params.params_, pointer, extent, thread_id, -+ threadblock_offset) {} -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset(make_Coord(0, 1)); -+ else -+ address_iterator_.add_tile_offset(make_Coord(1, 0)); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { address_iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } -+ -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, address_iterator_.valid()); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_byte_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType *access_ptr = reinterpret_cast(byte_ptr); -+ -+ if (address_iterator_.valid()) { -+ *access_ptr = frag_ptr[idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for affine rank 2 column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize -+> -+class PredicatedTileIterator { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::AffineRankN<2>, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessSize -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) -+ {} -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) -+ ) { } -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for affine rank 2 row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize -+> -+class PredicatedTileIterator { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::AffineRankN<2>, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessSize -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) -+ ) { } -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for interleaved data. It is mapped -+/// to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+ -+template -+class PredicatedTileIterator, -+ AdvanceRank, ThreadMap_, AccessSize, false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>; -+ -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row() * kInterleavedK, -+ extent.column() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row() * kInterleavedK, -+ threadblock_offset.column() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for interleaved-32 data. It is -+/// mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileIterator, -+ AdvanceRank, ThreadMap_, AccessSize, false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::RowMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>; -+ -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column() * kInterleavedK, -+ extent.row() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column() * kInterleavedK, -+ threadblock_offset.row() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h -new file mode 100644 -index 0000000..0a685fc ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h -@@ -0,0 +1,787 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile -+ first, with the objective of minimizing predicate mask updates during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be stored in registers, -+ and integer addition is used to advance the pointer through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h" -+#include "cutlass/transform/thread/transpose.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileIterator2dThreadTile -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+/// Regular tile iterator using a precomputed control structure to minimize register liveness -+/// and integer arithmetic. -+/// -+/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. -+/// -+/// Base pointer and tensor extents may be specified at the time the iterator is constructed. -+/// Subsequently, they are assumed to be immutable. -+/// -+/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. -+/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. -+/// -+/// Vistitation order is intended to first visit a "residual" tile that may be partially full in -+/// both the advance dimension and the steady-state dimension. This is assumed to be the last -+/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to -+/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent -+/// accesses may be performed without updating internal predicates and are efficient in terms of -+/// live register state and pointer arithmetic instructions. -+/// -+/// To be efficient, this assumes the iteraor will be dereferenced and advanced at least once -+/// outside any looping structure to minimize integer arithmetic. -+/// -+/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing -+/// the iterator. -+/// -+/// -+/// Example: -+/// -+/// An efficient pipeline structure may be constructed as follows: -+/// -+// template -+// __global__ void kernel( -+// typename Iterator::Params params, -+// typename Iterator::Element *ptr, -+// TensorCoord extent) { -+// -+// typename Iterator::Fragment fragment; -+// -+// TensorCoord threadblock_offset(0, 0); -+// -+// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -+// -+// -+// fragment = *iter; // load "residue" tile first -+// ++iter; // advance to first "steady state" tile and update internal masks -+// -+// -+// #pragma unroll -+// for (int i = Remaining - 1; i >= 0; --i) { -+// -+// f(fragment); -+// -+// if (!i) { -+// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. -+// } -+// -+// fragment = *iter; // load tile during "steady state" phase -+// ++iter; // advance to next tile - lightweight due to steady-state masks -+// } -+// } -+// -+// void host(TensorView view) { -+// -+// using Iterator = transform::threadblock::PredicatedTileIterator2dThreadTile; -+// -+// typename Iterator::Params params(view.layout()); -+// -+// kernel(params, view.data()); -+// } -+/// -+/// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ bool Transpose = false -+> -+class PredicatedTileIterator2dThreadTile; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileIterator2dThreadTile { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ /// Type used for internal memory accesses -+ /// extra set of parenthesis is needed for VS compiler -+ struct alignas((ThreadMap::kElementsPerAccess * sizeof_bits::value / -+ 8)) AccessType { -+ -+ Array storage; -+ -+ static int const kElements = ThreadMap::kElementsPerAccess; -+ }; -+ -+ /// Optinally this fragment can be 4x4 transposed -+ using Transform = thread::Transpose< ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount , layout::PitchLinearShape<4,4>, Element>; -+ static bool const transpose = Transpose_; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = -+ PredicatedTileAccessIterator2dThreadTile; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ using Base = typename TileAccessIterator::Params::Base; -+ -+ friend PredicatedTileIterator2dThreadTile; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : params_(layout) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : address_iterator_(params.params_, pointer, extent, thread_id, -+ threadblock_offset) {} -+ -+ /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile &operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset({0, 1}); -+ else -+ address_iterator_.add_tile_offset({1, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile operator++(int) { -+ PredicatedTileIterator2dThreadTile self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { address_iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){ -+ -+ int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \ -+ s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; -+ -+ address_iterator_.set_iteration_index(access_idx); -+ if (address_iterator_.valid()) { -+ -+ frag_ptr[access_idx] = -+ *(address_iterator_.get() + pointer_offset); -+ } -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ if (transpose) { -+ Transform t; -+ t.transform(frag, frag); -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){ -+ -+ int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \ -+ s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; -+ -+ address_iterator_.set_iteration_index(access_idx); -+ if (address_iterator_.valid()) { -+ *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ bool Transpose_ -+> -+class PredicatedTileIterator2dThreadTile { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static bool const Transpose = Transpose_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIterator2dThreadTile< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ Transpose -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIterator2dThreadTile; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) -+ ) { } -+ -+ /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile operator++(int) { -+ PredicatedTileIterator2dThreadTile self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ bool Transpose_ -+> -+class PredicatedTileIterator2dThreadTile { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static bool const Transpose = Transpose_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIterator2dThreadTile< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ Transpose -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIterator2dThreadTile; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) -+ ) { } -+ -+ /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile operator++(int) { -+ PredicatedTileIterator2dThreadTile self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h -new file mode 100644 -index 0000000..b849ee7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h -@@ -0,0 +1,818 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile -+ first, with the objective of minimizing predicate mask updates during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be stored in registers, -+ and integer addition is used to advance the pointer through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/memory.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileIteratorTriangularMatrix -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+/// Regular tile iterator using a precomputed control structure to minimize register liveness -+/// and integer arithmetic. -+/// -+/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. -+/// -+/// Base pointer and tensor extents may be specified at the time the iterator is constructed. -+/// Subsequently, they are assumed to be immutable. -+/// -+/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. -+/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. -+/// -+/// Vistitation order is intended to first visit a "residual" tile that may be partially full in -+/// both the advance dimension and the steady-state dimension. This is assumed to be the last -+/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to -+/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent -+/// accesses may be performed without updating internal predicates and are efficient in terms of -+/// live register state and pointer arithmetic instructions. -+/// -+/// To be efficient, this assumes the iteraor will be dereferenced and advanced at least once -+/// outside any looping structure to minimize integer arithmetic. -+/// -+/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing -+/// the iterator. -+/// -+/// -+/// Example: -+/// -+/// An efficient pipeline structure may be constructed as follows: -+/// -+// template -+// __global__ void kernel( -+// typename Iterator::Params params, -+// typename Iterator::Element *ptr, -+// TensorCoord extent) { -+// -+// typename Iterator::Fragment fragment; -+// -+// TensorCoord threadblock_offset(0, 0); -+// -+// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -+// -+// -+// fragment = *iter; // load "residue" tile first -+// ++iter; // advance to first "steady state" tile and update internal masks -+// -+// -+// #pragma unroll -+// for (int i = Remaining - 1; i >= 0; --i) { -+// -+// f(fragment); -+// -+// if (!i) { -+// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. -+// } -+// -+// fragment = *iter; // load tile during "steady state" phase -+// ++iter; // advance to next tile - lightweight due to steady-state masks -+// } -+// } -+// -+// void host(TensorView view) { -+// -+// using Iterator = transform::threadblock::PredicatedTileIteratorTriangularMatrix; -+// -+// typename Iterator::Params params(view.layout()); -+// -+// kernel(params, view.data()); -+// } -+/// -+/// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ SideMode kSideMode, -+ FillMode kFillMode, -+ DiagType kDiagType, -+ int AccessSize = ThreadMap::kElementsPerAccess -+> -+class PredicatedTileIteratorTriangularMatrix; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorTriangularMatrix for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileIteratorTriangularMatrix { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ /// Type used for internal memory accesses -+ using AccessType = AlignedArray::value / 8)>; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = -+ PredicatedTileAccessIteratorTriangularMatrix; -+ -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend PredicatedTileIteratorTriangularMatrix; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : params_(layout) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : address_iterator_(params.params_, pointer, extent, thread_id, -+ threadblock_offset) {} -+ -+ /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix &operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset({0, 1}); -+ else -+ address_iterator_.add_tile_offset({1, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix operator++(int) { -+ PredicatedTileIteratorTriangularMatrix self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { address_iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } -+ -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, address_iterator_.valid()); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_byte_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType *access_ptr = reinterpret_cast(byte_ptr); -+ -+ if (address_iterator_.valid()) { -+ *access_ptr = frag_ptr[idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorTriangularMatrix for column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ SideMode kSideMode, -+ FillMode kFillMode, -+ DiagType kDiagType, -+ int AccessSize -+> -+class PredicatedTileIteratorTriangularMatrix { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ AccessSize -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIteratorTriangularMatrix; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { -+ -+ } -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset ///< Initial offset of threadblock -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) -+ ) { } -+ -+ /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix operator++(int) { -+ PredicatedTileIteratorTriangularMatrix self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorTriangularMatrix for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ SideMode kSideMode, -+ FillMode kFillMode, -+ DiagType kDiagType, -+ int AccessSize -+> -+class PredicatedTileIteratorTriangularMatrix { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ AccessSize -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIteratorTriangularMatrix; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { -+ -+ }; -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset ///< Initial offset of threadblock -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) -+ ) { } -+ -+ /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix operator++(int) { -+ PredicatedTileIteratorTriangularMatrix self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h -new file mode 100644 -index 0000000..4762175 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h -@@ -0,0 +1,417 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates implementing computing the addresses of loading small -+ vectors from the global memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedVectorAccessIterator -+/// -+template < -+ /// Shape of the vector accessed by the entire threadblock -+ typename Shape, -+ /// Shape of the vector accessed by the warp -+ typename WarpShape, -+ /// Type of Element -+ typename Element, -+ /// Layout of the vector -+ typename Layout, -+ /// Number of elements for each access -+ int ElementsPerAccess, -+ /// Support residual tile -+ bool EnableResidualAccess = false -+> -+class PredicatedVectorAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Vector access iterator specialized for vectors, e.g. scale and bias -+/// Thread arrangements are for TensorOps -+/// -+template < -+ typename Shape_, -+ typename WarpShape_, -+ typename Element_, -+ int ElementsPerAccess, -+ bool EnableResidualAccess -+> -+class PredicatedVectorAccessIterator < -+ Shape_, -+ WarpShape_, -+ Element_, -+ layout::PitchLinear, -+ ElementsPerAccess, -+ EnableResidualAccess -+> { -+ public: -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+// static int const kElementsPerAccess = 128 / sizeof_bits::value; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kThreads = 32; -+ static int const kRowsPerIteration = 8; -+ static int const kThreadsPerRow = kThreads / kRowsPerIteration; -+ static int const kThreadsPerRowMask = 0x3; -+ static int const kIterations = WarpShape::kContiguous / (kThreadsPerRow * kElementsPerAccess); -+ static int const kWarpCountStrided = Shape::kStrided / WarpShape::kStrided; -+ -+ using AccessType = AlignedArray; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ /// Extent of tensor -+ TensorCoord extent_; -+ -+ /// pointer offset of each thread -+ TensorCoord thread_offset_; -+ -+ /// iteration index -+ LongIndex iteration_; -+ -+ /// residual access -+ bool is_residual_; -+ -+ /// residual offset of each thread -+ TensorCoord residual_offset_; -+ -+ public: -+ /// Constructs a vector access iterator -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator( -+ /// Pointer to the start of the vector -+ ConstPointer pointer, -+ /// Extent of vector -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// ID of each participating warp -+ int warp_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ extent_(extent), -+ is_residual_(false) { -+ -+ -+ int warp_offset = (warp_id / kWarpCountStrided) * WarpShape::kContiguous; -+ -+ // Per-thread offset in logical coordinates of tensor -+ -+ thread_offset_ = threadblock_offset + TensorCoord(warp_offset, 0) + -+ TensorCoord((thread_id & kThreadsPerRowMask) * kElementsPerAccess, 0); -+ -+ set_iteration_index(0); -+ -+ if(EnableResidualAccess) { -+ // compute residual offset -+ typename TensorCoord::Index residual_size = extent_.contiguous() % WarpShape::kContiguous; -+ if (residual_size) { -+ is_residual_ = true; -+ residual_offset_ = make_Coord(residual_size, 0); -+ } -+ } -+ } -+ -+ /// Construct a PredicatedVectorAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator( -+ /// Pointer to start of vector -+ ConstPointer pointer, -+ /// Extent of vector -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ /// ID of each participating warp -+ int warp_id) -+ : PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id, -+ make_Coord(0, 0)) {} -+ -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iteration_ = index; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ -+ thread_offset_ = -+ thread_offset_ + -+ TensorCoord(WarpShape::kContiguous * tile_offset.contiguous(), 0); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ return reinterpret_cast( -+ pointer_ + -+ ((thread_offset_.contiguous() + iteration_ * kThreadsPerRow * kElementsPerAccess) -+ * sizeof_bits::value / 8)); -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator &operator++() { -+ ++iteration_; -+ if(iteration_ >= kIterations) -+ iteration_ = 0; -+ -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ if(EnableResidualAccess && is_residual_) { -+ is_residual_ = false; -+ thread_offset_ += residual_offset_; -+ } -+ else -+ add_tile_offset(TensorCoord(1, 0)); -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator operator++(int) { -+ PredicatedVectorAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return ((thread_offset_.contiguous() + -+ iteration_ * kThreadsPerRow * kElementsPerAccess) < extent_.contiguous()); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedVectorAccessIterator for row-major data. -+/// -+template < -+ typename Shape_, -+ typename WarpShape_, -+ typename Element_, -+ int ElementsPerAccess, -+ bool EnableResidualAccess -+> -+class PredicatedVectorAccessIterator< -+ Shape_, -+ WarpShape_, -+ Element_, -+ layout::RowMajor, -+ ElementsPerAccess, -+ EnableResidualAccess -+> { -+ public: -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedVectorAccessIterator< -+ layout::PitchLinearShape, -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ ElementsPerAccess, -+ EnableResidualAccess>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; -+ static int const kRowsPerIteration = UnderlyingIterator::kRowsPerIteration; -+ static int const kThreads = UnderlyingIterator::kThreads; -+ static int const kIterations = UnderlyingIterator::kIterations; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator( -+ ///< Pointer to the start of the vector -+ ConstPointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< ID of each participating warp -+ int warp_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(pointer, layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, warp_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedVectorAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator( -+ ConstPointer pointer, ///< Pointer to the start of the vector -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ int warp_id ///< ID of each participating warp -+ ) -+ : PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator operator++(int) { -+ PredicatedVectorAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ iterator_.advance(); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h -new file mode 100644 -index 0000000..1de3e65 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h -@@ -0,0 +1,253 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates implementing computing the addresses of storing of small -+ scale and bias vectors in the shared memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// RegularScaleBiasVectorAccessIterator -+/// -+template -+class RegularScaleBiasVectorAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularScaleBiasVectorAccessIterator { -+ public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Element type per access -+ static int const kElementsPerAccess = 128 / sizeof_bits::value; -+ static int const kThreads = Shape::kContiguous / kElementsPerAccess; -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Internal pointer -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularScaleBiasVectorAccessIterator( -+ TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias -+ ///< vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : byte_offset_(0) { -+ // Per-thread offset in logical coordinates of tensor -+ int thread_offset = thread_id * kElementsPerAccess; -+ -+ // initialize pointer -+ pointer_ = -+ reinterpret_cast(scale_bias_ref.data() + thread_offset); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_DEVICE -+ AccessType *get() const { -+ -+ char *access_byte_ptr = -+ reinterpret_cast(pointer_); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularScaleBiasVectorAccessIterator &operator++() { return *this; } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularScaleBiasVectorAccessIterator operator++(int) { -+ RegularScaleBiasVectorAccessIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset in the unit of tile. -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ // Multiply by 2 because we store scale and bias belong to the same stage -+ // next to each other. -+ add_pointer_offset(coord.contiguous() * Shape::kContiguous * 2); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for row major layouts -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularScaleBiasVectorAccessIterator< -+ Shape_, Element_, -+ layout::RowMajor> { -+ public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularScaleBiasVectorAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularScaleBiasVectorAccessIterator( -+ TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias -+ ///< vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({scale_bias_ref.data(), scale_bias_ref.stride()}, thread_id) { -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularScaleBiasVectorAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularScaleBiasVectorAccessIterator operator++(int) { -+ RegularScaleBiasVectorAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h -new file mode 100644 -index 0000000..a3e30c2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h -@@ -0,0 +1,58 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing the address computation of storing of tiles -+ from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template ::value* ThreadMap::kElementsPerAccess / 8> -+class RegularTileAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h -new file mode 100644 -index 0000000..bba9f66 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h -@@ -0,0 +1,408 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing computing the addresses of storing of tiles -+ from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::PitchLinear, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_DEVICE -+ AccessType *get() const { -+ -+ AccessType *access_ptr = pointer_; -+ -+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset in the unit of tile. -+ /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. -+ /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. -+ /// For row major A operand, k dimension is contiguous dimension; -+ /// For col major A operand, k dimension is strided dimension; -+ /// For row major B operand, k dimension is strided dimension; -+ /// For col major B operand, k dimension is contiguous dimension. -+ /// Below two classes map col/row major to the pitch linear coordinates used -+ /// in this base class. -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset(coord.contiguous() * Shape::kContiguous + -+ coord.strided() * Shape::kStrided * stride_ * -+ ThreadMap::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for column major layouts -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajor, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for row major layouts -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::RowMajor, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h -new file mode 100644 -index 0000000..938b419 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h -@@ -0,0 +1,587 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing computing the addresses of storing of tiles -+ from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template ::value* ThreadMap::kElementsPerAccess / 8 -+ > -+class RegularTileAccessIteratorDirectConv; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations OFF -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIteratorDirectConv< -+ Shape_, Element_, -+ layout::PitchLinear, -+ AdvanceRank, ThreadMap_, false, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_num(int num) { -+ //Do nothing -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_DEVICE -+ AccessType *get() const { -+ -+ AccessType *access_ptr = pointer_; -+ -+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv operator++(int) { -+ RegularTileAccessIteratorDirectConv prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset in the unit of tile. -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset(coord.contiguous() * Shape::kContiguous + -+ coord.strided() * ThreadMap::Iterations::kStrided * -+ ThreadMap::Delta::kStrided * stride_ * ThreadMap::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations ON -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIteratorDirectConv< -+ Shape_, Element_, -+ layout::PitchLinear, -+ AdvanceRank, ThreadMap_,true, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ /// Total iterattions in the strided dimension: Dynamic value -+ int total_iteration_strided_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_num(int num) { -+ total_iteration_strided_ = num; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_DEVICE -+ AccessType *get() const { -+ -+ AccessType *access_ptr = pointer_; -+ -+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < total_iteration_strided_) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv operator++(int) { -+ RegularTileAccessIteratorDirectConv prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset in the unit of tile. -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset(coord.contiguous() * Shape::kContiguous + -+ coord.strided() * total_iteration_strided_ * ThreadMap::Delta::kStrided * stride_ * -+ ThreadMap::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for column major layouts -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIteratorDirectConv< -+ Shape_, Element_, -+ layout::ColumnMajor, -+ AdvanceRank, ThreadMap_, Dynamic_iterations , Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIteratorDirectConv< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap_, -+ Dynamic_iterations>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_num(int num) { -+ iterator_.set_iteration_num(num); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv operator++(int) { -+ RegularTileAccessIteratorDirectConv prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for row major layouts -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIteratorDirectConv< -+ Shape_, Element_, -+ layout::RowMajor, -+ AdvanceRank, ThreadMap_, Dynamic_iterations, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIteratorDirectConv< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap_, -+ Dynamic_iterations>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_num(int num) { -+ iterator_.set_iteration_num(num); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv operator++(int) { -+ RegularTileAccessIteratorDirectConv prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h -new file mode 100644 -index 0000000..c16daff ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h -@@ -0,0 +1,820 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing computing the addresses of storing of tiles -+ from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert(sizeof_bits::value * -+ ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ -+ ///< Number of pointers -+ static int const kPointerCount = -+ (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); -+ }; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_[Detail::kPointerCount]; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : stride_(ref.stride(0) / Layout::kElementsPerAccess), -+ byte_offset_(0) { -+ layout::PitchLinearCoord thread_offset_base = -+ ThreadMap::initial_offset(thread_id); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = -+ thread_offset_base + -+ layout::PitchLinearCoord{ -+ 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; -+ -+ // initialize pointer -+ pointer_[i] = reinterpret_cast( -+ ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ AccessType *access_ptr = pointer_[iteration_strided_ & 1]; -+ int stride_idx = (iteration_strided_ & ~1); -+ -+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset(coord.contiguous() * Shape::kContiguous + -+ coord.strided() * Shape::kStrided * stride_ * -+ Layout::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::RowMajorTensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for crosswise arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator::value, Crosswise>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ static int const kCrosswise = Crosswise; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ static_assert(!(ThreadMap::Delta::kContiguous % kCrosswise), -+ "kCrosswise is the smallest unit in the contiguous dimension " -+ "for shared memory swizzling."); -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert(sizeof_bits::value * -+ ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ -+ /// Number of pointers -+ /// -+ /// Note:TN kblock32 layouts only needs 1 pointer, but strangely -+ /// reducing pointer count hurts perfomrnace -+ static int const kPointerCount = -+ (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); -+ }; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Total number of sections. The memory is divided into stages. One stage -+ /// can store one tile. Stage is divided into sections. Interleaved layout -+ /// can have multiple sections in a stage. The rest layout only has one section -+ /// in a stage. -+ int sections_; -+ -+ /// Sections that a stage has -+ int sections_per_stage_; -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_[Detail::kPointerCount]; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : sections_(ref.stride(0) / kCrosswise), -+ sections_per_stage_(Shape::kContiguous / kCrosswise), -+ // stride_ = kCrosswise x sections_ x kFactor -+ stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), -+ byte_offset_(0) { -+ layout::PitchLinearCoord thread_offset_base = -+ ThreadMap::initial_offset(thread_id); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = -+ thread_offset_base + -+ layout::PitchLinearCoord{ -+ 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; -+ // initialize pointer -+ pointer_[i] = reinterpret_cast(ref.data()) + -+ ref.offset(thread_offset_in_threadblock_tile) / -+ Layout::kElementsPerAccess; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ AccessType *access_ptr = pointer_[iteration_strided_ & 1]; -+ int stride_idx = (iteration_strided_ & ~1); -+ -+ int access_offset = -+ stride_idx * ThreadMap::Delta::kStrided * stride_ / Layout::kFactor + -+ // kCrosswise elements in the contiguous dimension would span to a -+ // shared memory cache line. -+ iteration_contiguous_ * (ThreadMap::Delta::kContiguous / kCrosswise) * -+ Layout::TileShape::kContiguous; -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next section. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset(coord.contiguous() * sections_per_stage_ * stride_ * -+ ThreadMap::kElementsPerAccess / sections_ + -+ coord.strided() * Shape::kStrided * stride_ * -+ Layout::kElementsPerAccess / Layout::kFactor); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator::value, Crosswise>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h -new file mode 100644 -index 0000000..2b116d0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h -@@ -0,0 +1,1532 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing computing the addresses of storing of tiles -+ from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandCongruous64b, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorOpMultiplicandCongruous64b; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ static_assert(ThreadMap::kThreads / 32 > 1, -+ "This tile iterator requires at least two warps."); -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 64; -+ -+ static_assert(sizeof_bits::value * -+ ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 64b"); -+ -+ ///< Number of pointers -+ static int const kPointerCount = 1; -+ }; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): -+ stride_(ref.stride(0) / Layout::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ AccessType *access_ptr = pointer_; -+ -+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ -+ RegularTileAccessIterator prev(*this); -+ -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ -+ add_pointer_offset( -+ coord.contiguous() * Shape::kContiguous + -+ coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicandCongruous64b, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous64b, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCongruous64b; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous64b, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for crosswise arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicand64bCrosswise, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorOpMultiplicand64bCrosswise; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ static_assert(ThreadMap::kThreads / 32 > 1, -+ "This tile iterator requires at least two warps."); -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 64; -+ -+ static_assert(sizeof_bits::value * -+ ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 64b"); -+ -+ ///< Number of pointers - two pointers are needed if making more than 4 iterations along -+ ///< strided dimension -+ static int const kPointerCount = (ThreadMap::Iterations::kStrided > 4 ? 2 : 1); -+ }; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_[Detail::kPointerCount]; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_DEVICE -+ RegularTileAccessIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): -+ stride_(ref.stride(0) / ThreadMap::kElementsPerAccess) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data()); -+ -+ byte_offset_[0] = ref.offset(thread_offset_in_threadblock_tile) * sizeof(Element); -+ -+ if (Detail::kPointerCount == 2) { -+ byte_offset_[1] = byte_offset_[0] ^ 8; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ pointer_ += pointer_offset / ThreadMap::kElementsPerAccess; -+ } -+ -+ /// Returns a pointer -+ CUTLASS_DEVICE -+ AccessType *get() const { -+ -+ // Map the logical contiguous and strided access to the internal swizzled structure. -+ int uniform_offset = (iteration_strided_ & 0x3) * stride_ + (iteration_strided_ >> 3) * 16 + stride_ * ThreadMap::Delta::kContiguous * iteration_contiguous_; -+ -+ char *access_byte_ptr = reinterpret_cast(pointer_ + uniform_offset); -+ -+ int byte_offset; -+ -+ // This iterator may require two byte offsets if it must load more than 8 rows (or 2 iterations) -+ // in the strided dimension -+ if (Detail::kPointerCount == 2 && (iteration_strided_ & 0x4)) { -+ byte_offset = byte_offset_[1]; -+ } -+ else { -+ byte_offset = byte_offset_[0]; -+ } -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ -+ RegularTileAccessIterator prev(*this); -+ -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ -+ add_pointer_offset(coord.strided() * Shape::kStrided + coord.contiguous() * Shape::kContiguous * stride_); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicand64bCrosswise, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicand64bCrosswise, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicand64bCrosswise, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandCongruous128b, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorOpMultiplicandCongruous128b; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ static_assert(ThreadMap::kThreads / 32 > 1, -+ "This tile iterator requires at least two warps."); -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert(sizeof_bits::value * -+ ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128b"); -+ -+ ///< Number of pointers -+ static int const kPointerCount = 1; -+ }; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): -+ stride_(ref.stride(0) / Layout::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ AccessType *access_ptr = pointer_; -+ -+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ -+ RegularTileAccessIterator prev(*this); -+ -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ -+ add_pointer_offset( -+ coord.contiguous() * Shape::kContiguous + -+ coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicandCongruous128b, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous128b, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCongruous128b; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous128b, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): -+ iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandCrosswise128x4, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorOpMultiplicandCrosswise128x4; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ static_assert(ThreadMap::kThreads / 32 > 1, -+ "This tile iterator requires at least two warps."); -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert(sizeof_bits::value * -+ ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128b"); -+ -+ ///< Number of pointers -+ static int const kPointerCount = 1; -+ }; -+ -+ -+ static_assert(!(ThreadMap::Iterations::kStrided % 2), "This iterator requires at least two iterations along the strided dimension"); -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_DEVICE -+ RegularTileAccessIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): -+ stride_(ref.stride(0) / Layout::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ AccessType *access_ptr = pointer_; -+ -+ int offset_c = (iteration_contiguous_ * ThreadMap::Delta::kContiguous + (iteration_strided_ & 1) * 2); -+ int offset_s = (iteration_strided_ / 2) * 8; -+ -+ int access_offset = offset_c * stride_ + offset_s; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ -+ RegularTileAccessIterator prev(*this); -+ -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ -+ add_pointer_offset( -+ coord.contiguous() * Shape::kContiguous * stride_ + -+ coord.strided() * Shape::kStrided * Layout::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicandCrosswise128x4, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCrosswise128x4, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCrosswise128x4, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): -+ iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h -new file mode 100644 -index 0000000..26d7da7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h -@@ -0,0 +1,62 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing storing of tiles from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ int Alignment = sizeof_bits::value * ThreadMap::kElementsPerAccess / 8 -+> -+class RegularTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h -new file mode 100644 -index 0000000..f761cdd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h -@@ -0,0 +1,552 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile -+ first, with the objective of minimizing predicate mask updates during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be stored in registers, -+ and integer addition is used to advance the pointer through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "regular_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Regular tile iterator specialized for pitch-linear. This one is used by 2-stage SIMT kernels -+/// and sparse tensor core meta data. -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator { -+public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Fragment = Array; -+ -+ using AccessType = AlignedArray; -+ -+ static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, -+ "Advance rank may only be along the contiguous or strided dimensions."); -+ -+private: -+ -+ // -+ // Types -+ // -+ -+ // -+ // Data members -+ // -+ -+ /// Pointer to memory -+ uint8_t *pointer_; -+ -+ /// Stride quantity -+ StrideIndex stride_; -+ -+ /// Amount to increment pointer along strided dimension -+ Index increment_strided_; -+ -+ /// Amount to advance pointer between tiles -+ Index increment_advance_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ RegularTileIterator(): pointer_(nullptr), increment_strided_(0), increment_advance_(0) { } -+ -+ CUTLASS_DEVICE -+ RegularTileIterator( -+ TensorRef const &ref, -+ int thread_idx -+ ): -+ pointer_(reinterpret_cast(ref.data()) + (ref.offset(ThreadMap::initial_offset(thread_idx)) * sizeof_bits::value / 8)) { -+ -+ stride_ = ref.stride()[0]; -+ increment_strided_ = (ref.stride()[0] * sizeof_bits::value) * ThreadMap::Delta::kStrided / 8; -+ -+ increment_advance_ = -+ (kAdvanceRank == 0 ? -+ Shape::kContiguous * sizeof_bits::value / 8 : -+ Shape::kStrided * (ref.stride()[0] * sizeof_bits::value / 8)); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ uint8_t const *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int idx = c + s * ThreadMap::Iterations::kContiguous; -+ frag_ptr[idx] = access_ptr[c * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess]; -+ } -+ -+ if (s + 1 < ThreadMap::Iterations::kStrided) { -+ byte_pointer += increment_strided_; -+ } -+ } -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, TensorCoord const & tile_offset) { -+ load_with_pointer_offset( -+ frag, -+ tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided * stride_ -+ ); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ uint8_t *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int idx = c + s * ThreadMap::Iterations::kContiguous; -+ access_ptr[c * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess] = frag_ptr[idx]; -+ } -+ -+ if (s + 1 < ThreadMap::Iterations::kStrided) { -+ byte_pointer += increment_strided_; -+ } -+ } -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, TensorCoord const & tile_offset) { -+ store_with_pointer_offset( -+ frag, -+ tile_offset.contiguous() * Shape::kContiguous + tile_offset.strided() * Shape::kStrided * stride_ -+ ); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ pointer_ += increment_advance_; -+ return *this; -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator--() { -+ pointer_ -= increment_advance_; -+ return *this; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset; -+ } -+ -+ /// Adds a tile offset in the unit of tile. -+ /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. -+ /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. -+ /// For row major A operand, k dimension is contiguous dimension; -+ /// For col major A operand, k dimension is strided dimension; -+ /// For row major B operand, k dimension is strided dimension; -+ /// For col major B operand, k dimension is contiguous dimension. -+ /// Below two classes map col/row major to the pitch linear coordinates used -+ /// in this base class. -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ int offset = sizeof_bits::value * -+ (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8; -+ add_pointer_offset(offset); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+#if 0 -+ AccessType *access_ptr = pointer_[iteration_strided_ & 1]; -+ int stride_idx = (iteration_strided_ & ~1); -+ -+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+#endif -+ return reinterpret_cast(pointer_); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Regular tile iterator specialized for pitch-linear -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator { -+public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Fragment = Array; -+ -+ using Underlying = RegularTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ kAlignment -+ >; -+ -+ using AccessType = typename Underlying::AccessType; -+ -+ static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, -+ "Advance rank may only be along the row or column dimensions."); -+ -+private: -+ -+ Underlying iterator_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ RegularTileIterator() { } -+ -+ CUTLASS_DEVICE -+ RegularTileIterator( -+ TensorRef const &ref, -+ int thread_idx -+ ): -+ iterator_({ref.data(), ref.stride()}, thread_idx) { -+ -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, TensorCoord const & tile_offset) { -+ iterator_.load_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) { -+ iterator_.load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, TensorCoord const & tile_offset) { -+ iterator_.store_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ iterator_.store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator--() { -+ --iterator_; -+ return *this; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return iterator_.get(); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Regular tile iterator specialized for pitch-linear -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator { -+public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Fragment = Array; -+ -+ using Underlying = RegularTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap -+ >; -+ -+ using AccessType = typename Underlying::AccessType; -+ -+ static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, -+ "Advance rank may only be along the row or column dimensions."); -+ -+private: -+ -+ Underlying iterator_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ RegularTileIterator() { } -+ -+ CUTLASS_DEVICE -+ RegularTileIterator( -+ TensorRef const &ref, -+ int thread_idx -+ ): -+ iterator_({ref.data(), ref.stride()}, thread_idx) { -+ -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, TensorCoord const & tile_offset) { -+ iterator_.load_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) { -+ iterator_.load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, TensorCoord const & tile_offset) { -+ iterator_.store_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ iterator_.store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator--() { -+ --iterator_; -+ return *this; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return iterator_.get(); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h -new file mode 100644 -index 0000000..a954eb4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h -@@ -0,0 +1,509 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile -+ first, with the objective of minimizing predicate mask updates during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be stored in registers, -+ and integer addition is used to advance the pointer through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "regular_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ int Alignment = sizeof_bits::value * ThreadMap::kElementsPerAccess / 8 -+> -+class RegularTileIterator2dThreadTile; -+ -+ -+/// Regular tile iterator specialized for pitch-linear + 2d thread-tiled threadmapping -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator2dThreadTile { -+public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Fragment = Array; -+ -+ static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, -+ "Advance rank may only be along the contiguous or strided dimensions."); -+ -+private: -+ -+ // -+ // Types -+ // -+ -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Pointer to memory -+ uint8_t *pointer_; -+ -+ /// Stride quantity -+ StrideIndex stride_; -+ -+ /// Amount to increment pointer along strided dimension -+ LongIndex increment_strided_; -+ -+ /// Amount to advance pointer between tiles -+ LongIndex increment_advance_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ RegularTileIterator2dThreadTile(): pointer_(nullptr), increment_strided_(0), increment_advance_(0) { } -+ -+ CUTLASS_DEVICE -+ RegularTileIterator2dThreadTile( -+ TensorRef const &ref, -+ int thread_idx, -+ int interleave -+ ){ -+ -+ TensorCoord t = ThreadMap::initial_offset(thread_idx); -+ long int offset = t[0] * interleave + t[1] * ref.stride()[0]/interleave; -+ pointer_ = reinterpret_cast(ref.data() + offset); -+ -+ stride_ = ref.stride()[0] / interleave; -+ increment_strided_ = (ref.stride()[0] * sizeof_bits::value / 8) * ThreadMap::Delta::kStrided / interleave; -+ -+ increment_advance_ = -+ (kAdvanceRank == 0 ? -+ Shape::kContiguous * sizeof_bits::value / 8 : -+ Shape::kStrided * (ref.stride()[0] * sizeof_bits::value / 8) / interleave); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ uint8_t const *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int idx = c + s * ThreadMap::Iterations::kContiguous; -+ frag_ptr[idx] = access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided]; -+ } -+ -+ if (s + 1 < ThreadMap::Iterations::kStrided) { -+ byte_pointer += increment_strided_; -+ } -+ } -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, TensorCoord const & tile_offset) { -+ load_with_pointer_offset( -+ frag, -+ tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided * stride_ -+ ); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ uint8_t *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int idx = c + s * ThreadMap::Iterations::kContiguous; -+ access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided] = frag_ptr[idx]; -+ } -+ -+ if (s + 1 < ThreadMap::Iterations::kStrided) { -+ byte_pointer += increment_strided_; -+ } -+ } -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, TensorCoord const & tile_offset) { -+ store_with_pointer_offset( -+ frag, -+ tile_offset.contiguous() * Shape::kContiguous + tile_offset.strided() * Shape::kStrided * stride_ -+ ); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator2dThreadTile &operator++() { -+ pointer_ += increment_advance_; -+ return *this; -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator2dThreadTile &operator--() { -+ pointer_ -= increment_advance_; -+ return *this; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ int offset = sizeof_bits::value * -+ (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8; -+ add_pointer_offset(offset); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Regular tile iterator specialized for interleaved layout + 2d thread-tiled threadmapping -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator2dThreadTile, AdvanceRank, ThreadMap_, Alignment> { -+public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorInterleaved<4>; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Fragment = Array; -+ -+ using Underlying = RegularTileIterator2dThreadTile< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ kAlignment -+ >; -+ -+ static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, -+ "Advance rank may only be along the row or column dimensions."); -+ -+private: -+ -+ Underlying iterator_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ RegularTileIterator2dThreadTile() { } -+ -+ CUTLASS_DEVICE -+ RegularTileIterator2dThreadTile( -+ TensorRef const &ref, -+ int thread_idx -+ ): -+ iterator_({ref.data(), ref.stride()}, thread_idx, 4) { -+ -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, TensorCoord const & tile_offset) { -+ iterator_.load_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) { -+ iterator_.load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, TensorCoord const & tile_offset) { -+ iterator_.store_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ iterator_.store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator2dThreadTile &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator2dThreadTile &operator--() { -+ --iterator_; -+ return *this; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Regular tile iterator specialized for interleaved layout + 2d thread-tiled threadmapping -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator2dThreadTile, AdvanceRank, ThreadMap_, Alignment> { -+public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorInterleaved<4>; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Fragment = Array; -+ using PitchLinearThreadMap = PitchLinearStripminedThreadMap< layout::PitchLinearShape, -+ ThreadMap::kThreads, ThreadMap::ThreadAccessShape::kCount >; -+ -+ -+ using Underlying = RegularTileIterator2dThreadTile< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap -+ >; -+ -+ static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, -+ "Advance rank may only be along the row or column dimensions."); -+ -+private: -+ -+ Underlying iterator_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ RegularTileIterator2dThreadTile() { } -+ -+ CUTLASS_DEVICE -+ RegularTileIterator2dThreadTile( -+ TensorRef const &ref, -+ int thread_idx -+ ): -+ iterator_({ref.data(), ref.stride()}, thread_idx, 4) { -+ -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, TensorCoord const & tile_offset) { -+ iterator_.load_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) { -+ iterator_.load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, TensorCoord const & tile_offset) { -+ iterator_.store_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ iterator_.store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator2dThreadTile &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator2dThreadTile &operator--() { -+ --iterator_; -+ return *this; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h -new file mode 100644 -index 0000000..8ea0efa ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h -@@ -0,0 +1,1107 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing storing of tiles from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element))>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ -+ /// This iterator is specialized for an access size that is 128 bits in length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert( -+ sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ }; -+ -+private: -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = RegularTileAccessIterator; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : address_iterator_(ref, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ address_iterator_.add_tile_offset({0, 1}); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ address_iterator_.add_tile_offset(coord); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, Index byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ frag_ptr[access_idx] = *access_ptr; -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, Index byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType *access_ptr = reinterpret_cast(byte_ptr); -+ -+ *access_ptr = frag_ptr[access_idx]; -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_byte_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element))>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element))>, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): iterator_({ref.data(), ref.stride()}, thread_id) { -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator< -+ Shape_, Element_, -+ layout::RowMajorTensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element))>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element))>, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): iterator_({ref.data(), ref.stride()}, thread_id) { -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for crosswise arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator::value, Crosswise>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>; -+ -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert(sizeof_bits::value * ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ }; -+ -+ private: -+ /// Element type per access -+ using AccessType = Array; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = -+ Array; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = RegularTileAccessIterator; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : address_iterator_(ref, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ address_iterator_.add_tile_offset({1, 0}); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ address_iterator_.add_tile_offset(coord); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, Index byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType *access_ptr = reinterpret_cast(byte_ptr); -+ -+ *access_ptr = frag_ptr[access_idx]; -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator::value, Crosswise>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator::value, Crosswise>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for k interleaved arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandRowMajorInterleaved::value, -+ InterleavedK>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::TensorOpMultiplicandRowMajorInterleaved::value, -+ InterleavedK>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert(sizeof_bits::value * ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ }; -+ -+ private: -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = -+ Array; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = RegularTileAccessIterator; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : address_iterator_(ref, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ address_iterator_.add_pointer_offset(Shape::kCount); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ address_iterator_.add_pointer_offset(coord.contiguous() * Shape::kCount); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for k interleaved arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+ -+template -+class RegularTileIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandColumnMajorInterleaved::value, -+ InterleavedK>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::TensorOpMultiplicandColumnMajorInterleaved::value, -+ InterleavedK>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ cutlass::MatrixShape, -+ Element, -+ layout::TensorOpMultiplicandRowMajorInterleaved::value, InterleavedK>, -+ (kAdvanceRank == 1 ? 0 : 1), -+ ThreadMap -+ >; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+ private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.strided(), coord.contiguous()}); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h -new file mode 100644 -index 0000000..883faa5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h -@@ -0,0 +1,1460 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile -+ first, with the objective of minimizing predicate mask updates during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be stored in registers, -+ and integer addition is used to advance the pointer through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm70.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, -+ Element_, -+ layout::VoltaTensorOpMultiplicandCongruous::value>, -+ AdvanceRank, -+ ThreadMap_, -+ Alignment> { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::VoltaTensorOpMultiplicandCongruous::value>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ -+ /// This iterator is specialized for an access size that is 128 bits in length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert( -+ sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ -+ ///< Number of pointers -+ static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); -+ }; -+ -+ -+private: -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType * pointer_[Detail::kPointerCount]; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ -+ // This is the offset of a thread within a threadblock tile for a specific pointer -+ // (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = -+ thread_offset_base + layout::PitchLinearCoord{0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; -+ -+ // initialize pointer -+ pointer_[i] = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ add_pointer_offset((kAdvanceRank ? Shape::kStrided * stride_ * Layout::kElementsPerAccess : Shape::kContiguous)); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset( -+ coord.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + -+ coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess -+ ); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = pointer_[s & 1]; -+ int stride_idx = (s & ~1); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + -+ c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + -+ vec_pointer_offset; -+ -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); -+ -+ frag_ptr[access_idx] = *reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = pointer_[s & 1]; -+ int stride_idx = (s & ~1); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + -+ c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + -+ vec_pointer_offset; -+ -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); -+ -+ *reinterpret_cast(access_byte_ptr + byte_offset_) = frag_ptr[access_idx]; -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, -+ Element_, -+ layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>, -+ AdvanceRank, -+ ThreadMap_, -+ Alignment> { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::VoltaTensorOpMultiplicandCongruous::value>, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap_>; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): iterator_({ref.data(), ref.stride()}, thread_id) { -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, -+ Element_, -+ layout::RowMajorVoltaTensorOpMultiplicandCongruous::value>, -+ AdvanceRank, -+ ThreadMap_, -+ Alignment> { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorVoltaTensorOpMultiplicandCongruous::value>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::VoltaTensorOpMultiplicandCongruous::value>, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap_>; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): iterator_({ref.data(), ref.stride()}, thread_id) { -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, -+ Element_, -+ layout::VoltaTensorOpMultiplicandBCongruous::value>, -+ AdvanceRank, -+ ThreadMap_, -+ Alignment> { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::VoltaTensorOpMultiplicandBCongruous::value>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ -+ /// This iterator is specialized for an access size that is 128 bits in length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert( -+ sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ -+ ///< Number of pointers -+ static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); -+ }; -+ -+ -+private: -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType * pointer_[Detail::kPointerCount]; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ -+ // This is the offset of a thread within a threadblock tile for a specific pointer -+ // (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = -+ thread_offset_base + layout::PitchLinearCoord{0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; -+ -+ // initialize pointer -+ pointer_[i] = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ add_pointer_offset((kAdvanceRank ? Shape::kStrided * stride_ * Layout::kElementsPerAccess : Shape::kContiguous)); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset( -+ coord.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + -+ coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess -+ ); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = pointer_[s & 1]; -+ int stride_idx = (s & ~1); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + -+ c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + -+ vec_pointer_offset; -+ -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); -+ -+ frag_ptr[access_idx] = *reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = pointer_[s & 1]; -+ int stride_idx = (s & ~1); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + -+ c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + -+ vec_pointer_offset; -+ -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); -+ -+ *reinterpret_cast(access_byte_ptr + byte_offset_) = frag_ptr[access_idx]; -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, -+ Element_, -+ layout::ColumnMajorVoltaTensorOpMultiplicandBCongruous::value>, -+ AdvanceRank, -+ ThreadMap_, -+ Alignment> { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::VoltaTensorOpMultiplicandBCongruous::value>, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap_>; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): iterator_({ref.data(), ref.stride()}, thread_id) { -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, -+ Element_, -+ layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>, -+ AdvanceRank, -+ ThreadMap_, -+ Alignment> { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::VoltaTensorOpMultiplicandBCongruous::value>, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap_>; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): iterator_({ref.data(), ref.stride()}, thread_id) { -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+ -+/// Tile iterator specialized for crosswise arrangements for TensorOps. -+/// -+/// Volta TN SMEM layout is a little diffrent: -+/// Crosseised elements will be stored in a line, while contiguous elements -+/// sre stored in line-by-line. -+/// Padding is used to reduce SMEM bank conflicts. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, Element_, -+ layout::VoltaTensorOpMultiplicandCrosswise::value, -+ Shape_::kContiguous>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::VoltaTensorOpMultiplicandCrosswise::value, -+ Shape::kContiguous>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ -+ ///< Number of pointers -+ static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); -+ -+ /// Iterations for the kElementsPerAccess of ThreadMap -+ static int const kIterarionsPerAccess = -+ ThreadMap::kElementsPerAccess / Layout::kElementsPerAccess; -+ -+ /// Contiguous elements per line -+ static int const kContiguousElementsPerLine = 4; -+ }; -+ -+ private: -+ /// Element type per access -+ using AccessType = Array; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = -+ Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// The crosswised elements will be stored in a line. -+ /// line_size is size of crosswised dimention plus padding. -+ /// in units of AccessType -+ Index line_size; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_[Detail::kPointerCount]; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : line_size(ref.stride(0) * Detail::kContiguousElementsPerLine / Layout::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = -+ ThreadMap::initial_offset(thread_id); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = -+ thread_offset_base + -+ layout::PitchLinearCoord{ -+ 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; -+ -+ // initialize pointer -+ pointer_[i] = reinterpret_cast( -+ ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ // (Shape::kContiguous/Layout::kElementsPerAccess)* -+ // line_size * Layout::kElementsPerAccess -+ add_pointer_offset(Shape::kContiguous * line_size); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset((coord.contiguous() * (Shape::kContiguous / Layout::kElementsPerAccess) * -+ line_size + coord.strided() * Shape::kStrided) * -+ Layout::kElementsPerAccess); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ AccessType *access_ptr = pointer_[(s & 1) ^ (s / 2)]; -+ -+ access_ptr += 16 * (s / 2); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) { -+ -+ int access_offset = -+ c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + -+ vec_pointer_offset + i * line_size; -+ -+ int access_idx = (c + s * ThreadMap::Iterations::kContiguous) * -+ Detail::kIterarionsPerAccess + i; -+ -+ char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); -+ -+ frag_ptr[access_idx] = *reinterpret_cast( -+ access_byte_ptr + byte_offset_); -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = pointer_[(s & 1) ^ ((s >> 1) & 1)]; -+ -+ access_ptr += 16 * (s / 2) + vec_pointer_offset; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) { -+ -+ int access_offset = -+ c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + i * line_size; -+ -+ int access_idx = (c + s * ThreadMap::Iterations::kContiguous) * -+ Detail::kIterarionsPerAccess + i; -+ -+ char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); -+ -+ *reinterpret_cast(access_byte_ptr + byte_offset_) = -+ frag_ptr[access_idx]; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator::value, Shape_::kRow>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kRow>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::VoltaTensorOpMultiplicandCrosswise::value, -+ Shape::kRow>, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator::value, Shape_::kColumn>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kColumn>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::VoltaTensorOpMultiplicandCrosswise::value, -+ Shape::kColumn>, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/vector_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/vector_iterator.h -new file mode 100644 -index 0000000..8536a32 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/vector_iterator.h -@@ -0,0 +1,149 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template wraps the vector access iterator concept to load whole vector from tensors in -+ memory. This is typically used for per-channel scale and bias in convolution kernels. -+*/ -+ -+#pragma once -+ -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class VectorIterator { -+public: -+ using VectorAccessIterator = VectorAccessIterator_; -+ -+ using Shape = typename VectorAccessIterator::Shape; -+ using Element = typename VectorAccessIterator::Element; -+ using Layout = typename VectorAccessIterator::Layout; -+ using TensorCoord = typename Layout::TensorCoord; -+ using AccessType = typename VectorAccessIterator::AccessType; -+ using TensorRef = typename VectorAccessIterator::TensorRef; -+ using Index = typename VectorAccessIterator::Index; -+ using LongIndex = typename VectorAccessIterator::LongIndex; -+ -+ static int const kElementsPerAccess = VectorAccessIterator::kElementsPerAccess; -+ static int const kRowsPerIteration = VectorAccessIterator::kRowsPerIteration; -+ static int const kThreads = VectorAccessIterator::kThreads; -+ static int const kIterations = VectorAccessIterator::kIterations; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, kElementsPerAccess * kIterations>; -+ -+private: -+ -+ /// Internal state -+ VectorAccessIterator vector_access_iterator_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ VectorIterator( -+ Element const *ptr, -+ TensorCoord extent, -+ int thread_idx, -+ int warp_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ vector_access_iterator_(ptr, extent, thread_idx, warp_idx, threadblock_offset) { } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ VectorIterator &operator++() { -+ vector_access_iterator_.advance(); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ VectorIterator operator++(int) { -+ VectorIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ frag.clear(); -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[c], -+ vector_access_iterator_.get() + pointer_offset, -+ vector_access_iterator_.valid() -+ ); -+ -+ ++vector_access_iterator_; -+ } -+// } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ vector_access_iterator_.set_iteration_index(0); -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void advance() { -+ vector_access_iterator_.advance(); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h b/3rdparty/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h -new file mode 100644 -index 0000000..5b5baba ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h -@@ -0,0 +1,283 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of a warp vector -+ that participate in one warp-level mma operation. -+ -+ Typically, this is used to access the scale/bias fragement of a warp-level mma operation. -+ The scale/bias vector is then partitioned into smaller fragments that can be fed into -+ next warp-level mma operation. -+ -+ This iterator is necessary to accomplish warp-level mma fusion where the scale/bias vector is -+ applied to the multiplicand for the next mma. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_conversion.h" -+ -+namespace cutlass { -+namespace transform { -+namespace warp { -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the input fragment tile shape (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ //// Number of elements per access when loading fragment -+ int ElementsPerAccess> -+class VectorFragmentIterator; -+ -+ -+// Partial specialization for PitchLinear layout tile -+ -+template < -+ /// Size of the input fragment vector shape (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ //// Number of elements per access when loading fragment -+ int ElementsPerAccess> -+class VectorFragmentIterator { -+ public: -+ -+ /// Size of the input threadblock tile shape (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::PitchLinear; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kRowsPerIteration = 8; -+ static int const kColumnsPerAccess = 8; -+ static int const kElementsPerIteration = kRowsPerIteration * InstructionShape::kK / kThreads; -+ static int const kAccessPerIteration = kElementsPerIteration / kElementsPerAccess; -+ -+ /// Number of iterations -+ using Iterations = MatrixShape; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ // All fragments have kElementsPerAccess scale followed by bias -+ -+ /// Fragment object holding a thread's part of a tile -+ /// This is the fragment size produced by one iteration of the iterator. -+ using Fragment = Array; -+ -+ /// Input threadblock fragment tile -+ using ThreadblockFragment = Array; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Input threadblock fragment tile -+ AccessType const *iterator_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ VectorFragmentIterator(ThreadblockFragment const &threadblock_frag) -+ : iterator_(reinterpret_cast(&threadblock_frag)), -+ index_(0) {} -+ -+ /// Add offset -+ CUTLASS_HOST_DEVICE -+ void add_offset(int index_offset) { -+ index_ += index_offset; -+ -+ if(index_ >= Iterations::kColumn) -+ index_ = 0; -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ VectorFragmentIterator &operator++() { -+ add_offset(1); -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_index(int idx) { -+ index_ = idx; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int r = 0; r < Iterations::kRow; r++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kAccessPerIteration; i++) { -+ -+ frag_ptr[i * Iterations::kRow + r].clear(); -+ frag_ptr[i * Iterations::kRow + r] = iterator_[index_ * kAccessPerIteration + i]; -+ } -+ } -+ } -+ -+}; -+ -+// Partial specialization for Row-Major layout tile -+ -+template < -+ /// Size of the input fragment tile shape (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ //// Number of elements per access when loading fragment -+ int ElementsPerAccess> -+class VectorFragmentIterator { -+ public: -+ -+ /// Size of the input threadblock tile shape (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Underlying iterator -+ using Base = VectorFragmentIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, InstructionShape, ElementsPerAccess>; -+ -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ /// Fragment object holding a thread's part of a tile -+ /// This is the fragment size produced by one iteration of the iterator. -+ using Fragment = typename Base::Fragment; -+ -+ /// Input threadblock fragment tile -+ using ThreadblockFragment = typename Base::ThreadblockFragment; -+ -+ private: -+ /// Underlying iterator -+ Base iterator_; -+ -+public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ VectorFragmentIterator(ThreadblockFragment const &threadblock_frag) -+ : iterator_(threadblock_frag) {} -+ -+ /// Add offset -+ CUTLASS_HOST_DEVICE -+ void add_offset(int index_offset) { -+ iterator_.add_offset(index_offset); -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ VectorFragmentIterator &operator++() { -+ add_offset(1); -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_index(int idx) { -+ iterator_.set_index(idx); -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ iterator_.load(frag); -+ } -+ -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace conv -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/uint128.h b/3rdparty/cutlass/include/cutlass/uint128.h -new file mode 100644 -index 0000000..38d5b4d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/uint128.h -@@ -0,0 +1,272 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines an unsigned 128b integer with several operators to support 64-bit integer division. -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#include -+#include -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Optionally enable GCC's built-in type -+#if defined(__x86_64) && !defined(__CUDA_ARCH__) && defined(__GNUC__) -+#define CUTLASS_UINT128_NATIVE -+#elif defined(_MSC_VER) && defined(_M_AMD64) && !defined(__CUDA_ARCH__) -+#define CUTLASS_INT128_ARITHMETIC -+#include -+#if _MSC_VER >= 1920 -+#define CUTLASS_INT128_ARITHMETIC_DIV -+#include -+#endif -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///! Unsigned 128b integer type -+struct uint128_t { -+ -+ /// Size of one part of the uint's storage in bits -+ static constexpr int kPartSize = sizeof_bits::value; -+ -+ struct hilo { -+ uint64_t lo; -+ uint64_t hi; -+ -+ hilo() = default; -+ -+ CUTLASS_HOST_DEVICE hilo(uint64_t lo_, uint64_t hi_):lo(lo_), hi(hi_) {} -+ }; -+ -+ // Use a union to store either low and high parts or, if present, a built-in 128b integer type. -+ union { -+ struct hilo hilo_; -+ -+ #if defined(CUTLASS_UINT128_NATIVE) -+ unsigned __int128 native; -+ #endif // defined(CUTLASS_UINT128_NATIVE) -+ }; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ uint128_t() = default; -+ -+ /// Constructor from uint64 -+ CUTLASS_HOST_DEVICE -+ uint128_t(uint64_t lo_): hilo_(lo_, 0) { } -+ -+ /// Constructor from two 64b unsigned integers -+ CUTLASS_HOST_DEVICE -+ uint128_t(uint64_t lo_, uint64_t hi_): hilo_(lo_, hi_) { -+ -+ } -+ -+ /// Optional constructor from native value -+ #if defined(CUTLASS_UINT128_NATIVE) -+ uint128_t(unsigned __int128 value): native(value) { } -+ #endif -+ -+ /// Lossily cast to uint64 -+ CUTLASS_HOST_DEVICE -+ explicit operator uint64_t() const { -+ return hilo_.lo; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static void exception() { -+#if defined(__CUDA_ARCH__) -+ asm volatile (" brkpt;\n"); -+#else -+ // throw std::runtime_error("Not yet implemented."); -+ abort(); -+#endif -+ } -+ -+ /// Add -+ CUTLASS_HOST_DEVICE -+ uint128_t operator+(uint128_t const &rhs) const { -+ uint128_t y; -+#if defined(CUTLASS_UINT128_NATIVE) -+ y.native = native + rhs.native; -+#else -+ y.hilo_.lo = hilo_.lo + rhs.hilo_.lo; -+ y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (!y.hilo_.lo && (rhs.hilo_.lo)); -+#endif -+ return y; -+ } -+ -+ /// Subtract -+ CUTLASS_HOST_DEVICE -+ uint128_t operator-(uint128_t const &rhs) const { -+ uint128_t y; -+#if defined(CUTLASS_UINT128_NATIVE) -+ y.native = native - rhs.native; -+#else -+ y.hilo_.lo = hilo_.lo - rhs.hilo_.lo; -+ y.hilo_.hi = hilo_.hi - rhs.hilo_.hi - (rhs.hilo_.lo && y.hilo_.lo > hilo_.lo); -+#endif -+ return y; -+ } -+ -+ /// Multiply by unsigned 64b integer yielding 128b integer -+ CUTLASS_HOST_DEVICE -+ uint128_t operator*(uint64_t const &rhs) const { -+ uint128_t y{}; -+#if defined(CUTLASS_UINT128_NATIVE) -+ y.native = native * rhs; -+#elif defined(CUTLASS_INT128_ARITHMETIC) -+ // Multiply by the low part -+ y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi); -+ -+ // Add the high part and ignore the overflow -+ uint64_t overflow; -+ y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow); -+#else -+ // TODO - not implemented -+ CUTLASS_UNUSED(rhs); -+ exception(); -+#endif -+ return y; -+ } -+ -+ /// Divide 128b operation by 64b operation yielding a 64b quotient -+ CUTLASS_HOST_DEVICE -+ uint64_t operator/(uint64_t const &divisor) const { -+ uint64_t quotient = 0; -+#if defined(CUTLASS_UINT128_NATIVE) -+ quotient = uint64_t(native / divisor); -+#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) -+ // implemented using MSVC's arithmetic intrinsics -+ uint64_t remainder = 0; -+ quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -+#else -+ // TODO - not implemented -+ CUTLASS_UNUSED(divisor); -+ exception(); -+#endif -+ return quotient; -+ } -+ -+ /// Divide 128b operation by 64b operation yielding a 64b quotient -+ CUTLASS_HOST_DEVICE -+ uint64_t operator%(uint64_t const &divisor) const { -+ uint64_t remainder = 0; -+#if defined(CUTLASS_UINT128_NATIVE) -+ remainder = uint64_t(native % divisor); -+#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) -+ // implemented using MSVC's arithmetic intrinsics -+ (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -+#else -+ // TODO - not implemented -+ CUTLASS_UNUSED(divisor); -+ exception(); -+#endif -+ return remainder; -+ } -+ -+ /// Computes the quotient and remainder in a single method. -+ CUTLASS_HOST_DEVICE -+ uint64_t divmod(uint64_t &remainder, uint64_t divisor) const { -+ uint64_t quotient = 0; -+#if defined(CUTLASS_UINT128_NATIVE) -+ quotient = uint64_t(native / divisor); -+ remainder = uint64_t(native % divisor); -+#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) -+ // implemented using MSVC's arithmetic intrinsics -+ quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -+#else -+ // TODO - not implemented -+ CUTLASS_UNUSED(remainder); -+ CUTLASS_UNUSED(divisor); -+ exception(); -+#endif -+ return quotient; -+ } -+ -+ /// Left-shifts a 128b unsigned integer -+ CUTLASS_HOST_DEVICE -+ uint128_t operator<<(int sh) const { -+ if (sh == 0) { -+ return *this; -+ } -+ else if (sh >= kPartSize) { -+ return uint128_t(0, hilo_.lo << (sh - kPartSize)); -+ } -+ else { -+ return uint128_t( -+ (hilo_.lo << sh), -+ (hilo_.hi << sh) | uint64_t(hilo_.lo >> (kPartSize - sh)) -+ ); -+ } -+ } -+ -+ /// Right-shifts a 128b unsigned integer -+ CUTLASS_HOST_DEVICE -+ uint128_t operator>>(int sh) const { -+ if (sh == 0) { -+ return *this; -+ } -+ else if (sh >= kPartSize) { -+ return uint128_t((hilo_.hi >> (sh - kPartSize)), 0); -+ } -+ else { -+ return uint128_t( -+ (hilo_.lo >> sh) | (hilo_.hi << (kPartSize - sh)), -+ (hilo_.hi >> sh) -+ ); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/wmma_array.h b/3rdparty/cutlass/include/cutlass/wmma_array.h -new file mode 100644 -index 0000000..4a074b6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/wmma_array.h -@@ -0,0 +1,93 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Wmma array type (WmmaFragmentArray holds elements of of type nvcuda::wmma::fragment) -+template < -+ /// Element type -+ typename T, -+ /// Number of elements in the array -+ int N -+> -+class WmmaFragmentArray: public Array { -+public: -+ -+ /// Efficient clear method (override Array::clear()) -+ CUTLASS_HOST_DEVICE -+ void clear() -+ { -+ for(int i = 0; i < Array::kElements; i++) -+ { -+ nvcuda::wmma::fill_fragment((*this)[i], (typename T::element_type)0); -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ WmmaFragmentArray& operator+=(const WmmaFragmentArray& rhs) -+ { -+ using element_type = typename T::element_type; -+ plus add; -+ -+ for (int i = 0; i < Array::kElements; i++) -+ { -+ (*this)[i] = add((*this)[i], rhs[i]); -+ } -+ -+ return *this; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -diff --git a/3rdparty/cutlass/test/unit/common/cutlass_unit_test.h b/3rdparty/cutlass/test/unit/common/cutlass_unit_test.h -new file mode 100644 -index 0000000..8843e40 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/common/cutlass_unit_test.h -@@ -0,0 +1,102 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+#pragma warning (disable : 4068 ) /* disable unknown pragma warnings for vistual studio */ -+ -+#pragma nv_diag_suppress boolean_controlling_expr_is_constant -+#include -+#pragma nv_diag_warning boolean_controlling_expr_is_constant -+#pragma warning( disable : 4503) -+ -+#include -+#include -+ -+#include -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gets a CUDA device -+cudaDeviceProp GetCudaDevice(); -+ -+/// Prints device properties -+std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &device); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Sets flags for Unit test -+void FilterArchitecture(); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order -+// of problem sizes run by CUTLASS unit tests -+int CutlassUnitTestProblemCount(); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// active test macro -+#define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ -+ TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__ -+ -+// disabled test macro -+#define CUTLASS_TEST_LEVEL_DISABLED(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ -+ TEST(NAME_STATIC,DISABLED_L##LEVEL##_##NAME_DYNAMIC) {} -+ -+#if CUTLASS_TEST_LEVEL == 0 -+#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#elif CUTLASS_TEST_LEVEL == 1 -+#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#else -+#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#endif -+ -+#if !defined(CUTLASS_TEST_UNIT_ENABLE_WARNINGS) -+#define CUTLASS_TEST_UNIT_ENABLE_WARNINGS false -+#endif -+ -+#if (__CUDACC_VER_MAJOR__ >= 12) -+ #define CUDA_12_0_SM90_FEATURES_SUPPORTED true -+#else -+ #define CUDA_12_0_SM90_FEATURES_SUPPORTED false -+#endif -+ -+#include -+#include -+#include -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/cache_testbed_output.h b/3rdparty/cutlass/test/unit/conv/device/cache_testbed_output.h -new file mode 100644 -index 0000000..29be434 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/cache_testbed_output.h -@@ -0,0 +1,797 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Helper to construct cached name for -+*/ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#ifndef CUTLASS_TEST_ENABLE_CACHED_RESULTS -+#define CUTLASS_TEST_ENABLE_CACHED_RESULTS false -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result of a test -+struct CachedTestKey { -+ -+ std::string op; ///< Concatenated string representation of operation performed -+ std::string problem; ///< Concatenated string representation of problem description -+ std::string types; ///< Concatenated string representation of operand types -+ uint32_t A; ///< Hashed result of tensor A -+ uint32_t B; ///< Hashed result of tensor B -+ uint32_t C; ///< Hashed result of tensor C -+ -+ // -+ // Methods -+ // -+ inline CachedTestKey(): A(), B(), C() { } -+ -+ inline CachedTestKey( -+ std::string op, ///< Concatenated string representation of operation performed -+ std::string problem, ///< Concatenated string representation of problem description -+ std::string types, ///< Concatenated string representation of operand types -+ uint32_t A, ///< Hashed result of tensor A -+ uint32_t B, ///< Hashed result of tensor B -+ uint32_t C ///< Hashed result of tensor C -+ ): -+ op(op), problem(problem), types(types), A(A), B(B), C(C) -+ { } -+ -+ /// Checks for equality of the problem -+ bool operator==(CachedTestKey const &rhs) const { -+ return op == rhs.op && problem == rhs.problem && types == rhs.types && A == rhs.A && B == rhs.B && C == rhs.C; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+inline std::istream &operator>>(std::istream &in, CachedTestKey &result) { -+ -+ in >> result.op; -+ in >> result.problem; -+ in >> result.types; -+ in >> result.A; -+ in >> result.B; -+ in >> result.C; -+ -+ return in; -+} -+ -+inline std::ostream &operator<<(std::ostream &out, CachedTestKey const &result) { -+ -+ out << result.op << " "; -+ out << result.problem << " "; -+ out << result.types << " "; -+ out << result.A << " "; -+ out << result.B << " "; -+ out << result.C << " "; -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct CachedTestResult { -+ uint32_t D; -+ -+ // -+ // Methods -+ // -+ -+ CachedTestResult(): D() { } -+ -+ CachedTestResult(uint32_t D): D(D) { } -+ -+ operator bool() const { -+ return bool(D); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+inline std::istream &operator>>(std::istream &in, CachedTestResult &result) { -+ in >> result.D; -+ return in; -+} -+ -+inline std::ostream &operator<<(std::ostream &out, CachedTestResult const &result) { -+ out << result.D; -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct CachedTestResultListing { -+ -+ std::list> results; -+ -+ // -+ // Methods -+ // -+ -+ inline CachedTestResultListing(std::string const &path) { -+ std::ifstream file(path); -+ -+ while (file.good()) { -+ CachedTestKey key; -+ file >> key; -+ -+ CachedTestResult result; -+ file >> result; -+ -+ if (result) { -+ results.push_back(std::make_pair(key, result)); -+ } -+ } -+ } -+ -+ /// Returns the cached result -+ std::pair find(CachedTestKey const &rhs) const { -+ for (auto const & result : results) { -+ if (result.first == rhs) { -+ return std::make_pair(true, result.second); -+ } -+ } -+ return std::make_pair(false, CachedTestResult()); -+ } -+ -+ /// Appends an entry -+ void append(CachedTestKey const &key, CachedTestResult const &result) { -+ if (result) { -+ results.push_back(std::make_pair(key, result)); -+ } -+ } -+ -+ /// Writes the entire listing to a file -+ bool write(std::string const &path) { -+ std::ofstream file(path); -+ if (!file.good()) { -+ return false; -+ } -+ -+ for (auto const &result : results) { -+ file << result.first << result.second << std::endl; -+ } -+ -+ return true; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct ScalarEncoder { -+ Element scalar; -+ -+ ScalarEncoder(Element s): scalar(s) { } -+ -+ std::string str() const { -+ std::stringstream ss; -+ Element s = scalar; -+ if (s < Element()) { -+ s = -s; -+ ss << "n"; -+ } -+ ss << s; -+ return ss.str(); -+ } -+}; -+ -+template -+ScalarEncoder EncodeScalar(Element a) { -+ return ScalarEncoder(a); -+} -+ -+template -+struct ScalarEncoder> { -+ cutlass::complex scalar; -+ -+ ScalarEncoder(cutlass::complex s): scalar(s) { } -+ -+ std::string str() const { -+ std::stringstream ss; -+ ss << EncodeScalar(scalar.real()) << "_" << EncodeScalar(scalar.imag()) << "i"; -+ return ss.str(); -+ } -+}; -+ -+template -+std::ostream &operator<<(std::ostream &out, ScalarEncoder const &scalar) { -+ out << scalar.str(); -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+inline char const *EncodeOperator(cutlass::conv::Operator conv_op) { -+ switch (conv_op) { -+ case cutlass::conv::Operator::kFprop: return "fprop"; -+ case cutlass::conv::Operator::kDgrad: return "dgrad"; -+ case cutlass::conv::Operator::kWgrad: return "wgrad"; -+ } -+ return "conv_unknown"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Encode GemmCoord (Gemm problem size) -+inline std::ostream &EncodeProblemSize( -+ std::ostream &out, -+ cutlass::gemm::GemmCoord const &problem) { -+ -+ out << problem.m() << "x" << problem.n() << "x" << problem.k() << "_"; -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Encode Conv2dProblemSize -+inline std::ostream &EncodeProblemSize( -+ std::ostream &out, -+ cutlass::conv::Conv2dProblemSize const &problem) { -+ -+ out << problem.N << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" -+ << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; -+ -+ out << "pad_h" << problem.pad_h << "w" << problem.pad_w << "_"; -+ out << "stride_h" << problem.stride_h << "w" << problem.stride_w << "_"; -+ out << "dil_h" << problem.dilation_h << "w" << problem.dilation_w << "_"; -+ -+ switch (problem.mode) { -+ case cutlass::conv::Mode::kCrossCorrelation: -+ out << "corr"; -+ break; -+ case cutlass::conv::Mode::kConvolution: -+ out << "conv"; -+ break; -+ } -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Encode Conv3dProblemSize -+inline std::ostream &EncodeProblemSize( -+ std::ostream &out, -+ cutlass::conv::Conv3dProblemSize const &problem) { -+ -+ out << problem.N << "x" << problem.D << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" -+ << problem.Z << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; -+ -+ out << "pad_d" << problem.pad_h << "h" << problem.pad_h << "w" << problem.pad_w << "_"; -+ out << "stride_d" << problem.stride_d << "h" << problem.stride_h << "w" << problem.stride_w << "_"; -+ out << "dil_d" << problem.dilation_d << "h" << problem.dilation_h << "w" << problem.dilation_w << "_"; -+ -+ switch (problem.mode) { -+ case cutlass::conv::Mode::kCrossCorrelation: -+ out << "corr"; -+ break; -+ case cutlass::conv::Mode::kConvolution: -+ out << "conv"; -+ break; -+ } -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+inline std::string ElementTypeName() { -+ return std::string(typeid(Element).name()); -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "h"; -+} -+ -+template <> -+inline std::string ElementTypeName>() { -+ return "ch"; -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "bf16"; -+} -+ -+template <> -+inline std::string ElementTypeName>() { -+ return "cbf16"; -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "tf32"; -+} -+ -+template <> -+inline std::string ElementTypeName>() { -+ return "ctf32"; -+} -+ -+template <> -+inline std::string ElementTypeName>() { -+ return "c"; -+} -+ -+template <> -+inline std::string ElementTypeName>() { -+ return "z"; -+} -+ -+template <> -+inline std::string ElementTypeName>() { -+ return "q"; -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "s8"; -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "u8"; -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "s4"; -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "u4"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+inline std::string LayoutTypeName() { -+ return std::string(typeid(Layout).name()); -+} -+ -+template <> -+inline std::string LayoutTypeName() { -+ return "n"; -+} -+ -+template <> -+inline std::string LayoutTypeName() { -+ return "t"; -+} -+ -+template <> -+inline std::string LayoutTypeName() { -+ return "nhwc"; -+} -+ -+template <> -+inline std::string LayoutTypeName>() { -+ return "nc32hw32"; -+} -+ -+template <> -+inline std::string LayoutTypeName>() { -+ return "nc64hw64"; -+} -+ -+template <> -+inline std::string LayoutTypeName>() { -+ return "c32rsk32"; -+} -+ -+template <> -+inline std::string LayoutTypeName>() { -+ return "c64rsk64"; -+} -+ -+template <> -+inline std::string LayoutTypeName() { -+ return "ndhwc"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+inline std::string TensorTypeName() { -+ std::stringstream ss; -+ ss << ElementTypeName() << LayoutTypeName(); -+ return ss.str(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Hash function on a byte array -+struct CRC32 { -+ -+ uint32_t table[256]; -+ -+ // -+ // Methods -+ // -+ -+ CRC32() { -+ -+ uint32_t rem; -+ int i, j; -+ -+ for (i = 0; i < 256; i++) { -+ rem = i; -+ for (j = 0; j < 8; j++) { -+ if (rem & 1) { -+ rem >>= 1; -+ rem ^= 0xedb88320; -+ } else -+ rem >>= 1; -+ } -+ table[i] = rem; -+ } -+ } -+ -+ /// Computes the CRC of an array of bytes -+ uint32_t operator()(void const *start, size_t length, uint32_t crc = uint32_t()) const { -+ uint8_t const *p = static_cast(start); -+ uint8_t const *q = static_cast(start) + length; -+ -+ crc = ~crc; -+ -+ for (; p != q; ++p) { -+ uint8_t octet = *p; -+ crc = (crc >> 8) ^ table[(crc & 0xff) ^ octet]; -+ } -+ -+ return ~crc; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Element, typename Layout -+> -+uint32_t TensorHash( -+ cutlass::TensorView view, -+ CRC32 const &hash = CRC32(), -+ uint32_t crc = uint32_t() -+) { -+ -+ return hash(view.data(), view.capacity() * cutlass::sizeof_bits::value / 8, crc); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename ElementAccumulator, -+ typename ElementCompute -+> -+inline std::ostream &EncodeTypes( -+ std::ostream &out -+) { -+ -+ out << TensorTypeName() << "_" -+ << TensorTypeName() << "_" -+ << TensorTypeName() << "_" -+ << ElementTypeName() << "_" -+ << ElementTypeName(); -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename ElementAccumulator, -+ typename ElementCompute -+> -+inline CachedTestKey CreateCachedGemmTestKey( -+ cutlass::gemm::GemmCoord const &problem, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cutlass::TensorView A, -+ cutlass::TensorView B, -+ cutlass::TensorView C -+) { -+ -+ CachedTestKey key; -+ -+ // Encode gemm operator and problem sizes -+ key.op = "gemm"; -+ -+ std::stringstream ss_problem; -+ EncodeProblemSize(ss_problem, problem); -+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); -+ key.problem = ss_problem.str(); -+ -+ // Encode problem data types -+ std::stringstream ss_types; -+ EncodeTypes< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute>(ss_types); -+ key.types = ss_types.str(); -+ -+ // Encode hash for problem data -+ CRC32 crc_hash; -+ key.A = TensorHash(A, crc_hash); -+ key.B = TensorHash(B, crc_hash); -+ key.C = TensorHash(C, crc_hash); -+ -+ return key; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+template < -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename ElementAccumulator, -+ typename ElementCompute -+> -+inline CachedTestKey CreateCachedConv2dTestKey( -+ -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cutlass::TensorView A, -+ cutlass::TensorView B, -+ cutlass::TensorView C -+) { -+ -+ CachedTestKey key; -+ -+ // Encode conv2d operator and problem sizes -+ key.op = "conv2d"; -+ -+ std::stringstream ss_problem; -+ ss_problem << EncodeOperator(conv_operator) << "_"; -+ EncodeProblemSize(ss_problem, problem); -+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); -+ -+ key.problem = ss_problem.str(); -+ -+ // Encode problem data types -+ std::stringstream ss_types; -+ EncodeTypes< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute>(ss_types); -+ key.types = ss_types.str(); -+ -+ // Encode hash for problem data -+ CRC32 crc_hash; -+ -+ key.A = TensorHash(A, crc_hash); -+ key.B = TensorHash(B, crc_hash); -+ key.C = TensorHash(C, crc_hash); -+ -+ return key; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename ElementAccumulator, -+ typename ElementCompute -+> -+inline CachedTestKey CreateCachedConv2dWithBroadcastTestKey( -+ -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cutlass::TensorView A, -+ cutlass::TensorView B, -+ cutlass::TensorView C -+) { -+ -+ CachedTestKey key; -+ -+ // Encode conv2d operator and problem sizes -+ key.op = "conv2d_with_broadcast"; -+ -+ std::stringstream ss_problem; -+ ss_problem << EncodeOperator(conv_operator) << "_"; -+ EncodeProblemSize(ss_problem, problem); -+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); -+ -+ key.problem = ss_problem.str(); -+ -+ // Encode problem data types -+ std::stringstream ss_types; -+ EncodeTypes< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute>(ss_types); -+ key.types = ss_types.str(); -+ -+ // Encode hash for problem data -+ CRC32 crc_hash; -+ -+ key.A = TensorHash(A, crc_hash); -+ key.B = TensorHash(B, crc_hash); -+ key.C = TensorHash(C, crc_hash); -+ -+ return key; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename ElementAccumulator, -+ typename ElementCompute -+> -+inline CachedTestKey CreateCachedConv2dWithReductionTestKey( -+ -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cutlass::TensorView A, -+ cutlass::TensorView B, -+ cutlass::TensorView C -+) { -+ -+ CachedTestKey key; -+ -+ // Encode conv2d operator and problem sizes -+ key.op = "conv2d_with_reduction"; -+ -+ std::stringstream ss_problem; -+ ss_problem << EncodeOperator(conv_operator) << "_"; -+ EncodeProblemSize(ss_problem, problem); -+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); -+ -+ key.problem = ss_problem.str(); -+ -+ // Encode problem data types -+ std::stringstream ss_types; -+ EncodeTypes< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute>(ss_types); -+ key.types = ss_types.str(); -+ -+ // Encode hash for problem data -+ CRC32 crc_hash; -+ -+ key.A = TensorHash(A, crc_hash); -+ key.B = TensorHash(B, crc_hash); -+ key.C = TensorHash(C, crc_hash); -+ -+ return key; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename ElementAccumulator, -+ typename ElementCompute -+> -+inline CachedTestKey CreateCachedConv3dTestKey( -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv3dProblemSize const &problem, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cutlass::TensorView A, -+ cutlass::TensorView B, -+ cutlass::TensorView C -+) { -+ -+ CachedTestKey key; -+ -+ // Encode conv3d operator and problem sizes -+ key.op = "conv3d"; -+ -+ std::stringstream ss_problem; -+ -+ ss_problem << EncodeOperator(conv_operator) << "_"; -+ EncodeProblemSize(ss_problem, problem); -+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); -+ -+ key.problem = ss_problem.str(); -+ -+ // Encode problem data types -+ std::stringstream ss_types; -+ EncodeTypes< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute>(ss_types); -+ key.types = ss_types.str(); -+ -+ // Encode problem data -+ CRC32 crc_hash; -+ key.A = TensorHash(A, crc_hash); -+ key.B = TensorHash(B, crc_hash); -+ key.C = TensorHash(C, crc_hash); -+ -+ return key; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // nammespace conv -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu -new file mode 100644 -index 0000000..cbabe42 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 64x64_8x2_32x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 32x64_8x2_32x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu -new file mode 100644 -index 0000000..08e3abd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu -@@ -0,0 +1,141 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -new file mode 100644 -index 0000000..eaade32 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -@@ -0,0 +1,298 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2, -+ 128x64_64x3_64x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 2, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {35, 100, 50, 64}, // input size (NHWC) -+ {22, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..55d9525 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu -@@ -0,0 +1,126 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..62836ab ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu -@@ -0,0 +1,238 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_align2, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_align2, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..0891f80 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,209 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_64x4_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -new file mode 100644 -index 0000000..845e86b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -@@ -0,0 +1,141 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x4_32x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..3a7b380 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..42e85be ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,303 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+std::vector Conv2dFewChannelProblemSizes(int channels) { -+ -+ std::vector problems; -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, channels}, // input size (NHWC) -+ {16, 3, 3, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {16, 3, 3, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {16, 7, 7, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {32, 7, 7, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {64, 7, 7, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {64, 5, 5, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {64, 5, 5, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ return problems; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 8; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFewChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( -+ Conv2dFewChannelProblemSizes(kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_4, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 4; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFewChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( -+ Conv2dFewChannelProblemSizes(kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 2; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFewChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( -+ Conv2dFewChannelProblemSizes(2 * kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_1, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 1; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFewChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( -+ Conv2dFewChannelProblemSizes(3 * kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..e6f676b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,240 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+std::vector Conv2dFixedChannelProblemSizes(int channels) { -+ -+ std::vector problems; -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, channels}, // input size (NHWC) -+ {16, 3, 3, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {32, 7, 7, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {64, 7, 7, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {64, 5, 5, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {64, 5, 5, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ return problems; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 8; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFixedChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d( -+ Conv2dFixedChannelProblemSizes(kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_4, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 4; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFixedChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d( -+ Conv2dFixedChannelProblemSizes(kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 2; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFixedChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d( -+ Conv2dFixedChannelProblemSizes(kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu -new file mode 100644 -index 0000000..f892d33 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu -@@ -0,0 +1,138 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 32x64_8x2_32x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 32x128_8x2_16x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<32, 128, 8>, -+ cutlass::gemm::GemmShape<16, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu -new file mode 100644 -index 0000000..e320c77 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 128x128_8x5_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 5, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -new file mode 100644 -index 0000000..64d40d8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -@@ -0,0 +1,134 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM60_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 128x128_8x2_64x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm60, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM60_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 128x128_8x2_64x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm60, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -new file mode 100644 -index 0000000..af476db ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -@@ -0,0 +1,350 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 14}, // input size (NHWC) -+ {8, 3, 3, 14}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 23, 56, 98}, // input size (NHWC) -+ {128, 3, 3, 98}, // filter size (KRSC) -+ {4, 0, 5, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 14}, // input size (NHWC) -+ {8, 3, 3, 14}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 23, 56, 98}, // input size (NHWC) -+ {128, 3, 3, 98}, // filter size (KRSC) -+ {4, 0, 5, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 28}, // input size (NHWC) -+ {8, 3, 3, 28}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 23, 56, 100}, // input size (NHWC) -+ {128, 3, 3, 100}, // filter size (KRSC) -+ {4, 0, 5, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..b681dd2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..a848192 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+TEST(SM70_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..a3e96e2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu -@@ -0,0 +1,293 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align2, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..7b68a68 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#if 0 -+TEST(SM80_Device_Conv2d_Fprop_Precomputed_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm50.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm50.cu -new file mode 100644 -index 0000000..8f8eb88 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm50.cu -@@ -0,0 +1,88 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x2_64x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -new file mode 100644 -index 0000000..6b4fe2e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x4_32x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu -new file mode 100755 -index 0000000..2ac1dfd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32, -+ 16x32_8x2_16x16x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::Quaternion; -+ using ElementB = cutlass::Quaternion; -+ using ElementC = cutlass::Quaternion; -+ using ElementAccumulator = cutlass::Quaternion; -+ using ElementCompute = cutlass::Quaternion; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32, -+ 16x64_8x2_8x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::Quaternion; -+ using ElementB = cutlass::Quaternion; -+ using ElementC = cutlass::Quaternion; -+ using ElementAccumulator = cutlass::Quaternion; -+ using ElementCompute = cutlass::Quaternion; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<16, 64, 8>, -+ cutlass::gemm::GemmShape<8, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32, -+ 32x32_8x2_16x16x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::Quaternion; -+ using ElementB = cutlass::Quaternion; -+ using ElementC = cutlass::Quaternion; -+ using ElementAccumulator = cutlass::Quaternion; -+ using ElementCompute = cutlass::Quaternion; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Conv2d_Fprop_Optimized_ImplicitGemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32, -+ 16x32_8x2_16x16x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::Quaternion; -+ using ElementB = cutlass::Quaternion; -+ using ElementC = cutlass::Quaternion; -+ using ElementAccumulator = cutlass::Quaternion; -+ using ElementCompute = cutlass::Quaternion; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..0f794f2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm75.cu -@@ -0,0 +1,526 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x128_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x128_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x256_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x64_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x256_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x128_128x2_32x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x128_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x128_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x256_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x64_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x256_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x128_128x2_32x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..9af25ab ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm80.cu -@@ -0,0 +1,527 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x128_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x128_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x256_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x64_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x256_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x128_128x4_32x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 4, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x128_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x128_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x256_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x64_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x256_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x128_128x4_32x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 4, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..096e44f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu -@@ -0,0 +1,125 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..d285a2d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm80.cu -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..9d77fb1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.cu -@@ -0,0 +1,685 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x128_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x128_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x256_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x64_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x256_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x64_64x2_64x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x128_64x2_32x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x64_64x2_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x128_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x128_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x256_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x64_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x256_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x64_64x2_64x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x128_64x2_32x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x64_64x2_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..120ce06 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.cu -@@ -0,0 +1,686 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x256_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x64_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x256_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x64_64x4_64x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 4, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x128_64x4_32x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 4, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x64_64x6_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 6, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x256_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x64_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x256_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x64_64x4_64x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 4, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x128_64x4_32x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 4, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x64_64x6_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 6, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..d15f5c9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm75.cu -@@ -0,0 +1,125 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..e192d65 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm80.cu -@@ -0,0 +1,126 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..d15a435 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,142 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu -new file mode 100644 -index 0000000..19dc3c9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" -+#include "cutlass/epilogue/thread/linear_combination_residual_block.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_with_broadcast_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+// Test residual block fusion: UnaryOp(BinaryOp(ActivationOp(Conv2d(X) + bias), residual)) -+// LinearCombinationResidualBlock does not support the split-k mode unless ActivationOp is Identity. -+// This is because the activation needs to be applied to the fully accumulated output of the Conv2d op, -+// which only the last thread block would have an access to, before applying BinaryOp. -+// The epilogue functor in the last thread block would have to be given three inputs, namely -+// partial outputs, bias, and residual, but this is not supported in the current interface. -+// Set TestSplitK = false to skip split-k tests with non-trivial ActivationOp. -+template < -+ typename ElementAccumulator, -+ template class ActivationOp, -+ template class BinaryOp, -+ template class UnaryOp, -+ bool TestSplitK = false -+> -+void TestResidaulBlock() { -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementD = ElementC; -+ using ElementCompute = ElementAccumulator; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< -+ ElementD, -+ ElementAccumulator, -+ ElementCompute, -+ ElementC, -+ 8, -+ ActivationOp, -+ BinaryOp, -+ UnaryOp -+ >; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ struct ReferenceOp { -+ using OutputOp = typename Conv2dFprop::EpilogueOutputOp; -+ using ElementZ = typename OutputOp::ElementZ; -+ -+ ActivationOp activation; -+ BinaryOp binary_op; -+ UnaryOp unary_op; -+ -+ void operator()(ElementZ &Z, ElementZ&, ElementCompute conv2d, ElementCompute residual) { -+ Z = ElementZ(unary_op(binary_op(activation(conv2d), residual))); -+ } -+ }; -+ -+ bool passed = test::conv::device::TestAllConv2dWithBroadcast(); -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Device_Conv2d_Fprop_With_Residual_Block_Plus_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ // Resnet -+ TestResidaulBlock(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu -new file mode 100644 -index 0000000..17c77be ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu -@@ -0,0 +1,177 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" -+#include "cutlass/epilogue/thread/linear_combination_residual_block.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_with_broadcast_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8, -+ cutlass::epilogue::thread::ReLu -+ >; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2dWithBroadcast()); -+} -+ -+// Test residual block fusion: UnaryOp(BinaryOp(ActivationOp(Conv2d(X) + bias), residual)) -+// LinearCombinationResidualBlock does not support the split-k mode unless ActivationOp is Identity. -+// This is because the activation needs to be applied to the fully accumulated output of the Conv2d op, -+// which only the last thread block would have an access to, before applying BinaryOp. -+// The epilogue functor in the last thread block would have to be given three inputs, namely -+// partial outputs, bias, and residual, but this is not supported in the current interface. -+// Set TestSplitK = false to skip split-k tests with non-trivial ActivationOp. -+template < -+ typename ElementAccumulator, -+ template class ActivationOp, -+ template class BinaryOp, -+ template class UnaryOp, -+ bool TestSplitK = true -+> -+void TestResidaulBlock() { -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementD = ElementC; -+ using ElementCompute = ElementAccumulator; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< -+ ElementD, -+ ElementAccumulator, -+ ElementCompute, -+ ElementC, -+ 8, -+ ActivationOp, -+ BinaryOp, -+ UnaryOp -+ >; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ struct ReferenceOp { -+ using OutputOp = typename Conv2dFprop::EpilogueOutputOp; -+ using ElementZ = typename OutputOp::ElementZ; -+ -+ ActivationOp activation; -+ BinaryOp binary_op; -+ UnaryOp unary_op; -+ -+ void operator()(ElementZ &Z, ElementZ&, ElementCompute conv2d, ElementCompute residual) { -+ Z = ElementZ(unary_op(binary_op(activation(conv2d), residual))); -+ } -+ }; -+ -+ bool passed = test::conv::device::TestAllConv2dWithBroadcast(); -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Device_Conv2d_Fprop_With_Residual_Block_Plus_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ // Resnet -+ TestResidaulBlock(); -+} -+ -+TEST(SM75_Device_Conv2d_Fprop_With_Residual_Block_Multiply_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ // EfficientNet V2 -+ // Do not run split-K tests since the activation op is not Identity. -+ TestResidaulBlock(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_reduction_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_reduction_sm75.cu -new file mode 100644 -index 0000000..dc56278 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_reduction_sm75.cu -@@ -0,0 +1,94 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_with_elementwise.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_with_reduction_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv2d_Fprop_With_Reduction_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationWithElementwise< -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithReduction< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2dWithReduction()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_problems.h b/3rdparty/cutlass/test/unit/conv/device/conv2d_problems.h -new file mode 100644 -index 0000000..5d1fbdc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_problems.h -@@ -0,0 +1,860 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM testbed sizes for Conv2d problem -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+using Conv2dProblemVector = std::vector; -+ -+// -+// Structures to prune items from Conv2dProblemVector -+// -+// Specification template for pruning items for convolution problem lists -+template struct Specification -+{ -+ virtual ~Specification() = default; -+ virtual bool is_satisfied(T item) const = 0; -+}; -+ -+// input size (NHWC) specification -+struct InputSizeSpecification : Specification -+{ -+ cutlass::Tensor4DCoord input_size; -+ -+ InputSizeSpecification(cutlass::Tensor4DCoord input_size_) : input_size(input_size_) {} -+ -+ bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { -+ return ((input_size.n() == item.N) && (input_size.h() == item.H) && (input_size.w() == item.W) && (input_size.c() == item.C)); -+ } -+}; -+ -+// stride (stride_h, stride_w) specification -+struct StrideSpecification : Specification -+{ -+ cutlass::MatrixCoord stride; -+ -+ StrideSpecification(cutlass::MatrixCoord stride_) : stride(stride_) {} -+ -+ bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { -+ return ((stride.row() == item.stride_h) && (stride.column() == item.stride_h)); -+ } -+}; -+ -+// channel (C,K) specification, must be multiple of minimum channel -+struct ChannelDivisibilitySpecification : Specification -+{ -+ int channel_multiple; -+ -+ ChannelDivisibilitySpecification(int channel_multiple_) : channel_multiple(channel_multiple_) {} -+ -+ bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { -+ return ((item.K % channel_multiple == 0) && (item.C % channel_multiple == 0)); -+ } -+}; -+ -+// -+// Pruning function for items from Conv2dProblemVector based on a Specification -+// -+inline Conv2dProblemVector prune(Conv2dProblemVector const &items, -+ Specification const &spec) -+{ -+ Conv2dProblemVector pruned_list; -+ -+ for (auto& p : items) -+ if (spec.is_satisfied(p)) -+ pruned_list.push_back(p); -+ return pruned_list; -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////// -+/// Structure TestbedConv2dProblemSizes initializes and holds conv default and -+/// important network sizes -+//////////////////////////////////////////////////////////////////////////// -+struct TestbedConv2dProblemSizes { -+ -+ // -+ // Data members -+ // -+ int minimum_channel_size; -+ -+ Conv2dProblemVector conv2d_default_sizes; -+ Conv2dProblemVector conv2d_rigorous_sizes; -+ Conv2dProblemVector conv2d_resnet50_sizes; -+ Conv2dProblemVector conv2d_resnet50_sizes_perf; -+ -+ // -+ // Methods -+ // -+ /// Default ctor -+ TestbedConv2dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { -+ initialize_conv2d_default_sizes(); -+ initialize_conv2d_rigorous_sizes(); -+ initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes, 1 /*batch-size*/); -+ -+ initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes_perf, 34 /*batch-size*/); -+ filter_all(); -+ } -+ -+ /// Eliminates some illegal cases -+ void filter_all() { -+ -+ Conv2dProblemVector *problems_vectors[] = { -+ &conv2d_default_sizes, -+ &conv2d_rigorous_sizes, -+ &conv2d_resnet50_sizes, -+ &conv2d_resnet50_sizes_perf -+ }; -+ -+ for (Conv2dProblemVector *problems : problems_vectors) { -+ Conv2dProblemVector filtered; -+ -+ for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { -+ if (!(problem.C % minimum_channel_size)) { -+ filtered.push_back(problem); -+ } -+ } -+ -+ *problems = filtered; -+ } -+ } -+ -+ // Add a few standard convolution problem sizes -+ void initialize_conv2d_default_sizes() { -+ -+ //////////////////////////////////////////////////////////////////////////////////////////// -+ // Small input size x stride (1,1) -+ // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} -+ //////////////////////////////////////////////////////////////////////////////////////////// -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 1, 1, minimum_channel_size}, // input size (NHWC) -+ {8, 1, 1, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 1, 8, minimum_channel_size}, // input size (NHWC) -+ {8, 1, 3, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 7, 8, minimum_channel_size}, // input size (NHWC) -+ {8, 3, 3, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 7, 9, minimum_channel_size}, // input size (NHWC) -+ {8, 4, 4, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {2, 7, 9, minimum_channel_size}, // input size (NHWC) -+ {8, 5, 5, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {3, 7, 9, minimum_channel_size}, // input size (NHWC) -+ {8, 6, 5, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {3, 7, 9, minimum_channel_size}, // input size (NHWC) -+ {8, 6, 6, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {3, 7, 9, minimum_channel_size}, // input size (NHWC) -+ {8, 7, 7, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////////////// -+ // Small input size x stride (2,2) -+ // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} -+ //////////////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 11, 7, minimum_channel_size}, // input size (NHWC) -+ {8, 1, 1, minimum_channel_size}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 11, 7, minimum_channel_size}, // input size (NHWC) -+ {8, 3, 3, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 13, 11, minimum_channel_size}, // input size (NHWC) -+ {8, 1, 1, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 17, 19, minimum_channel_size}, // input size (NHWC) -+ {16, 2, 2, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 23, 5, minimum_channel_size}, // input size (NHWC) -+ {16, 3, 3, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 13, 17, 8}, // input size (NHWC) -+ {24, 3, 3, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 23, 21, 8}, // input size (NHWC) -+ {24, 3, 3, 8}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 20, 24, 8}, // input size (NHWC) -+ {40, 3, 3, 8}, // filter size (KRSC) -+ {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) -+ //////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 15, 19, 160}, // input size (NHWC) -+ {224, 1, 1, 160}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 19, 37, 160}, // input size (NHWC) -+ {224, 3, 3, 160}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, 160}, // input size (NHWC) -+ {224, 2, 3, 160}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 23, 21, 128}, // input size (NHWC) -+ {224, 3, 3, 128}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 29, 37, 160}, // input size (NHWC) -+ {224, 5, 5, 160}, // filter size (KRSC) -+ {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} -+ //////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 15, 19, 32 + minimum_channel_size}, // input size (NHWC) -+ {96, 3, 3, 32 + minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 24, 64 + minimum_channel_size}, // input size (NHWC) -+ {96, 3, 3, 64 + minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2) -+ //////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 13, 16, 288}, // input size (NHWC) -+ {160, 5, 5, 288}, // filter size (KRSC) -+ {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 55, 51, 256}, // input size (NHWC) -+ {512, 1, 1, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 71, 80, 32}, // input size (NHWC) -+ {64, 5, 5, 32}, // filter size (KRSC) -+ {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, 8}, // input size (NHWC) -+ {64, 7, 7, 8}, // filter size (KRSC) -+ {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // Medium input size stride (3, 3), filter (3, 3), non-default padding -+ //////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 27, 23, 256}, // input size (NHWC) -+ {512, 3, 3, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // Medium input size padding > stride, asymmetric filter, padding and striding -+ //////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 27, 31, 256}, // input size (NHWC) -+ {512, 3, 3, 256}, // filter size (KRSC) -+ {5, 5, 7, 7}, // padding (pad_h, _, pad_w, _) -+ {3, 4}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 27, 35, 256}, // input size (NHWC) -+ {512, 7, 5, 256}, // filter size (KRSC) -+ {11, 11, 7, 7}, // padding (pad_h, _, pad_w, _) -+ {3, 5}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // Medium input size *mixed* stride (1, 2) and (2, 1), -+ // filter (3, 3), default padding -+ //////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 27, 27, 256}, // input size (NHWC) -+ {512, 3, 3, 256}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 27, 27, 256}, // input size (NHWC) -+ {512, 3, 3, 256}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ ///////////////////////////////////////////////////////////////////////////// -+ // Additional input size -+ ///////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {3, 28, 28, 256}, // input size (NHWC) -+ {256, 2, 2, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 32, 32, 16}, // input size (NHWC) -+ {32, 3, 3, 16}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {6, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {32, 24, 32, 32}, // input size (NHWC) -+ {32, 1, 2, 32}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {4, 4, 5, 128}, // input size (NHWC) -+ {256, 3, 6, 128}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {4, 3, 3, 256} // output size (NPQK) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {4, 2, 3, 256}, // input size (NHWC) -+ {328, 3, 5, 256}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {4, 1, 1, 328} // output size (NPQK) -+ )); -+ } -+ -+ -+ // Add a few large and rigorous convolution problem sizes -+ void initialize_conv2d_rigorous_sizes() { -+ -+#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -+ conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 124, 224, 96}, // input size (NHWC) -+ {24, 7, 7, 96}, // filter size (KRSC) -+ {1, 229, 129, 32} // output size (NPQK) -+ )); -+ -+ conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 233, 35, 48}, // input size (NHWC) -+ {24, 7, 5, 48}, // filter size (KRSC) -+ {1, 233, 35, 24} // output size (NPQK) -+ )); -+ -+#endif -+ -+ } -+ -+ -+ // Add resent50 layers to unit testing sizes -+ void initialize_conv2d_resnet50_sizes(Conv2dProblemVector &conv2d_problem_vector, int batch_size = 1){ -+ -+#if 0 // Resnet50 first layer (layer_id = 0) with channel = 3 is not supported in cutlass -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ [1, 224, 224, 3], // input size (NHWC) -+ [64, 7, 7, 3], // filter size (KRSC) -+ [3, 3, 3, 3], // padding (pad_h, _, pad_w, _) -+ [2, 2], // stride (stride_h, stride_w) -+ [1, 1], // dilation (dilation_h, dilation_w) -+ )); -+#endif -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 56, 56, 64}, // input size (NHWC) -+ {256, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 56, 56, 64}, // input size (NHWC) -+ {64, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 56, 56, 256}, // input size (NHWC) -+ {64, 1, 1, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 56, 56, 256}, // input size (NHWC) -+ {512, 1, 1, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 56, 56, 256}, // input size (NHWC) -+ {128, 1, 1, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 28, 28, 128}, // input size (NHWC) -+ {128, 3, 3, 128}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 28, 28, 128}, // input size (NHWC) -+ {512, 1, 1, 128}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 28, 28, 512}, // input size (NHWC) -+ {128, 1, 1, 512}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 28, 28, 512}, // input size (NHWC) -+ {1024, 1, 1, 512}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 28, 28, 512}, // input size (NHWC) -+ {256, 1, 1, 512}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 14, 14, 256}, // input size (NHWC) -+ {256, 3, 3, 256}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 14, 14, 256}, // input size (NHWC) -+ {1024, 1, 1, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 14, 14, 1024}, // input size (NHWC) -+ {256, 1, 1, 1024}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 14, 14, 1024}, // input size (NHWC) -+ {2048, 1, 1, 1024}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 14, 14, 1024}, // input size (NHWC) -+ {512, 1, 1, 1024}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 7, 7, 512}, // input size (NHWC) -+ {512, 3, 3, 512}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 7, 7, 512}, // input size (NHWC) -+ {2048, 1, 1, 512}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 7, 7, 2048}, // input size (NHWC) -+ {512, 1, 1, 2048}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ } -+ -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////// -+/// Structure TestbedGroupConv2dProblemSizes initializes and holds group conv default and -+/// important network sizes -+//////////////////////////////////////////////////////////////////////////// -+struct TestbedGroupConv2dProblemSizes { -+ -+ // -+ // Data members -+ // -+ int threadblock_n; -+ int threadblock_k; -+ int minimum_channel_size; -+ -+ Conv2dProblemVector default_single_group_sizes; -+ Conv2dProblemVector default_multiple_group_sizes; -+ -+ // -+ // Methods -+ // -+ /// Default ctor -+ TestbedGroupConv2dProblemSizes( -+ int threadblock_n_, -+ int threadblock_k_, -+ int minimum_channel_size_ = 64) -+ : threadblock_n (threadblock_n_), -+ threadblock_k (threadblock_k_), -+ minimum_channel_size (minimum_channel_size_) { -+ initialize_group_conv2d_default_sizes(); -+ filter_all(); -+ } -+ -+ /// Eliminates some illegal cases -+ void filter_all() { -+ -+ Conv2dProblemVector *problems_vectors[] = { -+ &default_single_group_sizes, -+ &default_multiple_group_sizes -+ }; -+ -+ for (Conv2dProblemVector *problems : problems_vectors) { -+ Conv2dProblemVector filtered; -+ -+ for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { -+ if (!((problem.C / problem.groups) % minimum_channel_size)) { -+ filtered.push_back(problem); -+ } -+ } -+ -+ *problems = filtered; -+ } -+ } -+ -+ // Add a few standard convolution problem sizes -+ void initialize_group_conv2d_default_sizes() { -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0 -+ // One CTA calculates a single group -+ //////////////////////////////////////////////////////////////////////////////////// -+ -+ for (int cta_per_group_k = 1; cta_per_group_k < 4; ++cta_per_group_k) { -+ // groups = 2, 3, 4 -+ for (int groups = 2; groups < 5; ++groups) { -+ -+ int conv_k = cta_per_group_k * threadblock_n * groups; -+ default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, threadblock_k * 2 * groups}, // input size (NHWC) -+ {conv_k, 3, 3, threadblock_k * 2}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ groups // groups -+ )); -+ -+ } // loop groups -+ } // loop cta_per_group_k -+ -+ // Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K -+ default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, threadblock_k}, // input size (NHWC) -+ {threadblock_n * 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 2 // groups -+ )); -+ -+ // Larger problem sizes -+ -+ default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 56, 56, 696}, // input size (NHWC) -+ {768, 3, 3, 232}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 3 // groups -+ )); -+ default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 14, 14, 1392}, // input size (NHWC) -+ {1536, 3, 3, 232}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 3 // groups -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // One CTA calculate multiple groups: CTA::N % k_per_group = 0 -+ //////////////////////////////////////////////////////////////////////////////////// -+ -+ // 2 groups per CTA -+ default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, threadblock_k * 4}, // input size (NHWC) -+ {threadblock_n, 3, 3, threadblock_k * 2}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 2 // groups -+ )); -+ -+ // 2 groups per CTA and partial gemm_k -+ default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, threadblock_k}, // input size (NHWC) -+ {threadblock_n, 3, 3, threadblock_k / 2}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 2 // groups -+ )); -+ -+ // 4 groups per CTA -+ default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, threadblock_k * 8}, // input size (NHWC) -+ {threadblock_n / 2, 3, 3, threadblock_k * 2}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 4 // groups -+ )); -+ -+ // 4 groups per CTA and partial gemm_k -+ default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, threadblock_k * 2}, // input size (NHWC) -+ {threadblock_n / 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 4 // groups -+ )); -+ } -+ -+}; -+ -+ -+} // namespace device -+} // namespace conv -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..a910d61 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,370 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Strided Dgrad (Analytic) -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ -+// run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 8}, // input size (NHWC) -+ {8, 3, 3, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x256_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x256_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Strided Dgrad (Optimized) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 56, 56, 8}, // input size (NHWC) -+ {8, 1, 1, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 55, 55, 8}, // input size (NHWC) -+ {8, 1, 1, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 56, 56, 12}, // input size (NHWC) -+ {8, 1, 1, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 55, 55, 12}, // input size (NHWC) -+ {8, 1, 1, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..b607a8a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,112 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align4, -+ 64x64_32x5_32x32x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 5, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 1, 1, 16}, // input size (NHWC) -+ {8, 3, 3, 16}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 1, 1, 16}, // input size (NHWC) -+ {8, 3, 3, 16}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_testbed.h b/3rdparty/cutlass/test/unit/conv/device/conv2d_testbed.h -new file mode 100644 -index 0000000..582b433 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_testbed.h -@@ -0,0 +1,806 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM testbed -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "conv2d_problems.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cache_testbed_output.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+template -+class TestbedConv2d { -+public: -+ -+ using ElementA = typename Conv2d::ElementA; -+ using LayoutA = typename Conv2d::LayoutA; -+ using ElementB = typename Conv2d::ElementB; -+ using LayoutB = typename Conv2d::LayoutB; -+ using ElementC = typename Conv2d::ElementC; -+ using LayoutC = typename Conv2d::LayoutC; -+ using ElementAccumulator = typename Conv2d::ElementAccumulator; -+ using ElementCompute = typename Conv2d::ElementCompute; -+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; -+ -+ /// Reduction kernel -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ >; -+ -+ using ReductionDevice = cutlass::reduction::device::ReduceSplitK; -+ using ReductionStrideIndex = typename ReductionDevice::StrideIndex; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ int tested_problem_count; -+ -+public: -+ -+ TestbedConv2d( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 8) { -+ scope = 2; -+ } -+ else if (bits == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope = 3; -+ } -+ else { -+ scope = 5; -+ } -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { -+ -+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); -+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ -+ initialize_tensor(tensor_A.host_view(), init_A, seed); -+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ tensor_D_reference.sync_device(); -+ } -+ -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ // increment tested problem count run by the testbed -+ tested_problem_count++; -+ -+#if 0 // display conv2d problem size for debugging -+ std::cout << problem_size << std::endl -+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl -+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl -+ << std::endl; -+#endif -+ -+ initialize(problem_size); -+ -+ // configure the operator -+ Conv2d conv2d_op; -+ -+ typename Conv2d::Arguments conv2d_args( -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_computed.device_ref(), -+ {alpha, beta}, -+ split_k_mode -+ ); -+ -+ // find workspace requirement for parallel split-k reduction -+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // conv2d operation with parallel split-k-mode -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // conv2d output is written to workspace in global memory -+ conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); -+ // accumulate mma for each cta in k-dimension (1.0 * A * B) -+ conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; -+ // update conv2d operator arguments -+ status = conv2d_op.update(conv2d_args, workspace.get()); -+ } -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run conv2d operator -+ status = conv2d_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run." << std::endl; -+ return false; -+ } -+ -+ -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // configure parallel reduction operator -+ ReductionDevice reduction_op; -+ -+ typename ReductionDevice::Arguments reduction_args( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), -+ problem_size.split_k_slices, -+ cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), -+ { -+ reinterpret_cast (workspace.get()), -+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ { -+ tensor_D_computed.device_data(), -+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ { -+ tensor_C.device_data(), -+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C -+ {alpha, beta} -+ ); -+ -+ status = reduction_op.initialize(reduction_args, nullptr); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run prallel reduction kernel -+ status = reduction_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ } -+ bool passed = false; -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " -+ << cudaGetErrorString(result); -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference check - support caching results -+ // -+ -+ CachedTestKey cached_test_key = CreateCachedConv2dTestKey< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ alpha, -+ beta, -+ tensor_A.host_view(), -+ tensor_B.host_view(), -+ tensor_C.host_view() -+ ); -+ -+ // -+ // Look for the cached key -+ // -+ -+ bool cached_result_loaded = false; -+ CachedTestResult cached_test_result; -+ -+ std::string conv2d_result_cache_name = -+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ auto cached = cached_results.find(cached_test_key); -+ -+ cached_result_loaded = cached.first; -+ if (cached_result_loaded) { -+ cached_test_result = cached.second; -+ } -+ } -+ -+ if (!cached_result_loaded) { -+ -+#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED -+ -+ cutlass::reference::device::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_reference.device_ref(), -+ alpha, -+ beta); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D_reference.sync_host(); -+ -+#else -+ -+ cutlass::reference::host::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref(), -+ alpha, -+ beta); -+ -+#endif -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ cached_test_result.D = TensorHash(tensor_D_reference.host_view()); -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ cached_results.append(cached_test_key, cached_test_result); -+ cached_results.write(conv2d_result_cache_name); -+ } -+ } // if (!cached_result_loaded) -+ -+ uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ passed = (tensor_D_hash == cached_test_result.D); -+ -+ EXPECT_EQ(tensor_D_hash, cached_test_result.D) -+ << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; -+ } -+ else { -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view()); -+ } -+ -+ EXPECT_TRUE(passed); -+ -+ std::stringstream ss_problem_size_text; -+ ss_problem_size_text << "nhwc_" -+ << problem_size.N << "x" -+ << problem_size.H << "x" -+ << problem_size.W << "x" -+ << problem_size.C -+ << "_krsc_" -+ << problem_size.K << "x" -+ << problem_size.R << "x" -+ << problem_size.S << "x" -+ << problem_size.C -+ << "_padding_" -+ << problem_size.pad_h << "x" -+ << problem_size.pad_w -+ << "_stride_" -+ << problem_size.stride_h << "x" -+ << problem_size.stride_w -+ << "_dilation_" -+ << problem_size.dilation_h << "x" -+ << problem_size.dilation_w << "_" -+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Conv2d_ImplicitGemm_device_" -+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") -+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : -+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) -+ << ss_problem_size_text.str() -+ << Conv2d::ThreadblockShape::kM << "x" -+ << Conv2d::ThreadblockShape::kN << "x" -+ << Conv2d::ThreadblockShape::kK << "_" -+ << Conv2d::WarpShape::kM << "x" -+ << Conv2d::WarpShape::kN << "x" -+ << Conv2d::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n"; -+ -+ results << "\nD reference (hash: " << cached_test_result.D << ")\n"; -+ -+ if (!cached_result_loaded) { -+ results -+ << tensor_D_reference.host_view() << "\n"; -+ } -+ -+ results -+ << "\nD computed (hash: " << tensor_D_hash << ")\n" -+ << tensor_D_computed.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestSpecificConv2d( -+ const Conv2dProblemVector & problem_sizes) { -+ -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ -+ TestbedConv2d testbed; -+ -+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for(auto conv_problem : problem_sizes) { -+ -+ // -+ // Test -+ // -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////// -+// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -+// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -+// Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -+// (conv_blacklist_sizes) -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestAllConv2d( -+ const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), -+ const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { -+ -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ -+ TestbedConv2d testbed; -+ -+ // -+ // Get conv problem sizes to run conv operator -+ // -+ TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); -+ -+ // Vector of conv2d problem sizes to avoid duplicate runs -+ Conv2dProblemVector conv_tested_sizes; -+ -+ // Vectors of Conv2dProblemVector (lenient/easiest to rigorous problem sizes) -+ std::vector problem_vectors = { -+ conv_test_sizes, // run user specified sizes -+ conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes -+ //conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -+#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -+ conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -+#endif -+ }; -+ -+ // Flatten 2D problem_vectors into a 1D problem_sizes -+ std::vector problem_sizes; -+ for (auto problem_vector : problem_vectors) { -+ for(auto conv_problem : problem_vector) { -+ problem_sizes.push_back(conv_problem); -+ } -+ } -+ -+ // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reverse the order (rigorous to lenient) -+ // run the most rigorous problem size first -+ if (CutlassUnitTestProblemCount()) { -+ std::reverse(problem_sizes.begin(), problem_sizes.end()); -+ } -+ -+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for(auto conv_problem : problem_sizes) { -+ -+ // Skip blacklist and avoid duplicate problem sizes -+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || -+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { -+ continue; -+ } -+ -+ // -+ // Procedurally disable certain cases -+ // -+ -+ // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kUnity)) { -+ if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+ -+ // Fixed channels algorithm requires channel count to match access size -+ if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == -+ cutlass::conv::IteratorAlgorithm::kFixedChannels) { -+ if (conv_problem.C != ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { -+ continue; -+ } -+ } -+ -+ // Few channels algorithm requires channel count to match access size -+ if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == -+ cutlass::conv::IteratorAlgorithm::kFewChannels) { -+ if (conv_problem.C % ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { -+ continue; -+ } -+ } -+ -+ // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} -+ // Although strided dgrad works for all stride combinations, we are only going -+ // to run strided dgrad for non-unity strides -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kStrided)) { -+ if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+ -+ // -+ // Test -+ // -+ // push back tested problem size to avoid re-running duplicates -+ conv_tested_sizes.push_back(conv_problem); -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts -+ if (CutlassUnitTestProblemCount() && -+ testbed.tested_problem_count > CutlassUnitTestProblemCount()) { -+ return true; -+ } -+ } -+ -+ // Small-channels convolution can't run here. -+ if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == -+ cutlass::conv::IteratorAlgorithm::kFixedChannels) { -+ -+ return true; -+ } -+ -+ // Small-channels convolution can't run here. -+ if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == -+ cutlass::conv::IteratorAlgorithm::kFewChannels) { -+ -+ return true; -+ } -+ -+ // CUTLASS DGRAD's *strided* specialization does not support split-k mode -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kStrided)) { -+ -+ passed = testbed.run( -+ cutlass::conv::Conv2dProblemSize( -+ {1, 56, 56, 8}, // input size (NHWC) -+ {8, 1, 1, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1}), // dilation (dilation_h, dilation_w) -+ cutlass::conv::SplitKMode::kSerial, -+ cutlass::from_real(2.0), -+ cutlass::from_real(2.0)); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ return passed; -+ } -+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for -+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters -+ // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep -+ // alpha and beta for local testing, but only runs one value for alpha and beta. -+ cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( -+ {1, 17, 11, 288}, // input size (NHWC) -+ {160, 3, 3, 288}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ ); -+ -+ cutlass::conv::SplitKMode split_k_modes [] = { -+ cutlass::conv::SplitKMode::kSerial, -+ cutlass::conv::SplitKMode::kParallel, -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 3, 4, 201 -+ }; -+ -+ double problem_alpha[] = { -+ 2.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ for (auto split_k_mode : split_k_modes) { -+ for (auto split_k_slice : split_k_slices) { -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ passed = testbed.run( -+ conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), -+ split_k_mode, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta)); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts -+ if (CutlassUnitTestProblemCount() && -+ testbed.tested_problem_count > CutlassUnitTestProblemCount()) { -+ return true; -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace conv -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h b/3rdparty/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h -new file mode 100644 -index 0000000..79f00d1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h -@@ -0,0 +1,665 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM testbed -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "conv2d_problems.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/host_reorder.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cache_testbed_output.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+template -+class InterleavedTestbedConv2d { -+public: -+ -+ using ElementA = typename Conv2d::ElementA; -+ using LayoutA = typename Conv2d::LayoutA; -+ using ElementB = typename Conv2d::ElementB; -+ using LayoutB = typename Conv2d::LayoutB; -+ using ElementC = typename Conv2d::ElementC; -+ using LayoutC = typename Conv2d::LayoutC; -+ using ElementAccumulator = typename Conv2d::ElementAccumulator; -+ using ElementCompute = typename Conv2d::ElementCompute; -+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; -+ -+ /// Reduction kernel -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ >; -+ -+ using ReductionDevice = cutlass::reduction::device::ReduceSplitK; -+ using ReductionStrideIndex = typename ReductionDevice::StrideIndex; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_B_reordered; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+public: -+ -+ InterleavedTestbedConv2d( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 8) { -+ scope = 2; -+ } -+ else if (bits == 16) { -+ scope = 3; -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { -+ -+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); -+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_B_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ -+ initialize_tensor(tensor_A.host_view(), init_A, seed); -+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39); -+ -+ cutlass::reorder_convK( -+ tensor_B_reordered.host_ref(), tensor_B.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size)); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_B_reordered.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ tensor_D_reference.sync_device(); -+ } -+ -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerMultiprocessor < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 //display conv2d problem size for debugging -+ std::cout << problem_size << std::endl -+ << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl -+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl -+ << std::endl; -+#endif -+ -+ initialize(problem_size); -+ -+ // configure the operator -+ Conv2d conv2d_op; -+ -+ typename Conv2d::Arguments conv2d_args( -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B_reordered.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_computed.device_ref(), -+ {alpha, beta}, -+ split_k_mode -+ ); -+ -+ // find workspace requirement for parallel split-k reduction -+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); -+ -+ // conv2d operation with parallel split-k-mode -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // conv2d output is written to workspace in global memory -+ conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); -+ // accumulate mma for each cta in k-dimension (1.0 * A * B) -+ conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; -+ // update conv2d operator arguments -+ status = conv2d_op.update(conv2d_args, workspace.get()); -+ } -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run conv2d operator -+ status = conv2d_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // configure parallel reduction operator -+ ReductionDevice reduction_op; -+ -+ typename ReductionDevice::Arguments reduction_args( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), -+ problem_size.split_k_slices, -+ cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), -+ { -+ reinterpret_cast (workspace.get()), -+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ { -+ tensor_D_computed.device_data(), -+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ { -+ tensor_C.device_data(), -+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C -+ {alpha, beta} -+ ); -+ -+ status = reduction_op.initialize(reduction_args, nullptr); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run prallel reduction kernel -+ status = reduction_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ } -+ bool passed = false; -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference check - support caching results -+ // -+ -+ CachedTestKey cached_test_key = CreateCachedConv2dTestKey< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ alpha, -+ beta, -+ tensor_A.host_view(), -+ tensor_B.host_view(), -+ tensor_C.host_view() -+ ); -+ -+ // -+ // Look for the cached key -+ // -+ -+ bool cached_result_loaded = false; -+ CachedTestResult cached_test_result; -+ -+ std::string conv2d_result_cache_name = -+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ auto cached = cached_results.find(cached_test_key); -+ -+ cached_result_loaded = cached.first; -+ if (cached_result_loaded) { -+ cached_test_result = cached.second; -+ } -+ } -+ -+ if (!cached_result_loaded) { -+ -+#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED -+ -+ cutlass::reference::device::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ cutlass::NumericConverterClamp -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_reference.device_ref(), -+ alpha, -+ beta); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " -+ << cudaGetErrorString(result); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D_reference.sync_host(); -+ -+#else -+ -+ cutlass::reference::host::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ cutlass::NumericConverterClamp -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref(), -+ alpha, -+ beta); -+ -+#endif -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ cached_test_result.D = TensorHash(tensor_D_reference.host_view()); -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ cached_results.append(cached_test_key, cached_test_result); -+ cached_results.write(conv2d_result_cache_name); -+ } -+ } // if (!cached_result_loaded) -+ -+ uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ passed = (tensor_D_hash == cached_test_result.D); -+ -+ EXPECT_EQ(tensor_D_hash, cached_test_result.D) -+ << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; -+ } -+ else { -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view()); -+ } -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Conv2d_ImplicitGemm_device_" -+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") -+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : -+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) -+ << "ncxhwx_" -+ << problem_size.N << "x" -+ << problem_size.H << "x" -+ << problem_size.W << "x" -+ << problem_size.C -+ << "_cxrskx_" -+ << problem_size.K << "x" -+ << problem_size.R << "x" -+ << problem_size.S << "x" -+ << problem_size.C -+ << "_padding_" -+ << problem_size.pad_h << "x" -+ << problem_size.pad_w -+ << "_stride_" -+ << problem_size.stride_h << "x" -+ << problem_size.stride_w -+ << "_dilation_" -+ << problem_size.dilation_h << "x" -+ << problem_size.dilation_w << "_" -+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") -+ << Conv2d::ThreadblockShape::kM << "x" -+ << Conv2d::ThreadblockShape::kN << "x" -+ << Conv2d::ThreadblockShape::kK << "_" -+ << Conv2d::WarpShape::kM << "x" -+ << Conv2d::WarpShape::kN << "x" -+ << Conv2d::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n"; -+ -+ results << "\nD reference (hash: " << cached_test_result.D << ")\n"; -+ -+ if (!cached_result_loaded) { -+ results -+ << tensor_D_reference.host_view() << "\n"; -+ } -+ -+ results -+ << "\nD computed (hash: " << tensor_D_hash << ")\n" -+ << tensor_D_computed.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////// -+// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -+// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -+// Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -+// (conv_blacklist_sizes) -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestAllInterleavedConv2d( -+ const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), -+ const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { -+ -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ -+ InterleavedTestbedConv2d testbed; -+ -+ // -+ // Get conv problem sizes to run conv operator -+ // -+ TestbedConv2dProblemSizes conv_problems(InterleavedK); // minimum channel size must be multiple of InterleavedK for interleaved layout -+ -+ // Vector of conv2d problem sizes to avoid duplicate runs -+ Conv2dProblemVector conv_tested_sizes; -+ -+ Conv2dProblemVector const *problem_vectors[] = { -+ &conv_test_sizes, // run user specified sizes -+ &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes -+ &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -+#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -+ &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -+#endif -+ }; -+ -+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for (Conv2dProblemVector const * problem_vector : problem_vectors) { -+ -+ ChannelDivisibilitySpecification channel_spec(InterleavedK); //input and output channels must be multiple of InterleavedK -+ auto pruned_problem_vector = prune(*problem_vector, channel_spec); -+ -+ // Run conv testbed on default convolution sizes -+ for(auto conv_problem : pruned_problem_vector) { -+ -+ // Skip blacklist and avoid duplicate problem sizes -+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || -+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { -+ continue; -+ } -+ -+ // -+ // Procedurally disable certain cases -+ // -+ -+ // CUTLASS DGRAD's unity stride specialization only support stride {1, 1} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kUnity)) { -+ if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+ -+ // -+ // Test -+ // -+ // push back tested problem size to avoid re-running duplicates -+ conv_tested_sizes.push_back(conv_problem); -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ -+#if 0 -+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for -+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters -+ // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep -+ // alpha and beta for local testing, but only runs one value for alpha and beta. -+ cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( -+ {1, 17, 11, 288}, // input size (NHWC) -+ {160, 3, 3, 288}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ ); -+ -+ cutlass::conv::SplitKMode split_k_modes [] = { -+ cutlass::conv::SplitKMode::kSerial, -+ cutlass::conv::SplitKMode::kParallel, -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 3, 4, 201 -+ }; -+ -+ double problem_alpha[] = { -+ 2.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ for (auto split_k_mode : split_k_modes) { -+ for (auto split_k_slice : split_k_slices) { -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ passed = testbed.run( -+ conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), -+ split_k_mode, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta)); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+#endif -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace conv -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu -new file mode 100644 -index 0000000..4fbdf98 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu -@@ -0,0 +1,133 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 64x64_8x2_32x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 32x64_8x2_32x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu -new file mode 100644 -index 0000000..c8d6bde ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu -@@ -0,0 +1,138 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 128x128_8x4_32x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -new file mode 100644 -index 0000000..8932187 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -@@ -0,0 +1,128 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+ -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..23c749a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu -@@ -0,0 +1,84 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+TEST(SM70_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..a07c9b4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu -@@ -0,0 +1,195 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..9c81b48 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,274 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 64x256_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32 >, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -new file mode 100644 -index 0000000..3c6cbf4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..991e1e5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,143 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity, -+ 1, -+ 1 -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, 1}, // input size (NHWC) -+ {1, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/3rdparty/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h -new file mode 100644 -index 0000000..117fef0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h -@@ -0,0 +1,686 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM for fused epilogue broadcast testbed -+ -+ Parallel split-k is not tested because we can just use regular conv kernel -+ when we need to use parallel-splitk. Broadcast can happen in the reduction -+ kernel. -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "conv2d_problems.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cache_testbed_output.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct Conv2dWithBroadcastReferenceOp { -+ -+ using OutputOp = typename Conv2d::EpilogueOutputOp; -+ -+ using ElementCompute = typename OutputOp::ElementCompute; -+ using ElementZ = typename OutputOp::ElementZ; -+ using ElementT = typename OutputOp::ElementT; -+ -+ typename OutputOp::BinaryOp binary_op; -+ typename OutputOp::ElementwiseOp elementwise_op; -+ -+ Conv2dWithBroadcastReferenceOp() { } -+ -+ void operator()(ElementZ &Z, ElementT &T, ElementCompute conv2d, ElementCompute bias) { -+ ElementCompute t_full = binary_op(conv2d, bias); -+ T = ElementT(t_full); -+ -+ ElementCompute z_full = elementwise_op(t_full); -+ Z = ElementZ(z_full); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Fused testbed -+// -+// Y = CONV(AB, C) -+// -+// T[n, p, q, k] = ReductionOp(Y[n, p, q, k], Broadcast[k]) -+// -+// Z[n, p, q, k] = Elementwise(T[n, p, q, k]) -+// -+ -+template < -+ typename Conv2d, -+ typename ReferenceOp, -+ bool AddBroadcastFirst = false -+> -+class TestbedConv2dWithBroadcast { -+public: -+ -+ using ElementA = typename Conv2d::ElementA; -+ using LayoutA = typename Conv2d::LayoutA; -+ using ElementB = typename Conv2d::ElementB; -+ using LayoutB = typename Conv2d::LayoutB; -+ using ElementC = typename Conv2d::ElementC; -+ using LayoutC = typename Conv2d::LayoutC; -+ using ElementAccumulator = typename Conv2d::ElementAccumulator; -+ using ElementCompute = typename Conv2d::ElementCompute; -+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; -+ using ElementZ = typename EpilogueOutputOp::ElementZ; -+ using ElementT = typename EpilogueOutputOp::ElementT; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; -+ static const bool kAddBroadcastFirst = AddBroadcastFirst; -+ static const bool kStoreT = EpilogueOutputOp::kStoreT; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_C_reference; -+ cutlass::HostTensor tensor_Z_computed; -+ cutlass::HostTensor tensor_Z_reference; -+ cutlass::HostTensor tensor_T_computed; -+ cutlass::HostTensor tensor_T_reference; -+ cutlass::HostTensor tensor_Y_reference; -+ cutlass::HostTensor tensor_Broadcast; // Input Broadcast -+ -+public: -+ -+ TestbedConv2dWithBroadcast( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 8) { -+ scope = 2; -+ } -+ else if (bits == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope = 3; -+ } -+ else { -+ scope = 5; -+ } -+ } -+ else { -+ scope = 8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { -+ -+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); -+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_Broadcast.resize({ -+ 1, -+ 1, -+ 1, -+ implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), -+ }); -+ -+ initialize_tensor(tensor_A.host_view(), init_A, seed); -+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39); -+ initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); -+ -+ for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { -+ for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { -+ for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { -+ for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { -+ tensor_C_reference.at({n, p, q, k}) = ElementAccumulator(tensor_C.at({n, p, q, k})); -+ } -+ } -+ } -+ } -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_Broadcast.sync_device(); -+ tensor_C_reference.sync_device(); -+ tensor_Z_computed.sync_device(); -+ tensor_Z_reference.sync_device(); -+ tensor_T_computed.sync_device(); -+ tensor_T_reference.sync_device(); -+ tensor_Y_reference.sync_device(); -+ } -+ -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(1)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 //display conv2d problem size for debugging -+ std::cout << problem_size << std::endl -+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl -+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl -+ << std::endl; -+#endif -+ -+ initialize(problem_size); -+ -+ // configure the operator -+ Conv2d conv2d_op; -+ typename Conv2d::Arguments conv2d_args( -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_Z_computed.device_ref(), -+ {alpha, beta}, -+ split_k_mode, -+ tensor_Broadcast.device_data(), -+ kStoreT ? tensor_T_computed.device_data() : nullptr, -+ 0, // This must be zero -+ implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() -+ ); -+ -+ // initialize the kernel -+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // run conv2d operator -+ status = conv2d_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ bool passed = false; -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " -+ << cudaGetErrorString(result); -+ -+ tensor_T_computed.sync_host(); -+ tensor_Z_computed.sync_host(); -+ -+ // -+ // Reference check -+ // -+ -+ // When kAddBroadcastFirst is true, add bias on the host -+ ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; -+ -+#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED -+ -+ cutlass::reference::device::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementAccumulator, -+ LayoutC, -+ ElementAccumulator, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C_reference.device_ref(), -+ tensor_Y_reference.device_ref(), -+ alpha, -+ beta_ref); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_Y_reference.sync_host(); -+ -+#else -+ -+ cutlass::reference::host::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementAccumulator, -+ LayoutC, -+ ElementAccumulator, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ tensor_C_reference.host_ref(), -+ tensor_Y_reference.host_ref(), -+ alpha, -+ beta_ref); -+ -+#endif -+ ReferenceOp reference_op; -+ -+ // compute tensor Z and tensor T -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ for (int k = 0; k < problem_size.K; ++k) { -+ -+ ElementZ z; -+ ElementT t; -+ -+ ElementCompute accum = tensor_Y_reference.at({n, p, q, k}); -+ ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, k})); -+ -+ -+ if (kAddBroadcastFirst) { -+ reference_op(z, t, accum + bias, -+ beta * ElementCompute(tensor_C_reference.at({n, p, q, k}))); -+ } else { -+ reference_op(z, t, accum, bias); -+ } -+ -+ tensor_Z_reference.at({n, p, q, k}) = z; -+ tensor_T_reference.at({n, p, q, k}) = t; -+ } -+ } -+ } -+ } -+ -+ if (kStoreT) { -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_T_computed.host_view(), -+ tensor_T_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ } -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_Z_computed.host_view(), -+ tensor_Z_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Conv2d_ImplicitGemm_device_" -+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") -+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : -+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) -+ << "nhwc_" -+ << problem_size.N << "x" -+ << problem_size.H << "x" -+ << problem_size.W << "x" -+ << problem_size.C -+ << "_krsc_" -+ << problem_size.K << "x" -+ << problem_size.R << "x" -+ << problem_size.S << "x" -+ << problem_size.C -+ << "_padding_" -+ << problem_size.pad_h << "x" -+ << problem_size.pad_w -+ << "_stride_" -+ << problem_size.stride_h << "x" -+ << problem_size.stride_w -+ << "_dilation_" -+ << problem_size.dilation_h << "x" -+ << problem_size.dilation_w << "_" -+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") -+ << Conv2d::ThreadblockShape::kM << "x" -+ << Conv2d::ThreadblockShape::kN << "x" -+ << Conv2d::ThreadblockShape::kK << "_" -+ << Conv2d::WarpShape::kM << "x" -+ << Conv2d::WarpShape::kN << "x" -+ << Conv2d::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n" -+ << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" -+ << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" -+ << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" -+ << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" -+ << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" -+ << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////// -+// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -+// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -+// Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -+// (conv_blacklist_sizes) -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+template , -+ bool AddBroadcastFirst = false, -+ bool TestSplitK = true -+> -+bool TestAllConv2dWithBroadcast( -+ const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(), -+ const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector()) { -+ -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ -+ TestbedConv2dWithBroadcast testbed; -+ -+ // -+ // Get conv problem sizes to run conv operator -+ // -+ TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); -+ -+ // Vector of conv2d problem sizes to avoid duplicate runs -+ Conv2dProblemVector conv_tested_sizes; -+ -+ Conv2dProblemVector const *problem_vectors[] = { -+ &conv_test_sizes, // run user specified sizes -+ &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes -+ &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -+#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -+ &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -+#endif -+ }; -+ -+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for (Conv2dProblemVector const * problem_vector : problem_vectors) { -+ -+ // Run conv testbed on default convolution sizes -+ for(auto conv_problem : *problem_vector) { -+ -+ // Skip blacklist and avoid duplicate problem sizes -+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || -+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { -+ continue; -+ } -+ -+ // -+ // Procedurally disable certain cases -+ // -+ -+ // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kUnity)) { -+ if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+ -+#if 0 // relax restrictions on analytic strided dgrad -+ // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kStrided)) { -+ if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+#endif -+ -+ // -+ // Test -+ // -+ // push back tested problem size to avoid re-running duplicates -+ conv_tested_sizes.push_back(conv_problem); -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ -+ // CUTLASS DGRAD's *strided* specialization does not support split-k mode -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kStrided)) { -+ -+ passed = testbed.run( -+ cutlass::conv::Conv2dProblemSize( -+ {1, 56, 56, 8}, // input size (NHWC) -+ {8, 1, 1, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1}), // dilation (dilation_h, dilation_w) -+ cutlass::conv::SplitKMode::kSerial, -+ cutlass::from_real(2.0), -+ cutlass::from_real(2.0)); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ return passed; -+ } -+ -+ if (!TestSplitK) -+ return passed; -+ -+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for -+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters -+ // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep -+ // alpha and beta for local testing, but only runs one value for alpha and beta. -+ cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( -+ {1, 17, 11, 288}, // input size (NHWC) -+ {160, 3, 3, 288}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ ); -+ -+ cutlass::conv::SplitKMode split_k_modes [] = { -+ cutlass::conv::SplitKMode::kSerial -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 3, 4, 201 -+ }; -+ -+ double problem_alpha[] = { -+ 2.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ for (auto split_k_mode : split_k_modes) { -+ for (auto split_k_slice : split_k_slices) { -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ passed = testbed.run( -+ conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), -+ split_k_mode, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta)); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace conv -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h b/3rdparty/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h -new file mode 100644 -index 0000000..4064648 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h -@@ -0,0 +1,643 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM testbed -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/tensor_reduce.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "conv2d_problems.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cache_testbed_output.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+template -+class TestbedConv2dWithReduction { -+public: -+ -+ using ElementA = typename Conv2d::ElementA; -+ using LayoutA = typename Conv2d::LayoutA; -+ using ElementB = typename Conv2d::ElementB; -+ using LayoutB = typename Conv2d::LayoutB; -+ using ElementC = typename Conv2d::ElementC; -+ using LayoutC = typename Conv2d::LayoutC; -+ using ElementAccumulator = typename Conv2d::ElementAccumulator; -+ using ElementCompute = typename Conv2d::ElementCompute; -+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; -+ using ElementT = typename EpilogueOutputOp::ElementTensor; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ -+ cutlass::HostTensor tensor_Reduction; -+ cutlass::HostTensor tensor_Tensor; -+ cutlass::HostTensor tensor_Final_Reduction; -+ -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+public: -+ -+ TestbedConv2dWithReduction( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope = 2; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { -+ -+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); -+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ -+ tensor_Reduction.resize({ -+ 1, -+ 1, -+ (problem_size.N * problem_size.P * problem_size.Q - 1 + Conv2d::ThreadblockShape::kM) / Conv2d::ThreadblockShape::kM, -+ (problem_size.K) -+ }); -+ -+ tensor_Final_Reduction.resize({ -+ 1, -+ 1, -+ 1, -+ (problem_size.K) -+ }); -+ -+ tensor_Tensor.resize({(problem_size.N * problem_size.P * problem_size.Q), problem_size.K}); -+ -+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ -+ initialize_tensor(tensor_A.host_view(), init_A, seed); -+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ tensor_D_reference.sync_device(); -+ } -+ -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 //display conv2d problem size for debugging -+ std::cout << problem_size << std::endl -+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl -+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl -+ << std::endl; -+#endif -+ -+ initialize(problem_size); -+ -+ // configure the operator -+ Conv2d conv2d_op; -+ -+ typename Conv2d::Arguments conv2d_args( -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_computed.device_ref(), -+ {alpha, beta}, -+ split_k_mode, -+ tensor_Reduction.device_data(), -+ tensor_Tensor.device_data(), -+ static_cast(tensor_Reduction.stride()[0]), -+ static_cast(tensor_Tensor.stride()[0]) -+ ); -+ -+ // find workspace requirement for parallel split-k reduction -+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // conv2d operation with parallel split-k-mode -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // conv2d output is written to workspace in global memory -+ conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); -+ // accumulate mma for each cta in k-dimension (1.0 * A * B) -+ conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; -+ // update conv2d operator arguments -+ status = conv2d_op.update(conv2d_args, workspace.get()); -+ } -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run conv2d operator -+ status = conv2d_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ bool passed = false; -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " -+ << cudaGetErrorString(result); -+ -+ // Final reduction over the partial reduction tensor -+ using Functor = cutlass::plus; -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementAccumulator, -+ ElementAccumulator, -+ LayoutC, -+ Functor, -+ 8, -+ ElementAccumulator -+ >; -+ -+ TensorReduction reduction(tensor_Reduction.extent(), 2); -+ -+ cutlass::DeviceAllocation reduction_device_workspace(reduction.workspace_size()); -+ -+ status = reduction.reduce( -+ tensor_Final_Reduction.device_ref(), -+ tensor_Reduction.device_ref(), -+ reduction_device_workspace.get(), -+ ElementAccumulator()); -+ -+ EXPECT_EQ(status, cutlass::Status::kSuccess); -+ EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); -+ -+ // -+ // Reference check -+ // -+ -+ tensor_D_computed.sync_host(); -+ -+#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED -+ -+ cutlass::reference::device::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_reference.device_ref(), -+ alpha, -+ beta); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D_reference.sync_host(); -+ -+#else -+ -+ cutlass::reference::host::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref(), -+ alpha, -+ beta); -+ -+#endif -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ // -+ // Reference check on reduction results -+ // -+ -+ tensor_Reduction.sync_host(); -+ tensor_Final_Reduction.sync_host(); -+ -+ // compute backwards for reduction results -+ cutlass::HostTensor reference_Reduction; -+ reference_Reduction.resize({ -+ 1, -+ 1, -+ 1, -+ (problem_size.K) -+ }); -+ -+ for (int k = 0; k < problem_size.K; ++k) { -+ ElementAccumulator reduced_value = ElementAccumulator(); -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ reduced_value += tensor_D_reference.at({n, p, q, k}); -+ } -+ } -+ } -+ reference_Reduction.at({0, 0, 0, k}) = reduced_value; -+ } -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_Final_Reduction.host_view(), -+ reference_Reduction.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Conv2d_ImplicitGemm_device_" -+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") -+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : -+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) -+ << "nhwc_" -+ << problem_size.N << "x" -+ << problem_size.H << "x" -+ << problem_size.W << "x" -+ << problem_size.C -+ << "_krsc_" -+ << problem_size.K << "x" -+ << problem_size.R << "x" -+ << problem_size.S << "x" -+ << problem_size.C -+ << "_padding_" -+ << problem_size.pad_h << "x" -+ << problem_size.pad_w -+ << "_stride_" -+ << problem_size.stride_h << "x" -+ << problem_size.stride_w -+ << "_dilation_" -+ << problem_size.dilation_h << "x" -+ << problem_size.dilation_w << "_" -+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") -+ << Conv2d::ThreadblockShape::kM << "x" -+ << Conv2d::ThreadblockShape::kN << "x" -+ << Conv2d::ThreadblockShape::kK << "_" -+ << Conv2d::WarpShape::kM << "x" -+ << Conv2d::WarpShape::kN << "x" -+ << Conv2d::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n" -+ << "\nD reference:\n" << tensor_D_reference.host_view() << "\n" -+ << "\nD computed:\n" << tensor_D_computed.host_view() << "\n" -+ << "\nreduction reference:\n" << reference_Reduction.host_view() << "\n" -+ << "\nreduction computed:\n" << tensor_Reduction.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////// -+// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -+// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -+// Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -+// (conv_blacklist_sizes) -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestAllConv2dWithReduction( -+ const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), -+ const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { -+ -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ -+ TestbedConv2dWithReduction testbed; -+ -+ // -+ // Get conv problem sizes to run conv operator -+ // -+ TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); -+ -+ // Vector of conv2d problem sizes to avoid duplicate runs -+ Conv2dProblemVector conv_tested_sizes; -+ -+ Conv2dProblemVector const *problem_vectors[] = { -+ &conv_test_sizes, // run user specified sizes -+ &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes -+ &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -+#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -+ &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -+#endif -+ }; -+ -+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for (Conv2dProblemVector const * problem_vector : problem_vectors) { -+ -+ // Run conv testbed on default convolution sizes -+ for(auto conv_problem : *problem_vector) { -+ -+ // Skip blacklist and avoid duplicate problem sizes -+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || -+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { -+ continue; -+ } -+ -+ // -+ // Procedurally disable certain cases -+ // -+ -+ // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kUnity)) { -+ if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+ -+#if 0 // relax restrictions on analytic strided dgrad -+ // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kStrided)) { -+ if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+#endif -+ -+ // -+ // Test -+ // -+ // push back tested problem size to avoid re-running duplicates -+ conv_tested_sizes.push_back(conv_problem); -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ -+ // CUTLASS DGRAD's *strided* specialization does not support split-k mode -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kStrided)) { -+ -+ passed = testbed.run( -+ cutlass::conv::Conv2dProblemSize( -+ {1, 56, 56, 8}, // input size (NHWC) -+ {8, 1, 1, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1}), // dilation (dilation_h, dilation_w) -+ cutlass::conv::SplitKMode::kSerial, -+ cutlass::from_real(2.0), -+ cutlass::from_real(2.0)); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ return passed; -+ } -+ -+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for -+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters -+ // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep -+ // alpha and beta for local testing, but only runs one value for alpha and beta. -+ cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( -+ {1, 17, 11, 288}, // input size (NHWC) -+ {160, 3, 3, 288}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ ); -+ -+ // Parallel SplitK is not tested. -+ cutlass::conv::SplitKMode split_k_modes [] = { -+ cutlass::conv::SplitKMode::kSerial, -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 3, 4, 201 -+ }; -+ -+ double problem_alpha[] = { -+ 2.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ for (auto split_k_mode : split_k_modes) { -+ for (auto split_k_slice : split_k_slices) { -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ passed = testbed.run( -+ conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), -+ split_k_mode, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta)); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace conv -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..909a1df ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,126 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+TEST(SM80_Device_Conv3d_Dgrad_Analytic_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dDgradKernel = typename cutlass::conv::kernel::DefaultConv3dDgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Conv3d_Dgrad_Optimized_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dDgradKernel = typename cutlass::conv::kernel::DefaultConv3dDgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv3dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_dgrad_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_dgrad_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..6864bc4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_dgrad_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,128 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Dgrad_Analytic_ImplicitGemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dDgradKernel = typename cutlass::conv::kernel::DefaultConv3dDgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv3dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Dgrad_Optimized_ImplicitGemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dDgradKernel = typename cutlass::conv::kernel::DefaultConv3dDgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv3dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..7484e8d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu -@@ -0,0 +1,86 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv3d_Fprop_Analytic_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dFpropKernel = typename cutlass::conv::kernel::DefaultConv3dFprop< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..24990ff ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,165 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+TEST(SM80_Device_Conv3d_Fprop_Analytic_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dFpropKernel = typename cutlass::conv::kernel::DefaultConv3dFprop< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Conv3d_Fprop_Optimized_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dFpropKernel = typename cutlass::conv::kernel::DefaultConv3dFprop< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Fprop_Optimized_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 64x256_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dFpropKernel = typename cutlass::conv::kernel::DefaultConv3dFprop< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..723e15e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Fprop_Analytic_ImplicitGemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dFpropKernel = typename cutlass::conv::kernel::DefaultConv3dFprop< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Fprop_Optimized_ImplicitGemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dFpropKernel = typename cutlass::conv::kernel::DefaultConv3dFprop< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_problems.h b/3rdparty/cutlass/test/unit/conv/device/conv3d_problems.h -new file mode 100644 -index 0000000..3c0512e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_problems.h -@@ -0,0 +1,271 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM testbed sizes for Conv2d problem -+*/ -+#pragma once -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+using Conv3dProblemVector = std::vector; -+ -+//////////////////////////////////////////////////////////////////////////// -+/// Structure TestbedConv3dProblemSizes initializes and holds conv default and -+/// important network sizes -+//////////////////////////////////////////////////////////////////////////// -+struct TestbedConv3dProblemSizes { -+ -+ // -+ // Data members -+ // -+ int minimum_channel_size; -+ Conv3dProblemVector conv3d_default_sizes; -+ Conv3dProblemVector conv3d_vnet_medical_sizes; -+ -+ // -+ // Methods -+ // -+ /// Default ctor -+ TestbedConv3dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { -+ -+ initialize_conv3d_default_sizes(); -+ initialize_conv3d_vnet_medical_sizes(conv3d_vnet_medical_sizes, 1 /*batch-size*/); -+ -+ filter_all(); -+ } -+ -+ /// Eliminates some illegal cases -+ void filter_all() { -+ -+ Conv3dProblemVector *problems_vectors[] = { -+ &conv3d_default_sizes, -+ &conv3d_vnet_medical_sizes -+ }; -+ -+ for (Conv3dProblemVector *problems : problems_vectors) { -+ Conv3dProblemVector filtered; -+ -+ for (cutlass::conv::Conv3dProblemSize const & problem : *problems) { -+ if (!(problem.C % minimum_channel_size)) { -+ filtered.push_back(problem); -+ } -+ } -+ -+ *problems = filtered; -+ } -+ } -+ -+ // Add a few standard convolution problem sizes -+ void initialize_conv3d_default_sizes() { -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 1, 3, 3, minimum_channel_size}, // input size (NDHWC) -+ {8, 1, 1, 1, minimum_channel_size}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC) -+ {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC) -+ {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 16, 16, 16, minimum_channel_size}, // input size (NDHWC) -+ {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 1, 15, 19, 160}, // input size (NDHWC) -+ {224, 1, 3, 6, 160}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 2, 1, 1, minimum_channel_size}, // input size (NDHWC) -+ {8, 2, 1, 1, minimum_channel_size}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 1, 7, 7, minimum_channel_size}, // input size (NDHWC) -+ {16, 1, 3, 3, minimum_channel_size}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 11, 15, 19, 64}, // input size (NDHWC) -+ {32, 4, 3, 6, 64}, // filter size (KTRSC) -+ cutlass::Coord<3>({2, 1, 3}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ } -+ -+ // Add vnet layers to unit testing sizes -+ void initialize_conv3d_vnet_medical_sizes(Conv3dProblemVector &conv3d_problem_vector, int batch_size = 1) { -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 32, 32, 32, 16}, // input size (NDHWC) -+ {32, 2, 2, 2, 16}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 16, 16, 16, 32}, // input size (NDHWC) -+ {32, 3, 3, 3, 32}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 16, 16, 16, 32}, // input size (NDHWC) -+ {64, 2, 2, 2, 32}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 8, 8, 8, 64}, // input size (NDHWC) -+ {64, 3, 3, 3, 64}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 8, 8, 8, 64}, // input size (NDHWC) -+ {128, 2, 2, 2, 64}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 4, 4, 4, 128}, // input size (NDHWC) -+ {128, 3, 3, 3, 128}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 8, 8, 8, 128}, // input size (NDHWC) -+ {128, 3, 3, 3, 128}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 16, 16, 16, 64}, // input size (NDHWC) -+ {64, 3, 3, 3, 64}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 32, 32, 32, 16}, // input size (NDHWC) -+ {64, 2, 2, 2, 16}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 16, 16, 16, 32}, // input size (NDHWC) -+ {128, 2, 2, 2, 32}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ } -+ -+}; -+ -+} // namespace device -+} // namespace conv -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_testbed.h b/3rdparty/cutlass/test/unit/conv/device/conv3d_testbed.h -new file mode 100644 -index 0000000..a5fa186 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_testbed.h -@@ -0,0 +1,669 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM testbed -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+ -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+ -+#include "conv3d_problems.h" -+#include "cutlass/core_io.h" -+ -+#include "cache_testbed_output.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+template -+class TestbedConv3d { -+public: -+ -+ using ElementA = typename Conv3d::ElementA; -+ using LayoutA = typename Conv3d::LayoutA; -+ using ElementB = typename Conv3d::ElementB; -+ using LayoutB = typename Conv3d::LayoutB; -+ using ElementC = typename Conv3d::ElementC; -+ using LayoutC = typename Conv3d::LayoutC; -+ using ElementAccumulator = typename Conv3d::ElementAccumulator; -+ using ElementCompute = typename Conv3d::ElementCompute; -+ using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator; -+ -+ /// Reduction kernel -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ >; -+ -+ using ReductionDevice = cutlass::reduction::device::ReduceSplitK; -+ using ReductionStrideIndex = typename ReductionDevice::StrideIndex; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+public: -+ -+ TestbedConv3d( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 8) { -+ scope = 2; -+ } -+ else if (bits == 16) { -+ scope = 4; -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv3dProblemSize const &problem_size, uint64_t seed = 2019) { -+ -+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); -+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ -+ initialize_tensor(tensor_A.host_view(), init_A, seed); -+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ tensor_D_reference.sync_device(); -+ } -+ -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Conv3d::UnderlyingKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv3dProblemSize const &problem_size, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute()) { -+ -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 //display conv2d problem size for debugging -+ std::cout << problem_size << std::endl -+ << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl -+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl -+ << std::endl; -+#endif -+ -+ initialize(problem_size); -+ -+ // configure the operator -+ Conv3d conv3d_op; -+ -+ typename Conv3d::Arguments conv3d_args( -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_computed.device_ref(), -+ {alpha, beta}, -+ split_k_mode -+ ); -+ -+ // find workspace requirement for parallel split-k reduction -+ size_t workspace_size = Conv3d::get_workspace_size(conv3d_args); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = conv3d_op.initialize(conv3d_args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // conv3d operation with parallel split-k-mode -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // conv3d output is written to workspace in global memory -+ conv3d_args.ref_D.reset(reinterpret_cast(workspace.get())); -+ // accumulate mma for each cta in k-dimension (1.0 * A * B) -+ conv3d_args.output_op = {1.0, 0.0}; -+ // update conv3d operator arguments -+ status = conv3d_op.update(conv3d_args, workspace.get()); -+ } -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run conv3d operator -+ status = conv3d_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // configure parallel reduction operator -+ ReductionDevice reduction_op; -+ -+ typename ReductionDevice::Arguments reduction_args( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), -+ problem_size.split_k_slices, -+ cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), -+ { -+ reinterpret_cast (workspace.get()), -+ ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ { -+ tensor_D_computed.device_data(), -+ ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ { -+ tensor_C.device_data(), -+ ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C -+ {alpha, beta} -+ ); -+ -+ status = reduction_op.initialize(reduction_args, nullptr); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run prallel reduction kernel -+ status = reduction_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ } -+ bool passed = false; -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " -+ << cudaGetErrorString(result); -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference check - support caching results -+ // -+ -+ CachedTestKey cached_test_key = CreateCachedConv3dTestKey< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ alpha, -+ beta, -+ tensor_A.host_view(), -+ tensor_B.host_view(), -+ tensor_C.host_view() -+ ); -+ -+ // -+ // Look for the cached key -+ // -+ -+ bool cached_result_loaded = false; -+ CachedTestResult cached_test_result; -+ -+ std::string conv2d_result_cache_name = -+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ auto cached = cached_results.find(cached_test_key); -+ -+ cached_result_loaded = cached.first; -+ if (cached_result_loaded) { -+ cached_test_result = cached.second; -+ } -+ } -+ -+ if (!cached_result_loaded) { -+ -+#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED -+ -+ cutlass::reference::device::Conv3d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ ElementCompute -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_reference.device_ref(), -+ alpha, -+ beta -+ ); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D_reference.sync_host(); -+ -+#else -+ cutlass::reference::host::Conv3d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ ElementCompute -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref(), -+ alpha, -+ beta -+ ); -+#endif -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ cached_test_result.D = TensorHash(tensor_D_reference.host_view()); -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ cached_results.append(cached_test_key, cached_test_result); -+ cached_results.write(conv2d_result_cache_name); -+ } -+ } // if (!cached_result_loaded) -+ -+ uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ passed = (tensor_D_hash == cached_test_result.D); -+ -+ EXPECT_EQ(tensor_D_hash, cached_test_result.D) -+ << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; -+ } -+ else { -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view()); -+ } -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Conv3d_ImplicitGemm_device_" -+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") -+ << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : -+ (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) -+ << "ndhwc_" -+ << problem_size.N << "x" -+ << problem_size.D << "x" -+ << problem_size.H << "x" -+ << problem_size.W << "x" -+ << problem_size.C -+ << "_ktrsc_" -+ << problem_size.K << "x" -+ << problem_size.T << "x" -+ << problem_size.R << "x" -+ << problem_size.S << "x" -+ << problem_size.C -+ << "_padding_" -+ << problem_size.pad_d << "x" -+ << problem_size.pad_h << "x" -+ << problem_size.pad_w -+ << "_stride_" -+ << problem_size.stride_d << "x" -+ << problem_size.stride_h << "x" -+ << problem_size.stride_w -+ << "_dilation_" -+ << problem_size.dilation_d << "x" -+ << problem_size.dilation_h << "x" -+ << problem_size.dilation_w << "_" -+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") -+ << Conv3d::ThreadblockShape::kM << "x" -+ << Conv3d::ThreadblockShape::kN << "x" -+ << Conv3d::ThreadblockShape::kK << "_" -+ << Conv3d::WarpShape::kM << "x" -+ << Conv3d::WarpShape::kN << "x" -+ << Conv3d::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n"; -+ -+ -+ results << "\nD reference (hash: " << cached_test_result.D << ")\n"; -+ -+ if (!cached_result_loaded) { -+ results -+ << tensor_D_reference.host_view() << "\n"; -+ } -+ -+ results -+ << "\nD computed (hash: " << tensor_D_hash << ")\n" -+ << tensor_D_computed.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////// -+// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -+// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -+// Additionaly, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -+// (conv_blacklist_sizes) -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllConv3d( -+ const Conv3dProblemVector & conv_test_sizes = Conv3dProblemVector(), -+ const Conv3dProblemVector & conv_blacklist_sizes = Conv3dProblemVector()) { -+ -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ -+ //TestbedConv3d testbed(cutlass::Distribution::Sequential, cutlass::Distribution::Sequential, cutlass::Distribution::Sequential); -+ TestbedConv3d testbed; -+ -+ // -+ // Get conv problem sizes to run conv operator -+ // -+ TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits::value); -+ -+ // Vector of conv3d problem sizes to avoid duplicate runs -+ Conv3dProblemVector conv_tested_sizes; -+ -+ Conv3dProblemVector const *problem_vectors[] = { -+ &conv3d_problems.conv3d_default_sizes, -+ &conv3d_problems.conv3d_vnet_medical_sizes, -+ &conv_test_sizes -+ }; -+ -+ // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for (Conv3dProblemVector const * problem_vector : problem_vectors) { -+ -+ // Run conv testbed on default convolution sizes -+ for(auto conv_problem : *problem_vector) { -+ -+ // Skip blacklist and avoid duplicate problem sizes -+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || -+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { -+ continue; -+ } -+ -+ // -+ // Procedurally disable certain cases -+ // -+ -+ // CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ ((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kUnity) || -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport == -+ cutlass::conv::StrideSupport::kUnity))) { -+ if (!((conv_problem.stride_d == 1) && -+ (conv_problem.stride_h == 1) && -+ (conv_problem.stride_w == 1)) -+ ) { -+ continue; -+ } -+ } -+ -+ // -+ // Test -+ // -+ // push back tested problem size to avoid re-running duplicates -+ conv_tested_sizes.push_back(conv_problem); -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ -+ // Sweep split-k-slice using serial reduction with non-unity alpha and non-zero beta for -+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters -+ // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep -+ // alpha and beta for local testing, but only runs one value for alpha and beta. -+ cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size ( -+ {1, 8, 8, 8, 32}, // input size (NDHWC) -+ {32, 3, 3, 3, 32}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ ); -+ -+ cutlass::conv::SplitKMode split_k_modes [] = { -+ cutlass::conv::SplitKMode::kSerial, -+ cutlass::conv::SplitKMode::kParallel -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 3, 4, 201 -+ }; -+ -+ double problem_alpha[] = { -+ 2.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ for (auto split_k_mode : split_k_modes) { -+ for (auto split_k_slice : split_k_slices) { -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ passed = testbed.run( -+ conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), -+ split_k_mode, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta)); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace conv -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..4da6f71 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu -@@ -0,0 +1,84 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv3d_Wgrad_Analytic_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dWgradKernel = typename cutlass::conv::kernel::DefaultConv3dWgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..9d4f228 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,165 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+TEST(SM80_Device_Conv3d_Wgrad_Analytic_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dWgradKernel = typename cutlass::conv::kernel::DefaultConv3dWgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Conv3d_Wgrad_Optimized_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dWgradKernel = typename cutlass::conv::kernel::DefaultConv3dWgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv3dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Wgrad_Optimized_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 64x256_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dWgradKernel = typename cutlass::conv::kernel::DefaultConv3dWgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv3dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..abcb58b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,126 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Wgrad_Analytic_ImplicitGemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dWgradKernel = typename cutlass::conv::kernel::DefaultConv3dWgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv3d_Wgrad_Optimized_ImplicitGemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dWgradKernel = typename cutlass::conv::kernel::DefaultConv3dWgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv3dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h -new file mode 100644 -index 0000000..1c2506c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h -@@ -0,0 +1,473 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Depthwise Direct Conv testbed -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cache_testbed_output.h" -+#include "conv2d_problems.h" -+#include "cutlass/conv/device/direct_convolution.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+template -+class TestbedDepthwiseDirectConv2d { -+ public: -+ -+ using ElementA = typename Conv2d::ElementA; -+ using LayoutA = typename Conv2d::LayoutA; -+ using ElementB = typename Conv2d::ElementB; -+ using LayoutB = typename Conv2d::LayoutB; -+ using ElementC = typename Conv2d::ElementC; -+ using LayoutC = typename Conv2d::LayoutC; -+ using ElementAccumulator = typename Conv2d::ElementAccumulator; -+ using ElementCompute = typename Conv2d::ElementCompute; -+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; -+ -+ public: -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_reordered_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ int tested_problem_count; -+ -+ public: -+ TestbedDepthwiseDirectConv2d(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080) -+ : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {} -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor(cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 8) { -+ scope = 2; -+ } else if (bits == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope = 3; -+ } else { -+ scope = 5; -+ } -+ } else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -scope, 0); -+ } else if (dist_kind == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(view); -+ -+ } else if (dist_kind == cutlass::Distribution::Gaussian) { -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } else if (dist_kind == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } else { -+ } -+ } -+ -+ void initialize(cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { -+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); -+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_reordered_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ -+ initialize_tensor(tensor_A.host_view(), init_A, seed); -+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_reordered_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_reordered_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ tensor_D_reference.sync_device(); -+ } -+ -+ bool sufficient(int smem_size) const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run(cutlass::conv::Conv2dProblemSize const &problem_size, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha = ElementCompute(1.5), -+ ElementCompute beta = ElementCompute(1)) { -+ // increment tested problem count run by the testbed -+ tested_problem_count++; -+ -+#if 0 // display conv2d problem size for debugging -+ std::cout << problem_size << std::endl -+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl -+ << "split_k_mode: " -+ << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") -+ << std::endl -+ << std::endl; -+#endif -+ -+ initialize(problem_size); -+ -+ // configure the operator -+ Conv2d conv2d_op; -+ -+ typename Conv2d::Arguments conv2d_args(problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_computed.device_ref(), -+ {alpha, beta}, -+ tensor_reordered_B.device_ref(), -+ split_k_mode); -+ -+ // find workspace requirement for parallel split-k reduction -+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = conv2d_op.can_implement(problem_size); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ status = conv2d_op.initialize(conv2d_args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ if (!sufficient(conv2d_op.get_smem_size())) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ // run conv2d operator -+ status = conv2d_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run." << std::endl; -+ return false; -+ } -+ -+ bool passed = false; -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " << cudaGetErrorString(result); -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference check - support caching results -+ // -+ -+ CachedTestKey cached_test_key = -+ CreateCachedConv2dTestKey(kConvolutionalOperator, -+ problem_size, -+ alpha, -+ beta, -+ tensor_A.host_view(), -+ tensor_B.host_view(), -+ tensor_C.host_view()); -+ -+ // -+ // Look for the cached key -+ // -+ -+ bool cached_result_loaded = false; -+ CachedTestResult cached_test_result; -+ -+ std::string conv2d_result_cache_name = -+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ auto cached = cached_results.find(cached_test_key); -+ -+ cached_result_loaded = cached.first; -+ if (cached_result_loaded) { -+ cached_test_result = cached.second; -+ } -+ } -+ -+ if (!cached_result_loaded) { -+#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED -+ -+ cutlass::reference::device::Conv2d(kConvolutionalOperator, -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_reference.device_ref(), -+ alpha, -+ beta); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D_reference.sync_host(); -+ -+#else -+ -+ cutlass::reference::host::Conv2d(kConvolutionalOperator, -+ problem_size, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref(), -+ alpha, -+ beta); -+ -+#endif -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ cached_test_result.D = TensorHash(tensor_D_reference.host_view()); -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ cached_results.append(cached_test_key, cached_test_result); -+ cached_results.write(conv2d_result_cache_name); -+ } -+ } // if (!cached_result_loaded) -+ -+ uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ passed = (tensor_D_hash == cached_test_result.D); -+ -+ EXPECT_EQ(tensor_D_hash, cached_test_result.D) -+ << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; -+ } -+ else { -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view()); -+ } -+ -+ EXPECT_TRUE(passed); -+ -+ std::stringstream ss_problem_size_text; -+ ss_problem_size_text << "nhwc_" -+ << problem_size.N << "x" -+ << problem_size.H << "x" -+ << problem_size.W << "x" -+ << problem_size.C -+ << "_krsc_" -+ << problem_size.K << "x" -+ << problem_size.R << "x" -+ << problem_size.S << "x" -+ << problem_size.C -+ << "_padding_" -+ << problem_size.pad_h << "x" -+ << problem_size.pad_w -+ << "_stride_" -+ << problem_size.stride_h << "x" -+ << problem_size.stride_w -+ << "_dilation_" -+ << problem_size.dilation_h << "x" -+ << problem_size.dilation_w << "_" -+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Conv2d_DirectConv_device_" -+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") -+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : -+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) -+ << ss_problem_size_text.str() -+ << Conv2d::ThreadblockShape::kM << "x" -+ << Conv2d::ThreadblockShape::kN << "x" -+ << Conv2d::ThreadblockShape::kK << "_" -+ << Conv2d::WarpShape::kM << "x" -+ << Conv2d::WarpShape::kN << "x" -+ << Conv2d::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n"; -+ -+ results << "\nD reference (hash: " << cached_test_result.D << ")\n"; -+ -+ if (!cached_result_loaded) { -+ results -+ << tensor_D_reference.host_view() << "\n"; -+ } -+ -+ results -+ << "\nD computed (hash: " << tensor_D_hash << ")\n" -+ << tensor_D_computed.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestSpecificDepthwiseDirectConv2d(const Conv2dProblemVector &problem_sizes) { -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ TestbedDepthwiseDirectConv2d testbed; -+ -+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for (auto conv_problem : problem_sizes) { -+ // -+ // Test -+ // -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace conv -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -new file mode 100644 -index 0000000..8efc73e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -@@ -0,0 +1,429 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Depthwise Direct Conv interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_depthwise_fprop.h" -+#include "cutlass/conv/device/direct_convolution.h" -+ -+#include "conv2d_testbed.h" -+#include "depthwise_conv2d_direct_conv_testbed.h" -+ -+std::vector DepthwiseFpropProblemSizes_filter3x3() { -+ std::vector problems; -+ -+ for (int channels = 16; channels <= 512; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, channels}, // input size (NHWC) -+ {channels, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ -+ // if(channels == 512 || channels == 16*14) -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {channels, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ } -+ -+ return problems; -+} -+ -+std::vector DepthwiseFpropProblemSizes_filter5x5() { -+ std::vector problems; -+ -+ for (int channels = 16; channels < 256; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ } -+ -+ return problems; -+} -+ -+std::vector DepthwiseFpropProblemSizes_filter5x37() { -+ std::vector problems; -+ -+ for (int channels = 16; channels < 256; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 128, 128, channels}, // input size (NHWC) -+ {channels, 5, 37, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 108, // split_k_slices -+ channels // groups -+ )); -+ } -+ -+ return problems; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x32_4_8x32_3x3) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 32; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<3, 3>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<8, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 4; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kStrided>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter3x3())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x64_3_16x64_5x5) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 64; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<5, 5>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 3; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kStrided>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter5x5())); -+} -+ -+#if 0 -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x32_3_16x32_5x37) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 32; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<5, 37>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 2; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kStrided>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter5x37())); -+} -+#endif -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -new file mode 100644 -index 0000000..00bbafa ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -@@ -0,0 +1,522 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Depthwise Direct Conv interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_depthwise_fprop.h" -+#include "cutlass/conv/device/direct_convolution.h" -+ -+#include "conv2d_testbed.h" -+#include "depthwise_conv2d_direct_conv_testbed.h" -+ -+std::vector DepthwiseFpropProblemSizes_filter3x3_stride1x1_dilation1x1() { -+ std::vector problems; -+ -+ for (int channels = 16; channels <= 512; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, channels}, // input size (NHWC) -+ {channels, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ } -+ return problems; -+} -+ -+std::vector DepthwiseFpropProblemSizes_filter3x3_stride2x2_dilation2x2() { -+ std::vector problems; -+ for (int channels = 16; channels <= 512; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {channels, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ } -+ -+ return problems; -+} -+ -+std::vector DepthwiseFpropProblemSizes_filter5x5_stride1x1_dilation1x1() { -+ std::vector problems; -+ -+ for (int channels = 16; channels < 256; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ } -+ -+ return problems; -+ -+} -+ -+std::vector DepthwiseFpropProblemSizes_filter5x5_stride2x2_dilation2x2() { -+ std::vector problems; -+ for (int channels = 16; channels < 256; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ } -+ -+ return problems; -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x32_4_8x32_Filter3x3_Stride1x1_Dilation1x1) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 32; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<3, 3>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<8, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 4; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; -+ using StrideShape = cutlass::MatrixShape<1, 1>; -+ using DilationShape = cutlass::MatrixShape<1, 1>; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kFixed, -+ StrideShape, -+ DilationShape>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter3x3_stride1x1_dilation1x1())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x32_4_8x32_Filter3x3_Stride2x2_Dilation2x2) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 32; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<3, 3>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<8, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 4; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; -+ using StrideShape = cutlass::MatrixShape<2, 2>; -+ using DilationShape = cutlass::MatrixShape<2, 2>; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kFixed, -+ StrideShape, -+ DilationShape>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter3x3_stride2x2_dilation2x2())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x64_3_16x64_Filter5x5_Stride1x1_Dilation1x1) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 64; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<5, 5>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 3; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; -+ using StrideShape = cutlass::MatrixShape<1, 1>; -+ using DilationShape = cutlass::MatrixShape<1, 1>; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kFixed, -+ StrideShape, -+ DilationShape>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter5x5_stride1x1_dilation1x1())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x64_3_16x64_Filter5x5_Stride2x2_Dilation2x2) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 32; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<5, 5>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 3; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; -+ using StrideShape = cutlass::MatrixShape<2, 2>; -+ using DilationShape = cutlass::MatrixShape<2, 2>; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kFixed, -+ StrideShape, -+ DilationShape>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter5x5_stride2x2_dilation2x2())); -+} -diff --git a/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -new file mode 100644 -index 0000000..3c9cf10 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -@@ -0,0 +1,221 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for Depthwise Direct Conv interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_depthwise_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+std::vector DepthwiseFpropProblemSizes() { -+ -+std::vector problems; -+ -+for ( int channels = 16; channels < 256 ; channels+=16){ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, channels}, // input size (NHWC) -+ {channels, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {channels, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {channels, 7, 7, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 7, 7, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 7, 7, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2} , // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2} , // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+} -+ -+return problems; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM60_Device_Depthwise_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 128x128_8x2_64x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ -+ /// Device-level depthwiseFpropKernel instance -+ using depthwiseFpropKernel = typename cutlass::conv::kernel::DefaultDepthwiseFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm60, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using DepthwiseFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( -+ DepthwiseFpropProblemSizes())); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM60_Device_Depthwise_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x64_8x2_32x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ -+ /// Device-level depthwiseFpropKernel instance -+ using depthwiseFpropKernel = typename cutlass::conv::kernel::DefaultDepthwiseFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm60, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using DepthwiseFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( -+ DepthwiseFpropProblemSizes())); -+ -+} -diff --git a/3rdparty/cutlass/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..acf073f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,395 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ SingleGroupPerCTA_128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ SingleGroupPerCTA_64x64_64x3_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ MultipleGroupPerCTA_128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kMultipleGroup, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_multiple_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ MutipleGroupPerCTA_64x64_64x3_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kMultipleGroup, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_multiple_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ SingleGroupPerCTA_128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+// Optimized multistage singleGroup kernel -+TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ SingleGroupPerCTA_64x64_64x3_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+// Optimized 2 stage singleGroup kernel -+TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ SingleGroupPerCTA_64x64_64x2_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/array.cu b/3rdparty/cutlass/test/unit/core/array.cu -new file mode 100644 -index 0000000..910d1af ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/array.cu -@@ -0,0 +1,261 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/util/device_memory.h" -+#pragma warning( disable : 4800) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace core { -+ -+/// Each thread clears its array and writes to global memory. No PRMT instructions should -+/// be generated if Array is a multiple of 32 bits. -+template -+__global__ void test_array_clear(cutlass::Array *ptr) { -+ -+ cutlass::Array storage; -+ -+ storage.clear(); -+ -+ ptr[threadIdx.x] = storage; -+} -+ -+/// Each thread writes its thread index into the elements of its array and then writes the result -+/// to global memory. -+template -+__global__ void test_array_threadid(cutlass::Array *ptr) { -+ -+ cutlass::Array storage; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ storage.at(i) = T(int(threadIdx.x)); -+ } -+ -+ ptr[threadIdx.x] = storage; -+} -+ -+/// Each thread writes its thread index into the elements of its array and then writes the result -+/// to global memory. -+template -+__global__ void test_array_sequence(cutlass::Array *ptr) { -+ -+ cutlass::Array storage; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ storage.at(i) = T(i); -+ } -+ -+ ptr[threadIdx.x] = storage; -+} -+ -+} // namespace core -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TestArray { -+public: -+ -+ // -+ // Data members -+ // -+ -+ /// Number of threads -+ int const kThreads = 32; -+ -+ typedef cutlass::Array ArrayTy; -+ -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ TestArray() { -+ -+ } -+ -+ /// Runs the test -+ void run() { -+ -+ /// Device memory containing output -+ cutlass::device_memory::allocation< ArrayTy > output(kThreads); -+ std::vector< ArrayTy > output_host(kThreads); -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1, 1); -+ -+ test::core::test_array_clear<<< grid, block >>>(output.get()); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ // -+ // Verify contains all zeros -+ // -+ -+ cutlass::device_memory::copy_to_host(output_host.data(), output.get(), kThreads); -+ -+ result = cudaGetLastError(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ char const *ptr_host = reinterpret_cast(output_host.data()); -+ for (int i = 0; i < sizeof(ArrayTy) * kThreads; ++i) { -+ EXPECT_FALSE(ptr_host[i]); -+ } -+ -+ // -+ // Verify each element contains the low bits of the thread Id -+ // -+ -+ test::core::test_array_threadid<<< grid, block >>>(output.get()); -+ -+ result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ cutlass::device_memory::copy_to_host(output_host.data(), output.get(), kThreads); -+ -+ result = cudaGetLastError(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ for (int i = 0; i < kThreads; ++i) { -+ T tid = T(i); -+ -+ ArrayTy thread = output_host.at(i); -+ -+ // Element-wise access -+ for (int j = 0; j < N; ++j) { -+ EXPECT_TRUE(tid == thread[j]); -+ } -+ -+ // Iterator access -+ for (auto it = thread.begin(); it != thread.end(); ++it) { -+ EXPECT_TRUE(tid == *it); -+ } -+ -+ // Range-based for -+ for (auto const & x : thread) { -+ EXPECT_TRUE(tid == x); -+ } -+ } -+ -+ // -+ // Verify each element -+ // -+ -+ test::core::test_array_sequence<<< grid, block >>>(output.get()); -+ -+ result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ cutlass::device_memory::copy_to_host(output_host.data(), output.get(), kThreads); -+ -+ result = cudaGetLastError(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ for (int i = 0; i < kThreads; ++i) { -+ -+ ArrayTy thread = output_host.at(i); -+ -+ // Element-wise access -+ for (int j = 0; j < N; ++j) { -+ T got = T(j); -+ EXPECT_TRUE(got == thread[j]); -+ } -+ -+ // Iterator access -+ int j = 0; -+ for (auto it = thread.begin(); it != thread.end(); ++it, ++j) { -+ T got = T(j); -+ EXPECT_TRUE(got == *it); -+ } -+ -+ // Range-based for -+ j = 0; -+ for (auto const & x : thread) { -+ T got = T(j); -+ EXPECT_TRUE(got == x); -+ ++j; -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(Array, Int8x16) { -+ TestArray().run(); -+} -+ -+TEST(Array, Int32x4) { -+ TestArray().run(); -+} -+ -+#if __CUDA_ARCH__ >= 520 -+TEST(Array, Float16x8) { -+ TestArray().run(); -+} -+#endif -+ -+TEST(Array, FloatBF16x8) { -+ TestArray().run(); -+} -+ -+TEST(Array, FloatTF32x4) { -+ TestArray().run(); -+} -+ -+TEST(Array, Float32x4) { -+ TestArray().run(); -+} -+ -+TEST(Array, Int4x32) { -+ TestArray().run(); -+} -+ -+TEST(Array, Uint4x32) { -+ TestArray().run(); -+} -+ -+TEST(Array, Bin1x128) { -+ TestArray().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/bfloat16.cu b/3rdparty/cutlass/test/unit/core/bfloat16.cu -new file mode 100644 -index 0000000..6227250 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/bfloat16.cu -@@ -0,0 +1,218 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/core_io.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/host_tensor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+__global__ void convert_bf16_f32(cutlass::bfloat16_t *output, float const *input, int N) { -+ int tid = threadIdx.x + blockIdx.x * blockDim.x; -+ if (tid < N) { -+ output[tid] = static_cast(input[tid]); -+ } -+} -+ -+__global__ void convert_and_pack_bf16(cutlass::bfloat16_t *output, float const *input, int N) { -+ int tid = threadIdx.x + blockIdx.x * blockDim.x; -+ if (tid * 2 < N) { -+ -+ cutlass::NumericArrayConverter convert; -+ -+ cutlass::Array *dst_ptr = -+ reinterpret_cast *>(output + tid * 2); -+ -+ cutlass::Array const *src_ptr = -+ reinterpret_cast const *>(input + tid * 2); -+ -+ *dst_ptr = convert(*src_ptr); -+ } -+} -+ -+TEST(bfloat16_t, device_conversion) { -+ using T = cutlass::bfloat16_t; -+ using S = float; -+ -+ int const N = 256; -+ -+ cutlass::HostTensor destination({N, 1}); -+ cutlass::HostTensor source({N, 1}); -+ -+ for (int i = 0; i < N; ++i) { -+ source.at({i, 0}) = float(i - 128); -+ destination.at({i, 0}) = T(0); -+ } -+ -+ source.sync_device(); -+ destination.sync_device(); -+ -+ convert_bf16_f32<<< dim3(1,1), dim3(N, 1) >>>(destination.device_data(), source.device_data(), N); -+ -+ ASSERT_EQ(cudaGetLastError(), cudaSuccess) << "Kernel launch error."; -+ -+ destination.sync_host(); -+ -+ int errors = 0; -+ for (int i = 0; i < N; ++i) { -+ T got = destination.at({i, 0}); -+ S expected = source.at({i, 0}); -+ -+ if (S(got) != expected) { -+ ++errors; -+ if (errors < 10) { -+ std::cerr << "Basic conversion error - [" << i << "] - got " << got << ", expected " << expected << "\n"; -+ } -+ } -+ -+ destination.at({i, 0}) = T(0); -+ } -+ -+ destination.sync_device(); -+ -+ convert_and_pack_bf16<<< dim3(1,1), dim3(N, 1) >>>(destination.device_data(), source.device_data(), N); -+ -+ ASSERT_EQ(cudaGetLastError(), cudaSuccess) << "Kernel launch error."; -+ -+ destination.sync_host(); -+ -+ for (int i = 0; i < N; ++i) { -+ T got = destination.at({i, 0}); -+ S expected = source.at({i, 0}); -+ -+ if (S(got) != expected) { -+ ++errors; -+ if (errors < 10) { -+ std::cerr << "Convert and pack error - [" << i << "] - got " << got << ", expected " << expected << "\n"; -+ } -+ } -+ } -+ -+ EXPECT_EQ(errors, 0); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Host -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(bfloat16_t, host_conversion) { -+ for (int i = -128; i < 128; ++i) { -+ float f = static_cast(i); -+ -+ cutlass::bfloat16_t x = static_cast(i); -+ cutlass::bfloat16_t y = static_cast(f); -+ -+ EXPECT_TRUE(static_cast(x) == i); -+ EXPECT_TRUE(static_cast(y) == f); -+ } -+ -+ // Try out default-ctor (zero initialization of primitive proxy type) -+ EXPECT_TRUE(cutlass::bfloat16_t() == 0.0_bf16); -+ -+ // Try out user-defined literals -+ EXPECT_TRUE(cutlass::bfloat16_t(7) == 7_bf16); -+ EXPECT_TRUE(7 == static_cast(7_bf16)); -+} -+ -+TEST(bfloat16_t, host_arithmetic) { -+ -+ for (int i = -100; i < 100; ++i) { -+ for (int j = -100; j < 100; ++j) { -+ -+ cutlass::bfloat16_t x = static_cast(i); -+ cutlass::bfloat16_t y = static_cast(j); -+ -+ EXPECT_TRUE(static_cast(x + y) == (i + j)); -+ } -+ } -+} -+ -+TEST(bfloat16_t, host_round) { -+ -+ struct { -+ uint32_t f32_bits; -+ uint16_t expected; -+ } tests[] = { -+ {0x40040000, 0x4004}, // M=0, R=0, S=0 => rtz -+ {0x40048000, 0x4004}, // M=0, R=1, S=0 => rtz -+ {0x40040001, 0x4004}, // M=0, R=1, S=1 => +inf -+ {0x4004c000, 0x4005}, // M=0, R=1, S=1 => +inf -+ {0x4004a000, 0x4005}, // M=0, R=1, S=1 => +inf -+ {0x40050000, 0x4005}, // M=1, R=0, S=0 => rtz -+ {0x40054000, 0x4005}, // M=1, R=0, S=1 => rtz -+ {0x40058000, 0x4006}, // M=1, R=1, S=0 => +inf -+ {0x40058001, 0x4006}, // M=1, R=1, S=1 => +inf -+ {0x7f800000, 0x7f80}, // +inf -+ {0xff800000, 0xff80}, // -inf -+ {0x7fffffff, 0x7fff}, // canonical NaN -+ {0x7ff00001, 0x7fff}, // NaN -> canonical NaN -+ {0xfff00010, 0x7fff}, // Nan -> canonical NaN -+ {0, 0} -+ }; -+ -+ bool running = true; -+ for (int i = 0; running; ++i) { -+ -+ float f32 = reinterpret_cast(tests[i].f32_bits); -+ -+ cutlass::bfloat16_t bf16 = cutlass::bfloat16_t(f32); -+ -+ bool passed = (tests[i].expected == bf16.raw()); -+ -+ EXPECT_TRUE(passed) -+ << "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits -+ << ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << bf16.raw(); -+ -+ if (!tests[i].f32_bits) { -+ running = false; -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Device -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/complex.cu b/3rdparty/cutlass/test/unit/core/complex.cu -new file mode 100644 -index 0000000..2962f5a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/complex.cu -@@ -0,0 +1,195 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief CUTLASS host-device template for complex numbers supporting all CUTLASS numeric types. -+*/ -+ -+// Standard Library's std::complex used for reference checking -+#include -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/constants.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, f64_to_f32_conversion) { -+ -+ cutlass::complex source = {1.5, -1.25}; -+ -+ cutlass::complex dest = cutlass::complex(source); // explicit conversion -+ -+ EXPECT_TRUE(source.real() == 1.5 && source.imag() == -1.25 && -+ dest.real() == 1.5f && dest.imag() == -1.25f); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, f32_to_f64_conversion) { -+ -+ cutlass::complex source = {-1.5f, 1.25f}; -+ -+ cutlass::complex dest = source; // implicit conversion -+ -+ EXPECT_TRUE(source.real() == -1.5f && source.imag() == 1.25f && -+ dest.real() == -1.5 && dest.imag() == 1.25); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, s32_to_f64_conversion) { -+ -+ cutlass::complex source = {-2, 1}; -+ -+ cutlass::complex dest = source; // implicit conversion -+ -+ EXPECT_TRUE(source.real() == -2 && source.imag() == 1 && -+ dest.real() == -2 && dest.imag() == 1); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, f16_to_f32_conversion) { -+ -+ cutlass::complex source = {1.5_hf, -1.25_hf}; -+ -+ cutlass::complex dest = cutlass::complex(source); // explicit conversion -+ -+ EXPECT_TRUE(source.real() == 1.5_hf && source.imag() == -1.25_hf && -+ dest.real() == 1.5f && dest.imag() == -1.25f); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, exp_f32) { -+ -+ cutlass::complex Z[] = { -+ {1, 1}, -+ {2 , cutlass::constants::pi()/2.0f }, -+ {0.5f, cutlass::constants::pi() }, -+ {0.25f, cutlass::constants::pi()*3/4.0f }, -+ {0, 0}, -+ }; -+ -+ cutlass::complex Expected[] = { -+ {1.4686939399158851, 2.2873552871788423}, -+ {4.524491950137825e-16, 7.38905609893065}, -+ {-1.6487212707001282, 2.019101226849069e-16}, -+ {-0.9079430793557842, 0.9079430793557843}, -+ {1, 0} -+ }; -+ -+ double tolerance = 0.00001; -+ -+ for (int i = 0; cutlass::real(Z[i]); ++i) { -+ double e_r = cutlass::real(Expected[i]); -+ double e_i = cutlass::real(Expected[i]); -+ -+ cutlass::complex got = cutlass::exp(Z[i]); -+ float g_r = cutlass::real(got); -+ float g_i = cutlass::real(got); -+ -+ EXPECT_TRUE( -+ std::abs(g_r - e_r) < tolerance && std::abs(g_i - e_i) < tolerance -+ ) << "Expected(" << Expected[i] << "), Got(" << got << ")"; -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+ -+ /// Thorough testing for basic complex math operators. Uses std::complex as a reference. -+ template -+ struct ComplexOperators { -+ ComplexOperators() { -+ for (int ar = -N; ar <= N; ++ar) { -+ for (int ai = -N; ai <= N; ++ai) { -+ for (int br = -N; br <= N; ++br) { -+ for (int bi = -N; bi <= N; ++bi) { -+ -+ cutlass::complex Ae(T(ar) / T(M), T(ai) / T(M)); -+ cutlass::complex Be(T(br) / T(M), T(bi) / T(M)); -+ -+ std::complex Ar(T(ar) / T(M), T(ai) / T(M)); -+ std::complex Br(T(br) / T(M), T(bi) / T(M)); -+ -+ cutlass::complex add_e = Ae + Be; -+ cutlass::complex sub_e = Ae - Be; -+ cutlass::complex mul_e = Ae * Be; -+ -+ std::complex add_r = (Ar + Br); -+ std::complex sub_r = (Ar - Br); -+ std::complex mul_r = (Ar * Br); -+ -+ EXPECT_EQ(real(add_e), real(add_r)); -+ EXPECT_EQ(imag(add_e), imag(add_r)); -+ -+ EXPECT_EQ(real(sub_e), real(sub_r)); -+ EXPECT_EQ(imag(sub_e), imag(sub_r)); -+ -+ EXPECT_EQ(real(mul_e), real(mul_r)); -+ EXPECT_EQ(imag(mul_e), imag(mul_r)); -+ -+ if (!(br == 0 && bi == 0)) { -+ -+ cutlass::complex div_e = Ae / Be; -+ std::complex div_r = Ar / Br; -+ -+ T const kRange = T(0.001); -+ -+ EXPECT_NEAR(real(div_e), real(div_r), kRange); -+ EXPECT_NEAR(imag(div_e), imag(div_r), kRange); -+ } -+ } -+ } -+ } -+ } -+ } -+ }; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, host_float) { -+ test::ComplexOperators test; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, host_double) { -+ test::ComplexOperators test; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/float8.cu b/3rdparty/cutlass/test/unit/core/float8.cu -new file mode 100644 -index 0000000..b685838 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/float8.cu -@@ -0,0 +1,103 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for basic float8 functionality -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/numeric_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(float_e4m3_t, host_conversion) { -+ for (int i = -8; i < 8; ++i) { -+ float f = static_cast(i); -+ -+ cutlass::float_e4m3_t x = static_cast(i); -+ cutlass::float_e4m3_t y = static_cast(f); -+ -+ EXPECT_TRUE(static_cast(x) == i); -+ EXPECT_TRUE(static_cast(y) == f); -+ } -+ -+ // Try out default-ctor (zero initialization of primitive proxy type) -+ EXPECT_TRUE(cutlass::float_e4m3_t() == 0.0_fe4m3); -+ -+ // Try out user-defined literals -+ EXPECT_TRUE(cutlass::float_e4m3_t(7) == 7_fe4m3); -+ EXPECT_TRUE(7 == static_cast(7_fe4m3)); -+} -+ -+TEST(float_e5m2_t, host_conversion) { -+ for (int i = -8; i < 8; ++i) { -+ float f = static_cast(i); -+ -+ cutlass::float_e5m2_t x = static_cast(i); -+ cutlass::float_e5m2_t y = static_cast(f); -+ -+ EXPECT_TRUE(static_cast(x) == i); -+ EXPECT_TRUE(static_cast(y) == f); -+ } -+ -+ // Try out default-ctor (zero initialization of primitive proxy type) -+ EXPECT_TRUE(cutlass::float_e5m2_t() == 0.0_fe5m2); -+ -+ // Try out user-defined literals -+ EXPECT_TRUE(cutlass::float_e5m2_t(7) == 7_fe5m2); -+ EXPECT_TRUE(7 == static_cast(7_fe5m2)); -+} -+ -+TEST(float_e4m3_t, host_arithmetic) { -+ for (int i = -4; i < 4; ++i) { -+ for (int j = -4; j < 4; ++j) { -+ -+ cutlass::float_e4m3_t x = static_cast(i); -+ cutlass::float_e4m3_t y = static_cast(j); -+ -+ EXPECT_TRUE(static_cast(x + y) == (i + j)); -+ } -+ } -+} -+ -+TEST(float_e5m2_t, host_arithmetic) { -+ for (int i = -4; i < 4; ++i) { -+ for (int j = -4; j < 4; ++j) { -+ -+ cutlass::float_e5m2_t x = static_cast(i); -+ cutlass::float_e5m2_t y = static_cast(j); -+ -+ EXPECT_TRUE(static_cast(x + y) == (i + j)); -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/functional.cu b/3rdparty/cutlass/test/unit/core/functional.cu -new file mode 100644 -index 0000000..bd76bc0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/functional.cu -@@ -0,0 +1,494 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for functional operators. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/functional.h" -+#include "cutlass/core_io.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/util/host_tensor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace core { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Conversion template -+template -+__global__ void unary_operator(Element *d, Element const *a) { -+ -+ Operator op; -+ -+ *d = op(*a); -+} -+ -+/// Conversion template -+template -+__global__ void binary_operator(Element *d, Element const *a, Element const *b, int Iterations = 1) { -+ -+ Operator op; -+ -+ Element a_x = *a; -+ Element b_x = *b; -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int i = 0; i < Iterations; ++i) { -+ b_x = op(a_x, b_x); -+ } -+ -+ *d = b_x; -+} -+ -+/// Conversion template -+template -+__global__ void trinary_operator( -+ Element *d, -+ Element const *a, -+ Element const *b, -+ Element const *c, -+ int Iterations = 1) { -+ -+ Operator op; -+ -+ Element a_x = a[blockIdx.x]; -+ Element b_x = b[blockIdx.x]; -+ Element c_x = c[blockIdx.x]; -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int i = 0; i < Iterations; ++i) { -+ c_x = op(a_x, b_x, c_x); -+ } -+ -+ d[blockIdx.x] = c_x; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace core -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void Functional_plus_f16xN() { -+ -+ using Element = cutlass::Array; -+ using Operator = cutlass::plus; -+ -+ using Tensor = cutlass::HostTensor; -+ -+ Tensor D({1, kN}); -+ Tensor A({1, kN}); -+ Tensor B({1, kN}); -+ Tensor C({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ A.host_data()[i] = cutlass::half_t((i * 2 + 1) % 5); -+ B.host_data()[i] = cutlass::half_t((i * 4 + 8) % 7); -+ D.host_data()[i] = cutlass::half_t(0); -+ } -+ -+ D.sync_device(); -+ A.sync_device(); -+ B.sync_device(); -+ -+ test::core::kernel::binary_operator<<< dim3(1,1), dim3(1,1) >>>( -+ reinterpret_cast(D.device_data()), -+ reinterpret_cast(A.device_data()), -+ reinterpret_cast(B.device_data()) -+ ); -+ -+ D.sync_host(); -+ -+ bool some_d_nonzero = false; -+ -+ for (int i = 0; i < kN; ++i) { -+ float a = float(A.host_data()[i]); -+ float b = float(B.host_data()[i]); -+ float d = float(D.host_data()[i]); -+ -+ EXPECT_TRUE(d == (a + b)); -+ -+ if (d != 0) { -+ some_d_nonzero = true; -+ } -+ } -+ -+ EXPECT_TRUE(some_d_nonzero); -+} -+ -+TEST(Functional, plus_f16x16) { -+ Functional_plus_f16xN<16>(); -+} -+ -+TEST(Functional, plus_f16x17) { -+ Functional_plus_f16xN<17>(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void Functional_minus_f16xN() { -+ -+ using Element = cutlass::Array; -+ using Operator = cutlass::minus; -+ -+ using Tensor = cutlass::HostTensor; -+ -+ Tensor D({1, kN}); -+ Tensor A({1, kN}); -+ Tensor B({1, kN}); -+ Tensor C({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ A.host_data()[i] = cutlass::half_t((i * 2 + 1) % 5); -+ B.host_data()[i] = cutlass::half_t((i * 4 + 8) % 7); -+ D.host_data()[i] = cutlass::half_t(0); -+ } -+ -+ D.sync_device(); -+ A.sync_device(); -+ B.sync_device(); -+ -+ test::core::kernel::binary_operator<<< dim3(1,1), dim3(1,1) >>>( -+ reinterpret_cast(D.device_data()), -+ reinterpret_cast(A.device_data()), -+ reinterpret_cast(B.device_data()) -+ ); -+ -+ D.sync_host(); -+ -+ bool some_d_nonzero = false; -+ -+ for (int i = 0; i < kN; ++i) { -+ float a = float(A.host_data()[i]); -+ float b = float(B.host_data()[i]); -+ float d = float(D.host_data()[i]); -+ -+ EXPECT_TRUE(d == (a - b)); -+ -+ if (d != 0) { -+ some_d_nonzero = true; -+ } -+ } -+ -+ EXPECT_TRUE(some_d_nonzero); -+} -+ -+TEST(Functional, minus_f16x16) { -+ Functional_minus_f16xN<16>(); -+} -+ -+TEST(Functional, minus_f16x17) { -+ Functional_minus_f16xN<17>(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void Functional_multiplies_f16xN() { -+ -+ using Element = cutlass::Array; -+ using Operator = cutlass::multiplies; -+ -+ using Tensor = cutlass::HostTensor; -+ -+ Tensor D({1, kN}); -+ Tensor A({1, kN}); -+ Tensor B({1, kN}); -+ Tensor C({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ A.host_data()[i] = cutlass::half_t((i * 2 + 1) % 5); -+ B.host_data()[i] = cutlass::half_t((i * 4 + 8) % 7); -+ D.host_data()[i] = cutlass::half_t(0); -+ } -+ -+ D.sync_device(); -+ A.sync_device(); -+ B.sync_device(); -+ -+ test::core::kernel::binary_operator<<< dim3(1,1), dim3(1,1) >>>( -+ reinterpret_cast(D.device_data()), -+ reinterpret_cast(A.device_data()), -+ reinterpret_cast(B.device_data()) -+ ); -+ -+ D.sync_host(); -+ -+ bool some_d_nonzero = false; -+ -+ for (int i = 0; i < kN; ++i) { -+ float a = float(A.host_data()[i]); -+ float b = float(B.host_data()[i]); -+ float d = float(D.host_data()[i]); -+ -+ EXPECT_TRUE(d == (a * b)); -+ -+ if (d != 0) { -+ some_d_nonzero = true; -+ } -+ } -+ -+ EXPECT_TRUE(some_d_nonzero); -+} -+ -+TEST(Functional, multiplies_f16x16) { -+ -+ Functional_multiplies_f16xN<16>(); -+} -+ -+TEST(Functional, multiplies_f16x17) { -+ -+ Functional_multiplies_f16xN<17>(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void Functional_divides_f16xN() { -+ -+ using Element = cutlass::Array; -+ using Operator = cutlass::divides; -+ -+ using Tensor = cutlass::HostTensor; -+ -+ Tensor D({1, kN}); -+ Tensor A({1, kN}); -+ Tensor B({1, kN}); -+ Tensor C({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ A.host_data()[i] = cutlass::half_t((i * 2 + 1) % 5); -+ B.host_data()[i] = cutlass::half_t((i * 4 + 8) % 7); -+ D.host_data()[i] = cutlass::half_t(0); -+ } -+ -+ D.sync_device(); -+ A.sync_device(); -+ B.sync_device(); -+ -+ test::core::kernel::binary_operator<<< dim3(1,1), dim3(1,1) >>>( -+ reinterpret_cast(D.device_data()), -+ reinterpret_cast(A.device_data()), -+ reinterpret_cast(B.device_data()) -+ ); -+ -+ D.sync_host(); -+ -+ bool some_d_nonzero = false; -+ -+ for (int i = 0; i < kN; ++i) { -+ float a = float(A.host_data()[i]); -+ float b = float(B.host_data()[i]); -+ float d = float(D.host_data()[i]); -+ -+ float expected = a / b; -+ -+ float const kThreshold = 0.0005f; -+ -+ if (std::isnan(expected)) { -+ EXPECT_TRUE(std::isnan(d)); -+ } -+ else if (std::isinf(expected)) { -+ EXPECT_TRUE(std::isinf(d)); -+ } -+ else { -+ EXPECT_TRUE(std::abs(d - expected) < kThreshold) -+ << "Got: " << d << " = " << a << " / " << b << ", expected: " << (a / b); -+ } -+ -+ if (d != 0) { -+ some_d_nonzero = true; -+ } -+ } -+ -+ EXPECT_TRUE(some_d_nonzero); -+} -+ -+TEST(Functional, divides_f16x16) { -+ -+ Functional_divides_f16xN<16>(); -+} -+ -+TEST(Functional, divides_f16x17) { -+ -+ Functional_divides_f16xN<17>(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void Functional_multiply_add_TxN() { -+ -+ using Element = cutlass::Array; -+ using Operator = cutlass::multiply_add; -+ -+ using Tensor = cutlass::HostTensor; -+ -+ Tensor D({1, kN}); -+ Tensor A({1, kN}); -+ Tensor B({1, kN}); -+ Tensor C({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ A.host_data()[i] = T((i * 2 + 1) % 5); -+ B.host_data()[i] = T((i * 4 + 8) % 7); -+ C.host_data()[i] = T((i * 3 + 11) % 11); -+ D.host_data()[i] = T(0); -+ } -+ -+ D.sync_device(); -+ A.sync_device(); -+ B.sync_device(); -+ C.sync_device(); -+ -+ test::core::kernel::trinary_operator<<< dim3(1,1), dim3(1,1) >>>( -+ reinterpret_cast(D.device_data()), -+ reinterpret_cast(A.device_data()), -+ reinterpret_cast(B.device_data()), -+ reinterpret_cast(C.device_data()) -+ ); -+ -+ D.sync_host(); -+ -+ bool some_d_nonzero = false; -+ -+ for (int i = 0; i < kN; ++i) { -+ float a = float(A.host_data()[i]); -+ float b = float(B.host_data()[i]); -+ float c = float(C.host_data()[i]); -+ float d = float(D.host_data()[i]); -+ -+ EXPECT_TRUE(d == (a * b + c)); -+ -+ if (d != 0) { -+ some_d_nonzero = true; -+ } -+ } -+ -+ EXPECT_TRUE(some_d_nonzero); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Functional, multiply_add_f16x16) { -+ Functional_multiply_add_TxN(); -+} -+ -+TEST(Functional, multiply_add_f16x17) { -+ Functional_multiply_add_TxN(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Functional, multiply_add_bf16x16) { -+ Functional_multiply_add_TxN(); -+} -+ -+TEST(Functional, multiply_add_bf16x17) { -+ Functional_multiply_add_TxN(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+cutlass::Quaternion random_quaternion(int range) { -+ return cutlass::Quaternion{ -+ T((rand() % range * 2) - range), -+ T((rand() % range * 2) - range), -+ T((rand() % range * 2) - range), -+ T((rand() % range * 2) - range) -+ }; -+} -+ -+template -+void Functional_multiply_add_QuaternionT() { -+ -+ using Element = cutlass::Quaternion; -+ using Operator = cutlass::multiply_add; -+ using HostTensor = cutlass::HostTensor; -+ -+ int const kM = 128; -+ int const kRange = 8; -+ -+ HostTensor A({kM, 1}); -+ HostTensor B({kM, 1}); -+ HostTensor C({kM, 1}); -+ HostTensor D({kM, 1}); -+ -+ srand(2021); -+ -+ for (int m = 0; m < kM; ++m) { -+ A.at({m, 0}) = random_quaternion(kRange); -+ B.at({m, 0}) = random_quaternion(kRange); -+ C.at({m, 0}) = random_quaternion(kRange); -+ } -+ -+ A.sync_device(); -+ B.sync_device(); -+ C.sync_device(); -+ D.sync_device(); -+ -+ test::core::kernel::trinary_operator<<< dim3(kM,1), dim3(1,1) >>>( -+ D.device_data(), -+ A.device_data(), -+ B.device_data(), -+ C.device_data() -+ ); -+ -+ D.sync_host(); -+ -+ for (int m = 0; m < kM; ++m) { -+ -+ Element a = A.at({m, 0}); -+ Element b = B.at({m, 0}); -+ Element c = C.at({m, 0}); -+ Element got = D.at({m, 0}); -+ Element expected = a * b + c; -+ -+ EXPECT_TRUE(got == expected); -+ } -+} -+ -+TEST(Functional, multiply_add_quaternion_f32) { -+ Functional_multiply_add_QuaternionT(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/half.cu b/3rdparty/cutlass/test/unit/core/half.cu -new file mode 100644 -index 0000000..27d0872 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/half.cu -@@ -0,0 +1,90 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/util/device_memory.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Host -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(half_t, host_conversion) { -+ for (int i = -1024; i < 1024; ++i) { -+ float f = static_cast(i); -+ -+ cutlass::half_t x = static_cast(i); -+ cutlass::half_t y = static_cast(f); -+ -+ EXPECT_TRUE(static_cast(x) == i); -+ EXPECT_TRUE(static_cast(y) == f); -+ } -+ -+ // Try out default-ctor (zero initialization of primitive proxy type) -+ EXPECT_TRUE(cutlass::half_t() == 0.0_hf); -+ -+ // Try out user-defined literals -+ EXPECT_TRUE(cutlass::half_t(7) == 7_hf); -+ EXPECT_TRUE(7 == static_cast(7_hf)); -+} -+ -+TEST(half_t, host_arithmetic) { -+ -+ for (int i = -100; i < 100; ++i) { -+ for (int j = -100; j < 100; ++j) { -+ -+ cutlass::half_t x = static_cast(i); -+ cutlass::half_t y = static_cast(j); -+ -+ EXPECT_TRUE(static_cast(x + y) == (i + j)); -+ } -+ } -+ -+ for (int i = -6; i < 6; ++i) { -+ for (int j = -6; j < 6; ++j) { -+ -+ cutlass::half_t x = static_cast(i); -+ cutlass::half_t y = static_cast(j); -+ -+ EXPECT_TRUE(static_cast(x * y) == (i * j)); -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/matrix.cu b/3rdparty/cutlass/test/unit/core/matrix.cu -new file mode 100644 -index 0000000..334521c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/matrix.cu -@@ -0,0 +1,205 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Unit tests for the small matrix class. -+*/ -+ -+#include -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/matrix.h" -+#include "cutlass/core_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix, elementwise_add) { -+ -+ using Matrix4x4 = cutlass::Matrix4x4; -+ -+ Matrix4x4 A = { -+ 1, 2, 3, 4, -+ 5, 6, 7, 8, -+ 9, 10, 11, 12, -+ 13, 14, 15, 16 -+ }; -+ -+ Matrix4x4 B = A.transpose(); -+ -+ Matrix4x4 C = A.add(B * 2.125f); -+ -+ bool passed = true; -+ -+ for (int i = 0; i < 4; ++i) { -+ for (int j = 0; j < 4; ++j) { -+ float got = C.at(i, j); -+ float expected = A.at(i, j) + A.at(j, i) * 2.125f; -+ if (got != expected) { -+ passed = false; -+ } -+ } -+ } -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ std::cout << "A:\n" << A << "\n\nB:\n" << B << "\n\nC:\n" << C << std::endl; -+ } -+} -+ -+TEST(Matrix, elementwise_multiply) { -+ -+ using Matrix4x4 = cutlass::Matrix4x4; -+ -+ Matrix4x4 A = { -+ 1, 2, 3, 4, -+ 5, 6, 7, 8, -+ 9, 10, 11, 12, -+ 13, 14, 15, 16 -+ }; -+ -+ Matrix4x4 B = A.transpose(); -+ -+ Matrix4x4 C = A.multiply(B); -+ -+ bool passed = true; -+ -+ for (int i = 0; i < 4; ++i) { -+ for (int j = 0; j < 4; ++j) { -+ float got = C.at(i, j); -+ float expected = A.at(i, j) * A.at(j, i); -+ if (got != expected) { -+ passed = false; -+ } -+ } -+ } -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ std::cout << "A:\n" << A << "\n\nB:\n" << B << "\n\nC:\n" << C << std::endl; -+ } -+} -+ -+TEST(Matrix, product_4x4_overloads) { -+ -+ using Matrix4x4 = cutlass::Matrix4x4; -+ -+ Matrix4x4 A = { -+ 1, 2, 3, 4, -+ 5, 6, 7, 8, -+ 9, 10, 11, 12, -+ 13, 14, 15, 16 -+ }; -+ -+ Matrix4x4 B = { -+ -1, -2, 0, 4, -+ 1, 2, 1, 1, -+ 3, 2, 1, 1, -+ 1, 0, 8, 2 -+ }; -+ -+ Matrix4x4 C = Matrix4x4::identity(); -+ -+ Matrix4x4 D = A * B + C; -+ -+ bool passed = true; -+ -+ for (int i = 0; i < 4; ++i) { -+ for (int j = 0; j < 4; ++j) { -+ float got = D.at(i, j); -+ float expected = (i == j ? 1.0f : 0); -+ for (int k = 0; k < 4; ++k) { -+ expected += A.at(i, k) * B.at(k, j); -+ } -+ if (got != expected) { -+ passed = false; -+ } -+ } -+ } -+ -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ std::cout << "A:\n" << A << "\n\nB:\n" << B << "\n\nC:\n" << C << "\n\nD:\n" << D << std::endl; -+ } -+} -+ -+ -+TEST(Matrix, product_4x4) { -+ -+ using Matrix4x4 = cutlass::Matrix4x4; -+ -+ Matrix4x4 A = { -+ 1, 2, 3, 4, -+ 5, 6, 7, 8, -+ 9, 10, 11, 12, -+ 13, 14, 15, 16 -+ }; -+ -+ Matrix4x4 B = { -+ -1, -2, 0, 4, -+ 1, 2, 1, 1, -+ 3, 2, 1, 1, -+ 1, 0, 8, 2 -+ }; -+ -+ Matrix4x4 C = Matrix4x4::identity(); -+ -+ // Compute product with optional source accumulator -+ Matrix4x4 D = A.product(B, C); -+ -+ bool passed = true; -+ -+ for (int i = 0; i < 4; ++i) { -+ for (int j = 0; j < 4; ++j) { -+ float got = D.at(i, j); -+ float expected = (i == j ? 1.0f : 0.0f); -+ for (int k = 0; k < 4; ++k) { -+ expected += A.at(i, k) * B.at(k, j); -+ } -+ if (got != expected) { -+ passed = false; -+ } -+ } -+ } -+ -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ std::cout << "A:\n" << A << "\n\nB:\n" << B << "\n\nC:\n" << C << "\n\nD:\n" << D << std::endl; -+ } -+ -+ for (int i = 0; i < 4; ++i) { -+ for (int j = 0; j < 4; ++j) { -+ float c = (i == j ? 1.0f : 0.0f); -+ EXPECT_TRUE(A.row(i).dot(B.column(j)) + c == D.at(i, j)); -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/core/matrix_coord.cu b/3rdparty/cutlass/test/unit/core/matrix_coord.cu -new file mode 100644 -index 0000000..c703769 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/matrix_coord.cu -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+* -+**************************************************************************************************/ -+/*! \file -+\brief unit tests for matrix_coord -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/matrix_coord.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+namespace test { -+namespace core { -+ -+ void test_matrix_coord(cutlass::MatrixCoord::Index row, cutlass::MatrixCoord::Index column) { -+ cutlass::MatrixCoord matrix_coord(row, column); -+ -+ EXPECT_EQ(matrix_coord.row(), row); -+ EXPECT_EQ(matrix_coord.column(), column); -+ } -+ -+ void test_matrix_coord_operator_addition() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ auto matrix_coord_c = matrix_coord_a + matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_c.row(), row_a + row_b); -+ EXPECT_EQ(matrix_coord_c.column(), column_a + column_b); -+ } -+ -+ void test_matrix_coord_operator_subtraction() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ auto matrix_coord_c = matrix_coord_a - matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_c.row(), row_a - row_b); -+ EXPECT_EQ(matrix_coord_c.column(), column_a - column_b); -+ } -+ -+ void test_matrix_coord_operator_multiply() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ auto matrix_coord_c = matrix_coord_a * matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_c.row(), row_a * row_b); -+ EXPECT_EQ(matrix_coord_c.column(), column_a * column_b); -+ } -+ -+ void test_matrix_coord_operator_division() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ auto matrix_coord_c = matrix_coord_a / matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_c.row(), row_a / row_b); -+ EXPECT_EQ(matrix_coord_c.column(), column_a / column_b); -+ } -+ -+ void test_matrix_coord_operator_addition_assignment() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ matrix_coord_a += matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_a.row(), row_a + row_b); -+ EXPECT_EQ(matrix_coord_a.column(), column_a + column_b); -+ } -+ -+ void test_matrix_coord_operator_subtraction_assignment() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ matrix_coord_a -= matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_a.row(), row_a - row_b); -+ EXPECT_EQ(matrix_coord_a.column(), column_a - column_b); -+ } -+ -+ void test_matrix_coord_operator_multiply_assignment() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ matrix_coord_a *= matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_a.row(), row_a * row_b); -+ EXPECT_EQ(matrix_coord_a.column(), column_a * column_b); -+ } -+ -+ void test_matrix_coord_operator_division_assignment() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ matrix_coord_a /= matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_a.row(), row_a / row_b); -+ EXPECT_EQ(matrix_coord_a.column(), column_a / column_b); -+ } -+} -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_row12_column24) { -+ cutlass::MatrixCoord::Index row = 12; -+ cutlass::MatrixCoord::Index column = 24; -+ test::core::test_matrix_coord(row, column); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_addition) { -+ test::core::test_matrix_coord_operator_addition(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_subtraction) { -+ test::core::test_matrix_coord_operator_subtraction(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_multiply) { -+ test::core::test_matrix_coord_operator_multiply(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_division) { -+ test::core::test_matrix_coord_operator_division(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_addition_assignment) { -+ test::core::test_matrix_coord_operator_addition_assignment(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_subtraction_assignment) { -+ test::core::test_matrix_coord_operator_subtraction_assignment(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_multiply_assignment) { -+ test::core::test_matrix_coord_operator_multiply_assignment(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_division_assignment) { -+ test::core::test_matrix_coord_operator_division_assignment(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/numeric_conversion.cu b/3rdparty/cutlass/test/unit/core/numeric_conversion.cu -new file mode 100644 -index 0000000..8d7a296 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/numeric_conversion.cu -@@ -0,0 +1,331 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for conversion operators. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/util/host_tensor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace core { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Simple conversion function -+template -+__global__ void convert( -+ cutlass::Array *destination, -+ cutlass::Array const *source) { -+ -+ cutlass::NumericArrayConverter convert; -+ -+ *destination = convert(*source); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void run_test() { -+ const int kN = Count; -+ -+ dim3 grid(1, 1); -+ dim3 block(1, 1); -+ -+ cutlass::HostTensor destination({1, kN}); -+ cutlass::HostTensor source({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ source.host_data()[i] = Source(i % 4); -+ } -+ -+ source.sync_device(); -+ -+ convert<<< grid, block >>>( -+ reinterpret_cast *>(destination.device_data()), -+ reinterpret_cast const *>(source.device_data()) -+ ); -+ -+ destination.sync_host(); -+ -+ for (int i = 0; i < kN; ++i) { -+ EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i])); -+ } -+} -+ -+} // namespace kernel -+} // namespace core -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(NumericConversion, f32_to_f16_rn) { -+ int const kN = 1; -+ using Source = float; -+ using Destination = cutlass::half_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f32x8_to_f16x8_rn) { -+ int const kN = 8; -+ using Source = float; -+ using Destination = cutlass::half_t; -+ test::core::kernel::run_test(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(NumericConversion, f16_to_f32_rn) { -+ int const kN = 1; -+ using Source = cutlass::half_t; -+ using Destination = float; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f16x8_to_f32x8_rn) { -+ int const kN = 8; -+ using Source = cutlass::half_t; -+ using Destination = float; -+ test::core::kernel::run_test(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(NumericConversion, f32_to_fe4m3_rn) { -+ int const kN = 1; -+ using Source = float; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f32_to_fe4m3_rn_array) { -+ int const kN = 27; -+ using Source = float; -+ using Destination = cutlass::float_e4m3_t; -+ -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f32_to_fe5m2_rn) { -+ int const kN = 1; -+ using Source = float; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f32_to_fe5m2_rn_array) { -+ int const kN = 27; -+ using Source = float; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f16_to_fe4m3_rn) { -+ int const kN = 1; -+ using Source = cutlass::half_t; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f16_to_fe4m3_rn_array) { -+ int const kN = 27; -+ using Source = cutlass::half_t; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f16_to_fe5m2_rn) { -+ int const kN = 1; -+ using Source = cutlass::half_t; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f16_to_fe5m2_rn_array) { -+ int const kN = 27; -+ using Source = cutlass::half_t; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, bf16_to_fe4m3_rn) { -+ int const kN = 1; -+ using Source = cutlass::bfloat16_t; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, bf16_to_fe4m3_rn_array) { -+ int const kN = 27; -+ using Source = cutlass::bfloat16_t; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, bf16_to_fe5m2_rn) { -+ int const kN = 1; -+ using Source = cutlass::bfloat16_t; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, bf16_to_fe5m2_rn_array) { -+ int const kN = 27; -+ using Source = cutlass::bfloat16_t; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(NumericConversion, fe4m3_to_fe5m2_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_fe5m2_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_fe4m3_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_fe4m3_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_f32_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = float; -+ test::core::kernel::run_test(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(NumericConversion, f32x8_to_s8x8_rn) { -+ -+ int const kN = 8; -+ using Source = float; -+ using Destination = int8_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_f32_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = float; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_f32_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = float; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_f16_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = cutlass::half_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_f16_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = cutlass::half_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_f16_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = cutlass::half_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_f16_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = cutlass::half_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_bf16_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = cutlass::bfloat16_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_bf16_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = cutlass::bfloat16_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_bf16_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = cutlass::bfloat16_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_bf16_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = cutlass::bfloat16_t; -+ test::core::kernel::run_test(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/predicate_vector.cu b/3rdparty/cutlass/test/unit/core/predicate_vector.cu -new file mode 100644 -index 0000000..5db96c9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/predicate_vector.cu -@@ -0,0 +1,249 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/predicate_vector.h" -+#include "cutlass/util/host_tensor.h" -+ -+namespace test { -+ -+template -+__global__ void load_predicates(unsigned *output, unsigned const *input) { -+ -+ PredicateVector predicates; -+ -+ int const word_count = (PredicateVector::kPredicates + 31) / 32; -+ -+ int i = 0; -+ for (int word_idx = 0; word_idx < word_count; ++word_idx) { -+ unsigned word = input[word_idx]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int bit = 0; bit < sizeof(unsigned) * 8; ++bit) { -+ bool pred = ((word >> bit) & 1); -+ predicates.set(i, pred); -+ -+ if (predicates.at(i) != pred) { -+ printf("ERROR - cannot read back predicate\n"); -+ } -+ ++i; -+ } -+ } -+ -+ -+ __syncthreads(); -+ -+ i = 0; -+ for (int word_idx = 0; word_idx < word_count; ++word_idx) { -+ -+ unsigned result = 0; -+ for (int bit = 0; bit < sizeof(unsigned) * 8; ++bit) { -+ bool pred = predicates.at(i ++); -+ result |= (unsigned(pred) << bit); -+ } -+ output[word_idx] = result; -+ } -+} -+} -+ -+TEST(PredicateVector, Basic) { -+ -+ static int const Bits = 32; -+ static int const Words = (Bits + 31) / 32; -+ -+ typedef cutlass::PredicateVector PredicateVector; -+ -+ cutlass::HostTensor > output; -+ cutlass::HostTensor> input; -+ -+ output.reserve(Words); -+ input.reserve(Words); -+ -+ // some arbitrary test bits -+ unsigned values[] = { -+ 0xdeadbeef, -+ 0xa0070032, -+ 0x9076d001, -+ 0x00000000, -+ 0xabdfc0ad -+ }; -+ -+ for (int test = 0; test < 5; ++test) { -+ -+ input.host_data(0) = values[test]; -+ output.host_data(0) = 0; -+ -+ input.sync_device(); -+ output.sync_device(); -+ -+ test::load_predicates<<< -+ dim3(1,1,1), dim3(1,1,1) -+ >>>( -+ output.device_data(), -+ input.device_data() -+ ); -+ -+ output.sync_host(); -+ -+ for (int word = 0; word < Words; ++word) { -+ EXPECT_EQ(input.host_data(word), output.host_data(word)) -+ << "Expected: 0x" << std::hex << input.host_data(word) -+ << ", got: 0x" << output.host_data(word) -+ << std::dec; -+ } -+ } -+} -+ -+TEST(PredicateVector, Count) { -+ -+ { -+ typedef cutlass::PredicateVector<4, 8> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<4, 8> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<4, 4> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<4, 4> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<4, 2> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<4, 2> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<4, 1> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<4, 1> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<8, 8> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<8, 8> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<8, 4> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<8, 4> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<8, 2> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<8, 2> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<8, 1> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 2) -+ << "PredicateVector<8, 1> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<16, 8> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<16, 8> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<16, 4> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<16, 4> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<16, 2> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 2) -+ << "PredicateVector<16, 2> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<16, 1> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 4) -+ << "PredicateVector<16, 1> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<32, 8> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<32, 8> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<32, 4> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 2) -+ << "PredicateVector<32, 4> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<32, 2> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 4) -+ << "PredicateVector<32, 2> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<32, 1> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 8) -+ << "PredicateVector<32, 1> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<64, 8> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 2) -+ << "PredicateVector<64, 8> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<64, 4> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 4) -+ << "PredicateVector<64, 4> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<64, 2> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 8) -+ << "PredicateVector<64, 2> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<64, 1> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 16) -+ << "PredicateVector<64, 1> word count: " << int(PredicateVector::kWordCount); -+ } -+} -diff --git a/3rdparty/cutlass/test/unit/core/quaternion.cu b/3rdparty/cutlass/test/unit/core/quaternion.cu -new file mode 100644 -index 0000000..400ea6a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/quaternion.cu -@@ -0,0 +1,168 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for the CUTLASS Quaternion template class. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/core_io.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/constants.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static float const half_pi = cutlass::constants::half_pi(); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, add_f32) { -+ -+ cutlass::Quaternion q0(1, 1, 1, 1); -+ cutlass::Quaternion q1(0, 0, 0, 2); -+ -+ cutlass::Quaternion q2 = q0 + q1; -+ -+ EXPECT_TRUE( -+ q2.x() == 1 && -+ q2.y() == 1 && -+ q2.z() == 1 && -+ q2.w() == 3 -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, rotation) { -+ -+ cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); -+ cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi) * 2.0f; -+ cutlass::Matrix3x1 v = q.rotate(x); -+ -+ float epsilon = 0.001f; -+ -+ EXPECT_TRUE( -+ std::abs(v.at(0)) < epsilon && -+ std::abs(v.at(1)) > (1 - epsilon) && -+ std::abs(v.at(2)) < epsilon -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, rotation_inv) { -+ -+ cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); -+ cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi) * 2.0f; -+ cutlass::Matrix3x1 v = q.rotate(x); -+ -+ float epsilon = 0.001f; -+ -+ EXPECT_TRUE( -+ std::abs(v.at(0)) < epsilon && -+ std::abs(-v.at(1)) > (1 - epsilon) && -+ std::abs(v.at(2)) < epsilon -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, spinor_rotation) { -+ -+ cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); -+ cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi); -+ cutlass::Matrix3x1 v = cutlass::spinor_rotation(q, x); -+ -+ float epsilon = 0.001f; -+ -+ EXPECT_TRUE( -+ std::abs(v.at(0)) < epsilon && -+ std::abs(v.at(1)) > (1 - epsilon) && -+ std::abs(v.at(2)) < epsilon -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, spinor_rotation_inv) { -+ -+ cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); -+ cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi); -+ cutlass::Matrix3x1 v = cutlass::spinor_rotation_inv(q, x); -+ -+ float epsilon = 0.001f; -+ -+ EXPECT_TRUE( -+ std::abs(v.at(0)) < epsilon && -+ std::abs(-v.at(1)) > (1 - epsilon) && -+ std::abs(v.at(2)) < epsilon -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, as_rotation_matrix3x3) { -+ -+ cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); -+ cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi); -+ cutlass::Matrix3x1 v = q.as_rotation_matrix_3x3().product(x); -+ -+ float epsilon = 0.001f; -+ -+ EXPECT_TRUE( -+ std::abs(v.at(0)) < epsilon && -+ std::abs(v.at(1)) > (1 - epsilon) && -+ std::abs(v.at(2)) < epsilon -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, as_rotation_matrix4x4) { -+ -+ cutlass::Matrix4x1 x(1.0f, 0.0f, 0.0f, 1.0f); -+ cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi); -+ cutlass::Matrix4x1 v = q.as_rotation_matrix_4x4().product(x); -+ -+ float epsilon = 0.001f; -+ -+ EXPECT_TRUE( -+ std::abs(v.at(0)) < epsilon && -+ std::abs(v.at(1)) > (1 - epsilon) && -+ std::abs(v.at(2)) < epsilon && -+ std::abs(v.at(3)) > (1 - epsilon) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/core/tensor_ref.cu b/3rdparty/cutlass/test/unit/core/tensor_ref.cu -new file mode 100644 -index 0000000..a4c46fa ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/tensor_ref.cu -@@ -0,0 +1,224 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorRef, basic_rank2) { -+ int const M = 8; -+ int const N = 16; -+ -+ int matrix_data[M * N] = {0}; -+ -+ cutlass::TensorRef< -+ int, -+ cutlass::IdentityTensorLayout<2> > matrix_ref(matrix_data, cutlass::make_Coord(N, 1)); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ matrix_ref.at(cutlass::make_Coord(m, n)) = m * N + n; -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ EXPECT_EQ(matrix_data[m * N + n], int(m * N + n)); -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorRef, rank2_column_major) { -+ int const M = 8; -+ int const N = 8; -+ -+ int matrix_data[M * N]; -+ -+ cutlass::TensorRef ref(matrix_data, M); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ ref.at(cutlass::make_Coord(m, n)) = m * N + n; -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ EXPECT_EQ(matrix_data[m + n * M], int(m * N + n)); -+ } -+ } -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorRef, rank2_row_major) { -+ int const M = 8; -+ int const N = 16; -+ -+ int matrix_data[M * N] = { 0 }; -+ -+ cutlass::TensorRef ref(matrix_data, N); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ ref.at(cutlass::make_Coord(m, n)) = m * N + n; -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ EXPECT_EQ(matrix_data[m * N + n], int(m * N + n)); -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorRef, rank2_contiguous_dynamic) { -+ int const M = 8; -+ int const N = 16; -+ -+ typedef cutlass::TensorRef ContiguousTensorRef; -+ -+ cutlass::layout::Matrix layouts[] = { -+ cutlass::layout::Matrix::kColumnMajor, -+ cutlass::layout::Matrix::kRowMajor -+ }; -+ -+ for (int i = 0; i < 2; ++i) { -+ -+ int matrix_data[M * N] = { 0 }; -+ -+ int row_stride; -+ int col_stride; -+ -+ if (layouts[i] == cutlass::layout::Matrix::kColumnMajor) { -+ row_stride = 1; -+ col_stride = M; -+ } -+ else { -+ row_stride = N; -+ col_stride = 1; -+ } -+ -+ // Use helper to determine stride vector from leading dimension -+ ContiguousTensorRef ref( -+ matrix_data, -+ cutlass::layout::ContiguousMatrix::packed(cutlass::make_Coord(M, N), layouts[i])); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ ref.at(cutlass::make_Coord(m, n)) = m * N + n; -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ EXPECT_EQ(matrix_data[m * row_stride + n * col_stride], int(m * N + n)); -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorRef, rank2_column_major_interleaved) { -+ int const M = 16; -+ int const N = 16; -+ int const kInterleave = 4; -+ -+ int matrix_data[M * N] = {0}; -+ -+ // Define the Layout for a column-major interleaved matrix format -+ using Layout = cutlass::layout::ColumnMajorInterleaved; -+ -+ // Construct a TensorRef -+ cutlass::TensorRef< -+ int, -+ Layout> ref(matrix_data, Layout::packed(cutlass::make_Coord(M, N))); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ ref.at(cutlass::make_Coord(m, n)) = m + n * M; -+ } -+ } -+ -+ // Verify -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; n += kInterleave) { -+ for (int i = 0; i < kInterleave; ++i) { -+ EXPECT_EQ(matrix_data[m * kInterleave + n * M + i], int(m + (n + i) * M)); -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorRef, rank2_row_major_interleaved) { -+ int const M = 16; -+ int const N = 16; -+ int const kInterleave = 4; -+ -+ int matrix_data[M * N] = {0}; -+ -+ // Define the Layout for a row-major interleaved matrix format -+ using Layout = cutlass::layout::RowMajorInterleaved; -+ -+ // Construct a TensorRef -+ cutlass::TensorRef< -+ int, -+ Layout> ref(matrix_data, Layout::packed(cutlass::make_Coord(M, N))); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ ref.at(cutlass::make_Coord(m, n)) = m + n * M; -+ } -+ } -+ -+ // Verify -+ for (int m = 0; m < M; m += kInterleave) { -+ for (int n = 0; n < N; ++n) { -+ for (int i = 0; i < kInterleave; ++i) { -+ EXPECT_EQ(matrix_data[m * N + i + n * kInterleave], int((m + i) + n * M)); -+ } -+ } -+ } -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/core/tensor_view.cu b/3rdparty/cutlass/test/unit/core/tensor_view.cu -new file mode 100644 -index 0000000..26f1a70 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/tensor_view.cu -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorView, rank2_contiguous_dynamic) { -+ int const M = 8; -+ int const N = 16; -+ -+ typedef cutlass::TensorView ContiguousTensorView; -+ -+ cutlass::layout::Matrix layouts[] = { -+ cutlass::layout::Matrix::kColumnMajor, -+ cutlass::layout::Matrix::kRowMajor -+ }; -+ -+ cutlass::Coord<2> bounds = cutlass::make_Coord(M - 2, N - 2); -+ -+ for (int i = 0; i < 2; ++i) { -+ -+ int matrix_data[M * N] = { 0 }; -+ -+ int row_stride; -+ int col_stride; -+ -+ if (layouts[i] == cutlass::layout::Matrix::kColumnMajor) { -+ row_stride = 1; -+ col_stride = M; -+ } -+ else { -+ row_stride = N; -+ col_stride = 1; -+ } -+ -+ // Use helper to determine stride vector from leading dimension -+ ContiguousTensorView view( -+ matrix_data, -+ cutlass::layout::ContiguousMatrix::packed(cutlass::make_Coord(M, N), layouts[i]), -+ bounds); -+ -+ ASSERT_TRUE(view.good()); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ cutlass::Coord<2> coord = cutlass::make_Coord(m, n); -+ if (view.contains(coord)) { -+ view.at(coord) = m * N + n; -+ } -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ int expected = 0; -+ if (m < bounds[0] && n < bounds[1]) { -+ expected = int(m * N + n); -+ } -+ EXPECT_EQ(matrix_data[m * row_stride + n * col_stride], expected); -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Uncomment the following line to observe output from printing TensorView objects -+// -+ -+// #define OBSERVE_TENSORVIEW_IO // uncomment to enable printing -+ -+#ifdef OBSERVE_TENSORVIEW_IO -+ -+// This test construct a TensorView of rank=2 with matrix layouts known at runtime. This -+// uses TensorRefMapFunc classes defined in cutlass/matrix_traits.h to define the mapping -+// from logical tensor indices to storage in memory. -+// -+// Helpers in tools/util/tensor_view_io.h print both the logical TensorView and the -+// linear memory of the tensor. -+TEST(TensorView, contiguous) { -+ -+ int const M = 8; -+ int const N = 16; -+ -+ typedef cutlass::TensorView< -+ int32_t, -+ cutlass::layout::ContiguousLayout> ContiguousTensorView; -+ -+ cutlass::layout::Matrix layouts[] = { -+ cutlass::layout::Matrix::kColumnMajor, -+ cutlass::layout::Matrix::kRowMajor -+ }; -+ -+ cutlass::Coord<2> bounds = cutlass::make_Coord(M, N); -+ -+ for (int i = 0; i < 2; ++i) { -+ -+ int matrix_data[M * N] = { 0 }; -+ -+ int ldm; -+ int row_stride; -+ int col_stride; -+ -+ if (layouts[i] == cutlass::layout::Matrix::kColumnMajor) { -+ row_stride = 1; -+ col_stride = M; -+ ldm = col_stride; -+ } -+ else { -+ row_stride = N; -+ col_stride = 1; -+ ldm = row_stride; -+ } -+ -+ // Use helper to determine stride vector from leading dimension -+ ContiguousTensorView view( -+ matrix_data, -+ cutlass::layout::ContiguousLayout::stride(layouts[i], ldm), -+ bounds); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ cutlass::Coord<2> coord = cutlass::make_Coord(m, n); -+ if (view.contains(coord)) { -+ view.at(coord) = m * N + n; -+ } -+ } -+ } -+ -+ std::cout << "---------\n"; -+ std::cout << (layouts[i] == cutlass::layout::Matrix::kColumnMajor ? -+ "Column-major:" : "Row-major:") << "\n\n"; -+ -+ std::cout << "Logical view:\n"; -+ std::cout.width(4); -+ std::cout << view << "\n" << std::endl; // Print TensorView object. -+ -+ std::cout << "Linear memory:"; -+ for (int idx = 0; idx < view.capacity(); ++idx) { -+ if (!(idx % (layouts[i] == cutlass::layout::Matrix::kColumnMajor ? M : N))) { -+ std::cout << std::endl; -+ } -+ std::cout << std::setw(4) << view.at(idx) << " "; -+ } -+ -+ std::cout << "\n" << std::endl; -+ } -+} -+ -+// This test is similar to the previous except it uses a column-major, interleaved data -+// layout. The test prints both the logical representation (a typical column-major matrix) -+// and a representation of linear memory. -+// -+// Note, the interleave=4 structure implies that every four consecutive elements in the -+// same row shall be adjacent in memory followed by the next row. -+TEST(TensorView, rank2_column_major_interleaved) { -+ int const M = 16; -+ int const N = 16; -+ int const kInterleave = 4; -+ -+ int matrix_data[M * N] = {0}; -+ -+ cutlass::Coord<2> bounds = cutlass::make_Coord(M, N); -+ -+ // Define the TensorRefMapFunc for a column-major interleaved matrix format -+ typedef cutlass::layout::ColumnMajorInterleaved TensorRefMapFunc; -+ -+ // Define a TensorView of rank=2 using the column-major interleaved mapping function -+ typedef cutlass::TensorView< -+ int, -+ TensorRefMapFunc> InterleavedTensorView; -+ -+ InterleavedTensorView view( -+ matrix_data, -+ TensorRefMapFunc::stride(M), -+ bounds); -+ -+ // Initialize -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ view.at(cutlass::make_Coord(m, n)) = m + n * M; -+ } -+ } -+ -+ // Print logical view -+ std::cout << "Column-major, interleave=" << kInterleave << " (logical view):\n"; -+ -+ std::cout << std::setw(4) << view << "\n" << std::endl; -+ -+ // Now define a linear view of the same data in memory -+ typedef cutlass::TensorView LinearTensorView; -+ -+ LinearTensorView linear_view(matrix_data, cutlass::make_Coord(N), bounds); -+ -+ std::cout << "Linear view in memory:\n"; -+ std::cout << std::setw(4) << linear_view << std::endl; -+} -+ -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorView, int4) { -+ -+ int const M = 4; -+ int const N = 8; -+ -+ using T = cutlass::int4b_t; -+ -+ cutlass::HostTensor tensor({M, N}); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ T x = T(n ^ m); // some simple hash -+ tensor.host_view().at({m, n}) = x; -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ int x = (n ^ m); // some simple hash -+ EXPECT_TRUE(int(tensor.host_view().at({m, n})) == x); -+ } -+ } -+ -+ EXPECT_EQ(tensor.size(), M * N); -+} -+ -+TEST(TensorView, uint4) { -+ -+ int const M = 4; -+ int const N = 8; -+ -+ using T = cutlass::uint4b_t; -+ -+ cutlass::HostTensor tensor({M, N}); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ T x = T(n ^ m); // some simple hash -+ tensor.host_view().at({m, n}) = x; -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ int x = (n ^ m); // some simple hash -+ EXPECT_TRUE(int(tensor.host_view().at({m, n})) == x); -+ } -+ } -+ -+ EXPECT_EQ(tensor.size(), M * N); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/tfloat32.cu b/3rdparty/cutlass/test/unit/core/tfloat32.cu -new file mode 100644 -index 0000000..aff50cf ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/tfloat32.cu -@@ -0,0 +1,206 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/util/device_memory.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Host -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(tfloat32_t, host_conversion) { -+ for (int i = -1024; i < 1024; ++i) { -+ float f = static_cast(i); -+ -+ cutlass::tfloat32_t x = static_cast(i); -+ cutlass::tfloat32_t y = static_cast(f); -+ -+ EXPECT_TRUE(static_cast(x) == i); -+ EXPECT_TRUE(static_cast(y) == f); -+ } -+ -+ // Try out default-ctor (zero initialization of primitive proxy type) -+ EXPECT_TRUE(cutlass::tfloat32_t() == 0.0_tf32); -+ -+ // Try out user-defined literals -+ EXPECT_TRUE(cutlass::tfloat32_t(7) == 7_tf32); -+ EXPECT_TRUE(7 == static_cast(7_tf32)); -+} -+ -+TEST(tfloat32_t, host_arithmetic) { -+ -+ for (int i = -100; i < 100; ++i) { -+ for (int j = -100; j < 100; ++j) { -+ -+ cutlass::tfloat32_t x = static_cast(i); -+ cutlass::tfloat32_t y = static_cast(j); -+ -+ EXPECT_TRUE(static_cast(x + y) == (i + j)); -+ } -+ } -+} -+ -+TEST(tfloat32_t, host_round_nearest) { -+ -+ struct { -+ uint32_t f32_bits; -+ uint32_t expected; -+ } tests[] = { -+ {0x40000000, 0x40000000}, // M=0, R=0, S=0 => rtz -+ {0x40001000, 0x40000000}, // M=0, R=1, S=0 => rtz -+ {0x40000001, 0x40000000}, // M=0, R=0, S=1 => rtz -+ {0x40001001, 0x40002000}, // M=0, R=1, S=1 => +inf -+ {0x40002000, 0x40002000}, // M=1, R=0, S=0 => rtz -+ {0x40002001, 0x40002000}, // M=1, R=0, S=1 => rtz -+ {0x40003000, 0x40004000}, // M=1, R=1, S=0 => +inf -+ {0x40003001, 0x40004000}, // M=1, R=1, S=1 => +inf -+ {0x7f800000, 0x7f800000}, // +inf -+ {0xff800000, 0xff800000}, // -inf -+ {0x7fffffff, 0x7fffffff}, // canonical NaN to canonical NaN -+ {0x7f800001, 0x7fffffff}, // NaN to canonical NaN -+ {0xff800001, 0x7fffffff}, // NaN to canonical NaN -+ {0, 0} -+ }; -+ -+ bool running = true; -+ for (int i = 0; running; ++i) { -+ -+ float f32 = reinterpret_cast(tests[i].f32_bits); -+ -+ cutlass::NumericConverter< -+ cutlass::tfloat32_t, -+ float, -+ cutlass::FloatRoundStyle::round_to_nearest> converter; -+ -+ cutlass::tfloat32_t tf32 = converter(f32); -+ -+ // note, we must explicitly truncate the low-order bits since they are not defined in TF32. -+ if (cutlass::isfinite(tf32)) { -+ tf32.storage &= 0xffffe000; -+ } -+ -+ bool passed = (tests[i].expected == tf32.raw()); -+ -+ EXPECT_TRUE(passed) -+ << "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits -+ << ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << tf32.raw(); -+ -+ if (!tests[i].f32_bits) { -+ running = false; -+ } -+ } -+} -+ -+namespace test { -+namespace core { -+ -+__global__ void convert_tf32_half_ulp(cutlass::tfloat32_t *out, float const *in) { -+ -+ cutlass::NumericConverter< -+ cutlass::tfloat32_t, -+ float, -+ cutlass::FloatRoundStyle::round_half_ulp_truncate> convert; -+ -+ *out = convert(*in); -+} -+ -+} -+} -+ -+ -+TEST(tfloat32_t, host_round_half_ulp) { -+ -+ struct { -+ uint32_t f32_bits; -+ uint32_t expected; -+ } tests[] = { -+ {0x40001fff, 0x40002000}, -+ {0x40000000, 0x40000000}, // M=0, R=0, S=0 => rtz -+ {0x40001000, 0x40002000}, // M=0, R=1, S=0 => rtz - this difers from RNE -+ {0x40000001, 0x40000000}, // M=0, R=0, S=1 => rtz -+ {0x40001001, 0x40002000}, // M=0, R=1, S=1 => +inf -+ {0x40002000, 0x40002000}, // M=1, R=0, S=0 => rtz -+ {0x40002001, 0x40002000}, // M=1, R=0, S=1 => rtz -+ {0x40003000, 0x40004000}, // M=1, R=1, S=0 => +inf -+ {0x40003001, 0x40004000}, // M=1, R=1, S=1 => +inf -+ {0x7f800000, 0x7f800000}, // +inf -+ {0xff800000, 0xff800000}, // -inf -+ {0x7fffffff, 0x7fffffff}, // canonical NaN to canonical NaN -+ {0x7f800001, 0x7f800001}, // NaN to NaN -+ {0xff800001, 0xff800001}, // NaN to NaN -+ {0, 0} -+ }; -+ -+ cutlass::NumericConverter< -+ cutlass::tfloat32_t, -+ float, -+ cutlass::FloatRoundStyle::round_half_ulp_truncate> convert; -+ -+ bool running = true; -+ for (int i = 0; running; ++i) { -+ -+ float f32 = reinterpret_cast(tests[i].f32_bits); -+ -+ cutlass::tfloat32_t tf32 = convert(f32); -+ -+ // note, for this test, we must explicitly truncate the low-order bits since they are not -+ // defined in TF32. -+ if (cutlass::isfinite(tf32)) { -+ tf32.storage &= 0xffffe000; -+ } -+ -+ bool passed = (tests[i].expected == tf32.raw()); -+ -+ EXPECT_TRUE(passed) -+ << "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits -+ << ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << tf32.raw(); -+ -+ if (!tests[i].f32_bits) { -+ running = false; -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Device -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/cute/ampere/cp_async.cu b/3rdparty/cutlass/test/unit/cute/ampere/cp_async.cu -new file mode 100644 -index 0000000..7a80a51 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/cute/ampere/cp_async.cu -@@ -0,0 +1,104 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include "cutlass_unit_test.h" -+ -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include -+#include -+ -+#include -+ -+using namespace cute; -+ -+__global__ void -+test(double const* g_in, double* g_out) -+{ -+ extern __shared__ double smem[]; -+ -+ smem[threadIdx.x] = g_in[threadIdx.x]; -+ -+ __syncthreads(); -+ -+ g_out[threadIdx.x] = 2 * smem[threadIdx.x]; -+} -+ -+__global__ void -+test2(double const* g_in, double* g_out) -+{ -+ using namespace cute; -+ -+ extern __shared__ double smem[]; -+ -+ auto s_tensor = make_tensor(make_smem_ptr(smem + threadIdx.x), Int<1>{}); -+ auto g_tensor = make_tensor(make_gmem_ptr(g_in + threadIdx.x), Int<1>{}); -+ -+ copy(g_tensor, s_tensor); -+ -+ cp_async_fence(); -+ cp_async_wait<0>(); -+ __syncthreads(); -+ -+ g_out[threadIdx.x] = 2 * smem[threadIdx.x]; -+} -+ -+TEST(SM80_CuTe_Ampere, CpAsync) -+{ -+ constexpr int count = 32; -+ thrust::host_vector h_in(count); -+ for (int i = 0; i < count; ++i) { -+ h_in[i] = double(i); -+ } -+ -+ thrust::device_vector d_in(h_in); -+ -+ thrust::device_vector d_out(count, -1); -+ test<<<1, count, sizeof(double) * count>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_result = d_out; -+ -+ thrust::device_vector d_out_cp_async(count, -2); -+ test2<<<1, count, sizeof(double) * count>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out_cp_async.data())); -+ thrust::host_vector h_result_cp_async = d_out_cp_async; -+ -+ for (int i = 0; i < count; ++i) { -+ EXPECT_EQ(h_result[i], h_result_cp_async[i]); -+ } -+} -diff --git a/3rdparty/cutlass/test/unit/cute/ampere/ldsm.cu b/3rdparty/cutlass/test/unit/cute/ampere/ldsm.cu -new file mode 100644 -index 0000000..15ec44b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/cute/ampere/ldsm.cu -@@ -0,0 +1,431 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include "cutlass_unit_test.h" -+ -+#include -+ -+#include -+#include -+ -+#include -+ -+#include -+ -+ -+using namespace cute; -+ -+template -+__global__ void -+ldsm_test_device(uint16_t* g_in, uint16_t* g_out) -+{ -+ constexpr int count = sizeof(T) / 4; -+ int tid = threadIdx.x; -+ int stride = blockDim.x; -+ -+ // load input gmem -> smem -+ __shared__ uint32_t smem[32 * count]; -+ for (int i = 0; i < count; ++i) { -+ smem[tid + (stride * i)] = reinterpret_cast(g_in)[tid + (stride * i)]; -+ } -+ -+ __syncthreads(); -+ -+ uint32_t reg[count]; -+ for (int i = 0; i < count; ++i) { -+ reg[i] = 0; -+ } -+ -+ // load smem -> rmem using LDSM -+ uint128_t* smem_ptr = reinterpret_cast(smem) + tid; -+ T* rmem_ptr = reinterpret_cast(reg); -+ cute::copy_ldsm(smem_ptr, rmem_ptr); -+ -+ // store output rmem -> gmem -+ for (int i = 0; i < count; ++i) { -+ reinterpret_cast(g_out)[tid + (stride * i)] = reg[i]; -+ } -+} -+ -+template -+__global__ void -+ldsm_test_device_cute(uint16_t* g_in, uint16_t* g_out, -+ TiledCopy tiled_copy, SmemLayout smem_layout) -+{ -+ using namespace cute; -+ -+ __shared__ uint16_t smem[size(smem_layout)]; -+ -+ auto t_g_in = make_tensor(make_gmem_ptr(g_in), smem_layout); -+ auto t_g_out = make_tensor(make_gmem_ptr(g_out), smem_layout); -+ auto t_smem = make_tensor(make_smem_ptr(smem), smem_layout); -+ -+ int tid = threadIdx.x; -+ -+ // Load input gmem -> smem -+ for (int i = tid; i < size(t_smem); i += size(tiled_copy)) { -+ t_smem(i) = t_g_in(i); -+ } -+ -+ __syncthreads(); -+ -+ auto thr_copy = tiled_copy.get_thread_slice(tid); -+ -+ auto tXsX = thr_copy.partition_S(t_smem); // (V,M,N) -+ auto tXgX = thr_copy.partition_D(t_g_out); // (V,M,N) -+ -+ auto tXrX = make_tensor(shape(tXgX)); // (V,M,N) -+ clear(tXrX); // Just to make sure -+ -+/* -+ if (thread0()) { -+ print("tXsX: " ); print(tXsX.layout()); print("\n"); -+ print("tXgX: " ); print(tXgX.layout()); print("\n"); -+ print("tXrX: " ); print(tXrX.layout()); print("\n"); -+ } -+*/ -+ -+ // Copy smem -> rmem via tiled_copy (LDSM, LDS) -+ copy(tiled_copy, tXsX, tXrX); -+ -+ // Output rmem -> gmem -+ copy(tXrX, tXgX); -+} -+ -+ -+TEST(SM80_CuTe_Ampere, Ldsm) -+{ -+ constexpr int count = 1024; -+ -+ thrust::host_vector h_in(count); -+ for (int i = 0; i < count; ++i) { -+ h_in[i] = uint16_t(i); -+ } -+ thrust::device_vector d_in = h_in; -+ -+ // -+ // LDSM 1x (32b) -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ ldsm_test_device<<<1, 32>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < 32; ++i) { -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("LDSM 1x ldsm_test_device SUCCESS\n"); -+ } -+ -+ // -+ // LDSM 2x (64b) -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ ldsm_test_device<<<1, 32>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < 64; ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("LDSM 2x ldsm_test_device SUCCESS\n"); -+ } -+ -+ // -+ // LDSM 4x (128b) -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ ldsm_test_device<<<1, 32>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < 128; ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("LDSM 4x ldsm_test_device SUCCESS\n"); -+ } -+ -+ // -+ // CuTe LDSM -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x1_LDSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x2_LDSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x4_LDSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom, uint16_t>{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i] , h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved LDS.U16 SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U32x1_LDSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U32x2_LDSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U32x4_LDSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom, uint16_t>{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 LDS.U16 SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride<_32, _1>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U16x2_LDSM_T SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride<_32, _1>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U16x4_LDSM_T SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride<_32, _1>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U16x8_LDSM_T SUCCESS\n"); -+ } -+ -+ CUTLASS_TRACE_HOST("PASS"); -+} -diff --git a/3rdparty/cutlass/test/unit/cute/hopper/stsm.cu b/3rdparty/cutlass/test/unit/cute/hopper/stsm.cu -new file mode 100644 -index 0000000..ffc8aa7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/cute/hopper/stsm.cu -@@ -0,0 +1,426 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include "cutlass_unit_test.h" -+ -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+using namespace cute; -+ -+template -+__global__ void -+stsm_test_device(uint16_t* g_in, uint16_t* g_out) -+{ -+ constexpr int count = sizeof(T) / 4; -+ int tid = threadIdx.x; -+ int stride = blockDim.x; -+ -+ // load input gmem -> rmem -+ uint32_t reg[count]; -+ for (int i = 0; i < (sizeof(T) / 4); i++) { -+ reg[i] = reinterpret_cast(g_in)[tid + (stride * i)]; -+ } -+ -+ __shared__ uint32_t smem[32 * count]; -+ -+ // load rmem -> smem using STSM -+ uint128_t* smem_ptr = reinterpret_cast(smem) + tid; -+ T* rmem_ptr = reinterpret_cast(reg); -+ cute::copy_stsm(rmem_ptr, smem_ptr); -+ -+ __syncthreads(); -+ -+ // store output smem -> gmem -+ for (int i = 0; i < (sizeof(T) / 4); i++) { -+ reinterpret_cast(g_out)[tid + (stride * i)] = smem[tid + (stride * i)]; -+ } -+} -+ -+template -+__global__ void -+stsm_test_device_cute(uint16_t* g_in, uint16_t* g_out, -+ TiledCopy tiled_copy, SmemLayout smem_layout) -+{ -+ using namespace cute; -+ -+ __shared__ uint16_t smem[size(smem_layout)]; -+ -+ Tensor t_g_in = make_tensor(make_gmem_ptr(g_in), smem_layout); -+ Tensor t_g_out = make_tensor(make_gmem_ptr(g_out), smem_layout); -+ Tensor t_smem = make_tensor(make_smem_ptr(smem), smem_layout); -+ -+ int tid = threadIdx.x; -+ -+ auto thr_copy = tiled_copy.get_thread_slice(tid); -+ -+ Tensor tXgX = thr_copy.partition_S(t_g_in); // (V,M,N) -+ Tensor tXsX = thr_copy.partition_D(t_smem); // (V,M,N) -+ -+ Tensor tXrX = make_tensor(shape(tXgX)); // (V,M,N) -+ clear(tXrX); // Just to make sure -+ -+/* -+ if (thread0()) { -+ print("tXsX: " ); print(tXsX.layout()); print("\n"); -+ print("tXgX: " ); print(tXgX.layout()); print("\n"); -+ print("tXrX: " ); print(tXrX.layout()); print("\n"); -+ } -+*/ -+ -+ // Load input gmem -> rmem -+ copy(tXgX, tXrX); -+ -+ // Copy rmem -> smem via tiled_copy (STSM, STS) -+ copy(tiled_copy, tXrX, tXsX); -+ -+ // Output smem -> gmem -+ for (int i = tid; i < size(t_smem); i += size(tiled_copy)) { -+ t_g_out(i) = t_smem(i); -+ } -+} -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+TEST(SM90_CuTe_Hopper, Stsm) -+{ -+ constexpr int count = 1024; -+ -+ thrust::host_vector h_in(count); -+ for (int i = 0; i < count; ++i) { -+ h_in[i] = uint16_t(i); -+ } -+ thrust::device_vector d_in = h_in; -+ -+ // -+ // STSM 1x (32b) -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ stsm_test_device<<<1, 32>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < 32; ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("STSM 1x stsm_test_device SUCCESS\n"); -+ } -+ -+ // -+ // STSM 2x (64b) -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ stsm_test_device<<<1, 32>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < 64; ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("STSM 2x stsm_test_device SUCCESS\n"); -+ } -+ -+ // -+ // STSM 4x (128b) -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ stsm_test_device<<<1, 32>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < 128; ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("STSM 4x stsm_test_device SUCCESS\n"); -+ } -+ -+ // -+ // CuTe STSM -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x1_STSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x2_STSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x4_STSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom, uint16_t>{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved STS.U16 SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U32x1_STSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U32x2_STSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U32x4_STSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom, uint16_t>{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 STS.U16 SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride<_32, _1>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U16x2_STSM_T SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride<_32, _1>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U16x4_STSM_T SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride<_32, _1>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U16x8_STSM_T SUCCESS\n"); -+ } -+ -+ CUTLASS_TRACE_HOST("PASS"); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/cute/hopper/tma_load.cu b/3rdparty/cutlass/test/unit/cute/hopper/tma_load.cu -new file mode 100644 -index 0000000..24f17fc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/cute/hopper/tma_load.cu -@@ -0,0 +1,495 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include "cutlass_unit_test.h" -+ -+#include -+ -+#include -+#include -+ -+#include -+ -+using namespace cute; -+ -+template -+struct SharedStorage -+{ -+ cute::array_aligned> smem; -+ cute::uint64_t tma_load_mbar[1]; -+}; -+ -+// __grid_constant__ was introduced in CUDA 11.7. -+#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) -+# define CUTE_GRID_CONSTANT_SUPPORTED -+#endif -+ -+// __grid_constant__ can be enabled only on SM70+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) -+# define CUTE_GRID_CONSTANT_ENABLED -+#endif -+ -+#if ! defined(CUTE_GRID_CONSTANT) -+# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) -+# define CUTE_GRID_CONSTANT __grid_constant__ -+# else -+# define CUTE_GRID_CONSTANT -+# endif -+#endif -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+template -+__global__ void -+tma_test_device_cute(T const* g_in, T* g_out, -+ CUTE_GRID_CONSTANT TiledCopy const tma, -+ GmemLayout gmem_layout, SmemLayout smem_layout) -+{ -+ assert(product_each(shape(gmem_layout)) == product_each(smem_layout.shape())); -+ -+ // Use Shared Storage structure to allocate and distribute aligned SMEM addresses -+ extern __shared__ char shared_memory[]; -+ using SharedStorage = SharedStorage; -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ -+ // Shared memory barriers use 64bits in SMEM for synchronization -+ uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; -+ // Construct SMEM tensor -+ Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); -+ -+#if 0 -+ -+ // -+ // Read in trivially -+ // -+ -+ Tensor gA_in = make_tensor(make_gmem_ptr(g_in), gmem_layout); -+ -+ // Input gmem -> smem -+ for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { -+ sA(i) = gA_in(i); -+ } -+ __syncthreads(); -+ -+#else -+ -+ // TMA requires special handling of strides to deal with coord codomain mapping -+ // Represent the full tensors -- get these from TMA -+ Tensor gA = tma.get_tma_tensor(shape(gmem_layout)); -+ -+ // -+ // Prepare the TMA_LOAD -+ // -+ -+ auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice -+ -+ Tensor tAgA = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N) -+ Tensor tAsA = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N) -+ -+#if 0 -+ if (thread0()) { -+ print(" gA: "); print(gA.data()); print(" o "); print(gA.layout()); print("\n"); -+ print("tAgA: "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n"); -+ print(" sA: "); print(sA.data()); print(" o "); print(sA.layout()); print("\n"); -+ print("tAsA: "); print(tAsA.data()); print(" o "); print(tAsA.layout()); print("\n"); -+ } -+#endif -+ -+ // -+ // Perform the TMA_LOAD -+ // -+ -+ // Group the TMA_M and TMA_N modes -+ Tensor tAgA_2 = group_modes<1,rank(tAgA)>(tAgA); // (TMA,Rest) -+ Tensor tAsA_TR = group_modes<1,rank(tAsA)>(tAsA); // (TMA,Rest) -+ static_assert(size<1>(tAsA_TR) == 1); -+ Tensor tAsA_2 = tAsA_TR(_,0); -+ -+ // Loop over the TMA stages, using smem as our buffer -+ for (int stage = 0; stage < size<1>(tAgA_2); ++stage) -+ { -+ // Set the bytes transferred in this TMA transaction (may involve multiple issues) -+ constexpr int kTmaTransactionBytes = size(sA) * sizeof(T); -+ -+ if (threadIdx.x == 0) -+ { -+ /// Initialize shared memory barrier -+ tma_load_mbar[0] = 0; -+ cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/); -+ cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); -+ -+ copy(tma.with(tma_load_mbar[0]), tAgA_2(_,stage), tAsA_2); -+ } -+ __syncthreads(); -+ -+ /// Wait on the shared memory barrier until the phase bit flips from kPhaseBit value -+ constexpr int kPhaseBit = 0; -+ cute::wait_barrier(tma_load_mbar[0], kPhaseBit); -+ -+ #endif -+ -+ // -+ // Write out trivially -+ // -+ -+ Tensor gA_out = make_tensor(make_gmem_ptr(g_out), gmem_layout); -+ // Do the same slicing and grouping as sA -+ Tensor tAgA_out = cta_tma.partition_D(gA_out); // (TMA,TMA_M,TMA_N) -+ Tensor tAgA_2_out = group_modes<1,rank(tAgA_out)>(tAgA_out); // (TMA,Rest) -+ -+ // Output smem -> gmem -+ for (int i = threadIdx.x; i < size(tAsA_2); i += blockDim.x) { -+ tAgA_2_out(i,stage) = tAsA_2(i); -+ } -+ __syncthreads(); -+ } -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_32x32_Col) -+{ -+ using T = half_t; -+ Layout smem_layout = Layout, Stride<_1,_32>>{}; -+ Layout gmem_layout = smem_layout; -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD 32x32 ColMajor SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_32x32_Row) -+{ -+ using T = half_t; -+ Layout smem_layout = Layout, Stride<_32,_1>>{}; -+ Layout gmem_layout = smem_layout; -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD 32x32 RowMajor SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN) -+{ -+ using T = half_t; -+ auto smem_layout = GMMA::Layout_MN_SW128_Atom{}; -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_K) -+{ -+ using T = half_t; -+ auto smem_layout = GMMA::Layout_K_SW128_Atom{}; -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenRowMajor{}); -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_K_SW128_Atom SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN_Multi) -+{ -+ using T = half_t; -+ auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}); -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN_Multi2) -+{ -+ using T = half_t; -+ // Tile the GMMA::Layout atom in the K-mode first, then the M-mode to get a bigger box size -+ auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN_Multi_Dyn) -+{ -+ using T = half_t; -+ auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); -+ Layout gmem_layout = make_layout(make_shape(128, 128), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_32x32_Multimode) -+{ -+ using T = half_t; -+ auto smem_layout = Layout, Stride<_32,_1>>{}; -+ Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenRowMajor{}); -+ -+ //auto smem_layout = Layout>{}; -+ //Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_Tensor_blocking) -+{ -+ using T = half_t; -+ auto gmem_layout = make_shape(make_shape(336,40),make_shape(32,656)); // GMEM -+ auto cta_tile = make_shape(make_shape(_16{},_8{}),make_shape(_32{},_2{})); // GMEM Tiling: -+ // Take 16-elem from m0, 8-elem from m1, -+ // Take 32-elem from k0, 2-elem from k1 -+ auto smem_layout = make_layout(cta_tile); // Col-Major SMEM -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout, cta_tile, Int<1>{}); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD Tensor blocking SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_Tensor_blocking_2) -+{ -+ using T = half_t; -+ auto gmem_layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM -+ auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: -+ // Take 128-elem from m: m0 must divide 128, -+ // m-last may be predicated -+ // Take 32-elem from k0, 2-elem from k1 -+ auto smem_layout = make_layout(cta_tile); // Col-Major SMEM -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout, cta_tile, Int<1>{}); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD Tensor blocking 2 SUCCESS\n"); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/cute/hopper/tma_store.cu b/3rdparty/cutlass/test/unit/cute/hopper/tma_store.cu -new file mode 100644 -index 0000000..448b7f9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/cute/hopper/tma_store.cu -@@ -0,0 +1,384 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include "cutlass_unit_test.h" -+ -+#include -+ -+#include -+#include -+ -+#include -+ -+using namespace cute; -+ -+template -+struct SharedStorage -+{ -+ cute::array_aligned> smem; -+}; -+ -+// __grid_constant__ was introduced in CUDA 11.7. -+#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) -+# define CUTE_GRID_CONSTANT_SUPPORTED -+#endif -+ -+// __grid_constant__ can be enabled only on SM70+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) -+# define CUTE_GRID_CONSTANT_ENABLED -+#endif -+ -+#if ! defined(CUTE_GRID_CONSTANT) -+# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) -+# define CUTE_GRID_CONSTANT __grid_constant__ -+# else -+# define CUTE_GRID_CONSTANT -+# endif -+#endif -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+template -+__global__ void -+tma_test_device_cute(T const* g_in, T* g_out, -+ CUTE_GRID_CONSTANT TiledCopy const tma, -+ GmemLayout gmem_layout, SmemLayout smem_layout) -+{ -+ // Use Shared Storage structure to allocate and distribute aligned SMEM addresses -+ extern __shared__ char shared_memory[]; -+ using SharedStorage = SharedStorage; -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ // Construct SMEM tensor -+ Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); -+ -+ // -+ // Read in trivially -+ // -+ -+ Tensor gA_in = make_tensor(make_gmem_ptr(g_in), gmem_layout); -+ -+ // Input gmem -> smem -+ for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { -+ sA(i) = gA_in(i); -+ } -+ -+ __syncthreads(); -+ -+#if 0 -+ -+ // -+ // Write out trivially -+ // -+ -+ Tensor gA_out = make_tensor(make_gmem_ptr(g_out), gmem_layout); -+ -+ // Output smem -> gmem -+ for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { -+ gA_out(i) = sA(i); -+ } -+ -+#else -+ -+ // TMA requires special handling of strides to deal with coord codomain mapping -+ // Represent the full tensors -- get these from TMA -+ Tensor gA = tma.get_tma_tensor(shape(gmem_layout)); -+ -+ // -+ // Prepare the TMA_STORE -+ // -+ -+ auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice -+ -+ Tensor tAsA = cta_tma.partition_S(sA); -+ Tensor tAgA = cta_tma.partition_D(gA); -+ -+ // -+ // Perform the TMA_STORE -+ // -+ -+ if (threadIdx.x == 0) { -+ copy(tma, tAsA, tAgA); -+ } -+ -+#endif -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Col) -+{ -+ using T = half_t; -+ Layout smem_layout = Layout, Stride<_1,_32>>{}; -+ Layout gmem_layout = smem_layout; -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE 32x32 ColMajor SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Row) -+{ -+ using T = half_t; -+ Layout smem_layout = Layout, Stride<_32,_1>>{}; -+ Layout gmem_layout = smem_layout; -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE 32x32 RowMajor SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN) -+{ -+ using T = half_t; -+ auto smem_layout = GMMA::Layout_MN_SW128_Atom{}; -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_K) -+{ -+ using T = half_t; -+ auto smem_layout = GMMA::Layout_K_SW128_Atom{}; -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenRowMajor{}); -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_K_SW128_Atom SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN_Multi) -+{ -+ using T = half_t; -+ auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}); -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN_Multi2) -+{ -+ using T = half_t; -+ // Tile the GMMA::Layout atom in the K-mode first, then the M-mode to get a bigger box size -+ auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN_Multi_Dyn) -+{ -+ using T = half_t; -+ auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); -+ Layout gmem_layout = make_layout(make_shape(128, 128), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Multimode) -+{ -+ using T = half_t; -+ auto smem_layout = Layout, Stride<_32,_1>>{}; -+ Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenRowMajor{}); -+ -+ //auto smem_layout = Layout>{}; -+ //Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/cute/layout/layout_operator.cu b/3rdparty/cutlass/test/unit/cute/layout/layout_operator.cu -new file mode 100644 -index 0000000..6c44f5a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/cute/layout/layout_operator.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit tests Generic CuTe Layouts -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/layout.h" -+#include "cutlass/matrix_coord.h" -+ -+// Cute includes -+#include -+#include -+ -+using namespace cutlass; -+using namespace cute; -+ -+namespace test { -+namespace layout { -+ -+template -+ struct Testbed { -+ -+ -+ Testbed() {} -+ -+ bool run() { -+ GenericLayout generic_layout; -+ Layout layout = Layout::packed({size<0>(generic_layout), size<1>(generic_layout)}); -+ -+ for (int m = 0; m < size<0>(generic_layout); m++) { -+ for (int n = 0; n < size<1>(generic_layout); n++) { -+ if (generic_layout(m, n) != layout({m, n})) return false; -+ } -+ } -+ -+ return true; -+ } -+ }; -+ -+} -+} -+ -+////////////////////////////////////////////////////////////////////////// -+// Test Generic CuTe Layouts -+////////////////////////////////////////////////////////////////////////// -+ -+/// Canonical Layouts -+ -+TEST(GenericLayout, ColumnMajor) { -+ using GenericLayout = cute::Layout, Stride<_1, _8>>; -+ using Layout = cutlass::layout::ColumnMajor; -+ -+ test::layout::Testbed testbed; -+ -+ EXPECT_TRUE(testbed.run()); -+} -+////////////////////////////////////////////////////////////////////////// -+ -+TEST(GenericLayout, RowMajor) { -+ using GenericLayout = cute::Layout, Stride<_4, _1>>; -+ using Layout = cutlass::layout::RowMajor; -+ -+ test::layout::Testbed testbed; -+ -+ EXPECT_TRUE(testbed.run()); -+} -+////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Swizzle Shared Memory layouts -+ -+TEST(GenericLayout, RowMajorTensorOpMultiplicandCrosswise) { -+ -+ using GenericLayout = decltype( -+ composition( -+ Swizzle<3,3,3>{}, -+ Layout, Stride<_64, _1>>{}) -+ ); -+ -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ test::layout::Testbed testbed; -+ -+ EXPECT_TRUE(testbed.run()); -+} -+////////////////////////////////////////////////////////////////////////// -+ -+TEST(GenericLayout, ColumnMajorTensorOpMultiplicandCongruous) { -+ -+ using GenericLayout = decltype( -+ composition( -+ Swizzle<3,3,4>{}, -+ Layout>{}) -+ ); -+ -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ -+ test::layout::Testbed testbed; -+ -+ EXPECT_TRUE(testbed.run()); -+} -+////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/thread/activation.cu b/3rdparty/cutlass/test/unit/epilogue/thread/activation.cu -new file mode 100644 -index 0000000..9241ea2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/thread/activation.cu -@@ -0,0 +1,321 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/layout/layout.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+#include "cutlass/util/host_tensor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void test_Epilogue_thread_activation(T *out, T *in) { -+ -+ cutlass::Array *vec_out = reinterpret_cast *>(out); -+ cutlass::Array *vec_in = reinterpret_cast *>(in); -+ -+ Func func; -+ vec_out[threadIdx.x] = func(vec_in[threadIdx.x]); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Reference -+// -+ -+static double GELU_golden_input[] = { -+ 1.587425827980, 1.157652974129, 0.750432848930, -0.965980410576, -+ -0.388184845448, 0.014422321692, 0.353164494038, 1.354383468628, -+ 0.167588576674, 0.272798538208, -0.377032428980, 1.923444747925, -+ 0.308164477348, -0.341318070889, 0.278338819742, -0.292668998241, -+ -1.051743745804, -0.814175724983, 0.112737402320, 1.262938618660, -+ -1.582363605499, 0.722016870975, 1.053453564644, -0.659764587879, -+ 0.734917521477, 0.091274201870, 0.604461073875, -0.219043627381, -+ -0.136795744300, 0.960650205612, -1.805408835411, 0.091029644012, -+ -1.023343324661, 0.147713735700, -0.499895423651, 1.351878166199, -+ -1.631091356277, -0.336171895266, -1.612408638000, 0.090832948685, -+ -0.658132910728, -0.326727777719, -1.986387014389, 0.787685871124, -+ -1.015677452087, -0.225094825029, 0.876752018929, 0.744826257229, -+ 0.870290279388, -0.757595360279, 1.510331749916, 0.750012576580, -+ 0.906444966793, -0.915759027004, 1.260277032852, -0.158465340734, -+ -0.109191477299, -0.817102134228, 0.391305118799, -0.524910449982, -+ 0.351349592209, 0.801979541779, 0.446691334248, -0.741077482700, -+ 1.205966711044, -0.910210072994, 0.945986449718, 0.784096539021, -+ 1.670521497726, 0.344931513071, -0.301411420107, 0.309870749712, -+ -0.879704594612, -1.951189517975, -0.805817663670, -0.661812782288, -+ -0.505914270878, -1.836273789406, -0.381845980883, -0.554707705975, -+ -0.375447630882, -0.516645610332, 0.509586095810, 1.087131023407, -+ 2.664817094803, -1.558295488358, -0.076461032033, -0.504621028900, -+ 1.327111959457, -1.819981694221, 1.350415468216, -2.074112653732, -+ 1.501431345940, -1.339013576508, 0.162817999721, -1.473457217216, -+ 0.357770472765, 0.188413277268, 1.601302266121, -0.653882205486, -+ 0.856162548065, 0.763102591038, -0.526283502579, 0.581961452961, -+ 0.089969776571, 1.968745589256, 0.545802056789, -1.168786048889, -+ 1.206663012505, -0.109096683562, -1.223938226700, 0.744599223137, -+ -1.779406785965, 0.766436159611, -0.579044401646, -1.002057313919, -+ -0.715845823288, -0.562508940697, 0.886768460274, 2.327786445618, -+ -0.148763969541, -0.918884515762, -0.367678701878, -1.105021238327, -+ -0.461237311363, 0.158228352666, -0.254040330648, 1.427477598190, -+ 0.277530491352, 0.046293262392, -0.535557329655, -1.486695051193, -+ -0.953706681728, -1.040495038033, -0.314667612314, 0.348172843456, -+ 0.522773325443, 0.025960063562, -0.482472360134, 1.993084549904, -+ -0.253064930439, -0.012146313675, -2.166327714920, 0.398040622473, -+ -0.022238900885, -0.443580865860, -0.898376941681, -0.571689844131, -+ 1.666979670525, -0.831176340580, -0.671057403088, 0.481970995665, -+ -1.096243023872, -1.493894338608, 0.596651911736, -0.229505166411, -+ 1.165976166725, 0.905094027519, 0.049716457725, -1.362933635712, -+ -0.366948783398, 1.461613893509, -0.718411505222, 0.895385026932, -+ -0.763122260571, 1.329716682434, 1.366570711136, -0.086544901133, -+ 0.059739742428, 0.940766513348, -0.272854357958, -1.738811373711, -+ -0.361239165068, 0.696977972984, 1.288442254066, 1.264815807343, -+ -0.573566436768, -1.141678214073, 0.081865988672, -0.886228799820, -+ -0.236933603883, 1.050115466118, -0.538952171803, 0.651773929596, -+ -0.220034509897, -1.198960781097, 1.247478365898, -0.053529661149, -+ 0.639809548855, 1.672434806824, 0.511088073254, -1.179364681244, -+ -0.730427742004, 0.157630980015, 0.389369845390, -0.925578773022, -+ -0.093250080943, -0.391062080860, 0.852983593941, 1.868778109550, -+ -1.198786258698, 0.604997038841, -1.482687234879, -2.469333171844, -+ 0.718807697296, -0.559609353542, 2.187228441238, -2.927527904510, -+ 0.148535788059, -0.097280368209, 0.674131810665, -1.137645959854, -+ 0.792729616165, -1.166317462921, -0.498791724443, 1.675866723061, -+ -0.137909621000, -0.653263568878, -2.281216144562, 0.296096831560, -+ 2.002410173416, 1.083609819412, 0.933580815792, -1.504760265350, -+ 2.185185909271, 0.286121010780, -1.035485863686, -0.216372340918, -+ -0.274334043264, -0.849510788918, -1.397169828415, -0.407644748688, -+ 0.159476816654, -0.170650705695, 0.335193097591, -0.156852483749, -+ 0.036168430001, 0.858105242252, -1.086121797562, 0.404813349247, -+ -0.481496721506, -0.389882832766, 0.020690204576, -0.772020936012, -+ -0.758921504021, 0.323482036591, 0.115715265274, -0.811228036880, -+ -0.882436633110, 0.176811277866, 1.678015947342, 0.379081040621, -+ -0.842976212502, 0.346952259541, -0.545828759670, 1.632800459862 -+}; -+ -+static double GELU_golden_output[] = { -+ 1.498199582100, 1.014679551125, 0.580462038517, -0.161344811320, -+ -0.135453075171, 0.007294139825, 0.225325092673, 1.235459089279, -+ 0.094946734607, 0.165724009275, -0.133120641112, 1.871103763580, -+ 0.191376730800, -0.125069886446, 0.169681981206, -0.112644664943, -+ -0.154036879539, -0.169163048267, 0.061428427696, 1.132469892502, -+ -0.089851818979, 0.552240371704, 0.899579226971, -0.168043658137, -+ 0.565008401871, 0.048956073821, 0.439583092928, -0.090532489121, -+ -0.060955654830, 0.798911273479, -0.064101703465, 0.048816055059, -+ -0.156645998359, 0.082529976964, -0.154254898429, 1.232632875443, -+ -0.083896033466, -0.123835846782, -0.086161509156, 0.048703473061, -+ -0.167972877622, -0.121522113681, -0.046670529991, 0.617986679077, -+ -0.157319813967, -0.092503339052, 0.709896743298, 0.574865520000, -+ 0.703132867813, -0.169963955879, 1.411436080933, 0.580042064190, -+ 0.741154611111, -0.164741978049, 1.129479527473, -0.069256491959, -+ -0.049848672003, -0.169087052345, 0.255214750767, -0.157380074263, -+ 0.223928079009, 0.632535398006, 0.300378054380, -0.169946283102, -+ 1.068588852882, -0.165071934462, 0.783203184605, 0.614346146584, -+ 1.591325283051, 0.219006344676, -0.115003645420, 0.192637458444, -+ -0.166712537408, -0.049788996577, -0.169361919165, -0.168130636215, -+ -0.155041679740, -0.060888241976, -0.134137839079, -0.160614117980, -+ -0.132782235742, -0.156389534473, 0.354075312614, 0.936574816704, -+ 2.654553413391, -0.092845752835, -0.035900454968, -0.154874503613, -+ 1.204704761505, -0.062572605908, 1.230982899666, -0.039479542524, -+ 1.401402950287, -0.120890334249, 0.091938301921, -0.103604510427, -+ 0.228880971670, 0.108285568655, 1.513783097267, -0.167782157660, -+ 0.688394129276, 0.593158841133, -0.157540664077, 0.418839782476, -+ 0.048209801316, 1.920528769493, 0.386099845171, -0.141709372401, -+ 1.069367766380, -0.049809500575, -0.135230198503, 0.574639260769, -+ -0.066881760955, 0.596510827541, -0.162873372436, -0.158483341336, -+ -0.169686436653, -0.161375194788, 0.720409095287, 2.304597616196, -+ -0.065585561097, -0.164551988244, -0.131098195910, -0.148708447814, -+ -0.148663327098, 0.089060656726, -0.101548098028, 1.317959904671, -+ 0.169103100896, 0.024001283571, -0.158595800400, -0.101909510791, -+ -0.162240833044, -0.155090972781, -0.118474565446, 0.221488356590, -+ 0.365645468235, 0.013248858973, -0.151851043105, 1.946992278099, -+ -0.101253561676, -0.006014300976, -0.032804865390, 0.260597169399, -+ -0.010922161862, -0.145792976022, -0.165743649006, -0.162226170301, -+ 1.587365984917, -0.168676435947, -0.168497130275, 0.330191940069, -+ -0.149622067809, -0.100989677012, 0.432351946831, -0.093922272325, -+ 1.023946166039, 0.739726305008, 0.025843897834, -0.117827951908, -+ -0.130937814713, 1.356489539146, -0.169726014137, 0.729478538036, -+ -0.169943705201, 1.207641005516, 1.249209761620, -0.040288090706, -+ 0.031292784959, 0.777626037598, -0.107090584934, -0.071350336075, -+ -0.129670530558, 0.527676224709, 1.161149263382, 1.134579420090, -+ -0.162394225597, -0.144757837057, 0.043603736907, -0.166386902332, -+ -0.096278958023, 0.895924389362, -0.158969298005, 0.484089732170, -+ -0.090857118368, -0.138206124306, 1.115107178688, -0.025622237474, -+ 0.472724437714, 1.593463659286, 0.355387806892, -0.140493586659, -+ -0.169871479273, 0.088687323034, 0.253673940897, -0.164135158062, -+ -0.043161027133, -0.136040985584, 0.685087263584, 1.811169505119, -+ -0.138226687908, 0.440080583096, -0.102422207594, -0.016713079065, -+ 0.549075841904, -0.161096408963, 2.155813455582, -0.005001218989, -+ 0.083037458360, -0.044870752841, 0.505522191525, -0.145202502608, -+ 0.623111069202, -0.141991063952, -0.154108211398, 1.597298502922, -+ -0.061391282827, -0.167753636837, -0.025704355910, 0.182520583272, -+ 1.957115054131, 0.932696640491, 0.769961357117, -0.099604383111, -+ 2.153636932373, 0.175279796124, -0.155551761389, -0.089653611183, -+ -0.107515335083, -0.168032020330, -0.113423995674, -0.139319628477, -+ 0.089841812849, -0.073763631284, 0.211594089866, -0.068651281297, -+ 0.018605981022, 0.690416753292, -0.150658726692, 0.266040354967, -+ -0.151710823178, -0.135800719261, 0.010515870526, -0.169883996248, -+ -0.169960290194, 0.202769815922, 0.063187584281, -0.169236257672, -+ -0.166577890515, 0.100812792778, 1.599699616432, 0.245525524020, -+ -0.168275654316, 0.220552831888, -0.159705042839, 1.549110531807 -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_gelu_taylor, device_f32) { -+ -+ int const kN = 256; -+ int const kV = 4; -+ -+ using Element = float; -+ using Func = cutlass::epilogue::thread::GELU_taylor>; -+ -+ double tolerance = 0.005; -+ -+ // -+ // Construct workspace -+ // -+ cutlass::HostTensor tensor_Destination({1, kN}); -+ cutlass::HostTensor tensor_Source({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ tensor_Source.host_data(i) = Element(GELU_golden_input[i]); -+ } -+ -+ tensor_Destination.sync_device(); -+ tensor_Source.sync_device(); -+ -+ // -+ // Launch the kernel -+ // -+ dim3 grid(1,1,1); -+ dim3 block(kN / kV, 1, 1); -+ -+ test_Epilogue_thread_activation<<< grid, block >>>( -+ tensor_Destination.device_data(), -+ tensor_Source.device_data()); -+ -+ tensor_Destination.sync_host(); -+ -+ // -+ // Verify -+ // -+ -+ for (int i = 0; i < kN; ++i) { -+ Element input = Element(GELU_golden_input[i]); -+ Element got = tensor_Destination.host_data(i); -+ Element expected = Element(GELU_golden_output[i]); -+ -+ double rel_error = (double(got) - double(expected)) / double(expected); -+ -+ double tolerance_override = tolerance; -+ -+ switch (i) { -+ case 142: tolerance_override = 0.008; break; -+ case 203: tolerance_override = 0.03; break; -+ case 207: tolerance_override = 0.09; break; -+ case 218: tolerance_override = 0.013; break; -+ } -+ -+ EXPECT_LT(std::abs(rel_error), tolerance_override) -+ << "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected; -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_gelu_taylor, device_f16) { -+ -+ int const kN = 256; -+ int const kV = 8; -+ -+ using Element = cutlass::half_t; -+ using Func = cutlass::epilogue::thread::GELU_taylor>; -+ -+ double tolerance = 0.005; -+ -+ // -+ // Construct workspace -+ // -+ cutlass::HostTensor tensor_Destination({1, kN}); -+ cutlass::HostTensor tensor_Source({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ tensor_Source.host_data(i) = Element(GELU_golden_input[i]); -+ } -+ -+ tensor_Destination.sync_device(); -+ tensor_Source.sync_device(); -+ -+ // -+ // Launch the kernel -+ // -+ dim3 grid(1,1,1); -+ dim3 block(kN / kV, 1, 1); -+ -+ test_Epilogue_thread_activation<<< grid, block >>>( -+ tensor_Destination.device_data(), -+ tensor_Source.device_data()); -+ -+ tensor_Destination.sync_host(); -+ -+ // -+ // Verify -+ // -+ -+ for (int i = 0; i < kN; ++i) { -+ Element input = Element(GELU_golden_input[i]); -+ Element got = tensor_Destination.host_data(i); -+ Element expected = Element(GELU_golden_output[i]); -+ -+ double rel_error = (double(got) - double(expected)) / double(expected); -+ -+ double tolerance_override = tolerance; -+ -+ switch (i) { -+ case 36: tolerance_override = 0.006; break; -+ case 77: tolerance_override = 0.009; break; -+ case 95: tolerance_override = 0.008; break; -+ case 112: tolerance_override = 0.007; break; -+ case 171: tolerance_override = 0.006; break; -+ case 203: tolerance_override = 0.03; break; -+ case 207: tolerance_override = 0.15; break; -+ } -+ -+ EXPECT_LT(std::abs(rel_error), tolerance_override) -+ << "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected; -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/thread/linear_combination.cu b/3rdparty/cutlass/test/unit/epilogue/thread/linear_combination.cu -new file mode 100644 -index 0000000..548924e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/thread/linear_combination.cu -@@ -0,0 +1,205 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_linear_combination, device_side_f16_f32_value) { -+ -+ using Element = float; -+ using ElementOutput = cutlass::half_t; -+ int const kCount = 8; -+ -+ using LinearCombination = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ Element alpha = Element(2); -+ Element beta = Element(1); -+ -+ typename LinearCombination::Params params(alpha, beta); -+ -+ LinearCombination linear_combination_op(params); -+ -+ cutlass::Array source; -+ cutlass::Array accum; -+ -+ for (int i = 0; i < kCount; ++i) { -+ accum[i] = Element(i * 2); -+ source[i] = ElementOutput((i * 7 % 9) - 4); -+ } -+ -+ cutlass::Array destination = linear_combination_op(accum, source); -+ -+ for (int i = 0; i < kCount; ++i) { -+ -+ ElementOutput expected = ElementOutput( -+ alpha * accum[i] + -+ beta * Element(ElementOutput(source[i])) -+ ); -+ -+ ElementOutput got = destination[i]; -+ -+ EXPECT_TRUE(expected == got); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_linear_combination, device_side_f16_f32_ptr) { -+ -+ using Element = float; -+ using ElementOutput = cutlass::half_t; -+ int const kCount = 8; -+ -+ using LinearCombination = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ Element alpha = Element(2); -+ Element beta = Element(1); -+ -+ typename LinearCombination::Params params(&alpha, &beta); -+ -+ LinearCombination linear_combination_op(params); -+ -+ cutlass::Array source; -+ cutlass::Array accum; -+ -+ for (int i = 0; i < kCount; ++i) { -+ accum[i] = Element(i * 2); -+ source[i] = ElementOutput((i * 7 % 9) - 4); -+ } -+ -+ cutlass::Array destination = linear_combination_op(accum, source); -+ -+ for (int i = 0; i < kCount; ++i) { -+ -+ ElementOutput expected = ElementOutput( -+ alpha * accum[i] + -+ beta * Element(ElementOutput(source[i])) -+ ); -+ -+ ElementOutput got = destination[i]; -+ -+ EXPECT_TRUE(expected == got); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ int const kCount = 8; -+ -+ using LinearCombinationGELU = cutlass::epilogue::thread::LinearCombinationGELU< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ Element alpha = Element(1); -+ Element beta = Element(0); -+ -+ typename LinearCombinationGELU::Params params(&alpha, &beta); -+ -+ LinearCombinationGELU linear_combination_op(params); -+ -+ cutlass::Array accum; -+ -+ for (int i = 0; i < kCount; ++i) { -+ accum[i] = Element((float)i * 0.3f); -+ } -+ -+ cutlass::Array destination = linear_combination_op(accum, accum); -+ cutlass::epilogue::thread::GELU gelu_func; -+ -+ for (int i = 0; i < kCount; ++i) { -+ ElementOutput expected = gelu_func(accum[i]); -+ ElementOutput got = destination[i]; -+ EXPECT_TRUE(expected == got); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_linear_combination_gelu_taylor, device_side_f16_f16_ptr) { -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ int const kCount = 8; -+ -+ using LinearCombinationGELU = cutlass::epilogue::thread::LinearCombinationGELU< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ Element alpha = Element(1); -+ Element beta = Element(0); -+ -+ typename LinearCombinationGELU::Params params(&alpha, &beta); -+ -+ LinearCombinationGELU linear_combination_op(params); -+ -+ cutlass::Array accum; -+ -+ for (int i = 0; i < kCount; ++i) { -+ accum[i] = Element((float)i * 0.3f); -+ } -+ -+ cutlass::Array destination = linear_combination_op(accum, accum); -+ cutlass::epilogue::thread::GELU gelu_func; -+ -+ for (int i = 0; i < kCount; ++i) { -+ ElementOutput expected = gelu_func(accum[i]); -+ ElementOutput got = destination[i]; -+ EXPECT_TRUE(expected == got); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/epilogue/thread/linear_combination_planar_complex.cu b/3rdparty/cutlass/test/unit/epilogue/thread/linear_combination_planar_complex.cu -new file mode 100644 -index 0000000..cc027e0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/thread/linear_combination_planar_complex.cu -@@ -0,0 +1,286 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace epilogue { -+namespace thread { -+ -+using FunctorPlanarComplexF32F32 = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float>; -+ -+__global__ void epilogue_thread_functor_planar_complex_f32_f32( -+ float *output_ptr, -+ float const *accum_ptr, -+ float const *source_ptr, -+ typename FunctorPlanarComplexF32F32::Params params) { -+ -+ FunctorPlanarComplexF32F32 linear_combination_op(params); -+ -+ auto accum = *reinterpret_cast const *>(accum_ptr); -+ auto source = *reinterpret_cast const *>(source_ptr); -+ -+ *reinterpret_cast*>(output_ptr) = linear_combination_op(accum, source); -+} -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_linear_combination_planar_complex, f32) { -+ -+ using Element = float; -+ using ElementOutput = float; -+ int const kCount = 4; -+ -+ using Functor = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ cutlass::complex alpha(Element(2), Element(1)); -+ cutlass::complex beta(Element(1), Element(-1)); -+ -+ typename Functor::Params params(alpha, beta); -+ -+ Functor linear_combination_op(params); -+ -+ cutlass::ArrayPlanarComplex source; -+ cutlass::ArrayPlanarComplex accum; -+ -+ // Define arbitrary inputs -+ for (int i = 0; i < kCount; ++i) { -+ accum.real[i] = Element(i * 2); -+ accum.imag[i] = Element((i * 3 % 6) - 3); -+ source.real[i] = ElementOutput((i * 7 % 9) - 4); -+ source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); -+ } -+ -+ cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); -+ -+ // Verify each result -+ for (int i = 0; i < kCount; ++i) { -+ -+ cutlass::complex expected = alpha * cutlass::complex(accum.real[i], accum.imag[i]) + -+ beta * cutlass::complex(Element(source.real[i]), Element(source.imag[i])); -+ -+ cutlass::complex got(destination.real[i], destination.imag[i]); -+ -+ EXPECT_TRUE(ElementOutput(expected.real()) == got.real()); -+ EXPECT_TRUE(ElementOutput(expected.imag()) == got.imag()); -+ EXPECT_TRUE(expected.real() != Element(0) || expected.imag() != Element(0)); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace epilogue { -+namespace thread { -+ -+using FunctorPlanarComplexF16F32 = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ cutlass::half_t, -+ 4, -+ float, -+ float>; -+ -+__global__ void epilogue_thread_functor_planar_complex_f16_f32( -+ cutlass::half_t *output_ptr, -+ float const *accum_ptr, -+ cutlass::half_t const *source_ptr, -+ typename FunctorPlanarComplexF16F32::Params params, -+ int N) { -+ -+ FunctorPlanarComplexF16F32 linear_combination_op(params); -+ -+ -+ auto accum = *reinterpret_cast const *>(accum_ptr); -+ auto source = *reinterpret_cast const *>(source_ptr); -+ -+ #pragma unroll 1 -+ for (int n = 0; n < N; ++n) { -+ source = linear_combination_op(accum, source); -+ } -+ -+ *reinterpret_cast*>(output_ptr) = source; -+} -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_linear_combination_planar_complex, f16_f32) { -+ -+ using Element = float; -+ using ElementOutput = cutlass::half_t; -+ int const kCount = 4; -+ -+ using Functor = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ cutlass::complex alpha(Element(2), Element(1)); -+ cutlass::complex beta(Element(1), Element(-1)); -+ -+ typename Functor::Params params(alpha, beta); -+ -+ Functor linear_combination_op(params); -+ -+ cutlass::ArrayPlanarComplex source; -+ cutlass::ArrayPlanarComplex accum; -+ -+ // Define arbitrary inputs -+ for (int i = 0; i < kCount; ++i) { -+ accum.real[i] = Element(i * 2); -+ accum.imag[i] = Element((i * 3 % 6) - 3); -+ source.real[i] = ElementOutput((i * 7 % 9) - 4); -+ source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); -+ } -+ -+ cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); -+ -+ // Verify each result -+ for (int i = 0; i < kCount; ++i) { -+ -+ cutlass::complex expected = alpha * cutlass::complex(accum.real[i], accum.imag[i]) + -+ beta * cutlass::complex(Element(source.real[i]), Element(source.imag[i])); -+ -+ cutlass::complex got(destination.real[i], destination.imag[i]); -+ -+ EXPECT_TRUE(ElementOutput(expected.real()) == got.real()); -+ EXPECT_TRUE(ElementOutput(expected.imag()) == got.imag()); -+ EXPECT_TRUE(expected.real() != Element(0) || expected.imag() != Element(0)); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace epilogue { -+namespace thread { -+ -+using FunctorPlanarComplexF16F16 = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ cutlass::half_t, -+ 4, -+ cutlass::half_t, -+ cutlass::half_t>; -+ -+__global__ void epilogue_thread_functor_planar_complex_f16_f16( -+ cutlass::half_t *output_ptr, -+ cutlass::half_t const *accum_ptr, -+ cutlass::half_t const *source_ptr, -+ typename FunctorPlanarComplexF16F16::Params params, -+ int N) { -+ -+ FunctorPlanarComplexF16F16 linear_combination_op(params); -+ -+ auto accum = *reinterpret_cast const *>(accum_ptr); -+ auto source = *reinterpret_cast const *>(source_ptr); -+ -+ #pragma unroll 1 -+ for (int n = 0; n < N; ++n) { -+ source = linear_combination_op(accum, source); -+ } -+ -+ *reinterpret_cast*>(output_ptr) = source; -+} -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_linear_combination_planar_complex, f16_f16) { -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ int const kCount = 8; -+ -+ using Functor = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ cutlass::complex alpha(Element(2), Element(1)); -+ cutlass::complex beta(Element(1), Element(-1)); -+ -+ typename Functor::Params params(alpha, beta); -+ -+ Functor linear_combination_op(params); -+ -+ cutlass::ArrayPlanarComplex source; -+ cutlass::ArrayPlanarComplex accum; -+ -+ // Define arbitrary inputs -+ for (int i = 0; i < kCount; ++i) { -+ accum.real[i] = Element(i * 2); -+ accum.imag[i] = Element((i * 3 % 6) - 3); -+ source.real[i] = ElementOutput((i * 7 % 9) - 4); -+ source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); -+ } -+ -+ cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); -+ -+ // Verify each result -+ for (int i = 0; i < kCount; ++i) { -+ -+ cutlass::complex expected = alpha * cutlass::complex(accum.real[i], accum.imag[i]) + -+ beta * cutlass::complex(Element(source.real[i]), Element(source.imag[i])); -+ -+ cutlass::complex got(destination.real[i], destination.imag[i]); -+ -+ EXPECT_TRUE(ElementOutput(expected.real()) == got.real()); -+ EXPECT_TRUE(ElementOutput(expected.imag()) == got.imag()); -+ EXPECT_TRUE(expected.real() != Element(0) || expected.imag() != Element(0)); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_planar_complex.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_planar_complex.cu -new file mode 100644 -index 0000000..341e009 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_planar_complex.cu -@@ -0,0 +1,510 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+// Tensor Op -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+ -+// Volta Tensor Op -+#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" -+#include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h" -+ -+// Simt -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+// Epilogue components -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_planar_complex.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed_planar_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_threadblock_epilogue, planar_complex_f32_f32_tensor_op_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, -+ InstructionShape, -+ Element, LayoutA, -+ Element, LayoutB, -+ ElementAccumulator, cutlass::layout::RowMajor -+ >::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ Shape, -+ WarpMmaTensorOp, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpiloguePlanarComplexTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_threadblock_epilogue, planar_complex_f16_f32_tensor_op_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, -+ InstructionShape, -+ Element, LayoutA, -+ Element, LayoutB, -+ ElementAccumulator, cutlass::layout::RowMajor -+ >::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ Shape, -+ WarpMmaTensorOp, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpiloguePlanarComplexTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_threadblock_epilogue, planar_complex_f16_f16_tensor_op_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, -+ InstructionShape, -+ Element, LayoutA, -+ Element, LayoutB, -+ ElementAccumulator, cutlass::layout::RowMajor -+ >::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ Shape, -+ WarpMmaTensorOp, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpiloguePlanarComplexTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_threadblock_epilogue, planar_complex_f32_f32_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using Element = cutlass::half_t; -+ -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementAccumulator, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ Shape, -+ WarpMmaTensorOp, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpiloguePlanarComplexTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_threadblock_epilogue, planar_complex_simt_f32_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using Element = float; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ Shape, -+ WarpMmaSimt, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpiloguePlanarComplexTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_threadblock_epilogue, planar_complex_simt_f64_64x64_16x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using Element = double; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ Shape, -+ WarpMmaSimt, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpiloguePlanarComplexTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt.cu -new file mode 100644 -index 0000000..5bd1ddf ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt.cu -@@ -0,0 +1,1172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+ -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued single precision tests -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f32_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using Element = float; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f32_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using Element = float; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f32_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using Element = float; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f32_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using Element = float; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued double precision tests -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f64_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = double; -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f64_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = double; -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f64_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = double; -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f64_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = double; -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Complex-valued single-precision -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f32_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::complex; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f32_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::complex; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f32_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::complex; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Complex-valued double-precision -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f64_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::complex; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f64_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::complex; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f64_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::complex; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Quaternion-valued single-precision -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_quaternion_f32_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::Quaternion; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt_sm60.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt_sm60.cu -new file mode 100644 -index 0000000..36ed25b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt_sm60.cu -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+ -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued half precision tests -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM60_Epilogue_threadblock_epilogue, simt_f16_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM60_Epilogue_threadblock_epilogue, simt_f16_64x64_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<8, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM60_Epilogue_threadblock_epilogue, simt_f16_64x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<8, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM60_Epilogue_threadblock_epilogue, simt_f16_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<8, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM60_Epilogue_threadblock_epilogue, simt_f16_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<8, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM60_Epilogue_threadblock_epilogue, simt_f16_256x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<8, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt_sm61.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt_sm61.cu -new file mode 100644 -index 0000000..ff17915 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt_sm61.cu -@@ -0,0 +1,1120 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+ -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued Integer tests -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i32_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i32_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i32_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i32_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i32_128x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued Integer - single-precision float output -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_f32_i32_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = float; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_f32_i32_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = float; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_f32_i32_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = float; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_f32_i32_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = float; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_f32_i32_128x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = float; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued Integer tests - mixed-precision with clamping -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i8_i32_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i8_i32_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i8_i32_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i8_i32_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i8_i32_128x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_tensor_op.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_tensor_op.cu -new file mode 100644 -index 0000000..cdeb188 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_tensor_op.cu -@@ -0,0 +1,3076 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_64x64_64x64x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_64x64_32x32x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_128x128_64x64x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_128x64_64x32x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_64x128_32x64x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_32x128_32x64x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_128x32_64x32x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 32, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_256x128_64x64x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_128x256_64x64x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x64_64x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x64_32x3216) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x128_64x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x128_64x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x64_64x32x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x128_32x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_32x128_32x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x32_64x32x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_64x64_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_256x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_32x32_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_128x64_64x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Mixed precision tests -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_64x64_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_256x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_32x32_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_128x64_64x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// F16 acumulation -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_64x64_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_256x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_32x32_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_128x64_64x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_128x64_64x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_64x128_32x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_128x128_32x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, vec1_mixed_f16_f32_tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, vec1_mixed_f16_f32_tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+ -+TEST(SM75_Epilogue_threadblock_epilogue, vec1_tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, vec1_tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_volta_tensor_op.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_volta_tensor_op.cu -new file mode 100644 -index 0000000..62c86c8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_volta_tensor_op.cu -@@ -0,0 +1,2893 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" -+#include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h" -+ -+#include "cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_128x64_64x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_64x128_32x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_64x64_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_64x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_128x64_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_128x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_128x256_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_256x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Mixed: F32 accumulation -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_64x64_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_128x256_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_256x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_128x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_64x128_32x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_128x64_64x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// F32 accumulation, F32 output -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_64x64_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_64x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_128x64_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_128x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_128x256_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_256x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_128x64_64x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_64x128_32x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// This works -+TEST(SM70_Epilogue_threadblock_epilogue, vec8_f16_f32_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 8; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+// This works -+TEST(SM70_Epilogue_threadblock_epilogue, vec2_f16_f32_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 2; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// This fails -+TEST(SM70_Epilogue_threadblock_epilogue, vec1_f16_f32_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 1; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, vec1_f32_volta_tensor_op_128x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 1; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, vec1_f16_f32_volta_tensor_op_128x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 1; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, vec1_f16_f32_volta_tensor_op_128x256_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 1; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_tensor_op.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_tensor_op.cu -new file mode 100644 -index 0000000..1932765 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_tensor_op.cu -@@ -0,0 +1,879 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_drelu.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "epilogue_with_reduction_testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Disable selected tests on CUDA 11.1 -+// -+// -+#define ENABLE_BLOCKED_TESTS (!(__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 1)) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_64x64_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_64x64_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_128x64_64x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if ENABLE_BLOCKED_TESTS -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_128x64_64x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_256x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_256x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h -new file mode 100644 -index 0000000..c0e6fcc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h -@@ -0,0 +1,435 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for epilogues -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/complex.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace kernel { -+ -+template -+__global__ void epilogue_with_reduction_threadblock( -+ typename Epilogue::ElementVector *ptr_Reduction, -+ typename Epilogue::OutputTileIterator::Params params_D, -+ typename Epilogue::OutputTileIterator::Element *ptr_D, -+ typename Epilogue::OutputTileIterator::Params params_C, -+ typename Epilogue::OutputTileIterator::Element *ptr_C, -+ typename Epilogue::TensorTileIterator::Params params_Tensor, -+ typename Epilogue::TensorTileIterator::Element *ptr_Tensor, -+ typename Epilogue::OutputOp::Params params_output_op, -+ cutlass::MatrixCoord problem_size, -+ cutlass::TensorRef< -+ typename Epilogue::WarpMmaOperator::ElementC, -+ typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, -+ int epilogue_count = 1) { -+ -+ __shared__ typename Epilogue::SharedStorage shared_storage; -+ -+ int thread_idx = threadIdx.x; -+ int warp_idx = threadIdx.x / 32; -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Construct the epilogue -+ // -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_D( -+ params_D, -+ ptr_D, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_C( -+ params_C, -+ ptr_C, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::TensorTileIterator iterator_T( -+ params_Tensor, -+ ptr_Tensor, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Epilogue operator -+ Epilogue epilogue( -+ shared_storage, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // -+ // Initialize the accumulators -+ // -+ -+ int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); -+ int warp_m = warp_mn % Epilogue::WarpCount::kM; -+ int warp_n = warp_mn / Epilogue::WarpCount::kM; -+ -+ accumulator_ref.add_coord_offset({ -+ warp_m * Epilogue::WarpMmaOperator::Shape::kM, -+ warp_n * Epilogue::WarpMmaOperator::Shape::kN}); -+ -+ typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); -+ -+ typename Epilogue::AccumulatorTile accumulators; -+ -+ accumulators.clear(); -+ accumulator_iterator.load(accumulators); -+ -+#if 0 -+ // For debugging, enable this block of code to fill each accumulator element with its -+ // source thread ID. -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulators.size(); ++i) { -+ typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); -+ //typename Epilogue::WarpMmaOperator::ElementC x(i); -+ accumulators[i] = x; -+ } -+ -+ /* -+ #pragma unroll 1 -+ for (int tid = 0; tid < 32; ++tid) { -+ if (tid == thread_idx) { -+ printf("\nT%d: ", thread_idx); -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulators.size(); ++i) { -+ printf("%d ", int(accumulators[i])); -+ } -+ } -+ } -+ -+ if (thread_idx == 0) { -+ printf("\n\n"); -+ } -+ */ -+ -+ __syncthreads(); -+ -+#endif -+ -+ // -+ // Perform the epilogue operation -+ // -+ -+ typename Epilogue::OutputOp output_op(params_output_op); -+ -+ // Place the epilogue in a loop -+ for (int iter = 0; iter < epilogue_count; ++iter) { -+ epilogue(output_op, ptr_Reduction, iterator_D, accumulators, iterator_C, iterator_T); -+ } -+} -+ -+} // namespace kernel -+} // namespace test -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Epilogue_ -+> -+class EpilogueWithReductionTestbed { -+public: -+ -+ using Epilogue = Epilogue_; -+ using ElementAccumulator = typename Epilogue::ElementAccumulator; -+ using ElementCompute = typename Epilogue::OutputOp::ElementCompute; -+ using ElementTensor = typename Epilogue::TensorTileIterator::Element; -+ using ElementOutput = typename Epilogue::ElementOutput; -+ using OutputOpParams = typename Epilogue::OutputOp::Params; -+ -+public: -+ -+ // -+ // Data members -+ // -+ -+ cutlass::MatrixCoord quantized_size; -+ cutlass::HostTensor accumulator_tensor; -+ cutlass::HostTensor source_tensor; -+ cutlass::HostTensor output_tensor; -+ cutlass::HostTensor additional_tensor; -+ cutlass::HostTensor reduction_tensor; -+ -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ EpilogueWithReductionTestbed(): -+ quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), -+ accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ additional_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ reduction_tensor({1, Epilogue::Shape::kN}) { -+ -+ // -+ // Initialize problem space -+ // -+ -+ uint64_t seed = 2019; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ accumulator_tensor.host_view(), -+ seed, -+ 20, -+ -20, -+ 0); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ source_tensor.host_view(), -+ seed + 2018, -+ 20, -+ -20, -+ 0); -+ -+ cutlass::reference::host::TensorFill(additional_tensor.host_view(), ElementTensor(1)); -+ } -+ -+ bool run_all() { -+ -+ /* -+ double alpha_values[] = {1, 0, 2.25}; -+ double beta_values[] = {0, 1, -1.25}; -+ -+ // Test runtime explodes if we tried to test every case exhaustively. This tests the full -+ // output tile and several smaller sizes to stress predication. -+ for (int m_idx = 0; m_idx < 3; ++m_idx) { -+ for (int n_idx = 0; n_idx < 3; ++n_idx) { -+ -+ int m = quantized_size.row() - m_idx * 3; -+ int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; -+ -+ for (double const &alpha : alpha_values) { -+ for (double const &beta : beta_values) { -+ -+ bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ return true; -+ */ -+ -+ double alpha = 1; -+ double beta = 0; -+ -+ return run( -+ {quantized_size.row(), quantized_size.column()}, -+ {cutlass::from_real(alpha), cutlass::from_real(beta)}); -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::MatrixCoord problem_size, -+ OutputOpParams output_params) { -+ -+ // -+ // Initialize problem space -+ // -+ -+ ElementOutput default_output = ElementOutput(-127); -+ ElementAccumulator default_reduction = ElementAccumulator(); -+ -+ cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); -+ cutlass::reference::host::TensorFill(reduction_tensor.host_view(), default_reduction); -+ -+ accumulator_tensor.sync_device(); -+ output_tensor.sync_device(); -+ source_tensor.sync_device(); -+ additional_tensor.sync_device(); -+ reduction_tensor.sync_device(); -+ -+ // -+ // Initialize epilogue parameters -+ // -+ -+ typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); -+ typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); -+ typename Epilogue::TensorTileIterator::Params params_T(additional_tensor.device_ref().layout()); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1, 1); -+ dim3 block(Epilogue::WarpCount::kCount * 32, 1); -+ -+ test::kernel::epilogue_with_reduction_threadblock<<< grid, block >>>( -+ reduction_tensor.device_data(), -+ params_D, -+ output_tensor.device_data(), -+ params_C, -+ source_tensor.device_data(), -+ params_T, -+ additional_tensor.device_data(), -+ output_params, -+ problem_size, -+ accumulator_tensor.device_view()); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ // -+ // Verify results -+ // -+ output_tensor.sync_host(); -+ reduction_tensor.sync_host(); -+ -+ int errors = 0; -+ int const kMaxErrors = 5; -+ -+ // -+ // The output has two parts: -+ // - GEMM tensor epilogue in canonical layout -+ // - partial reduction in canonical row-major layout -+ // -+ -+ // Verify the GEMM tensor output -+ for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { -+ for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { -+ -+ cutlass::MatrixCoord coord{r, c}; -+ ElementOutput got = output_tensor.at(coord); -+ -+ ElementOutput expected; -+ if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { -+ -+ expected = ElementOutput(output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + -+ output_params.beta * ElementCompute(source_tensor.at(coord))); -+ } -+ else { -+ expected = default_output; -+ } -+ -+ if (expected != got) { -+ -+ using OutputIO = cutlass::ScalarIO; -+ -+ EXPECT_TRUE(false) -+ << "-------\n" -+ << "Error - output element (" << coord << ") - expected: " -+ << OutputIO(expected) -+ << ", got: " << OutputIO(got) << std::endl; -+ -+ ++errors; -+ } -+ } -+ } -+ -+ // Verify the partial reduction -+ for (int c = 0; c < quantized_size.column(); ++c) { -+ -+ ElementAccumulator reduction_acc = ElementAccumulator(); -+ -+ for (int r = 0; r < quantized_size.row(); ++r) { -+ reduction_acc += accumulator_tensor.at({r, c}); -+ } -+ -+ ElementAccumulator expected = default_reduction; -+ ElementAccumulator got = reduction_tensor.at({0, c}); -+ -+ if (c < problem_size.column()) { -+ expected = reduction_acc; -+ } -+ else { -+ expected = default_reduction; -+ } -+ -+ if (expected != got) { -+ -+ using OutputIO = cutlass::ScalarIO; -+ -+ EXPECT_TRUE(false) -+ << "-------\n" -+ << "Error - reduction element (" << c << ") - expected: " -+ << OutputIO(expected) -+ << ", got: " << OutputIO(got) << std::endl; -+ } -+ } -+ -+ // -+ // Report results on error -+ // -+ -+ if (errors) { -+ std::stringstream ss; -+ ss -+ << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" -+ << Epilogue::WarpTileIterator::WarpShape::kM << "x" -+ << Epilogue::WarpTileIterator::WarpShape::kN -+ << "_slice_" << Epilogue::WarpCount::kK << ".csv"; -+ -+ std::ofstream output_file(ss.str()); -+ output_file << output_tensor.host_view(); -+ } -+ -+ return !errors; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_wmma_tensor_op_sm70.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_wmma_tensor_op_sm70.cu -new file mode 100644 -index 0000000..bc835f2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_wmma_tensor_op_sm70.cu -@@ -0,0 +1,264 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/gemm/warp/default_mma_wmma_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed.h" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// F16 acumulation -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Epilogue_threadblock_epilogue, f16_wmma_tensor_op_64x64_64x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+ -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_wmma_tensor_op_64x128_64x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// F32 acumulation and F32 output -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Epilogue_threadblock_epilogue, f32_wmma_tensor_op_64x64_64x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+ -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+#endif //CUTLASS_ARCH_WMMA_ENABLED -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/output_tile_threadmap.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/output_tile_threadmap.cu -new file mode 100644 -index 0000000..7874363 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/output_tile_threadmap.cu -@@ -0,0 +1,549 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+ -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Prototype algorithm for partitioning a 4D space across warps to achieve several performance -+/// objectives: -+/// -+/// - coalesced memory accesses in units of 128 Byte lines -+/// - minimal address arithmetic -+/// - minimal predicate calculations -+/// -+struct OutputTileThreadMapExpr { -+ -+ struct Shape { -+ int column; -+ int row; -+ int group; -+ int cluster; -+ -+ Shape(int col = 1, int r = 1, int g = 1, int c = 1): -+ column(col), row(r), group(g), cluster(c) { } -+ }; -+ -+ int const kWarpSize = 32; -+ int const kMemoryAccessSize = 256; // size in bytes of the preferred memory access size -+ -+ // -+ // Data members -+ // -+ -+ Shape shape; -+ Shape count; -+ int threads; -+ int warp_count; -+ int elements_per_access; -+ int element_size; -+ -+ Shape iterations; -+ Shape delta; -+ Shape warp_partitions; -+ -+ int access_width_in_vectors; -+ int access_rows; -+ -+ // -+ // Methods -+ // -+ -+ OutputTileThreadMapExpr( -+ Shape shape_, -+ Shape count_, -+ int threads_, -+ int elements_per_access_, -+ int element_size_ -+ ): -+ shape(shape_), -+ count(count_), -+ threads(threads_), -+ warp_count(threads_ / kWarpSize), -+ elements_per_access(elements_per_access_), -+ element_size(element_size_) { -+ -+ int warps_remaining = warp_count; -+ -+ // clusters -+ if (shape.cluster > warp_count) { -+ iterations.cluster = shape.cluster / warp_count; -+ delta.cluster = shape.row * count.row * shape.group * count.group * shape.cluster / iterations.cluster; -+ warps_remaining = 1; -+ warp_partitions.cluster = warp_count; -+ } -+ else { -+ iterations.cluster = 1; -+ delta.cluster = 1; -+ warps_remaining = warp_count / shape.cluster; -+ warp_partitions.cluster = warps_remaining; -+ } -+ -+ // group size -+ if (shape.group > warps_remaining) { -+ iterations.group = shape.group / warps_remaining; -+ delta.group = shape.row * count.row * shape.group / iterations.group; -+ warps_remaining = 1; -+ warp_partitions.group = warps_remaining; -+ } -+ else { -+ iterations.group = 1; -+ delta.group = 1; -+ warps_remaining = warps_remaining / shape.group; -+ warp_partitions.group = warps_remaining; -+ } -+ -+ // Number of rows in a group -+ if (shape.row > warps_remaining) { -+ -+ // We must cover this shape within a warp -+ int shape_row = shape.row / warps_remaining; -+ int shape_width_vectors = shape.column / elements_per_access; -+ -+ // We would still like to minimize the number of strided increments. We can accomplish this -+ // by arranging the memory instructions as 2D, 128B wide accesses. -+ -+ int target_memory_access_width = kMemoryAccessSize / (elements_per_access * element_size / 8); -+ int target_rows_per_access = kWarpSize / target_memory_access_width; -+ -+ if (target_rows_per_access > shape_row) { -+ access_rows = shape_row; -+ access_width_in_vectors = kWarpSize / access_rows; -+ } -+ else { -+ -+ access_width_in_vectors = cutlass::platform::min( -+ shape_width_vectors, -+ cutlass::platform::min(kWarpSize, kMemoryAccessSize / (elements_per_access * element_size / 8))); -+ -+ access_rows = cutlass::platform::min(shape_row, kWarpSize / access_width_in_vectors); -+ } -+ -+ iterations.row = shape_row / access_rows; -+ delta.row = access_rows; -+ -+ iterations.column = shape_width_vectors / access_width_in_vectors; -+ delta.column = access_width_in_vectors * elements_per_access; -+ -+ warp_partitions.column = 1; -+ warp_partitions.row = 1; -+ } -+ else { -+ iterations.row = 1; -+ delta.row = 1; -+ iterations.column = (shape.column / elements_per_access) / kWarpSize; -+ delta.column = kWarpSize * elements_per_access; -+ -+ access_width_in_vectors = kWarpSize; -+ access_rows = 1; -+ -+ warp_partitions.row = 1; -+ warp_partitions.column = warps_remaining; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+std::ostream & operator<<(std::ostream &out, OutputTileThreadMapExpr::Shape const &shape) { -+ out << "col: " << shape.column << ", r: " << shape.row << ", g: " << shape.group << ", c: " << shape.cluster; -+ return out; -+} -+ -+std::ostream & operator<<(std::ostream &out, OutputTileThreadMapExpr const &map) { -+ out -+ << " shape(" << map.shape << ")\n" -+ << " count(" << map.count << ")\n" -+ << " iterations(" << map.iterations << ")\n" -+ << " delta(" << map.delta << ")\n" -+ << " warps(" << map.warp_partitions << ")\n" -+ << " access(width: " << map.access_width_in_vectors -+ << ", rows: " << map.access_rows -+ << ") x v" << map.elements_per_access -+ << ".b" << map.element_size << "\n"; -+ -+ return out; -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape, -+ typename Count, -+ int Threads, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct ThreadMapTestbed { -+ ThreadMapTestbed() { -+ OutputTileThreadMapExpr map( -+ { Shape::kColumn, Shape::kRow, Shape::kGroup, Shape::kCluster }, -+ { Count::kColumn, Count::kRow, Count::kGroup, Count::kCluster }, -+ Threads, -+ ElementsPerAccess, -+ ElementSize -+ ); -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap< -+ Shape, -+ Count, -+ Threads, -+ ElementsPerAccess, -+ ElementSize -+ >; -+ -+ using CompactThreadmap = typename ThreadMap::CompactedThreadMap; -+ -+ bool const kVerbose = false; -+ -+ if (kVerbose) { -+ -+ std::cout << map << std::endl; -+ -+ std::cout << "ThreadMap::warps remaining:\n" -+ << " for groups: " << ThreadMap::Detail::kWarpsRemainingForGroups << "\n" -+ << " for rows: " << ThreadMap::Detail::kWarpsRemainingForRows << "\n"; -+ -+ std::cout << "ThreadMap::Access:\n" -+ << " width: " << ThreadMap::Detail::kAccessWidth << "\n" -+ << " rows: " << ThreadMap::Detail::kAccessRows << "\n"; -+ -+ std::cout << "ThreadMap::RowArrangement::Iterations:\n" -+ << " row: " << int(ThreadMap::Detail::RowArrangement::kIterationsRow) << "\n"; -+ } -+ -+ EXPECT_EQ(int(ThreadMap::Delta::kCluster), map.delta.cluster); -+ EXPECT_EQ(int(ThreadMap::Delta::kGroup), map.delta.group); -+ EXPECT_EQ(int(ThreadMap::Delta::kRow), map.delta.row); -+ EXPECT_EQ(int(ThreadMap::Delta::kColumn), map.delta.column); -+ -+ EXPECT_EQ(int(ThreadMap::Iterations::kCluster), map.iterations.cluster); -+ EXPECT_EQ(int(ThreadMap::Iterations::kGroup), map.iterations.group); -+ EXPECT_EQ(int(ThreadMap::Iterations::kRow), map.iterations.row); -+ EXPECT_EQ(int(ThreadMap::Iterations::kColumn), map.iterations.column); -+ -+ if (kVerbose) { -+ std::cout << "Iterations(col: " << ThreadMap::Iterations::kColumn -+ << ", r: " << ThreadMap::Iterations::kRow -+ << ", g: " << ThreadMap::Iterations::kGroup -+ << ", c: " << ThreadMap::Iterations::kCluster << ")\n"; -+ -+ std::cout << "Delta(col: " << ThreadMap::Delta::kColumn -+ << ", r: " << ThreadMap::Delta::kRow -+ << ", g: " << ThreadMap::Delta::kGroup -+ << ", c: " << ThreadMap::Delta::kCluster << ")\n"; -+ -+ for (int tid = 0; tid < Threads; ++tid) { -+ auto output_coord = ThreadMap::initial_offset(tid); -+ auto source_coord = CompactThreadmap::initial_offset(tid); -+ -+ std::cout << "T" << tid << " - output: " << output_coord << ", source: " << source_coord << "\n"; -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(ThreadMap, f16_tensor_op_64x64_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 8, 1, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 1>; -+ int const kThreads = 32; -+ int const kElementsPerAccess = 8; -+ int const kElementSize = 16; -+ -+ ThreadMapTestbed(); -+} -+ -+ -+TEST(ThreadMap, f16_tensor_op_128x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 8, 2, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 8; -+ int const kElementSize = 16; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f16_tensor_op_256x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 8; -+ int const kElementSize = 16; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f16_tensor_op_128x256_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<256, 8, 2, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 8; -+ int const kElementSize = 16; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f16_tensor_op_128x64_64x32x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 8, 2, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 8; -+ int const kElementSize = 16; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f16_tensor_op_64x128_128x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 8, 1, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 8; -+ int const kElementSize = 16; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_tensor_op_64x64_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 8, 1, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 1>; -+ int const kThreads = 32; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_tensor_op_128x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 8, 2, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_tensor_op_256x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_tensor_op_128x256_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<256, 8, 2, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_tensor_op_128x64_64x32x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 8, 2, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_tensor_op_64x128_128x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 8, 1, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(ThreadMap, f32_volta_tensor_op_64x64_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 2, 4, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 32; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_volta_tensor_op_64x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 2, 4, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 64; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_volta_tensor_op_128x64_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 2, 4, 2, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 64; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_volta_tensor_op_128x64_64x32x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 2, 4, 2, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_volta_tensor_op_128x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 2, 4, 2, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_volta_tensor_op_128x256_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<256, 2, 4, 2, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_volta_tensor_op_256x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 2, 4, 4, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(ThreadMap, simt_32x64_32x64x1) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 1, 4, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 32; -+ int const kElementsPerAccess = 1; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, simt_32x128_32x64x1) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 1, 4, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 64; -+ int const kElementsPerAccess = 1; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, simt_64x128_32x64x1) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 1, 4, 2, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 1; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, simt_128x128_32x64x1) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 1, 4, 4, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 1; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/predicated_tile_iterator.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/predicated_tile_iterator.cu -new file mode 100644 -index 0000000..287b51a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/predicated_tile_iterator.cu -@@ -0,0 +1,1125 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+ -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_store_iterator( -+ typename TileIterator::Params params, -+ typename TileIterator::TensorRef ref, -+ cutlass::MatrixCoord extent) { -+ -+ TileIterator iterator(params, ref.data(), extent, threadIdx.x, {0, 0}); -+ -+ typename TileIterator::Fragment fragment; -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int iter = 0; iter < TileIterator::ThreadMap::Count::kTile; ++iter) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < TileIterator::Fragment::kElements; ++i) { -+ typename TileIterator::Element tidx(iter + 1); -+ fragment[i] = tidx; -+ } -+ -+ iterator.store(fragment); -+ -+ ++iterator; -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+static bool verify_footprint(cutlass::TensorView view, cutlass::MatrixCoord extent) { -+ -+ for (int r = 0; r < view.extent().row(); ++r) { -+ for (int c = 0; c < view.extent().column(); ++c) { -+ -+ cutlass::MatrixCoord coord{r, c}; -+ bool within = coord < extent; -+ if (within) { -+ if (view.at(coord) == T(0)) { -+ return false; -+ } -+ } -+ else { -+ if (view.at(coord) != T(0)) { -+ return false; -+ } -+ } -+ } -+ } -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, tensor_op_64x64x32_64x64x8) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 32; -+ -+ // -+ // The following tests were used to develop the OutputTileOptimalThreadMap -+ // metaprogram. The definitions in the disabled blocks of code in this and -+ // the following tests are hand-written quantities. They are expected to -+ // match what is defined in the ThreadMap. -+ // -+ -+ #if 1 -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<64, 8, 1, 1, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+ #else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<64, 64>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 64, // column -+ 8, // row -+ 1, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 1, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 2, // row -+ 1, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 8, // row -+ 1, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+ #endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{64, 64}; -+ cutlass::MatrixCoord output_extent{62, 56}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("tensor_op_64x64x32_64x64x8.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, tensor_op_128x64x32_64x64x8) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 64; -+ -+ #if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<128, 8, 2, 1, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+ #else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<64, 128>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 64, // column -+ 8, // row -+ 2, // group -+ 1, // cluster -+ 8 // tile -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 2, // row -+ 2, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 64, // group -+ 1, // cluster -+ 1 // tile -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 8, // row -+ 1, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+ #endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{128, 64}; -+ cutlass::MatrixCoord output_extent{125, 56}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("tensor_op_128x64x32_64x64x8.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, tensor_op_128x256x32_64x64x8) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 256; -+ -+ #if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<256, 8, 2, 1, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+ #else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<256, 128>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 256, // column -+ 8, // row -+ 2, // group -+ 1, // cluster -+ 8 // tile -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 2, // row -+ 2, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 64, // group -+ 1, // cluster -+ 1 // tile -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 8, // row -+ 1, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+ #endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{128, 256}; -+ cutlass::MatrixCoord output_extent{123, 252}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("tensor_op_128x256x32_64x64x8.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, volta_tensor_op_64x64x32_64x64x4) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 32; -+ -+#if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<64, 2, 4, 1, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+#else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<64, 8>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 64, // column -+ 2, // row -+ 4, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 1, // row -+ 4, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 1, // row -+ 8, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 2, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+#endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{64, 64}; -+ cutlass::MatrixCoord output_extent{62, 56}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("volta_tensor_op_64x64x32_64x64x4.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, volta_tensor_op_64x128x32_32x64x4) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 128; -+ -+#if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<128, 2, 4, 1, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+#else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<128, 8>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 128, // column -+ 2, // row -+ 2, // group -+ 2, // cluster -+ 8 // iterations -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 1, // row -+ 1, // group -+ 2, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 1, // row -+ 8, // group -+ 32, // cluster -+ 1 // iterations -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 4, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+#endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{64, 128}; -+ cutlass::MatrixCoord output_extent{57, 124}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("volta_tensor_op_64x128x32_32x64x4.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, volta_tensor_op_128x256x32_64x64x4) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 256; -+ -+#if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<256, 2, 4, 2, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+#else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<256, 16>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 256, // column -+ 2, // row -+ 4, // group -+ 2, // cluster -+ 8 // iterations -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 1, // row -+ 2, // group -+ 2, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 1, // row -+ 16, // group -+ 64, // cluster -+ 1 // iterations -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 2, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+#endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{128, 256}; -+ cutlass::MatrixCoord output_extent{128, 256}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed || true) { -+ std::ofstream output("volta_tensor_op_128x256x32_64x64x4.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+ -+TEST(PredicatedTileIterator, volta_tensor_op_256x128x32_64x64x4) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 256; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<128, 2, 4, 4, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{ 256, 128 }; -+ cutlass::MatrixCoord output_extent{ 256, 128 }; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1, 1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator <<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed || true) { -+ std::ofstream output("volta_tensor_op_256x128x32_64x64x4.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, simt_32x64x8_32x64x1) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ static int const kThreads = 32; -+ -+#if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<64, 1, 4, 1, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+#else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<64, 4>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 64, // column -+ 1, // row -+ 4, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 2, // column -+ 1, // row -+ 4, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 32, // column -+ 1, // row -+ 4, // group -+ 16, // cluster -+ 1 // iterations -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 2, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+#endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{32, 64}; -+ cutlass::MatrixCoord output_extent{27, 63}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("simt_32x64x8_32x64x1.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, simt_128x128x8_32x64x1) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ static int const kThreads = 256; -+ -+#if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<128, 1, 4, 4, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+#else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<128, 16>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 128, // column -+ 1, // row -+ 4, // group -+ 4, // cluster -+ 1 // iterations -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 2, // column -+ 1, // row -+ 2, // group -+ 4, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 32, // column -+ 1, // row -+ 8, // group -+ 32, // cluster -+ 1 // iterations -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 2, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+#endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{128, 128}; -+ cutlass::MatrixCoord output_extent{123, 121}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("simt_128x128x8_32x64x1.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/testbed.h b/3rdparty/cutlass/test/unit/epilogue/threadblock/testbed.h -new file mode 100644 -index 0000000..c2982c3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/testbed.h -@@ -0,0 +1,371 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for epilogues -+*/ -+#pragma once -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace kernel { -+ -+template -+__global__ void epilogue_threadblock( -+ typename Epilogue::OutputTileIterator::Params params_D, -+ typename Epilogue::OutputTileIterator::Element *ptr_D, -+ typename Epilogue::OutputTileIterator::Params params_C, -+ typename Epilogue::OutputTileIterator::Element *ptr_C, -+ typename Epilogue::OutputOp::Params params_output_op, -+ cutlass::MatrixCoord problem_size, -+ cutlass::TensorRef< -+ typename Epilogue::WarpMmaOperator::ElementC, -+ typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, -+ int epilogue_count = 1) { -+ -+ __shared__ typename Epilogue::SharedStorage shared_storage; -+ -+ int thread_idx = threadIdx.x; -+ int warp_idx = threadIdx.x / 32; -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Construct the epilogue -+ // -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_D( -+ params_D, -+ ptr_D, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_C( -+ params_C, -+ ptr_C, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Epilogue operator -+ Epilogue epilogue( -+ shared_storage, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // -+ // Initialize the accumulators -+ // -+ -+ int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); -+ int warp_m = warp_mn % Epilogue::WarpCount::kM; -+ int warp_n = warp_mn / Epilogue::WarpCount::kM; -+ -+ accumulator_ref.add_coord_offset({ -+ warp_m * Epilogue::WarpMmaOperator::Shape::kM, -+ warp_n * Epilogue::WarpMmaOperator::Shape::kN}); -+ -+ typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); -+ -+ typename Epilogue::AccumulatorTile accumulators; -+ -+ accumulators.clear(); -+ accumulator_iterator.load(accumulators); -+ -+#if 0 -+ // For debugging, enable this block of code to fill each accumulator element with its -+ // source thread ID. -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulators.size(); ++i) { -+ typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); -+ //typename Epilogue::WarpMmaOperator::ElementC x(i); -+ accumulators[i] = x; -+ } -+ -+ /* -+ #pragma unroll 1 -+ for (int tid = 0; tid < 32; ++tid) { -+ if (tid == thread_idx) { -+ printf("\nT%d: ", thread_idx); -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulators.size(); ++i) { -+ printf("%d ", int(accumulators[i])); -+ } -+ } -+ } -+ -+ if (thread_idx == 0) { -+ printf("\n\n"); -+ } -+ */ -+ -+ __syncthreads(); -+ -+#endif -+ -+ // -+ // Perform the epilogue operation -+ // -+ -+ typename Epilogue::OutputOp output_op(params_output_op); -+ -+ // Place the epilogue in a loop -+ for (int iter = 0; iter < epilogue_count; ++iter) { -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ } -+} -+ -+} // namespace kernel -+} // namespace test -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Epilogue_ -+> -+class EpilogueTestbed { -+public: -+ -+ using Epilogue = Epilogue_; -+ using ElementAccumulator = typename Epilogue::ElementAccumulator; -+ using ElementCompute = typename Epilogue::OutputOp::ElementCompute; -+ using ElementOutput = typename Epilogue::ElementOutput; -+ using OutputOpParams = typename Epilogue::OutputOp::Params; -+ -+public: -+ -+ // -+ // Data members -+ // -+ -+ cutlass::MatrixCoord quantized_size; -+ cutlass::HostTensor accumulator_tensor; -+ cutlass::HostTensor source_tensor; -+ cutlass::HostTensor output_tensor; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ EpilogueTestbed(): -+ quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), -+ accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { -+ -+ // -+ // Initialize problem space -+ // -+ -+ uint64_t seed = 2019; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ accumulator_tensor.host_view(), -+ seed, -+ 20, -+ -20, -+ 0); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ source_tensor.host_view(), -+ seed + 2018, -+ 20, -+ -20, -+ 0); -+ } -+ -+ bool run_all() { -+ -+ double alpha_values[] = {1, 0, 2.25}; -+ double beta_values[] = {0, 1, -1.25}; -+ -+ // Test runtime explodes if we tried to test every case exhaustively. This tests the full -+ // output tile and several smaller sizes to stress predication. -+ for (int m_idx = 0; m_idx < 3; ++m_idx) { -+ for (int n_idx = 0; n_idx < 3; ++n_idx) { -+ -+ int m = quantized_size.row() - m_idx * 3; -+ int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; -+ -+ for (double const &alpha : alpha_values) { -+ for (double const &beta : beta_values) { -+ -+ bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::MatrixCoord problem_size, -+ OutputOpParams output_params) { -+ -+ // -+ // Initialize problem space -+ // -+ -+ ElementOutput default_output = ElementOutput(-127); -+ cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); -+ -+ accumulator_tensor.sync_device(); -+ output_tensor.sync_device(); -+ source_tensor.sync_device(); -+ -+ // -+ // Initialize epilogue parameters -+ // -+ -+ typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); -+ typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1, 1); -+ dim3 block(Epilogue::WarpCount::kCount * 32, 1); -+ -+ test::kernel::epilogue_threadblock<<< grid, block >>>( -+ params_D, -+ output_tensor.device_data(), -+ params_C, -+ source_tensor.device_data(), -+ output_params, -+ problem_size, -+ accumulator_tensor.device_view()); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ // -+ // Verify results -+ // -+ output_tensor.sync_host(); -+ -+ int errors = 0; -+ int const kMaxErrors = 5; -+ -+ for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { -+ for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { -+ -+ cutlass::MatrixCoord coord{r, c}; -+ ElementOutput got = output_tensor.at(coord); -+ -+ ElementOutput expected; -+ if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { -+ ElementCompute intermediate = -+ output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + -+ output_params.beta * ElementCompute(source_tensor.at(coord)); -+ -+ if (std::numeric_limits::is_integer -+ && !std::numeric_limits::is_integer) { -+ std::fesetround(FE_TONEAREST); -+ expected = ElementOutput(std::nearbyint(float(cutlass::real(intermediate)))); -+ } else { -+ expected = ElementOutput(intermediate); -+ } -+ } else { -+ expected = default_output; -+ } -+ -+ if (expected != got) { -+ -+ using OutputIO = cutlass::ScalarIO; -+ -+ EXPECT_TRUE(false) -+ << "-------\n" -+ << "Error - output element (" << coord << ") - expected: " -+ << OutputIO(expected) -+ << ", got: " << OutputIO(got) -+ << ", accum: " << (accumulator_tensor.at(coord)) -+ << ", source: " << OutputIO(source_tensor.at(coord)) -+ << ", alpha: " << (output_params.alpha) -+ << ", beta: " << (output_params.beta) << "\n"; -+ -+ ++errors; -+ } -+ } -+ } -+ -+ // -+ // Report results on error -+ // -+ -+ if (errors) { -+ std::stringstream ss; -+ ss -+ << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" -+ << Epilogue::WarpTileIterator::WarpShape::kM << "x" -+ << Epilogue::WarpTileIterator::WarpShape::kN -+ << "_slice_" << Epilogue::WarpCount::kK << ".csv"; -+ -+ std::ofstream output_file(ss.str()); -+ output_file << output_tensor.host_view(); -+ } -+ -+ return !errors; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h b/3rdparty/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h -new file mode 100644 -index 0000000..68da6f4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h -@@ -0,0 +1,394 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for epilogues -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/complex.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/util/host_tensor_planar_complex.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace kernel { -+ -+template -+__global__ void epilogue_planar_complex_threadblock( -+ typename Epilogue::OutputTileIterator::Params params_D, -+ typename Epilogue::OutputTileIterator::Element *ptr_D, -+ int64_t imaginary_stride_D, -+ typename Epilogue::OutputTileIterator::Params params_C, -+ typename Epilogue::OutputTileIterator::Element *ptr_C, -+ int64_t imaginary_stride_C, -+ typename Epilogue::OutputOp::Params params_output_op, -+ cutlass::MatrixCoord problem_size, -+ cutlass::TensorRef< -+ typename Epilogue::WarpMmaOperator::ElementC, -+ typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, -+ int64_t imaginary_stride_accum, -+ int epilogue_count = 1) { -+ -+ __shared__ typename Epilogue::SharedStorage shared_storage; -+ -+ int thread_idx = threadIdx.x; -+ int warp_idx = threadIdx.x / 32; -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Construct the epilogue -+ // -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_D_real( -+ params_D, -+ ptr_D, -+ problem_size, -+ thread_idx -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_D_imag( -+ params_D, -+ ptr_D + imaginary_stride_D, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_C_real( -+ params_C, -+ ptr_C, -+ problem_size, -+ thread_idx -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_C_imag( -+ params_C, -+ ptr_C + imaginary_stride_C, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Epilogue operator -+ Epilogue epilogue( -+ shared_storage, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // -+ // Initialize the accumulators -+ // -+ -+ int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); -+ int warp_m = warp_mn % Epilogue::WarpCount::kM; -+ int warp_n = warp_mn / Epilogue::WarpCount::kM; -+ -+ accumulator_ref.add_coord_offset({ -+ warp_m * Epilogue::WarpMmaOperator::Shape::kM, -+ warp_n * Epilogue::WarpMmaOperator::Shape::kN}); -+ -+ // -+ // Load accumulators -+ // -+ -+ typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); -+ -+ typename Epilogue::AccumulatorTile accumulators; -+ -+ accumulators.clear(); -+ -+ accumulator_iterator.load(accumulators.real); -+ accumulator_iterator.load_with_pointer_offset(accumulators.imag, imaginary_stride_accum); -+ -+ // -+ // Perform the epilogue operation -+ // -+ -+ typename Epilogue::OutputOp output_op(params_output_op); -+ -+ // Place the epilogue in a loop so assembly is clearly visible -+ for (int iter = 0; iter < epilogue_count; ++iter) { -+ epilogue( -+ output_op, -+ iterator_D_real, -+ iterator_D_imag, -+ accumulators, -+ iterator_C_real, -+ iterator_C_imag); -+ } -+} -+ -+} // namespace kernel -+} // namespace test -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Epilogue_ -+> -+class EpiloguePlanarComplexTestbed { -+public: -+ -+ using Epilogue = Epilogue_; -+ using ElementAccumulator = typename Epilogue::ElementAccumulator; -+ using ElementCompute = typename Epilogue::OutputOp::ElementCompute; -+ using ElementOutput = typename Epilogue::ElementOutput; -+ using OutputOpParams = typename Epilogue::OutputOp::Params; -+ -+ using ComplexElementOutput = cutlass::complex; -+ using ComplexElementAccumulator = cutlass::complex; -+ using ComplexElementCompute = cutlass::complex; -+ -+public: -+ -+ // -+ // Data members -+ // -+ -+ cutlass::MatrixCoord quantized_size; -+ cutlass::HostTensorPlanarComplex accumulator_tensor; -+ cutlass::HostTensorPlanarComplex source_tensor; -+ cutlass::HostTensorPlanarComplex output_tensor; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ EpiloguePlanarComplexTestbed(): -+ quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), -+ accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { -+ -+ // -+ // Initialize problem space -+ // -+ -+ #if 1 -+ uint64_t seed = 2019; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ accumulator_tensor.host_view(), -+ seed, -+ 20, -+ -20, -+ 0); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ source_tensor.host_view(), -+ seed + 2018, -+ 20, -+ -20, -+ 0); -+ #else -+ -+ cutlass::reference::host::BlockFillSequential(accumulator_tensor.host_data(), accumulator_tensor.capacity()); -+ -+ #endif -+ } -+ -+ bool run_all() { -+ -+ cutlass::complex alpha_values[3]; -+ -+ alpha_values[0] = cutlass::complex(1, 0); -+ alpha_values[1] = cutlass::complex(0, 0); -+ alpha_values[2] = cutlass::complex(2.25f, -0.5f); -+ -+ cutlass::complex beta_values[3]; -+ -+ beta_values[0] = cutlass::complex(0, 0); -+ beta_values[1] = cutlass::complex(1, 0); -+ beta_values[2] = cutlass::complex(0.5f, -2.25f); -+ -+ // Test runtime explodes if we tried to test every case exhaustively. This tests the full -+ // output tile and several smaller sizes to stress predication. -+ for (int m_idx = 0; m_idx < 3; ++m_idx) { -+ for (int n_idx = 0; n_idx < 3; ++n_idx) { -+ -+ cutlass::MatrixCoord problem_size( -+ quantized_size.row() - m_idx * 3, -+ quantized_size.column() - n_idx * Epilogue::kElementsPerAccess -+ ); -+ -+ for (auto const &alpha : alpha_values) { -+ for (auto const &beta : beta_values) { -+ -+ bool passed = run(problem_size, {alpha, beta}); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::MatrixCoord problem_size, -+ OutputOpParams output_params) { -+ -+ // -+ // Initialize problem space -+ // -+ -+ ComplexElementOutput default_output = ComplexElementOutput(ElementOutput(-127), ElementOutput(-101)); -+ -+ cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); -+ -+ accumulator_tensor.sync_device(); -+ output_tensor.sync_device(); -+ source_tensor.sync_device(); -+ -+ // -+ // Initialize epilogue parameters -+ // -+ -+ typename Epilogue::OutputTileIterator::Params params_D(output_tensor.layout()); -+ typename Epilogue::OutputTileIterator::Params params_C(source_tensor.layout()); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1, 1); -+ dim3 block(Epilogue::WarpCount::kCount * 32, 1); -+ -+ test::kernel::epilogue_planar_complex_threadblock<<< grid, block >>>( -+ params_D, -+ output_tensor.device_data(), -+ output_tensor.imaginary_stride(), -+ params_C, -+ source_tensor.device_data(), -+ source_tensor.imaginary_stride(), -+ output_params, -+ problem_size, -+ accumulator_tensor.device_view_real(), -+ accumulator_tensor.imaginary_stride() -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ // -+ // Verify results -+ // -+ output_tensor.sync_host(); -+ -+ int errors = 0; -+ int const kMaxErrors = 5; -+ -+ for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { -+ for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { -+ -+ cutlass::MatrixCoord coord{r, c}; -+ ComplexElementOutput got = output_tensor.at(coord); -+ -+ ComplexElementOutput expected = default_output; -+ -+ if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { -+ -+ ComplexElementOutput src = source_tensor.at(coord); -+ -+ ComplexElementCompute tmp = -+ output_params.alpha * ComplexElementCompute(accumulator_tensor.at(coord)) + -+ output_params.beta * ComplexElementCompute(src.real(), src.imag()); -+ -+ expected = ComplexElementOutput(ElementOutput(tmp.real()), ElementOutput(tmp.imag())); -+ } -+ -+ if (expected != got) { -+ -+ using OutputIO = cutlass::ScalarIO; -+ -+ EXPECT_TRUE(false) -+ << "-------\n" -+ << "Error - output element (" << coord << ") - expected: " -+ << OutputIO(expected) -+ << ", got: " << OutputIO(got) << std::endl; -+ -+ ++errors; -+ } -+ } -+ } -+ -+ // -+ // Report results on error -+ // -+ -+ if (errors) { -+ -+ -+ std::cout << "Incorrect result for problem(" -+ << problem_size.row() << ", " -+ << problem_size.column() << ") for alpha: " << output_params.alpha << ", beta: " << output_params.beta << std::endl; -+ -+ std::stringstream ss; -+ ss -+ << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" -+ << Epilogue::WarpTileIterator::WarpShape::kM << "x" -+ << Epilogue::WarpTileIterator::WarpShape::kN -+ << "_slice_" << Epilogue::WarpCount::kK << ".csv"; -+ -+ std::ofstream output_file(ss.str()); -+ output_file << output_tensor.host_view(); -+ -+ std::cout << "Wrote workspace to '" << ss.str() << "'" << std::endl; -+ } -+ -+ return !errors; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu b/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu -new file mode 100644 -index 0000000..e7e15ce ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu -@@ -0,0 +1,194 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_warp_FragmentIterator, mma_f32_64x64x8) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ Shape, -+ typename MmaTensorOp::Policy::Operator::Shape, -+ typename MmaTensorOp::Policy::Operator::ElementC, -+ typename MmaTensorOp::Policy::Operator::FragmentC, -+ cutlass::layout::RowMajor -+ >; -+ -+ // This test just prints things. -+ #if 0 -+ typename MmaTensorOp::FragmentC accum; -+ -+ std::cout << "Native accumulators:\n"; -+ -+ for (int i = 0; i < MmaTensorOp::FragmentC::kElements; ++i) { -+ accum[i] = ElementC(i); -+ -+ std::cout << accum[i] << " "; -+ if (i && !((i + 1) % 4)) { -+ std::cout << "\n"; -+ } -+ } -+ -+ std::cout << std::endl; -+ -+ std::cout << "FragmentIterator::Policy = { \n" -+ << " kAccessesPerInstruction: " << FragmentIterator::Policy::kIterationsPerInstruction << "\n" -+ << " kAccumulatorRowStride: " << FragmentIterator::Policy::kAccumulatorRowStride << "\n" -+ << " kAccumulatorColumnStride: " << FragmentIterator::Policy::kAccumulatorColumnStride << "\n" -+ << " kIterations: " << FragmentIterator::Policy::kIterations << "\n" -+ << " }" << std::endl; -+ -+ FragmentIterator fragment_iterator(accum); -+ -+ for (int iter = 0; iter < FragmentIterator::kIterations; ++iter) { -+ -+ typename FragmentIterator::Fragment frag; -+ -+ fragment_iterator.load(frag); -+ -+ std::cout << "Iteration " << iter << ":\n"; -+ -+ for (int i = 0; i < FragmentIterator::Fragment::kElements; ++i) { -+ std::cout << frag[i] << " "; -+ } -+ -+ std::cout << std::endl; -+ -+ ++fragment_iterator; -+ } -+ #endif -+} -+ -+TEST(SM75_Epilogue_warp_FragmentIterator, mma_f16_64x64x8) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ Shape, -+ typename MmaTensorOp::Policy::Operator::Shape, -+ typename MmaTensorOp::Policy::Operator::ElementC, -+ typename MmaTensorOp::Policy::Operator::FragmentC, -+ cutlass::layout::RowMajor -+ >; -+ -+ // This test just prints things. -+ #if 0 -+ typename MmaTensorOp::FragmentC accum; -+ -+ std::cout << "Native accumulators:\n"; -+ -+ for (int i = 0; i < MmaTensorOp::FragmentC::kElements; ++i) { -+ accum[i] = ElementC(i); -+ -+ std::cout << (float)accum[i] << " "; -+ if (i && !((i + 1) % 4)) { -+ std::cout << "\n"; -+ } -+ } -+ -+ std::cout << std::endl; -+ -+ std::cout << "FragmentIterator::Policy = { \n" -+ << " kAccessesPerInstruction: " << FragmentIterator::Policy::kIterationsPerInstruction << "\n" -+ << " kAccumulatorRowStride: " << FragmentIterator::Policy::kAccumulatorRowStride << "\n" -+ << " kAccumulatorColumnStride: " << FragmentIterator::Policy::kAccumulatorColumnStride << "\n" -+ << " kIterations: " << FragmentIterator::Policy::kIterations << "\n" -+ << " }" << std::endl; -+ -+ FragmentIterator fragment_iterator(accum); -+ -+ for (int iter = 0; iter < FragmentIterator::kIterations; ++iter) { -+ -+ typename FragmentIterator::Fragment frag; -+ -+ fragment_iterator.load(frag); -+ -+ std::cout << "Iteration " << iter << ":\n"; -+ -+ for (int i = 0; i < FragmentIterator::Fragment::kElements; ++i) { -+ std::cout << (float)frag[i] << " "; -+ } -+ -+ std::cout << std::endl; -+ -+ ++fragment_iterator; -+ } -+ #endif -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_volta_tensor_op.cu b/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_volta_tensor_op.cu -new file mode 100644 -index 0000000..ffb670f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_volta_tensor_op.cu -@@ -0,0 +1,216 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" -+#include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_warp_FragmentIterator, mma_f16_64x64x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ cutlass::HostTensor accumulator_tensor({Shape::kM, Shape::kN}); -+ -+ cutlass::reference::host::TensorFill(accumulator_tensor.host_view(), ElementC(-1)); -+ -+ for (int tid = 0; tid < 1; ++tid) { -+ typename MmaTensorOp::IteratorC::Fragment accumulator_tile; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulator_tile.size(); ++i) { -+ accumulator_tile[i] = ElementC(i); -+ } -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< -+ cutlass::gemm::GemmShape<64, 64, 4>, -+ cutlass::gemm::GemmShape<32, 32, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >; -+ -+ FragmentIterator frag_iterator(accumulator_tile); -+ -+ typename FragmentIterator::Fragment frag; -+ -+ for (int iter = 0; iter < FragmentIterator::kIterations; ++iter) { -+ frag_iterator.load(frag); -+ ++frag_iterator; -+ -+ #if 0 -+ std::cout << "T" << tid << ": "; -+ for (int i = 0; i < frag.size(); ++i) { -+ std::cout << " " << frag[i]; -+ } -+ std::cout << std::endl; -+ #endif -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_warp_FragmentIterator, mma_f32_64x64x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ cutlass::HostTensor accumulator_tensor({Shape::kM, Shape::kN}); -+ -+ cutlass::reference::host::TensorFill(accumulator_tensor.host_view(), ElementC(-1)); -+ -+ for (int tid = 0; tid < 1; ++tid) { -+ typename MmaTensorOp::IteratorC::Fragment accumulator_tile; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulator_tile.size(); ++i) { -+ accumulator_tile[i] = ElementC(i); -+ } -+ -+ typename MmaTensorOp::IteratorC iterator_C(accumulator_tensor.host_ref(), tid); -+ iterator_C.store(accumulator_tile); -+ } -+ -+ /* -+ std::ofstream output("volta_mma_f32_64x64x4.csv"); -+ output << accumulator_tensor.host_view() << std::endl; -+ */ -+ -+ for (int tid = 0; tid < 1; ++tid) { -+ typename MmaTensorOp::IteratorC::Fragment accumulator_tile; -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< -+ cutlass::gemm::GemmShape<64, 64, 4>, -+ cutlass::gemm::GemmShape<32, 32, 4>, -+ ElementC, -+ LayoutC -+ >; -+ -+ FragmentIterator frag_iterator(accumulator_tile); -+ -+ for (int iter = 0; iter < FragmentIterator::kIterations; ++iter) { -+ -+ typename FragmentIterator::Fragment frag; -+ frag_iterator.load(frag); -+ ++frag_iterator; -+ -+ #if 0 -+ std::cout << "Iteration: " << iter << " - T" << tid << ": "; -+ -+ for (int i = 0; i < frag.size(); ++i) { -+ std::cout << " " << frag[i]; -+ } -+ -+ std::cout << std::endl; -+ #endif -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_wmma_tensor_op.cu b/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_wmma_tensor_op.cu -new file mode 100644 -index 0000000..fe3f47a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_wmma_tensor_op.cu -@@ -0,0 +1,186 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_wmma.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_warp_FragmentIterator, wmma_f16_64x64x16) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp< -+ Shape, -+ typename MmaTensorOp::Policy::Operator::Shape, -+ typename MmaTensorOp::Policy::Operator::ElementC, -+ typename MmaTensorOp::Policy::Operator::FragmentC, -+ cutlass::layout::RowMajor -+ >; -+ -+ #if 0 -+ // -+ // Enable this code block to print comments for debugging. -+ // -+ -+ std::cout << "FragmentIterator::Policy = { \n" -+ << " OperatorCount: (" << FragmentIterator::Policy::OperatorCount::kRow <<", "<; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp< -+ Shape, -+ typename MmaTensorOp::Policy::Operator::Shape, -+ typename MmaTensorOp::Policy::Operator::ElementC, -+ typename MmaTensorOp::Policy::Operator::FragmentC, -+ cutlass::layout::RowMajor -+ >; -+ -+ #if 0 -+ // -+ // Enable this code block to print comments for debugging. -+ // -+ std::cout << "FragmentIterator::Policy = { \n" -+ << " OperatorCount: (" << FragmentIterator::Policy::OperatorCount::kRow <<", "< -+struct DefaultGemmConfigurationToCutlass3Types { -+ static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists."); -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct DefaultGemm_TensorOpSm80_OperandA; -+ -+template -+struct DefaultGemm_TensorOpSm80_OperandB; -+ -+// -+// F16: 128-by-128-by-64 -+// -+ -+/// Operand A - Row-major (K-Major) -+template <> -+struct DefaultGemm_TensorOpSm80_OperandA -+{ -+ // Smem -+ using SmemLayoutAtom = decltype( -+ composition(Swizzle<3,3,3>{}, -+ Layout, -+ Stride<_64, _1>>{})); -+ using SmemCopyAtom = Copy_Atom; -+ -+ // Gmem -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, half_t>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+}; -+ -+/// Operand A - Column-major (M-major) -+template -+struct DefaultGemm_TensorOpSm80_OperandA -+{ -+ // Smem -+ using SmemLayoutAtom = decltype( -+ composition(Swizzle<3,3,3>{}, -+ Layout, -+ Stride< _1,_64>>{})); -+ using SmemCopyAtom = Copy_Atom; -+ -+ // Gmem -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, half_t>{}, -+ Layout, -+ Stride< _1,_16>>{}, -+ Layout>{})); -+}; -+ -+// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands -+ -+// Operand B - Column-Major (K-major) -+template -+struct DefaultGemm_TensorOpSm80_OperandB -+ : DefaultGemm_TensorOpSm80_OperandA -+{}; -+ -+// Operand B - Row-Major (N-major) -+template -+struct DefaultGemm_TensorOpSm80_OperandB -+ : DefaultGemm_TensorOpSm80_OperandA -+{}; -+ -+// -+// F16: 128-by-128-by-32 (small k-block) -+// -+ -+/// Operand A - Row-major (K-Major) -+template <> -+struct DefaultGemm_TensorOpSm80_OperandA -+{ -+ // Smem -+ using SmemLayoutAtom = decltype( -+ composition(Swizzle<2,3,3>{}, -+ Layout, -+ Stride<_32, _1>>{})); -+ using SmemCopyAtom = Copy_Atom; -+ -+ // Gmem -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, half_t>{}, -+ Layout, -+ Stride< _4,_1>>{}, -+ Layout>{})); -+}; -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Ampere MMA F32F16 -+template -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ half_t, LayoutA, -+ half_t, LayoutB, -+ float, LayoutC, -+ float> -+{ -+ using TileShape = Shape<_128, _128, _32>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>, // 2x2x1 thread group -+ Layout>>; // 1x2x1 value group for 16x16x16 MMA and LDSM -+ -+ // A -+ static constexpr int kAlignmentA = 8; -+ using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< -+ half_t, LayoutA, kAlignmentA, 32>; -+ using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K -+ using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; -+ using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; -+ -+ // B -+ static constexpr int kAlignmentB = 8; -+ using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< -+ half_t, LayoutB, kAlignmentB, 32>; -+ using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K -+ using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; -+ using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ half_t, TagToStrideA_t, -+ half_t, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+// -+// TF32: 128-by-128-by-kblock (kBlock = 16, 32) -+// -+ -+/// Operand A - Row-major (K-major) (kBlock = 32) -+template <> -+struct DefaultGemm_TensorOpSm80_OperandA -+{ -+ // Smem -+ using SmemLayoutAtom = decltype( -+ composition(Swizzle<3,2,3>{}, -+ Layout, -+ Stride<_32, _1>>{})); -+ using SmemCopyAtom = Copy_Atom; -+ -+ // Gmem -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, tfloat32_t>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+}; -+ -+/// Operand A - Row-major (K-major) (kBlock = 16) -+template <> -+struct DefaultGemm_TensorOpSm80_OperandA -+{ -+ // Smem -+ using SmemLayoutAtom = decltype( -+ composition(Swizzle<2,2,3>{}, -+ Layout, -+ Stride<_16, _1>>{})); -+ using SmemCopyAtom = Copy_Atom; -+ // Gmem -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, tfloat32_t>{}, -+ Layout, -+ Stride< _4,_1>>{}, -+ Layout>{})); -+}; -+ -+/// Operand A - Column-major (M-major) -+template -+struct DefaultGemm_TensorOpSm80_OperandA -+{ -+ // Smem -+ using SmemLayoutAtom = decltype( -+ composition(Swizzle<3,2,3>{}, -+ Layout, -+ Stride< _1,_32>>{})); -+ using SmemCopyAtom = Copy_Atom, tfloat32_t>; -+ // Gmem -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, tfloat32_t>{}, -+ Layout, -+ Stride< _1,_16>>{}, -+ Layout>{})); -+}; -+ -+// Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands -+ -+// Operand B - Column-Major (K-major) -+template -+struct DefaultGemm_TensorOpSm80_OperandB -+ : DefaultGemm_TensorOpSm80_OperandA -+{}; -+ -+// Operand B - Row-Major (N-major) -+template -+struct DefaultGemm_TensorOpSm80_OperandB -+ : DefaultGemm_TensorOpSm80_OperandA -+{}; -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Ampere MMA F32TF32 -+template -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ tfloat32_t, LayoutA, -+ tfloat32_t, LayoutB, -+ float, LayoutC, -+ float> -+{ -+ using TileShape = Shape<_128, _128, _32>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout, Stride<_2, _1, _1>>, // 2x2x1 thread group -+ Layout>>; // 1x2x1 value group for 16x16x8 and LDSM -+ -+ // A -+ static constexpr int kAlignmentA = 4; -+ using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< -+ tfloat32_t, LayoutA, kAlignmentA, 32>; -+ using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K -+ using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; -+ using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; -+ -+ // B -+ static constexpr int kAlignmentB = 4; -+ using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< -+ tfloat32_t, LayoutB, kAlignmentB, 32>; -+ using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K -+ using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; -+ using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ tfloat32_t, TagToStrideA_t, -+ tfloat32_t, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+template -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ int8_t, cutlass::layout::RowMajor, -+ int8_t, cutlass::layout::ColumnMajor, -+ int32_t, LayoutC, -+ int32_t> -+{ -+ using TileShape = Shape<_128, _128, _64>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>, // 2x2x1 thread group -+ Layout>>; // 1x2x1 value group for 16x16x32 and LDSM -+ -+ // A (M,K) K-major -+ using SmemLayoutAtomA = decltype( -+ composition( -+ Swizzle<2,4,3>{}, -+ Layout, -+ Stride<_64, _1>>{})); -+ static constexpr int kAlignmentA = 16; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, int8_t>{}, -+ Layout, -+ Stride< _4,_1>>{}, -+ Layout>>{})); -+ // LDS.32- or LDSM-based copy atom -+ // using SmemCopyAtomA = Copy_Atom; -+ using SmemCopyAtomA = Copy_Atom; // LDSM works -+ -+ // B (N,K) K-major -+ using SmemLayoutAtomB = decltype( -+ composition( -+ Swizzle<2,4,3>{}, -+ Layout, -+ Stride<_64, _1>>{})); -+ static constexpr int kAlignmentB = 16; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, int8_t>{}, -+ Layout, -+ Stride< _4,_1>>{}, -+ Layout>>{})); -+ -+ // LDS.32- or LDSM-based copy atom -+ // using SmemCopyAtomB = Copy_Atom; -+ using SmemCopyAtomB = Copy_Atom; // LDSM works -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ int8_t, TagToStrideA_t, -+ int8_t, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////////// SIMT TWO STAGE /////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct DefaultGemm_Simt_OperandA; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct DefaultGemm_Simt_OperandA -+{ -+ using SmemLayoutAtom = Layout, -+ Stride< _1,_128>>; -+ -+ using SmemCopyAtom = Copy_Atom; -+ -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, Element>{}, -+ Layout, -+ Stride< _1,_32>>{}, -+ Layout>{})); -+}; -+ -+template -+struct DefaultGemm_Simt_OperandA -+{ -+ using SmemLayoutAtom = Layout, -+ Stride< _1,Int<128 + 4>>>; // Padded -+ -+ using SmemCopyAtom = Copy_Atom; -+ -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, Element>{}, -+ Layout, -+ Stride< _8, _1>>{}, -+ Layout>{})); -+ -+}; -+ -+template -+struct DefaultGemm_Simt_OperandB; -+ -+template -+struct DefaultGemm_Simt_OperandB -+ : DefaultGemm_Simt_OperandA {}; -+ -+template -+struct DefaultGemm_Simt_OperandB -+ : DefaultGemm_Simt_OperandA {}; -+ -+} // end namespace detail -+ -+// SIMT Two Stage -+template < -+ class ArchTag, -+ class ElementA, class LayoutA, -+ class ElementB, class LayoutB, -+ class ElementC, class LayoutC, -+ class ElementAccumulator> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, ArchTag, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator> -+{ -+ using TileShape = Shape<_128, _128, _8>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm70TwoStage; -+ using TiledMma = TiledMMA< -+ MMA_Atom>, -+ Layout>>; -+ -+ // A -+ static constexpr int kAlignmentA = 1; -+ using DefaultOperandA = detail::DefaultGemm_Simt_OperandA; -+ using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; -+ using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; -+ using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; -+ -+ // B -+ static constexpr int kAlignmentB = 1; -+ using DefaultOperandB = detail::DefaultGemm_Simt_OperandB; -+ using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; -+ using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; -+ using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+ -+// -+// DP4A - int8 Proof-of-concept -+// -+ -+// SIMT Two Stage TN - idp4a -+template < -+ class ArchTag, -+ class ElementC, class LayoutC> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, ArchTag, -+ int8_t, cutlass::layout::RowMajor, -+ int8_t, cutlass::layout::ColumnMajor, -+ ElementC, LayoutC, -+ int32_t> -+{ -+ using TileShape = Shape<_128, _128, _32>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm70TwoStage; -+ // NOTE: permuting MMA M mode lets us generate 128b smem loads (LDS.128) but has worst case bank conflicts -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>>; // Tile of atoms (threads) -+ -+ // A (M,K) K-major -+ using ElementA = int8_t; -+ // 40% from regular M and N major layout -+ // using SmemLayoutAtomA = Layout, -+ // Stride< _1,_128>>; -+ // 80% from interleaved layouts -+ using SmemLayoutAtomA = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 4; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+ -+ // B (N,K) K-major -+ using ElementB = int8_t; -+ // 40% from regular M and N major layout -+ // using SmemLayoutAtomB = Layout, -+ // Stride< _1,_128>>; -+ // 80% from interleaved layouts -+ using SmemLayoutAtomB = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 4; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Two Stage NN - idp4a -+template < -+ class ArchTag, -+ class ElementC, class LayoutC> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, ArchTag, -+ int8_t, cutlass::layout::ColumnMajor, -+ int8_t, cutlass::layout::ColumnMajor, -+ ElementC, LayoutC, -+ int32_t> -+{ -+ using TileShape = Shape<_128, _128, _32>; -+ static constexpr int ThreadCount = 256; -+ -+ using DispatchPolicy = MainloopSm70TwoStage; -+ -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>>; -+ -+ // A (M,K) M-major -+ using ElementA = int8_t; -+ using SmemLayoutAtomA = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 1; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout, -+ Stride< _1,_32>>{}, -+ Layout>{})); -+ -+ // B (N,K) K-major -+ using ElementB = int8_t; -+ using SmemLayoutAtomB = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 4; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilouge -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Two Stage NT - idp4a -+template < -+ class ArchTag, -+ class ElementC, class LayoutC> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, ArchTag, -+ int8_t, cutlass::layout::ColumnMajor, -+ int8_t, cutlass::layout::RowMajor, -+ ElementC, LayoutC, -+ int32_t> -+{ -+ using TileShape = Shape<_128, _128, _32>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm70TwoStage; -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>>; -+ -+ // A (M,K) M-major -+ using ElementA = int8_t; -+ using SmemLayoutAtomA = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 1; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout, -+ Stride< _1,_32>>{}, -+ Layout>{})); -+ -+ // B (N,K) N-major -+ using ElementB = int8_t; -+ using SmemLayoutAtomB = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 1; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout, -+ Stride< _1,_32>>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Two Stage TT - idp4a -+template < -+ class ArchTag, -+ class ElementC, class LayoutC> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, ArchTag, -+ int8_t, cutlass::layout::RowMajor, -+ int8_t, cutlass::layout::RowMajor, -+ ElementC, LayoutC, -+ int32_t> -+{ -+ using TileShape = Shape<_128, _128, _32>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm70TwoStage; -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>>; -+ -+ // A (M,K) K-major -+ using ElementA = int8_t; -+ using SmemLayoutAtomA = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 4; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+ -+ // B (N,K) N-major -+ using ElementB = int8_t; -+ using SmemLayoutAtomB = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 1; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout, -+ Stride< _1,_32>>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////// SIMT MULTI STAGE ////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Multi Stage NT -+template < -+ class ElementA, -+ class ElementB, -+ class ElementC, class LayoutC, -+ class ElementAccumulator> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, arch::Sm80, -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, LayoutC, -+ ElementAccumulator> -+{ -+ using TileShape = Shape<_128, _128, _16>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom>, -+ Layout>, -+ Layout>, -+ Tile,Layout<_2,_16>,Underscore>>; -+ -+ // A (M,K) M-major -+ using SmemLayoutAtomA = Layout>; -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 2; -+ using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout>{}, -+ Layout>{})); -+ -+ // B (N,K) N-major -+ using SmemLayoutAtomB = Layout>; -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 2; -+ using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Multi Stage TN -+template < -+ class ElementA, -+ class ElementB, -+ class ElementC, class LayoutC, -+ class ElementAccumulator> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, arch::Sm80, -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, LayoutC, -+ ElementAccumulator> -+{ -+ using TileShape = Shape<_128, _128, _16>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom>, -+ Layout>>; -+ -+ // A (M,K) K-major -+ using SmemLayoutAtomA = Layout, -+ Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 1; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout, -+ Stride<_16, _1>>{})); -+ -+ // B (N,K) K-major -+ using SmemLayoutAtomB = Layout, -+ Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 1; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout, -+ Stride<_16, _1>>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Multi Stage NN -+template < -+ class ElementA, -+ class ElementB, -+ class ElementC, class LayoutC, -+ class ElementAccumulator> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, arch::Sm80, -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, LayoutC, -+ ElementAccumulator> -+{ -+ using TileShape = Shape<_128, _128, _16>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom>, -+ Layout>, -+ Layout>, -+ Tile,Underscore,Underscore>>; -+ -+ // A (M,K) M-major -+ using SmemLayoutAtomA = Layout>; -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 2; -+ using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout>{}, -+ Layout>{})); -+ -+ // B (N,K) K-major -+ using SmemLayoutAtomB = Layout, -+ Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 1; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout, -+ Stride<_16, _1>>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Multi Stage TT -+template < -+ class ElementA, -+ class ElementB, -+ class ElementC, class LayoutC, -+ class ElementAccumulator> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, arch::Sm80, -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, LayoutC, -+ ElementAccumulator> -+{ -+ using TileShape = Shape<_128, _128, _16>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom>, -+ Layout>, -+ Layout>, -+ Tile,Underscore>>; -+ -+ // A (M,K) K-major -+ using SmemLayoutAtomA = Layout, -+ Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 1; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout, -+ Stride<_16, _1>>{})); -+ -+ // B (N,K) N-major -+ using SmemLayoutAtomB = Layout>; -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 2; -+ using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Ampere fp64 MMA TN (K-Major A and K-Major B) -+template <> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double> -+{ -+ using TileShape = Shape<_128, _64, _16>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, // Atom -+ Layout>, // Atom layout -+ Layout>, // Val layout -+ Tile,Layout<_2,_16>,Underscore>>; // Mode permutations -+ -+ // A (M,K) K-Major -+ using SmemLayoutAtomA = decltype( -+ composition(SwizzleXor<2,0,2>{}, -+ Layout, -+ Stride<_1, _4>>{})); // M, K -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 1; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride<_16, _1>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 1x1 doubles -+ -+ // B (N,K) K-Major -+ using SmemLayoutAtomB = decltype( -+ composition(SwizzleXor<2,0,2>{}, -+ Layout, -+ Stride<_1, _4>>{})); // N, K -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 1; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride<_16, _1>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 1x1 doubles -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ double, TagToStrideA_t, -+ double, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+ -+/* -+ using EpilogueOutputOp = epilogue::collective::Epilogue< -+ epilogue::thread::LinearCombination, -+ Layout, -+ Stride< _1,_64>>, // SMEM layout -+ Copy_Atom,double>, // R2S with tiled_mma layout -+ decltype(make_tiled_copy(Copy_Atom,double>{},// S2R -+ Layout, -+ Stride< _1,_16>>{}, // Thread layout -+ Layout>{})), // Value layout -+ Copy_Atom,double> // R2G with S2R_dst layout -+ >; -+*/ -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Ampere fp64 MMA NN (M-Major A and K-Major B) -+template <> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double> -+{ -+ using TileShape = Shape<_128, _64, _16>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, // Atom -+ Layout>, // Atom layout -+ Layout>, // Val layout -+ Tile,Layout<_2,_16>,Underscore>>; // Mode permutations -+ -+ // A (M,K) M-Major -+ using SmemLayoutAtomA = decltype( -+ composition(SwizzleXor<2,2,0>{}, -+ Layout, -+ Stride< _1,_16>>{})); // M, K -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 2; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride< _1,_16>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 2x1 doubles -+ -+ // B (N,K) K-Major -+ using SmemLayoutAtomB = decltype( -+ composition(SwizzleXor<2,0,2>{}, -+ Layout, -+ Stride<_1, _4>>{}));// N, K -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 1; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride<_16, _1>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 1x1 doubles -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ double, TagToStrideA_t, -+ double, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Ampere fp64 MMA NT (M-Major A and N-Major B) -+template <> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double> -+{ -+ using TileShape = Shape<_128, _64, _16>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, // Atom -+ Layout>, // Atom layout -+ Layout>, // Val layout -+ Tile,Layout<_2,_16>,Underscore>>; // Mode permutations -+ -+ // A (M,K) M-Major -+ using SmemLayoutAtomA = decltype( -+ composition(SwizzleXor<2,2,0>{}, -+ Layout, -+ Stride< _1,_16>>{})); // M, K -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 2; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride< _1,_16>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 2x1 doubles -+ -+ // B (N,K) N-Major -+ using SmemLayoutAtomB = decltype( -+ composition(SwizzleXor<2,2,0>{}, -+ Layout, -+ Stride< _1,_16>>{})); // N, K -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 2; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride< _1,_16>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 2x1 doubles -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ double, TagToStrideA_t, -+ double, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Ampere fp64 MMA TT (K-Major A and N-Major B) -+template <> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double> -+{ -+ using TileShape = Shape<_128, _64, _16>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, // Atom -+ Layout>, // Atom layout -+ Layout>, // Val layout -+ Tile,Layout<_2,_16>,Underscore>>; // Mode permutations -+ -+ // A (M,K) K-Major -+ using SmemLayoutAtomA = decltype( -+ composition(SwizzleXor<2,0,2>{}, -+ Layout, -+ Stride<_1, _4>>{})); // M, K -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 1; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride<_16, _1>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 1x1 doubles -+ -+ // B (N,K) N-Major -+ using SmemLayoutAtomB = decltype( -+ composition(SwizzleXor<2,2,0>{}, -+ Layout, -+ Stride< _1,_16>>{})); // N, K -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 2; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride< _1,_16>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 2x1 doubles -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ double, TagToStrideA_t, -+ double, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Hopper fp64 MMA TN -+template <> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm90, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double> -+{ -+ using TileShape = Shape<_128, _64, _16>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>>; -+ -+ // A (M,K) K-major -+ using SmemLayoutAtomA = decltype( -+ make_ordered_layout(Shape<_128,_16>{}, -+ Step < _2, _1>{})); // M, K -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 2; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+ -+ // B (N,K) K-major -+ using SmemLayoutAtomB = decltype( -+ make_ordered_layout(Shape<_64,_16>{}, -+ Step < _2, _1>{})); // N, K -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 2; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ double, TagToStrideA_t, -+ double, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..45c1d80 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu -@@ -0,0 +1,233 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 512>, -+ cutlass::gemm::GemmShape<32, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..67bcb85 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu -@@ -0,0 +1,379 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x1024_64x64x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, -+ cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x1024_64x64x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x1024_64x64x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, -+ cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x1024_64x64x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x1024_64x64x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x1024_32x64x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 1024>, -+ cutlass::gemm::GemmShape<32, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x1024_64x32x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 1024>, -+ cutlass::gemm::GemmShape<64, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x1024_32x32x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 1024>, -+ cutlass::gemm::GemmShape<32, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 512>, -+ cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..6c8ab54 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+////// WMMA Instruction Shape = 8x8x128, DataType/Instruction = b1 ^ b1 + s32 => s32 ///////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 512>, -+ cutlass::gemm::GemmShape<32, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..445fa88 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu -@@ -0,0 +1,232 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x256x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x64x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 512>, -+ cutlass::gemm::GemmShape<32, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..c819148 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu -@@ -0,0 +1,380 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x256x1024_64x64x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, -+ cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x128x1024_64x64x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x128x1024_64x64x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, -+ cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x64x1024_64x64x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x256x1024_64x64x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x128x1024_32x64x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 1024>, -+ cutlass::gemm::GemmShape<32, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x64x1024_64x32x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 1024>, -+ cutlass::gemm::GemmShape<64, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x64x1024_32x32x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 1024>, -+ cutlass::gemm::GemmShape<32, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x64x512_64x64x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x256x512_64x64x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 512>, -+ cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..755661f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+////// WMMA Instruction Shape = 8x8x128, DataType/Instruction = b1 ^ b1 + s32 => s32 ///////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 512>, -+ cutlass::gemm::GemmShape<32, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_bf16n_bf16n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_bf16n_bf16n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..a25d8aa ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_bf16n_bf16n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,359 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..3dfd4f1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu -@@ -0,0 +1,343 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu -new file mode 100644 -index 0000000..00c64b7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu -@@ -0,0 +1,259 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Operands data type: complex -+// Rounding: float -> tfloat32_t (half_ulp_truncate) -+// Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part) -+// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+// Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part) -+// Output data type: complex -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x64x16_16x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x64x16_32x32x16) { -+ -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) { -+ -+ using Element = cutlass::complex;; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) { -+ -+ using Element = cutlass::complex;; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 128x128x16_32x64x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu -new file mode 100644 -index 0000000..146e2ec ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Operands data type: complex -+// Rounding: float -> tfloat32_t (round to nearest) -+// Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part) -+// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+// Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part) -+// Output data type: complex -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x64x16_16x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x64x16_32x32x16) { -+ -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) { -+ -+ using Element = cutlass::complex;; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) { -+ -+ using Element = cutlass::complex;; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x128x16_32x64x16) { -+ -+ using Element = cutlass::complex;; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..9164326 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu -@@ -0,0 +1,198 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x16_16x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x8_16x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu -new file mode 100644 -index 0000000..cc94303 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu -@@ -0,0 +1,198 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface with Hopper FP64 -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x16_16x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x8_16x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..d93f3fb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,252 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_16x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_16x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..e2931b0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu -@@ -0,0 +1,252 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface with Hopper FP64 -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_16x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_16x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..60ec7a8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x8_32x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x16_32x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu -new file mode 100644 -index 0000000..eb011e4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface with Hopper FP64 -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x8_32x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x16_32x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..d9d171b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,305 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..c0333e7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu -@@ -0,0 +1,305 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface with Hopper FP64 -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..e98764e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu -@@ -0,0 +1,114 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/kernel/gemm_universal.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_universal.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/epilogue/threadblock/epilogue_direct_store.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_direct_store.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_DirectStore_f16n_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ // Define the GEMM kernel -+ using GemmBase = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 4, // This is the vector size of the epilogue. -+ ElementAccumulator, -+ ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 8, -+ 8 -+ >; -+ -+ // Define the direct store epilogue -+ using EpilogueDirectStore = typename cutlass::epilogue::threadblock::DefaultEpilogueDirectStore< -+ typename GemmBase::GemmKernel::Epilogue -+ >::Epilogue; -+ -+ // Define a new kernel -+ using Kernel = cutlass::gemm::kernel::GemmUniversal< -+ typename GemmBase::GemmKernel::Mma, -+ EpilogueDirectStore, -+ typename GemmBase::GemmKernel::ThreadblockSwizzle -+ >; -+ -+ // Define the adaptor -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..4fc49a0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,157 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..f4912ee ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,154 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..2808f9d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm75.cu -@@ -0,0 +1,307 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32_brief) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32_brief) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x32_32x64x32_brief) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..274d41f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..c7c894b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,272 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x128_64x64x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x128_64x64x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x128_64x32x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x128_32x32x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_volta_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_volta_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..9bdf56d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_volta_tensor_op_f32_sm70.cu -@@ -0,0 +1,274 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..411e95b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,404 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..c99f75a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,403 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..6bf8ae5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm75.cu -@@ -0,0 +1,307 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32_brief) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x32_64x64x32_brief) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x32_32x64x32_brief) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..5c398c3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,343 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..4ff878b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..a123d29 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm75.cu -@@ -0,0 +1,307 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32_brief) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32_brief) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32_brief) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..4811c9d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,346 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..4c7ce79 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,273 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x128_64x64x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x128_64x64x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x128_64x32x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x128_32x32x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_volta_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_volta_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..bba2b6d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_volta_tensor_op_f32_sm70.cu -@@ -0,0 +1,274 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..350d7e9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..eff38fd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,157 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..ddd426d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,155 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu -new file mode 100644 -index 0000000..7392cf9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu -@@ -0,0 +1,88 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k, 64x64x64_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM75_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu -new file mode 100644 -index 0000000..468e698 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu -@@ -0,0 +1,88 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k, 128x64x64_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu -new file mode 100644 -index 0000000..5fd2fb7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu -new file mode 100644 -index 0000000..90fd6d0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64> , -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x32_64x32x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x32_32x32x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu -new file mode 100644 -index 0000000..ebe3acb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu -@@ -0,0 +1,271 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64> , -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x128_64x64x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x128_64x64x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x128_64x32x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x128_32x32x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..a94f4ac ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu -@@ -0,0 +1,81 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..969f54b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu -@@ -0,0 +1,267 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x256x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 256x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x64x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..ca8eeeb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,405 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..8b15b0d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,87 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..fe7e1b0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..3b3d293 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..7387b99 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,384 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+#include "testbed_universal.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64_sk, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, cutlass::layout::ColumnMajor, -+ cutlass::half_t, cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32n_tensor_op_f32, 128x128x64_64x64x64_sk, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, cutlass::layout::ColumnMajor, -+ cutlass::half_t, cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..60a4760 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,272 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if (CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x128_64x64x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x128_64x64x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x128_64x32x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x128_32x32x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_volta_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_volta_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..02c4ddf ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_volta_tensor_op_f32_sm70.cu -@@ -0,0 +1,267 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ENABLE_TENSOR_CORE_MMA) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..ef37420 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..cf6d5dd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,321 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x64x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x64_64x32x64_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x32_64x32x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x32_64x32x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..9156d8e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,157 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..8bfed3f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,155 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 128x128x32_64x64x16_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x16_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x16_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..9e5973a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,321 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x64x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x64_64x32x64_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x32_64x32x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x32_64x32x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu -new file mode 100644 -index 0000000..b69a304 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu -@@ -0,0 +1,440 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for GEMM + broadcast interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+#include "cutlass/gemm/device/gemm_universal_with_broadcast.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_residual_block.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_elementwise.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+template -+struct TestbedUtils { -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; // Input A -+ cutlass::HostTensor tensor_B; // Input B -+ cutlass::HostTensor tensor_C; // Input C -+ cutlass::HostTensor tensor_D1; // Input D -+ cutlass::HostTensor tensor_D2; // Input D -+ cutlass::HostTensor tensor_Y1; // Input Y -+ cutlass::HostTensor tensor_Y2; // Input Y -+ cutlass::HostTensor tensor_Y_ref; -+ -+ // -+ // Methods -+ // -+ -+ TestbedUtils( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ tensor_A.resize(problem_size.mk()); -+ tensor_B.resize(problem_size.kn()); -+ tensor_C.resize({1, problem_size.n()}); -+ tensor_D1.resize(problem_size.mn()); -+ tensor_D2.resize(problem_size.mn()); -+ tensor_Y1.resize(problem_size.mn()); -+ tensor_Y2.resize(problem_size.mn()); -+ tensor_Y_ref.resize(problem_size.mn()); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ // Initialize D data to smaller data range. This helps avoid large roundoff errors. -+ int d_scope_min = -2; -+ int d_scope_max = 2; -+ cutlass::reference::host::TensorFillRandomUniform(tensor_D1.host_view(), seed + 2016, d_scope_max, d_scope_min, 0); -+ cutlass::reference::host::TensorFillRandomUniform(tensor_D2.host_view(), seed + 2015, d_scope_max, d_scope_min, 0); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_Y1.host_view(), cutlass::Distribution::AllZeros, 0)); -+ EXPECT_TRUE(initialize_tensor(tensor_Y2.host_view(), cutlass::Distribution::AllZeros, 0)); -+ EXPECT_TRUE(initialize_tensor(tensor_Y_ref.host_view(), cutlass::Distribution::AllZeros, 0)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = GemmElement(1); -+ tensor_B.host_view().at({0, 0}) = GemmElement(1); -+ tensor_C.host_view().at({0, 0}) = GemmElement(1); -+ tensor_D1.host_view().at({0, 0}) = GemmElement(1); -+ tensor_D2.host_view().at({0, 0}) = GemmElement(1); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D1.sync_device(); -+ tensor_D2.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, cutlass::HostTensor& tensor_Y_ref, cutlass::HostTensor& tensor_Y) { -+ -+ tensor_Y_ref.sync_host(); -+ tensor_Y.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Y_ref.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Y.host_view()), 0); -+ -+ bool passed = true; -+ float norm_diff = 0; -+ -+ norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Y_ref.host_view(), tensor_Y.host_view(), float()); -+ passed = (norm_diff <= 0.1f); -+ EXPECT_LT(norm_diff, 0.1f) << " tensor_Y is incorrect"; -+ -+ -+ if (!passed) { -+ std::ofstream file("errors_testbed_gemm_broadcast_new.txt"); -+ -+ -+ file -+ << "problem: " << problem_size << "\n\n"; -+ -+ file -+ << "capacity: \n" -+ << "A: " << tensor_A.capacity() -+ << "\nB: " << tensor_B.capacity() -+ << "\nC: " << tensor_C.capacity() -+ << "\nD1: " << tensor_D1.capacity() -+ << "\nD2: " << tensor_D2.capacity() -+ << "\nY: " << tensor_Y.capacity() -+ << "\n\n" -+ << "\nY_ref: " << tensor_Y_ref.capacity() -+ << "\n\n"; -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\n\nB =\n" << tensor_B.host_view() -+ << "\n\nC =\n" << tensor_C.host_view() -+ << "\n\nD1 =\n" << tensor_D1.host_view() -+ << "\n\nD2 =\n" << tensor_D2.host_view() -+ << "\n\nY =\n" << tensor_Y.host_view() -+ << "\n\nY_ref =\n" << tensor_Y_ref.host_view(); -+ } -+ -+ return passed; -+ } -+}; -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+TEST(SM80_Device_GemmWithBroadcast_f16t_f16n_f16t_tensor_op_f16, 128x128_32x3_64x64x32_16x8x16) { -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using OpClass = cutlass::arch::OpClassTensorOp; -+ using ArchTag = cutlass::arch::Sm80; -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; -+ const int kStages = 3; -+ -+ const int batch_count = 1; -+ const cutlass::half_t alpha(1); -+ const cutlass::half_t beta(1); -+ -+ const int M = 1024; -+ const int K = 10240; -+ const int N = 512; -+ cutlass::gemm::GemmCoord problem{M, N, K}; -+ -+ const int batch_stride_A = 0; -+ const int batch_stride_B = 0; -+ const int batch_stride_C1 = 0; -+ const int batch_stride_C2 = 0; -+ const int batch_stride_D = 0; -+ const int batch_stride_Vector = 0; -+ const int batch_stride_Tensor = 0; -+ -+ const int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0); -+ const int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0); -+ const int64_t ldc1 = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ const int64_t ldc2 = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ const int64_t ldd = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ const int64_t ldv = 0; -+ const int64_t ldt = 0; -+ -+ TestbedUtils utils; -+ utils.initialize(problem); -+ -+ // -+ // Create reference Gemm -+ // -+ using GemmRef = cutlass::gemm::device::GemmUniversal< -+ ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, -+ OpClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ ThreadblockSwizzle, kStages>; -+ -+ typename GemmRef::Arguments args_ref{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem, -+ batch_count, -+ {alpha, beta}, -+ utils.tensor_A.device_data(), -+ utils.tensor_B.device_data(), -+ utils.tensor_C.device_data(), -+ utils.tensor_Y_ref.device_data(), -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C1, -+ batch_stride_D, -+ lda, -+ ldb, -+ ldv, -+ ldd, -+ }; -+ -+ GemmRef gemm_op_ref; -+ size_t workspace_size_ref = GemmRef::get_workspace_size(args_ref); -+ cutlass::device_memory::allocation workspace_ref(workspace_size_ref); -+ cutlass::Status status = gemm_op_ref.initialize(args_ref, workspace_ref.get()); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); -+ -+ status = gemm_op_ref(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); -+ -+ // -+ // Create GemmWithBroadcast from single source -+ // -+ using GemmSingle = cutlass::gemm::device::GemmUniversalWithBroadcast< -+ ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, -+ OpClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationResidualBlock< -+ ElementOutput, ElementAccumulator, ElementAccumulator, -+ ElementAccumulator, 128 / cutlass::sizeof_bits::value, -+ cutlass::epilogue::thread::Identity, cutlass::multiplies, cutlass::epilogue::thread::Identity>, -+ ThreadblockSwizzle, kStages>; -+ -+ typename GemmSingle::Arguments args_single{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem, -+ batch_count, -+ {alpha, beta}, -+ utils.tensor_A.device_data(), -+ utils.tensor_B.device_data(), -+ utils.tensor_D1.device_data(), -+ utils.tensor_Y1.device_data(), -+ utils.tensor_C.device_data(), -+ /* ptr_Tensor = */ nullptr, -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C1, -+ batch_stride_D, -+ batch_stride_Vector, -+ batch_stride_Tensor, -+ lda, -+ ldb, -+ ldc1, -+ ldd, -+ ldv, -+ ldt -+ }; -+ -+ GemmSingle gemm_op_single; -+ size_t workspace_size_single = GemmSingle::get_workspace_size(args_single); -+ cutlass::device_memory::allocation workspace_single(workspace_size_single); -+ status = gemm_op_single.initialize(args_single, workspace_single.get()); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); -+ -+ status = gemm_op_single(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); -+ -+ // Compute the broadcast on the reference previously computed and compare results -+ utils.tensor_Y_ref.sync_host(); -+ cutlass::reference::host::TensorMul(utils.tensor_Y_ref.host_view(), utils.tensor_D1.host_view()); -+ utils.tensor_Y_ref.sync_device(); -+ utils.compare_reference(problem, utils.tensor_Y_ref, utils.tensor_Y1); -+ -+ // -+ // Create GemmWithBroadcast from two sources -+ // -+ using GemmDouble = cutlass::gemm::device::GemmUniversalWithBroadcast< -+ ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, -+ OpClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationResidualBlock< -+ ElementOutput, ElementAccumulator, ElementAccumulator, -+ ElementAccumulator, 128 / cutlass::sizeof_bits::value, -+ cutlass::epilogue::thread::Identity, cutlass::multiplies, cutlass::epilogue::thread::Identity, cutlass::plus>, -+ ThreadblockSwizzle, kStages>; -+ -+ typename GemmDouble::Arguments args_double{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem, -+ batch_count, -+ {alpha, beta}, -+ utils.tensor_A.device_data(), -+ utils.tensor_B.device_data(), -+ utils.tensor_D1.device_data(), -+ utils.tensor_D2.device_data(), -+ utils.tensor_Y2.device_data(), -+ utils.tensor_C.device_data(), -+ /* ptr_Tensor = */ nullptr, -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C1, -+ batch_stride_C2, -+ batch_stride_D, -+ batch_stride_Vector, -+ batch_stride_Tensor, -+ lda, -+ ldb, -+ ldc1, -+ ldc2, -+ ldd, -+ ldv, -+ ldt -+ }; -+ -+ GemmDouble gemm_op_double; -+ size_t workspace_size_double = GemmDouble::get_workspace_size(args_double); -+ cutlass::device_memory::allocation workspace_double(workspace_size_double); -+ status = gemm_op_double.initialize(args_double, workspace_double.get()); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); -+ -+ status = gemm_op_double(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); -+ -+ // Compute the broadcast on the reference previously computed and compare results -+ utils.tensor_Y_ref.sync_host(); -+ cutlass::reference::host::TensorAdd(utils.tensor_Y_ref.host_view(), utils.tensor_D2.host_view()); -+ utils.tensor_Y_ref.sync_device(); -+ utils.compare_reference(problem, utils.tensor_Y_ref, utils.tensor_Y2); -+} -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu -new file mode 100644 -index 0000000..f595fd6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu -@@ -0,0 +1,88 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k, 64x64x64_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM75_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu -new file mode 100644 -index 0000000..0881964 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu -@@ -0,0 +1,89 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k, 128x64x64_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu -new file mode 100644 -index 0000000..c3b3094 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu -@@ -0,0 +1,242 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu -new file mode 100644 -index 0000000..0343f0b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu -@@ -0,0 +1,345 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x32_64x32x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x32_32x32x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu -new file mode 100644 -index 0000000..2c9402f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu -@@ -0,0 +1,273 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x128_64x64x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x128_64x64x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x128_64x32x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x128_32x32x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_volta_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_volta_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..f986113 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_volta_tensor_op_f16_sm70.cu -@@ -0,0 +1,274 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x256x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 256x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x64x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ENABLE_TENSOR_CORE_MMA) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..be966e2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,405 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..afd4fbf ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,402 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..5f6b77d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,158 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..fab5576 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,226 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x128x32_64x32x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x128x32_64x32x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..5aa81be ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..708b4df ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..e1d9381 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,271 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x128_64x64x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x128_64x64x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x128_64x32x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x128_32x32x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_volta_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_volta_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..32f4923 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_volta_tensor_op_f32_sm70.cu -@@ -0,0 +1,274 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..70298cf ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..129d4a0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,157 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..6fcdcee ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,155 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..71fea92 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,405 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..fc980ac ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,403 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..9612d76 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..f89b076 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..6a32a33 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,156 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..d21ee87 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..5a3d6d0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..3bafd4d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_volta_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_volta_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..b2a2397 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_volta_tensor_op_f32_sm70.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ENABLE_TENSOR_CORE_MMA) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..4572356 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_bf16_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_bf16_f32_sm80.cu -new file mode 100644 -index 0000000..c83f7a7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_bf16_f32_sm80.cu -@@ -0,0 +1,93 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface using BF16. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 4, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastBF16 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..bd71b19 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,88 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..895f175 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,429 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..c37d48c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,429 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..43bb129 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,428 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..4500835 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,429 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..d733d8d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64an_f64at_f64at_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2RowMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ LayoutA, -+ double, -+ LayoutB, -+ ElementOutput, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..62cb15d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface with Hopper FP64 -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 32x32x16_16x16x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x64x16_32x32x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x64x16_64x32x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x128x16_32x64x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x128x16_32x64x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..8961ab7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,259 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64at_f64an_f64at_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using LayoutA = cutlass::layout::AffineRank2RowMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ LayoutA, -+ double, -+ LayoutB, -+ ElementOutput, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..881d81c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface with Hopper FP64 -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 32x32x16_16x16x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x64x16_32x32x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x128x16_32x64x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x64x16_64x32x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x128x16_32x64x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_grouped_scheduler_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_grouped_scheduler_sm80.cu -new file mode 100644 -index 0000000..4fed1dc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_grouped_scheduler_sm80.cu -@@ -0,0 +1,222 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped GEMM problem visitors -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+ -+#include "testbed_grouped_scheduler.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Run a series of tests on the testbed -+template -+void run_tests() { -+ for (int scale_factor : {8, 16, 32, 64}) { -+ for (int threadblock_count : {54, 108, 216, 324, 432}) { -+ for (int problems : {1, 27, 180, 300}) { -+ Testbed testbed; -+ testbed.run(problems, threadblock_count, scale_factor); -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p128_t128, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 128; -+ static int const kThreadCount = 128; -+ static bool const kTranspose = false; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p128_t128_transpose, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 128; -+ static int const kThreadCount = 128; -+ static bool const kTranspose = true; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p256_t256, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static bool const kTranspose = false; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p256_t128, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 128; -+ static bool const kTranspose = false; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p256_t256, 64x32x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static bool const kTranspose = false; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p256_t256_transpose, 64x32x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static bool const kTranspose = true; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p256_t256, 32x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static bool const kTranspose = false; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p256_t256_transpose, 32x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static bool const kTranspose = true; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_grouped_sm80.cu -new file mode 100644 -index 0000000..3fa3519 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_grouped_sm80.cu -@@ -0,0 +1,859 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Visitor class to abstract away the algorithm for iterating over tiles. -+// -+// This is the prototype. We will delete this when the efficient kernel is -+// available. -+struct GemmGroupedProblemVisitor { -+ -+ struct Params { -+ cutlass::gemm::GemmCoord const *problem_sizes; -+ int32_t problem_count; -+ int64_t const *tile_count; -+ }; -+ -+ struct SharedStorage { -+ // -+ // Nothing for now. As an optimization step, we could consider parallel -+ // argmin or prefix sums across the block. -+ // -+ }; -+ -+ // -+ // Data members -+ // -+ -+ SharedStorage &shared_storage; -+ Params const ¶ms; -+ cutlass::MatrixCoord threadblock_shape; -+ -+ int64_t tile_idx; -+ int64_t tile_count_sum; -+ int64_t problem_tile_start; -+ int32_t problem_idx; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ GemmGroupedProblemVisitor( -+ SharedStorage &shared_storage_, -+ Params const ¶ms_, -+ cutlass::MatrixCoord threadblock_shape_, -+ int32_t block_idx -+ ): -+ shared_storage(shared_storage_), -+ params(params_), -+ threadblock_shape(threadblock_shape_), -+ tile_idx(block_idx), -+ tile_count_sum(0), -+ problem_idx(0) -+ { -+ -+ cutlass::gemm::GemmCoord problem = params.problem_sizes[problem_idx]; -+ -+ cutlass::gemm::GemmCoord grid = grid_shape(problem); -+ -+ problem_tile_start = 0; -+ tile_count_sum = grid.m() * grid.n(); -+ } -+ -+ /// Get the grid shape -+ CUTLASS_HOST_DEVICE -+ static cutlass::gemm::GemmCoord grid_shape( -+ cutlass::gemm::GemmCoord const &problem, -+ cutlass::MatrixCoord const & block_shape) { -+ -+ return cutlass::gemm::GemmCoord( -+ ((problem.m() - 1 + block_shape.row()) / block_shape.row()), -+ ((problem.n() - 1 + block_shape.column()) / block_shape.column()), -+ 1); -+ } -+ -+ /// Get the grid shape -+ CUTLASS_DEVICE -+ cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const &problem) const { -+ return grid_shape(problem, threadblock_shape); -+ } -+ -+ /// Returns true if there is a tile to compute -+ CUTLASS_DEVICE -+ bool next_tile() { -+ -+ if (tile_idx < tile_count_sum) { -+ return true; -+ } -+ -+ do { -+ ++problem_idx; -+ -+ if (problem_idx >= params.problem_count) { -+ return false; -+ } -+ -+ cutlass::gemm::GemmCoord problem = params.problem_sizes[problem_idx]; -+ cutlass::gemm::GemmCoord grid = grid_shape(problem); -+ -+ int64_t tile_count = grid.m() * grid.n(); -+ -+ problem_tile_start = tile_count_sum; -+ tile_count_sum += tile_count; -+ -+ } while (tile_count_sum <= tile_idx); -+ -+ return true; -+ } -+ -+ /// Gets the global tile index -+ CUTLASS_HOST_DEVICE -+ int64_t tile_index() const { -+ return tile_idx; -+ } -+ -+ /// Gets the index of the problem -+ CUTLASS_HOST_DEVICE -+ int32_t problem_index() const { -+ return problem_idx; -+ } -+ -+ /// Returns the problem size for the current problem -+ CUTLASS_HOST_DEVICE -+ cutlass::gemm::GemmCoord problem_size() const { -+ return params.problem_sizes[problem_idx]; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int64_t threadblock_idx() const { -+ return tile_idx - problem_tile_start; -+ } -+ -+ CUTLASS_DEVICE -+ void advance(int32_t grid_size) { -+ tile_idx += grid_size; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void GroupedBatchedKernel(GemmGroupedProblemVisitor::Params params) { -+ -+ __shared__ GemmGroupedProblemVisitor::SharedStorage shared_storage; -+ -+ GemmGroupedProblemVisitor problem_visitor( -+ shared_storage, -+ params, -+ {ThreadblockShapeM, ThreadblockShapeN}, -+ blockIdx.x); -+ -+ while (problem_visitor.next_tile()) { -+ -+ cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size(); -+ int64_t threadblock_idx = problem_visitor.threadblock_idx(); -+ -+ cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); -+ -+ int threadblock_tile_m_idx = int(threadblock_idx / grid_shape.n()); -+ int threadblock_tile_n_idx = int(threadblock_idx % grid_shape.n()); -+ -+ // -+ // Do the MMA -+ // -+ -+ if (threadIdx.x == 0) { -+ #if 0 -+ printf("Block %d - tile: %lld, problem %d, threadblock_idx: %lld, threadblock(m: %d, n: %d)\n", -+ blockIdx.x, -+ problem_visitor.tile_index(), -+ problem_visitor.problem_index(), -+ threadblock_idx, -+ threadblock_tile_m_idx, -+ threadblock_tile_n_idx); -+ #endif -+ } -+ -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_scheduler, 64x64x32_32x32x32) { -+ -+ int32_t problem_count = 16; -+ -+ int const kThreadblockShapeM = 64; -+ int const kThreadblockShapeN = 64; -+ -+ std::vector problem_sizes(problem_count); -+ std::vector tile_counts(problem_count); -+ -+ // construct a few problems of random sizes -+ srand(1921); -+ for (int32_t i = 0; i < problem_count; ++i) { -+ problem_sizes.at(i) = cutlass::gemm::GemmCoord( -+ 8 * (rand() % 48) + 64, -+ 8 * (rand() % 48) + 64, -+ 8 * (rand() % 48) + 64); -+ } -+ -+ // compute prefix sum -+ int64_t tile_count = 0; -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ -+ cutlass::gemm::GemmCoord grid_shape = GemmGroupedProblemVisitor::grid_shape( -+ problem_sizes.at(i), {kThreadblockShapeM, kThreadblockShapeN}); -+ -+ int32_t problem_tile_count = (grid_shape.m() * grid_shape.n()); -+ -+ int64_t tile_start = tile_count; -+ -+ tile_count += problem_tile_count; -+ tile_counts.at(i) = tile_count; -+ -+ if (false) { -+ std::cout << "Problem " << i << " size(" -+ << problem_sizes.at(i).m() << "-by-" << problem_sizes.at(i).n() -+ << ") - tiles: " << problem_tile_count << ", grid(" << grid_shape.m() << ", " << grid_shape.n() -+ << "), tiles[" << tile_start << ", " << tile_count << ")" << std::endl; -+ } -+ } -+ -+ // Copy to device memory -+ cutlass::DeviceAllocation problem_sizes_device(problem_count); -+ cutlass::DeviceAllocation tile_counts_device(problem_count); -+ -+ problem_sizes_device.copy_from_host(problem_sizes.data()); -+ tile_counts_device.copy_from_host(tile_counts.data()); -+ -+ GemmGroupedProblemVisitor::Params params; -+ params.problem_sizes = problem_sizes_device.get(); -+ params.problem_count = problem_count; -+ params.tile_count = tile_counts_device.get(); -+ -+ // Launch the kernel -+ dim3 grid(108, 1, 1); -+ dim3 block(128, 1, 1); -+ -+ GroupedBatchedKernel<<< grid, block >>>(params); -+ -+ // wait -+ cudaDeviceSynchronize(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ ElementOutput, cutlass::layout::RowMajor, // row major -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f16t_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 4>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 4>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f64t_f64t_f64n_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementInput = double; -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 4>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f32t_f32t_f32n_simt_f32, 128x128x8_64x32x1) { -+ -+ using ElementInput = float; -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f32t_f32t_f32t_simt_f32, 128x128x8_64x32x1) { -+ -+ using ElementInput = float; -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f32t_f32t_f32n_simt_f32, 128x64x8_64x32x1) { -+ -+ using ElementInput = float; -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f32t_f32t_f32t_simt_f32, 128x64x8_64x32x1) { -+ -+ using ElementInput = float; -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_cf32n_cf32n_cf32n_tensorop_f32, 64x64x16_32x32x16) { -+ -+ using ElementInput = cutlass::complex; -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::arch::OpMultiplyAddComplex>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_cf32c_cf32t_cf32n_tensorop_f32, 64x64x16_32x32x16) { -+ -+ using ElementInput = cutlass::complex; -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 1, -+ ElementInput, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 1, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::arch::OpMultiplyAddComplex>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_cf32c_cf32t_cf32t_tensorop_f32, 64x64x16_32x32x16) { -+ -+ using ElementInput = cutlass::complex; -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 1, -+ ElementInput, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 1, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::arch::OpMultiplyAddComplex>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_cf32t_cf32h_cf32n_tensorop_f32, 64x64x16_16x16x16) { -+ -+ using ElementInput = cutlass::complex; -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 1, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::arch::OpMultiplyAddComplex>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu -new file mode 100644 -index 0000000..83e7cfa ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu -@@ -0,0 +1,353 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-level GEMM API for Planar Complex. -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "testbed_planar_complex.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s884_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s884_tn : gemm_planar_complex_s884_tn_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16t_f16n_f32n_tensor_op_f32_884, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s884_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s884_nt : gemm_planar_complex_s884_nt_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16n_f16t_f32n_tensor_op_f32_884, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s884_nn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s884_nn : gemm_planar_complex_s884_nn_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16n_f16n_f32n_tensor_op_f32_884, 128x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s884_f16_nn_128x64_32x2_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ cutlass::half_t, -+ 8, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s884_f16_nn_128x64_32x2 : gemm_planar_complex_f16_s884_f16_nn_128x64_32x2_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16n_f16n_f16n_tensor_op_f32_884, 128x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s884_f16_nn_64x128_32x2_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ cutlass::half_t, -+ 8, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s884_f16_nn_64x128_32x2 : gemm_planar_complex_f16_s884_f16_nn_64x128_32x2_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16n_f16n_f16n_tensor_op_f32_884, 64x128x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s884_f16_tt_128x64_32x2_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ cutlass::half_t, -+ 8, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s884_f16_tt_128x64_32x2 : gemm_planar_complex_f16_s884_f16_tt_128x64_32x2_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16t_f16t_f16n_tensor_op_f32_884, 128x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s884_f16_tt_64x128_32x2_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ cutlass::half_t, -+ 8, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s884_f16_tt_64x128_32x2 : gemm_planar_complex_f16_s884_f16_tt_64x128_32x2_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16t_f16t_f16n_tensor_op_f32_884, 64x128x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu -new file mode 100644 -index 0000000..1f702ab ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-level GEMM API for Planar Complex. -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" -+#include "cutlass/gemm/device/gemm_universal_base.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "testbed_planar_complex.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s1688_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s1688_tn : gemm_planar_complex_s1688_tn_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmPlanarComplex_f16t_f16n_f32n_tensor_op_f32_1688, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s1688_hc_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s1688_hc : gemm_planar_complex_s1688_hc_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmPlanarComplex_f16h_f16c_f32n_tensor_op_f32_1688, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s1688_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s1688_nt : gemm_planar_complex_s1688_nt_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmPlanarComplex_f16n_f16t_f32n_tensor_op_f32_1688, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s1688_ch_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s1688_ch : gemm_planar_complex_s1688_ch_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmPlanarComplex_f16c_f16h_f32n_tensor_op_f32_1688, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu -new file mode 100644 -index 0000000..beed868 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu -@@ -0,0 +1,393 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-level GEMM API for Planar Complex. -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "testbed_planar_complex.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s16816_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s16816_tn : gemm_planar_complex_s16816_tn_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16t_f16n_f32n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s16816_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s16816_tn : gemm_planar_complex_f16_s16816_tn_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16t_f16n_f16n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s16816_hc_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s16816_hc : gemm_planar_complex_s16816_hc_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16h_f16c_f32n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s16816_hc_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s16816_hc : gemm_planar_complex_f16_s16816_hc_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16h_f16c_f16n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s16816_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s16816_nt : gemm_planar_complex_s16816_nt_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16n_f16t_f32n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s16816_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s16816_nt : gemm_planar_complex_f16_s16816_nt_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16n_f16t_f16n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s16816_ch_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s16816_ch : gemm_planar_complex_s16816_ch_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16c_f16h_f32n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_cf16_s16816_ch_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_cf16_s16816_ch : gemm_planar_complex_cf16_s16816_ch_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16c_f16h_f16n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..f8505ed ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 64x128x128_32x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 128x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 256x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 128x256x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..0d95c50 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu -@@ -0,0 +1,215 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "multistage_testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 64x128x128_32x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 32, -+ 32, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 128x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 32, -+ 32, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 256x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 32, -+ 32, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 128x256x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 32, -+ 32, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..fbb576f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu -@@ -0,0 +1,249 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x256x128_64x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x128x128_64x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x128x128_64x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x128x128_32x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x64x128_64x32x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x64x128_32x32x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..0c028e0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu -@@ -0,0 +1,360 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x256x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x128x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x128x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x64x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x256x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x128x256_32x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 256>, -+ cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x64x256_64x32x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 256>, -+ cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x64x256_32x32x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x256x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x128x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x128x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x64x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x256x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x128x128_32x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x64x128_64x32x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x64x128_32x32x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif //#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..23dc8eb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu -@@ -0,0 +1,248 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x8x32, DataType/Instruction = s4 * s4 + s32 => s32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 128x256x128_64x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 256x128x128_64x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 128x128x128_64x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 64x128x128_32x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 128x64x128_64x32x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 64x64x128_32x32x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..4016558 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu -@@ -0,0 +1,249 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x128_64x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x128_64x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x128_64x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x128x128_32x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x128_64x32x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x128_32x32x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..d962249 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu -@@ -0,0 +1,363 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x256_64x64x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x256_64x64x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x256_64x64x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x64x256_64x64x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x256_64x64x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x128x256_32x64x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 256>, -+ cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x256_64x32x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 256>, -+ cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x256_32x32x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x64x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x128_32x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x128_64x32x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x128_32x32x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu -new file mode 100644 -index 0000000..f9903f3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu -@@ -0,0 +1,267 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x64x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x128x256_32x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 256>, -+ cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x256_64x32x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 256>, -+ cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x256_32x32x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x512_64x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x512_64x32x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x512_32x32x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..3df1180 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu -@@ -0,0 +1,247 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x8x32, DataType/Instruction = s4 * s4 + s32 => s32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 128x256x128_64x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 256x128x128_64x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 128x128x128_64x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 64x128x128_32x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 128x64x128_64x32x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 64x64x128_32x32x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..09a502b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu -@@ -0,0 +1,312 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x256x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x256x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x64x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x128x128_32x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x64x128_64x32x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x64x128_32x32x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..7b002d5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu -@@ -0,0 +1,374 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "multistage_testbed.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x256x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x64x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x256x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x128x256_32x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 256>, -+ cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x64x256_64x32x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 256>, -+ cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x64x256_32x32x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x256x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x64x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x256x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x128x128_32x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x64x128_64x32x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x64x128_32x32x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..525677a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu -@@ -0,0 +1,312 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x256x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x256x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x64x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x128x128_32x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x64x128_64x32x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x64x128_32x32x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..ccaaabf ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu -@@ -0,0 +1,374 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "multistage_testbed.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x256x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x64x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x256x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x128x256_32x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 256>, -+ cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x64x256_64x32x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 256>, -+ cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x64x256_32x32x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x256x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x64x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x256x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x128x128_32x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x64x128_64x32x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x64x128_32x32x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..98b74e1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu -@@ -0,0 +1,293 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 32x64x64_16x32x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x128x64_32x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 256x128x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x256x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..5dbb51e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm80.cu -@@ -0,0 +1,358 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "multistage_testbed_interleaved.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x128x64_32x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 256x128x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x256x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 256x64x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x256x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..7a7b9df ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu -@@ -0,0 +1,249 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x256x64_64x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x128x64_64x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x128x64_32x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..4ce7d70 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu -@@ -0,0 +1,361 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x256x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x128x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x64x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x256x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x128x128_32x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x64x128_64x32x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x64x128_32x32x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x256x64_64x64x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x128x64_64x64x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x64_64x64x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x64x64_64x64x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x256x64_64x64x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x128x64_32x64x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x64x64_64x32x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x64x64_32x32x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu -new file mode 100644 -index 0000000..251e138 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu -@@ -0,0 +1,151 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 16x16x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 8x32x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM72_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..2c4ca98 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu -@@ -0,0 +1,249 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x64_64x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x64_64x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x64_32x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..9ff2bcc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu -@@ -0,0 +1,361 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x64x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x256x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x128_32x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x128_64x32x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x128_32x32x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x64_64x64x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x64_64x64x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x64_64x64x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x64x64_64x64x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x256x64_64x64x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x64_32x64x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x64_64x32x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x64_32x32x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu -new file mode 100644 -index 0000000..c2f46a2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu -@@ -0,0 +1,269 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x64x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x256x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x128_32x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x128_64x32x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x128_32x32x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x256_64x32x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 256>, -+ cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x256_32x32x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu -new file mode 100644 -index 0000000..08ae460 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu -@@ -0,0 +1,186 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 16x16x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 32x8x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 8x32x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM72_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..cb7401f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu -@@ -0,0 +1,198 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x256x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x128x64_32x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x64x64_64x32x64, { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x64x64_32x32x64, { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+} ) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..18daa2b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu -@@ -0,0 +1,374 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "multistage_testbed.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x256x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x128x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x128x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x64x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x256x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x128x128_32x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x64x128_64x32x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x64x128_32x32x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x256x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x64x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x256x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x128x64_32x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x64x64_64x32x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x64x64_32x32x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu -new file mode 100644 -index 0000000..e0ffc83 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu -@@ -0,0 +1,177 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 16x16x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 32x8x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 8x32x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM72_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..a13a6eb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x256x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x128x64_32x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x64x64_64x32x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ test::gemm::device::Testbed testbed; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x64x64_32x32x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ test::gemm::device::Testbed testbed; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..9af1d01 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu -@@ -0,0 +1,374 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "multistage_testbed.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x256x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x128x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x128x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x64x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x256x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x128x128_32x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x64x128_64x32x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x64x128_32x32x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x256x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x64x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x256x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x128x64_32x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x64x64_64x32x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x64x64_32x32x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu -new file mode 100644 -index 0000000..12ab891 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu -@@ -0,0 +1,178 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 16x16x16, DataType/Instruction = s8*s8+s32=>s8 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 32x8x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 8x32x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM72_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_serial_tensor_op_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_serial_tensor_op_sm75.cu -new file mode 100644 -index 0000000..30f0bb7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_serial_tensor_op_sm75.cu -@@ -0,0 +1,114 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmSplitKSerial_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ static const int kStages = 2; -+ -+ static const int kAlignmentA = cutlass::gemm::device::DefaultGemmConfiguration< -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ElementA, -+ ElementB, -+ ElementOutput, -+ ElementAccumulator>::kAlignmentA; -+ -+ static const int kAlignmentB = cutlass::gemm::device::DefaultGemmConfiguration< -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ElementA, -+ ElementB, -+ ElementOutput, -+ ElementAccumulator>::kAlignmentB; -+ -+ static const bool kSplitKSerial = true; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages, -+ kAlignmentA, -+ kAlignmentB, -+ kSplitKSerial -+ >; -+ -+ bool result = test::gemm::device::TestAllGemm(); -+ EXPECT_TRUE(result); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_simt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_simt_sm50.cu -new file mode 100644 -index 0000000..1b4eba2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_simt_sm50.cu -@@ -0,0 +1,146 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_splitk_parallel.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_splitk.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_GemmSplitKParallel_f32n_f32t_f32t_simt_f32, 128x128x8) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM50_Device_GemmSplitKParallel_f32n_f32n_f32n_simt_f32, 128x128x8) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_GemmSplitKParallel_f64n_f64n_f64t_simt_f64, 64x128x8) { -+ -+ using Element = double; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM50_Device_GemmSplitKParallel_f64t_f64t_f64n_simt_f64, 64x64x8) { -+ -+ using Element = double; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu -new file mode 100644 -index 0000000..5f15749 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_splitk_parallel.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_splitk.h" -+ -+// These operators are assert(0) unless extended PTX is used. -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmSplitK_f16n_f16t_f32t_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM70_Device_GemmSplitK_f16n_f16t_f16t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM70_Device_GemmSplitK_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmSplitK_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM70_Device_GemmSplitK_f16t_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+ -+TEST(SM70_Device_GemmSplitK_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu -new file mode 100644 -index 0000000..3d8e8db ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu -@@ -0,0 +1,336 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_splitk_parallel.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_splitk.h" -+ -+// These operators are assert(0) unless extended PTX is used. -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmSplitKParallel_f16n_f16t_f32t_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16n_f16t_f32n_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16n_f16t_f16t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16n_f16t_f16n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16n_f16t_f16n_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmSplitKParallel_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16t_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16t_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16t_f16n_f16n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16t_f16n_f16n_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp b/3rdparty/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp -new file mode 100644 -index 0000000..24a9e24 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp -@@ -0,0 +1,717 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/packed_stride.hpp" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gett.hpp" -+ -+#include "testbed_utils.h" -+ -+#include "cutlass/kernel_hardware_info.hpp" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cute/int_tuple.hpp" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail{ -+ -+template -+struct TestbedImpl { -+ // Kernel data types -+ using ElementA = typename Gemm::GemmKernel::ElementA; -+ using StrideA = typename Gemm::GemmKernel::StrideA; -+ using ElementB = typename Gemm::GemmKernel::ElementB; -+ using StrideB = typename Gemm::GemmKernel::StrideB; -+ using ElementC = typename Gemm::GemmKernel::ElementC; -+ using StrideC = typename Gemm::GemmKernel::StrideC; -+ using ElementD = typename Gemm::GemmKernel::ElementD; -+ using StrideD = typename Gemm::GemmKernel::StrideD; -+ using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::CollectiveEpilogue::ElementCompute; -+ using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; -+ using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; -+ -+ static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ -+ // Looks at Cute Stride to check Row / Column Major -+ template -+ static constexpr bool is_row_or_col_major(){ -+ int stride_0 = int(cute::size<0>(Stride{})); -+ int stride_1 = int(cute::size<1>(Stride{})); -+ int depth = cute::depth(Stride{}); -+ return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); -+ } -+ -+ // Note: this limitation comes from testbed / not the library -+ static_assert(is_row_or_col_major(), -+ "ERROR : A Layout is neither Row / Column Major)"); -+ static_assert(is_row_or_col_major(), -+ "ERROR : B Layout is neither Row / Column Major)"); -+ static_assert(is_row_or_col_major(), -+ "ERROR : C Layout is neither Row / Column Major)"); -+ static_assert(is_row_or_col_major(), -+ "ERROR : D Layout is neither Row / Column Major)"); -+ -+ // Deduce Cutlass Layouts (RowMajor & ColumnMajor) -+ using LayoutTagA = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); -+ using LayoutTagB = decltype(cutlass::gemm::detail::stride_to_layout_tag_B()); -+ using LayoutTagC = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); -+ using LayoutTagD = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); -+ using LayoutTagPackedVector = cutlass::layout::PackedVectorLayout; -+ -+ /// Initialization -+ StrideA stride_a; -+ StrideB stride_b; -+ StrideC stride_c; -+ StrideD stride_d; -+ typename LayoutTagA::Stride stride_factor_A; -+ typename LayoutTagB::Stride stride_factor_B; -+ typename LayoutTagC::Stride stride_factor_C; -+ typename LayoutTagD::Stride stride_factor_D; -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ static constexpr uint64_t kDefaultSeed = 4096; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ uint32_t sm_count; -+ -+ // Used to force multi-wave tests for persistent kernel schedules -+ constexpr static int MaxSmCount = 16; -+ -+ // -+ // Methods -+ // -+ -+ TestbedImpl( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = kDefaultSeed -+ ): -+ stride_factor_A(typename LayoutTagA::Stride()), -+ stride_factor_B(typename LayoutTagB::Stride()), -+ stride_factor_C(typename LayoutTagC::Stride()), -+ stride_factor_D(typename LayoutTagD::Stride()), -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ TestbedImpl( -+ typename LayoutTagA::Stride stride_factor_A_, -+ typename LayoutTagB::Stride stride_factor_B_, -+ typename LayoutTagC::Stride stride_factor_C_, -+ typename LayoutTagD::Stride stride_factor_D_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = kDefaultSeed -+ ): -+ stride_factor_A(stride_factor_A_), -+ stride_factor_B(stride_factor_B_), -+ stride_factor_C(stride_factor_C_), -+ stride_factor_D(stride_factor_D_), -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } -+ else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ -+ else { -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(ProblemShapeType problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ auto problem_shape_MNKL = cute::append<4>(problem_size, 1); -+ auto M = cute::size<0>(problem_shape_MNKL); -+ auto N = cute::size<1>(problem_shape_MNKL); -+ auto K = cute::size<2>(problem_shape_MNKL); -+ auto L = cute::size<3>(problem_shape_MNKL); -+ -+ stride_a = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); -+ stride_b = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); -+ stride_c = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); -+ stride_d = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); -+ -+ // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode -+ auto a_coord = cutlass::make_Coord(M * L, K); -+ auto c_coord = cutlass::make_Coord(M * L, N); -+ // Cutlass has Row/Col major refers to MxK times KxN matrix product, -+ // so the HostTensorB should be treated as KxN in "coord"'s view -+ auto b_coord = cutlass::make_Coord(K, N * L); -+ -+ -+ tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); -+ tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); -+ tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); -+ tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); -+ reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2020)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = ElementA(1); -+ tensor_B.host_view().at({0, 0}) = ElementB(1); -+ tensor_C.host_view().at(cutlass::make_Coord(0, 0)) = ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cute::Shape problem_shape_MNKL, -+ ElementScalar alpha, -+ ElementScalar beta -+ ) { -+ auto [M, N, K, L] = problem_shape_MNKL; -+ -+ tensor_D.sync_host(); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ if (tensor_D.size() > 1) { -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ } -+ -+ if (reference_D.size() > 1) { -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ } -+ -+ bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ std::stringstream fname; -+ fname << "error_Gemm_device_" -+ << M << "x" << N << "x" << K << "x" << L << "_" -+ << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" -+ << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" -+ << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ file -+ << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L -+ << ", alpha: " << float(alpha) << ", beta: " << float(beta) << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\n\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ ProblemShapeType problem_size, -+ ElementScalar alpha, -+ ElementScalar beta -+ ) { -+ auto problem_shape_MNKL = cute::append<4>(problem_size, 1); -+ auto M = cute::size<0>(problem_shape_MNKL); -+ auto N = cute::size<1>(problem_shape_MNKL); -+ auto K = cute::size<2>(problem_shape_MNKL); -+ auto L = cute::size<3>(problem_shape_MNKL); -+ -+ auto A = cute::make_tensor(tensor_A.host_data(), -+ cute::make_layout(cute::make_shape(M, K, L), stride_a)); -+ auto B = cute::make_tensor(tensor_B.host_data(), -+ cute::make_layout(cute::make_shape(N, K, L), stride_b)); -+ auto C = cute::make_tensor(tensor_C.host_data(), -+ cute::make_layout(cute::make_shape(M, N, L), stride_c)); -+ auto D = cute::make_tensor(reference_D.host_data(), -+ cute::make_layout(cute::make_shape(M, N, L), stride_d)); -+ cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; -+ -+ cutlass::reference::host::GettEpilogueParams< -+ ElementScalar, -+ ElementAccumulator, -+ ElementCompute, -+ decltype(C), -+ decltype(D) -+ > -+ epilogue_params{ -+ alpha, beta, -+ C, D -+ }; -+ -+ cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); -+ -+ return compare_reference( -+ problem_shape_MNKL, alpha, beta -+ ); -+ } -+ -+ /// Determine if the CUDA device is sufficient to run the kernel -+ bool sufficient() { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = Gemm::GemmKernel::SharedStorageSize; -+ -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ cudaDeviceProp properties; -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ this->sm_count = properties.multiProcessorCount; -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ bool profile( -+ ProblemShapeType problem_size, -+ int iterations, -+ Gemm& gemm_op, -+ typename Gemm::Arguments& arguments, -+ cutlass::device_memory::allocation& workspace) { -+ int M = cute::size<0>(problem_size); -+ int N = cute::size<1>(problem_size); -+ int K = cute::size<2>(problem_size); -+ int L = 1; -+ if constexpr(cute::rank(ProblemShapeType{}) == 4) { -+ L = cute::size<3>(problem_size); -+ } -+ -+ -+ cutlass::Status status; -+ // -+ // Run the GEMM -+ // -+ cudaError_t result; -+ -+ for (int iter = 0; iter < iterations; ++iter) { -+ status = gemm_op(arguments, workspace.get()); -+ if (status != cutlass::Status::kSuccess) { -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ return false; -+ } -+ } -+ -+ result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ ProblemShapeType problem_size, -+ ElementScalar alpha = ElementScalar(1), -+ ElementScalar beta = ElementScalar(0), -+ bool profiling = false, -+ int iterations = 20 -+ ) { -+ // Fail test if insufficient CUDA device -+ if (!sufficient()) { -+ std::cout << "Test failed due to insufficient CUDA device." << std::endl; -+ return false; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments; -+ cutlass::KernelHardwareInfo hw_info; -+ hw_info.device_id = 0; -+ if (not profiling) { -+ this->sm_count = min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); -+ hw_info.sm_count = this->sm_count; -+ } -+ else { -+ this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); -+ hw_info.sm_count = this->sm_count; -+ } -+ -+ // DefaultEpilogue -+ arguments = typename Gemm::Arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, -+ tensor_A.device_data(), -+ stride_a, -+ tensor_B.device_data(), -+ stride_b, -+ {tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d, {alpha, beta}}, -+ hw_info -+ }; -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // -+ // Run the GEMM -+ // -+ -+ if (profiling) { -+ return profile(problem_size, iterations, gemm_op, arguments, workspace); -+ } -+ else { -+ cudaError_t result; -+ status = gemm_op.initialize(arguments, workspace.get()); -+ status = gemm_op.run(); -+ result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; -+ return false; -+ } -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ bool passed = this->verify( -+ problem_size, alpha, beta -+ ); -+ if (!passed) { -+ std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta) -+ << "\n"; -+ } -+ -+ return passed; -+ } -+ } -+}; -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct Testbed { -+ -+ using TestBedImplementation = typename detail::TestbedImpl; -+ -+ using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::CollectiveEpilogue::ElementCompute; -+ using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; -+ using LayoutTagA = typename TestBedImplementation::LayoutTagA; -+ using LayoutTagB = typename TestBedImplementation::LayoutTagB; -+ using LayoutTagC = typename TestBedImplementation::LayoutTagC; -+ using LayoutTagD = typename TestBedImplementation::LayoutTagD; -+ -+ // Detail Implementation -+ TestBedImplementation impl_; -+ -+ // -+ // Methods -+ // -+ Testbed( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = TestBedImplementation::kDefaultSeed) -+ : impl_(init_A_, init_B_, init_C_, seed_) {} -+ -+ Testbed( -+ typename LayoutTagA::Stride stride_factor_A_, -+ typename LayoutTagB::Stride stride_factor_B_, -+ typename LayoutTagC::Stride stride_factor_C_, -+ typename LayoutTagD::Stride stride_factor_D_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = TestBedImplementation::kDefaultSeed) -+ : impl_(stride_factor_A_, -+ stride_factor_B_, -+ stride_factor_C_, -+ stride_factor_D_, -+ init_A_, -+ init_B_, -+ init_C_, -+ seed_) {} -+ -+ /// Executes one test -+ bool run( -+ typename TestBedImplementation::ProblemShapeType problem_size, -+ ElementScalar alpha = ElementScalar(1), -+ ElementScalar beta = ElementScalar(0), -+ bool profiling = false, -+ int iterations = 20 -+ ) { -+ return impl_.run( -+ problem_size, alpha, beta, profiling, iterations -+ ); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAll() { -+ using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; -+ using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; -+ -+ int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); -+ std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; -+ std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; -+ -+ if constexpr (std::is_same_v) { -+ problem_size_m.push_back(768); -+ problem_size_n.push_back(768); -+ } -+ -+ constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; -+ constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); -+ -+ std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; -+ -+ Testbed testbed; -+ bool passed = true; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ ProblemShapeType problem_size; -+ if constexpr (cute::rank(ProblemShapeType{}) == 4) { -+ problem_size = ProblemShapeType{m, n, k, /* l */ 1}; -+ } -+ else { -+ problem_size = ProblemShapeType{m, n, k}; -+ } -+ -+ passed = testbed.run( -+ problem_size, -+ cutlass::from_real(1), -+ cutlass::from_real(0) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ -+ // if we do support batched GEMM, just run one test on it to save on test time -+ if constexpr (cute::rank(ProblemShapeType{}) == 4) { -+ auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; -+ passed = testbed.run( -+ problem_size, -+ cutlass::from_real(1), -+ cutlass::from_real(0) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestGemmPerf(int iterations = 20) { -+ using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; -+ using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; -+ using ElementScalar = ElementAccumulator; -+ bool passed = true; -+ -+ std::vector problem_size_m = { 4608 }; -+ std::vector problem_size_n = { 4608 }; -+ std::vector problem_size_k = { 8192 }; -+ -+ Testbed testbed; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ ProblemShapeType problem_size; -+ if constexpr (cute::rank(ProblemShapeType{}) == 4) { -+ problem_size = ProblemShapeType{m, n, k, /* l */ 1}; -+ } -+ else { -+ problem_size = ProblemShapeType{m, n, k}; -+ } -+ -+ passed = testbed.run( -+ problem_size, -+ cutlass::from_real(1), -+ cutlass::from_real(0), -+ true, -+ iterations -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ -+ -+ // if we do support batched GEMM, just run it once -+ if constexpr (cute::rank(ProblemShapeType{}) == 4) { -+ auto problem_size = ProblemShapeType{problem_size_m[0], problem_size_n[0], problem_size_k[0], /* l */ 4}; -+ passed = testbed.run( -+ problem_size, -+ cutlass::from_real(1), -+ cutlass::from_real(0), -+ true, -+ iterations -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ -+ return passed; -+} -+ -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..1c33d50 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,555 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 256x64x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x128x16_32x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x64x16_64x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..375e7b9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,555 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 256x64x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x128x16_32x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x64x16_64x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..3353368 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,493 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 256x64x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 64x128x16_32x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x64x16_64x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..3b243b2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,556 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 256x64x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x128x16_32x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x64x16_64x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_u8t_u8n_s32t_wmma_tensor_op_s32_sm72.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_u8t_u8n_s32t_wmma_tensor_op_s32_sm72.cu -new file mode 100644 -index 0000000..a39f29d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_u8t_u8n_s32t_wmma_tensor_op_s32_sm72.cu -@@ -0,0 +1,185 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 16x16x16, DataType/Instruction = u8*u8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_u8t_u8n_s32t_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ uint8_t, -+ cutlass::layout::RowMajor, -+ uint8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_u8t_u8n_s32t_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ uint8_t, -+ cutlass::layout::RowMajor, -+ uint8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 32x8x16, DataType/Instruction = u8*u8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_u8t_u8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ uint8_t, -+ cutlass::layout::RowMajor, -+ uint8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 8x32x16, DataType/Instruction = u8*u8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_u8t_u8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ uint8_t, -+ cutlass::layout::RowMajor, -+ uint8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM72_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..96981a2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_universal.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf32n_cf32t_cf32n_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf32n_cf32h_cf32n_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf32h_cf32t_cf32n_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf32h_cf32c_cf32n_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..29103d0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu -@@ -0,0 +1,200 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_universal.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64n_cf64t_cf64n_tensor_op_f64_gaussian, 64x64x32_32x32x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64n_cf64h_cf64n_tensor_op_f64_gaussian, 64x64x32_32x32x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64h_cf64t_cf64n_tensor_op_f64_gaussian, 64x32x32_32x16x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64h_cf64c_cf64n_tensor_op_f64_gaussian, 64x64x32_32x16x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..b82c5e5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,200 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_universal.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64n_cf64t_cf64n_tensor_op_f64, 64x64x32_32x32x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64n_cf64h_cf64n_tensor_op_f64, 64x64x32_32x32x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64h_cf64t_cf64n_tensor_op_f64, 64x64x32_32x32x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64h_cf64c_cf64n_tensor_op_f64, 64x64x32_32x32x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_f16n_f16t_f32n_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_f16n_f16t_f32n_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..72c3d5d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_f16n_f16t_f32n_tensor_op_f32_sm75.cu -@@ -0,0 +1,117 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_universal.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmUniversal_f16n_f16t_f32n_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+ -+TEST(SM75_Device_GemmUniversal_f16n_f16t_f32n_tensor_op_f32, 64x64x32_32x32x32_updated_batch_count) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 2, -+ 1, -+ 1>; -+ -+ EXPECT_TRUE(test::gemm::device::TestGemmUniversal( -+ {128, 128, 2}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 15)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..771573b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm75.cu -@@ -0,0 +1,115 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_universal.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+TEST(SM75_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32_updated_batch_count) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 2, -+ 1, -+ 1>; -+ -+ EXPECT_TRUE(test::gemm::device::TestGemmUniversal( -+ {128, 128, 2}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 15)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu -new file mode 100644 -index 0000000..8a1884d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu -@@ -0,0 +1,464 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" -+#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_gemm_with_broadcast.h" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes: -+/// -+/// Z = GEMM+Bias+ReLu -+/// T = Relu conditional -+/// -+template -+struct GemmWithBiasReluReferenceOp { -+ -+ using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; -+ -+ using ElementCompute = typename OutputOp::ElementCompute; -+ using ElementZ = typename OutputOp::ElementZ; -+ using ElementT = typename OutputOp::ElementT; -+ -+ typename OutputOp::BinaryOp binary_op; -+ typename OutputOp::ElementwiseOp elementwise_op; -+ -+ GemmWithBiasReluReferenceOp() { } -+ -+ void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) { -+ -+ ElementCompute kThreshold = ElementCompute(); -+ -+ ElementCompute z_full = binary_op(gemm, bias); -+ -+ bool conditional = (z_full >= kThreshold); -+ -+ if (!conditional) { -+ z_full = kThreshold; -+ } -+ -+ Z = ElementZ(z_full); -+ T = ElementT(conditional); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8, -+ cutlass::epilogue::thread::GELU_taylor -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8, -+ cutlass::epilogue::thread::GELU_taylor -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast(); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ 8, -+ true -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast >(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ 8, -+ true -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast >(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defiend(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x5_64x64x32_16x8x16) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8, -+ cutlass::epilogue::thread::GELU_taylor -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 5, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast(); -+} -+ -+TEST(SM80_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x5_64x64x32_16x8x16) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ 8, -+ true -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 5, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast>(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x4_64x64x32_16x8x16) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8, -+ cutlass::epilogue::thread::GELU_taylor -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 4, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast(); -+} -+ -+TEST(SM80_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x4_64x64x32_16x8x16) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ 8, -+ true -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 4, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast>(); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x3_64x64x32_16x8x16) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8, -+ cutlass::epilogue::thread::GELU_taylor -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast(); -+} -+ -+TEST(SM80_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x3_64x64x32_16x8x16) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ 8, -+ true -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast >(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu -new file mode 100644 -index 0000000..15eca4b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu -@@ -0,0 +1,384 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_with_reduction.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_drelu.h" -+#include "cutlass/epilogue/thread/linear_combination_dgelu.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_gemm_with_reduction.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct dReluLambda { -+ float operator()(float d_y, float t) { -+ if (t <= 0) { -+ d_y = 0; -+ } -+ return d_y; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmWithReduction_dReLU_bGrad_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::GemmWithReductionReference< -+ Gemm, -+ dReluLambda -+ >; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {520, 264, 96}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 2, -+ float(1.25), -+ float(2.25) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmWithReduction_dReLU_bGrad_f16n_f16n_f16n_tensor_op_f32, 256x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::GemmWithReductionReference< -+ Gemm, -+ dReluLambda -+ >; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {520, 264, 96}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 1, -+ float(1.25), -+ float(2.25) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmWithReduction_dReLU_bGrad_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::GemmWithReductionReference< -+ Gemm, -+ dReluLambda -+ >; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {520, 264, 96}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 2, -+ float(1.25), -+ float(2.25) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmWithReduction_dReLU_bGrad_f16n_f16n_f16n_tensor_op_f32, 256x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::GemmWithReductionReference< -+ Gemm, -+ dReluLambda -+ >; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {520, 264, 96}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 1, -+ float(1.25), -+ float(2.25) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+template -+struct Gemm_dReLU_packed_bits_reference_op { -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::ElementCompute; -+ using ElementC = typename Gemm::ElementC; -+ using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; -+ -+ // -+ // Methods -+ // -+ -+ Gemm_dReLU_packed_bits_reference_op() { } -+ -+ ElementCompute operator()( -+ ElementAccumulator d_y, -+ ElementT t) const { -+ -+ ElementCompute result = ElementCompute(d_y); -+ -+ bool cond = bool(t); -+ if (!cond) { -+ result = ElementCompute(); -+ } -+ -+ return result; -+ } -+}; -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmWithReduction_dReLU_conditional_bits_bGrad_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDReluConditionalBits< -+ float, -+ float, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::Gemm_dReLU_packed_bits_reference_op; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {520, 264, 96}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 2, -+ float(1.25), -+ float(2.25) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmWithReduction_dReLU_conditional_bits_bGrad_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDReluConditionalBits< -+ float, -+ float, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::Gemm_dReLU_packed_bits_reference_op; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {520, 264, 96}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 2, -+ float(1.25), -+ float(2.25) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defiend(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu -new file mode 100644 -index 0000000..3e04929 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu -@@ -0,0 +1,118 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_with_reduction.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_drelu.h" -+#include "cutlass/epilogue/thread/linear_combination_dgelu.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_gemm_with_reduction.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct dReluLambda { -+ float operator()(float d_y, float t) { -+ if (t <= 0) { -+ d_y = 0; -+ } -+ return d_y; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmWithReduction_dReLU_bGrad_f16t_f16n_f16n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 5, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::GemmWithReductionReference< -+ Gemm, -+ dReluLambda -+ >; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {136, 6920, 512}, -+ cutlass::gemm::GemmUniversalMode::kGemm -+ ); -+} -+ -+#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemv.cu b/3rdparty/cutlass/test/unit/gemm/device/gemv.cu -new file mode 100644 -index 0000000..fe68e0e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemv.cu -@@ -0,0 +1,444 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMV interface -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/kernel/gemv.h" -+#include "cutlass/gemm/device/gemv.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+#include "testbed_utils.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+ -+template -+class TestbedGemv { -+public: -+ -+ using ElementA = typename Gemv::ElementA; -+ using LayoutA = typename Gemv::LayoutA; -+ using ElementB = typename Gemv::ElementB; -+ using ElementC = typename Gemv::ElementC; -+ -+ using ElementAccumulator = typename Gemv::ElementAccumulator; -+ using ElementCompute = typename Gemv::EpilogueOutputOp::ElementCompute; -+ -+ using LayoutV = cutlass::layout::RowMajor; -+ -+private: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ TestbedGemv( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize( -+ cutlass::MatrixCoord problem_size -+ ) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ tensor_A.resize(problem_size); -+ tensor_B.resize({problem_size.column(), 1}); -+ tensor_C.resize({problem_size.row(), 1}); -+ tensor_D.resize({problem_size.row(), 1}); -+ reference_D.resize({problem_size.row(), 1}, false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Gemv::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Gemv::ElementB(1); -+ tensor_C.host_view().at({0, 0}) = typename Gemv::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::MatrixCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed) << " mismatched reference"; -+ -+ if (!passed) { -+ -+ std::ofstream file("testbed_universal_errors.txt"); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::MatrixCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::GemmComplex< -+ typename Gemv::ElementA, typename Gemv::LayoutA, -+ typename Gemv::ElementB, LayoutV, -+ typename Gemv::ElementC, LayoutV, -+ ElementCompute, ElementAccumulator -+ >( -+ {problem_size.row(), 1, problem_size.column()}, -+ alpha, -+ tensor_A.host_ref(), -+ Gemv::kTransformA, -+ tensor_B.host_ref(), -+ Gemv::kTransformB, -+ beta, -+ tensor_C.host_ref(), -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Runs one problem size -+ bool run( -+ cutlass::MatrixCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemv::Arguments arguments{ -+ problem_size, -+ {alpha, beta}, -+ tensor_A.device_ref(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_D.device_data(), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0) -+ }; -+ -+ Gemv gemm_op; -+ -+ size_t workspace_size = Gemv::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllGemv() { -+ -+ using ElementCompute = typename Gemv::EpilogueOutputOp::ElementCompute; -+ -+ int M[] = { -+ 8, 48, 192, 520 -+ }; -+ -+ int K[] = { -+ 8, 192, 528 -+ }; -+ -+ double Alpha[] = { -+ 1, 1.25 -+ }; -+ -+ double Beta[] = { -+ 0, 1, 1.25 -+ }; -+ -+ for (int m : M) { -+ for (int k : K) { -+ for (double alpha : Alpha) { -+ for (double beta : Beta) { -+ -+ TestbedGemv testbed; -+ -+ if (!testbed.run({m, k}, ElementCompute(alpha), ElementCompute(beta))) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+} -+ -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemv_f32n_f32_f32_simt_f32, Simple) { -+ -+ using ElementOutput = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator>; -+ -+ using Gemv = cutlass::gemm::device::Gemv< -+ cutlass::gemm::kernel::Gemv< -+ ElementOutput, // Element A -+ LayoutA, // Layout A -+ ElementOutput, // Element B -+ ElementOutput, // Element C -+ ElementAccumulator, // Element Accumulator -+ EpilogueOp // Output operator -+ > -+ >; -+ -+ -+ EXPECT_TRUE(test::gemm::TestAllGemv()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemv_f16n_f16_f32_simt_f32, Simple) { -+ -+ using ElementInput = cutlass::half_t; -+ using ElementOutput = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator>; -+ -+ using Gemv = cutlass::gemm::device::Gemv< -+ cutlass::gemm::kernel::Gemv< -+ ElementInput, // Element A -+ LayoutA, // Layout A -+ ElementInput, // Element B -+ ElementOutput, // Element C -+ ElementAccumulator, // Element Accumulator -+ EpilogueOp // Output operator -+ > -+ >; -+ -+ -+ EXPECT_TRUE(test::gemm::TestAllGemv()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemv_f16n_f16_f16_simt_f32, Simple) { -+ -+ using ElementInput = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator>; -+ -+ using Gemv = cutlass::gemm::device::Gemv< -+ cutlass::gemm::kernel::Gemv< -+ ElementInput, // Element A -+ LayoutA, // Layout A -+ ElementInput, // Element B -+ ElementOutput, // Element C -+ ElementAccumulator, // Element Accumulator -+ EpilogueOp // Output operator -+ > -+ >; -+ -+ -+ EXPECT_TRUE(test::gemm::TestAllGemv()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..e09bf17 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_ls_l_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_ls_u_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_ls_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_f32_rs_sm80.cu -new file mode 100644 -index 0000000..45fbebd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_f32_rs_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_rs_l_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_rs_u_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_rs_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..def5edd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_ls_l_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_ls_u_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_ls_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_rs_sm80.cu -new file mode 100644 -index 0000000..ebf9055 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_rs_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_rs_l_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_rs_u_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_rs_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..9a11549 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,135 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Hemm_cf64h_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Hemm_cf64h_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..882bbf2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_ls_u_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_ls_u_tensor_op_f64_gaussian, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_sm80.cu -new file mode 100644 -index 0000000..4b4b166 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_ls_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_ls_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_ls_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_rs_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_rs_f64_sm80.cu -new file mode 100644 -index 0000000..d6d1690 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_rs_f64_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_rs_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..1764e32 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,149 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HER2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf32n_cf32n_l_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf32h_cf32n_l_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..926ed0a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,149 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HER2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf32n_cf32n_l_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf32h_cf32n_l_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..fbc4efd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,149 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HER2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Her2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Her2k_cf64c_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..4697598 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,310 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+// NOTE: HER2K requires that LayoutA == LayoutB, and that LayoutC == ColumnMajor -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_l_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_l_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_u_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_u_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..c7dca8c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,310 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+// NOTE: HER2K requires that LayoutA == LayoutB, and that LayoutC == ColumnMajor -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_l_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_l_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_u_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_u_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..717d3f9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,149 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HER2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf64h_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..3a65931 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,201 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HER2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#if 0 // HER2K with RowMajor output is not supported -+TEST(SM80_Device_Her2k_cf64n_cf64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ false, // IsBetaZero -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf64c_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ false, // IsBetaZero -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf64h_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ false, // IsBetaZero -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..a6503d1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,219 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HERK interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_N (column-major) input layouts -+TEST(SM80_Device_Herk_cf32n_cf32n_l_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// HERK operator on CUBLAS_OP_N (column-major) input layouts -+TEST(SM80_Device_Herk_cf32n_cf32n_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts -+TEST(SM80_Device_Herk_cf32h_cf32n_l_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts -+TEST(SM80_Device_Herk_cf32h_cf32n_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..56ef601 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,219 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HERK interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_N (column-major) input layouts -+TEST(SM80_Device_Herk_cf32n_cf32n_l_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// HERK operator on CUBLAS_OP_N (column-major) input layouts -+TEST(SM80_Device_Herk_cf32n_cf32n_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts -+TEST(SM80_Device_Herk_cf32h_cf32n_l_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts -+TEST(SM80_Device_Herk_cf32h_cf32n_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..114a20c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,93 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HERK interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts -+TEST(SM90_Device_Herk_cf64h_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/herk_cf64h_cf64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/herk_cf64h_cf64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..71d8a9c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/herk_cf64h_cf64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HERK interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_N (column-major) input layouts -+TEST(SM80_Device_Herk_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_N (column-major) input layouts -+TEST(SM80_Device_Herk_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts -+TEST(SM80_Device_Herk_cf64h_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/multistage_testbed.h b/3rdparty/cutlass/test/unit/gemm/device/multistage_testbed.h -new file mode 100644 -index 0000000..681e051 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/multistage_testbed.h -@@ -0,0 +1,301 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MultistageTestbed { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = -+ typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ MultistageTestbed( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080) -+ : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {} -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor(cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, uint64_t seed) { -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ int scope = (cutlass::sizeof_bits::value == 8) ? 2 : 8; -+ cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -+ -scope, 0); -+ } else if (dist_kind == cutlass::Distribution::Gaussian) { -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, -1); -+ } else if (dist_kind == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(view); -+ } else if (dist_kind == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(view.data(), -+ view.capacity()); -+ } else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Waives test if CUDA device is insufficient -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run(cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waives test if CUDA device is insufficient -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor -+ tensor_A(problem_size.mk()); -+ -+ cutlass::HostTensor -+ tensor_B(problem_size.kn()); -+ -+ cutlass::HostTensor -+ tensor_C(problem_size.mn()); -+ -+ cutlass::HostTensor -+ tensor_D(problem_size.mn()); -+ -+ cutlass::HostTensor -+ reference_D(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), -+ tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, tensor_A.device_ref(), tensor_B.device_ref(), -+ tensor_C.device_ref(), tensor_D.device_ref(), {alpha, beta}}; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op.initialize(arguments); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::Gemm< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm::Operator> -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, alpha, tensor_A.host_ref(), tensor_B.host_ref(), beta, -+ reference_D.host_ref(), ElementAccumulator(0)); -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D.host_view(), tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" << problem_size.m() << "x" -+ << problem_size.n() << "x" << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" << Gemm::ThreadblockShape::kN -+ << "x" << Gemm::ThreadblockShape::kK << "_" << Gemm::WarpShape::kM -+ << "x" << Gemm::WarpShape::kN << "x" << Gemm::WarpShape::kK -+ << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ -+ file << "problem: " << problem_size << ", alpha: " << alpha -+ << ", beta: " << beta << "\n\n"; -+ -+ file << "A =\n" -+ << tensor_A.host_view() << "\nB =\n" -+ << tensor_B.host_view() << "\nC =\n" -+ << tensor_C.host_view() << "\n\nReference =\n" -+ << reference_D.host_view() << "\nComputed =\n" -+ << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Runs a set of problem sizes -+ bool run_all() { -+ bool passed = true; -+ -+ int problem_size_m[] = {16, 528}; -+ -+ int problem_size_n[] = {16, 528}; -+ -+ int problem_size_k[] = {Gemm::InstructionShape::kK, -+ Gemm::ThreadblockShape::kK * Gemm::kStages + -+ Gemm::InstructionShape::kK}; -+ -+ double problem_alpha[] = {1.0}; -+ -+ // TODO Try non zero beta value after multistaged epilogue is implemented -+ double problem_beta[] = {0.0}; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (double alpha : problem_alpha) { -+ for (double beta : problem_beta) { -+ passed = -+ run({m, n, k}, ElementCompute(alpha), ElementCompute(beta)); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h b/3rdparty/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h -new file mode 100644 -index 0000000..5f33206 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h -@@ -0,0 +1,349 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/host_reorder.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MultistageInterleavedTestbed { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ MultistageInterleavedTestbed( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerMultiprocessor < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementA, -+ typename Gemm::LayoutA> tensor_A(problem_size.mk()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementB, -+ typename Gemm::LayoutB> tensor_B(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementB, -+ typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementC, -+ typename Gemm::LayoutC> tensor_C(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementC, -+ typename Gemm::LayoutC> tensor_D(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementC, -+ typename Gemm::LayoutC> reference_D(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ cutlass::reorder_column( -+ tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); -+ -+ cutlass::reference::host::TensorCopy( -+ reference_D.host_view(), -+ tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B_reordered.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B_reordered.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D.device_ref(), -+ {alpha, beta} -+ }; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op.initialize(arguments); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::Gemm< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm::Operator> -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ beta, -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D.host_view(), -+ tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nB_reordered =\n" << tensor_B_reordered.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Runs a set of problem sizes -+ bool run_all() { -+ bool passed = true; -+ -+ int problem_size_m[] = { -+ InterleavedK, 512 + InterleavedK -+ }; -+ -+ int problem_size_n[] = { -+ InterleavedK, 512 + InterleavedK -+ }; -+ -+ int problem_size_k[] = { -+ InterleavedK, Gemm::ThreadblockShape::kK * Gemm::kStages + InterleavedK -+ }; -+ -+ double problem_alpha[] = { -+ 1.0 -+ }; -+ -+ double problem_beta[] = { -+ 0.0 -+ }; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (double alpha : problem_alpha) { -+ for (double beta : problem_beta) { -+ -+ passed = run( -+ {m, n, k}, -+ ElementCompute(alpha), -+ ElementCompute(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/rank_2k_grouped_scheduler_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/rank_2k_grouped_scheduler_sm80.cu -new file mode 100644 -index 0000000..021182e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/rank_2k_grouped_scheduler_sm80.cu -@@ -0,0 +1,234 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K problem visitors -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+ -+#include "testbed_grouped_rank_2k_scheduler.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Run a series of tests on the testbed -+template -+void run_tests(bool skip_tile_check=false) { -+ for (int scale_factor : {8, 16, 32, 64}) { -+ for (int threadblock_count : {54, 108, 216, 324, 432}) { -+ for (int problems : {1, 27, 180, 300}) { -+ Testbed testbed(skip_tile_check); -+ testbed.run(problems, threadblock_count, scale_factor); -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p128_t128_l, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 128; -+ static int const kThreadCount = 128; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p128_t128_u, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 128; -+ static int const kThreadCount = 128; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kUpper; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_l, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p256_t128_l, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 128; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_l, 64x32x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ -+ // Skip individual tile check for the non-square SYR2K versions. We still -+ // compare the problem visitors with one another -+ run_tests(/*skip_tile_check=*/true); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_u, 64x32x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kUpper; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ -+ // Skip individual tile check for the non-square SYR2K versions. We still -+ // compare the problem visitors with one another -+ run_tests(/*skip_tile_check=*/true); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_l, 32x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ -+ // Skip individual tile check for the non-square SYR2K versions. We still -+ // compare the problem visitors with one another -+ run_tests(/*skip_tile_check=*/true); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_u, 32x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kUpper; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ -+ // Skip individual tile check for the non-square SYR2K versions. We still -+ // compare the problem visitors with one another -+ run_tests(/*skip_tile_check=*/true); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nn_sm50.cu -new file mode 100644 -index 0000000..51632bb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nn_sm50.cu -@@ -0,0 +1,1131 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_nn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nt_sm50.cu -new file mode 100644 -index 0000000..512fcbc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nt_sm50.cu -@@ -0,0 +1,1311 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_nt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nt_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nt_sm80.cu -new file mode 100644 -index 0000000..805937a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nt_sm80.cu -@@ -0,0 +1,265 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 32x64x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 64x64x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x128x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 64x128x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x64x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x128x8_64x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x256x8_64x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tn_sm50.cu -new file mode 100644 -index 0000000..7405802 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tn_sm50.cu -@@ -0,0 +1,1131 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_tn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tn_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tn_sm80.cu -new file mode 100644 -index 0000000..cfb3764 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tn_sm80.cu -@@ -0,0 +1,269 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 32x64x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 64x64x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x128x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 64x128x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x64x8_64x32x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x128x8_64x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x256x8_64x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tt_sm50.cu -new file mode 100644 -index 0000000..3c232f1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tt_sm50.cu -@@ -0,0 +1,1130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_tt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_nn_sm50.cu -new file mode 100644 -index 0000000..f65fd01 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_nn_sm50.cu -@@ -0,0 +1,991 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_affin2_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_nt_sm50.cu -new file mode 100644 -index 0000000..5ffbbfa ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_nt_sm50.cu -@@ -0,0 +1,1170 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_affine2_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2RowMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_tn_sm50.cu -new file mode 100644 -index 0000000..9205761 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_tn_sm50.cu -@@ -0,0 +1,991 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_affine2_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2RowMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_tt_sm50.cu -new file mode 100644 -index 0000000..b635978 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_tt_sm50.cu -@@ -0,0 +1,991 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_affine2_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_f8gemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_f8gemm_tn_sm50.cu -new file mode 100644 -index 0000000..a10a604 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_f8gemm_tn_sm50.cu -@@ -0,0 +1,87 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if (__CUDACC_VER_MAJOR__ > 11) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) -+ -+TEST(SM50_Device_Gemm_fe4m3t_fe4m3n_fe4m3t_simt_f32, 32x64x8_32x64x1) { -+ -+ using ElementA = cutlass::float_e4m3_t; -+ using ElementB = cutlass::float_e4m3_t; -+ using ElementC = cutlass::float_e4m3_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ ElementA, -+ cutlass::layout::RowMajor, -+ ElementB, -+ cutlass::layout::ColumnMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementC>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_nn_sm50.cu -new file mode 100644 -index 0000000..b399303 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_nn_sm50.cu -@@ -0,0 +1,2181 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 32x128x8_32x128x1_8x16_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_nn, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 64x64x8_64x64x1_16x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 128x32x8_128x32x1_16x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 32x256x8_32x128x1_8x16_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_nn, 64x128x8_64x64x1_16x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 128x64x8_64x64x1_16x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 256x32x8_128x32x1_16x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 64x256x8_32x128x1_8x16_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 128x128x8_64x64x1_16x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 256x64x8_128x32x1_16x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 128x256x8_64x64x1_16x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_nn, 256x128x8_64x64x1_16x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x256x8_32x64x1_8x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 256x128x8_64x32x1_8x8_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_nt_sm50.cu -new file mode 100644 -index 0000000..d414a7b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_nt_sm50.cu -@@ -0,0 +1,2181 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 32x128x8_32x128x1_8x16_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_nt, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 64x64x8_64x64x1_16x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 128x32x8_128x32x1_16x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 32x256x8_32x128x1_8x16_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_nt, 64x128x8_64x64x1_16x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 128x64x8_64x64x1_16x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 256x32x8_128x32x1_16x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 64x256x8_32x128x1_8x16_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 128x128x8_64x64x1_16x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 256x64x8_128x32x1_16x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 128x256x8_64x64x1_16x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_nt, 256x128x8_64x64x1_16x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x256x8_32x64x1_8x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 256x128x8_64x32x1_8x8_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_tn_sm50.cu -new file mode 100644 -index 0000000..2891c97 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_tn_sm50.cu -@@ -0,0 +1,2181 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 32x128x8_32x128x1_8x16_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_tn, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 64x64x8_64x64x1_16x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 128x32x8_128x32x1_16x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 32x256x8_32x128x1_8x16_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_tn, 64x128x8_64x64x1_16x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 128x64x8_64x64x1_16x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 256x32x8_128x32x1_16x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 64x256x8_32x128x1_8x16_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 128x128x8_64x64x1_16x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 256x64x8_128x32x1_16x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 128x256x8_64x64x1_16x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_tn, 256x128x8_64x64x1_16x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x256x8_32x64x1_8x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 256x128x8_64x32x1_8x8_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_tt_sm50.cu -new file mode 100644 -index 0000000..c9eb576 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_tt_sm50.cu -@@ -0,0 +1,2181 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 32x128x8_32x128x1_8x16_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_tt, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 64x64x8_64x64x1_16x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 128x32x8_128x32x1_16x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 32x256x8_32x128x1_8x16_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_tt, 64x128x8_64x64x1_16x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 128x64x8_64x64x1_16x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 256x32x8_128x32x1_16x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 64x256x8_32x128x1_8x16_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 128x128x8_64x64x1_16x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 256x64x8_128x32x1_16x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 128x256x8_64x64x1_16x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_tt, 256x128x8_64x64x1_16x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x256x8_32x64x1_8x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 256x128x8_64x32x1_8x8_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_nn_sm50.cu -new file mode 100644 -index 0000000..5292e59 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_nn_sm50.cu -@@ -0,0 +1,1701 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_nn, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_nt_sm50.cu -new file mode 100644 -index 0000000..64391a4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_nt_sm50.cu -@@ -0,0 +1,1761 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_nt, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_tn_sm50.cu -new file mode 100644 -index 0000000..9e6c841 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_tn_sm50.cu -@@ -0,0 +1,1671 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_tn, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_tt_sm50.cu -new file mode 100644 -index 0000000..87c7976 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_tt_sm50.cu -@@ -0,0 +1,1731 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_tt, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61.cu -new file mode 100644 -index 0000000..22729f4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61.cu -@@ -0,0 +1,161 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#define N cutlass::layout::ColumnMajor -+#define T cutlass::layout::RowMajor -+ -+#define RUN_GEMM(X, Y) \ -+ using ElementOutput = int8_t; \ -+ using ElementAccumulator = int32_t; \ -+ using ElementCompute = float; \ -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ -+ using Gemm = cutlass::gemm::device::Gemm< \ -+ int8_t, \ -+ X, \ -+ int8_t, \ -+ Y, \ -+ ElementOutput, \ -+ cutlass::layout::RowMajor, \ -+ int32_t, \ -+ cutlass::arch::OpClassSimt, \ -+ cutlass::arch::Sm61, \ -+ ThreadBlockShape, \ -+ WarpShape, \ -+ InstructionShape, \ -+ cutlass::epilogue::thread::LinearCombinationClamp< \ -+ ElementOutput, \ -+ 1, \ -+ ElementAccumulator, \ -+ ElementCompute \ -+ >, \ -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, \ -+ 2 \ -+ >; \ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a, 64x64x16_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ RUN_GEMM(N, T) -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a, 256x128x64_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ RUN_GEMM(N, T) -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a, 256x256x16_128x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 256, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ RUN_GEMM(N, T) -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a, 64x64x16_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ RUN_GEMM(T, N) -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a, 256x128x64_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ RUN_GEMM(T, N) -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a, 256x256x16_128x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 256, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ RUN_GEMM(T, N) -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a, 64x64x16_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ RUN_GEMM(N, N) -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a, 256x128x64_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ RUN_GEMM(N, N) -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a, 256x256x16_128x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 256, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ RUN_GEMM(N, N) -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a, 64x64x16_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ RUN_GEMM(T, T) -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a, 256x128x64_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ RUN_GEMM(T, T) -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a, 256x256x16_128x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 256, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ RUN_GEMM(T, T) -+} -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61_perf.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61_perf.cu -new file mode 100644 -index 0000000..10a61fc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61_perf.cu -@@ -0,0 +1,195 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////// -+// NT -+///////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a_perf, 128x256x32_64x64x8) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int8_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestGemmPerf()); -+} -+ -+///////////////////////////////////// -+// TT -+///////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a_perf, 128x256x32_64x64x8) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestGemmPerf()); -+} -+ -+///////////////////////////////////// -+// NN -+///////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a_perf, 128x256x32_64x64x8) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestGemmPerf()); -+} -+ -+///////////////////////////////////// -+// TN -+///////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a_perf, 128x256x32_64x64x8) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestGemmPerf()); -+} -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61_sliced_k.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61_sliced_k.cu -new file mode 100644 -index 0000000..bd0a1f8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61_sliced_k.cu -@@ -0,0 +1,307 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a_sliced_k, 32x32x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int8_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a_sliced_k, 32x64x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int8_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a_sliced_k, 32x32x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a_sliced_k, 32x64x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a_sliced_k, 32x32x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a_sliced_k, 32x64x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a_sliced_k, 32x32x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a_sliced_k, 32x64x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_nn_sm50.cu -new file mode 100644 -index 0000000..10889bb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_nn_sm50.cu -@@ -0,0 +1,861 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_nt_sm50.cu -new file mode 100644 -index 0000000..f3d0a78 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_nt_sm50.cu -@@ -0,0 +1,861 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_tn_sm50.cu -new file mode 100644 -index 0000000..ed0f74d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_tn_sm50.cu -@@ -0,0 +1,861 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_tt_sm50.cu -new file mode 100644 -index 0000000..c8127c5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_tt_sm50.cu -@@ -0,0 +1,861 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nn_sm50.cu -new file mode 100644 -index 0000000..f48e9e1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nn_sm50.cu -@@ -0,0 +1,1740 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_nn, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_affine2_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nt_sm50.cu -new file mode 100644 -index 0000000..69058bb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nt_sm50.cu -@@ -0,0 +1,1800 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_nt, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_affine2_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2RowMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nt_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nt_sm80.cu -new file mode 100644 -index 0000000..fda68e5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nt_sm80.cu -@@ -0,0 +1,296 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 32x64x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 64x64x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x128x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32an_f32at_f32at_simt_f32, 128x128x8_32x64x1) { -+ -+ using Element = float; -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2RowMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C )); -+ -+} -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 64x128x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x64x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x128x8_64x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x256x8_64x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tn_sm50.cu -new file mode 100644 -index 0000000..b67aa23 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tn_sm50.cu -@@ -0,0 +1,1710 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_tn, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_affine2_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2RowMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tn_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tn_sm80.cu -new file mode 100644 -index 0000000..202c5a1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tn_sm80.cu -@@ -0,0 +1,296 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 32x64x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 64x64x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x128x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32at_f32an_f32t_simt_f32, 128x128x8_32x64x1) { -+ -+ using Element = float; -+ using LayoutA = cutlass::layout::AffineRank2RowMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {1}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm( stride_factor_A, stride_factor_B, stride_factor_C )); -+} -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 64x128x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x64x8_64x32x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x128x8_64x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x256x8_64x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tt_sm50.cu -new file mode 100644 -index 0000000..82b0773 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tt_sm50.cu -@@ -0,0 +1,1770 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_tt, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_affine2_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_nn_sm50.cu -new file mode 100644 -index 0000000..fb268af ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_nn_sm50.cu -@@ -0,0 +1,801 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_nt_sm50.cu -new file mode 100644 -index 0000000..0b1312a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_nt_sm50.cu -@@ -0,0 +1,801 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_tn_sm50.cu -new file mode 100644 -index 0000000..28dbb9b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_tn_sm50.cu -@@ -0,0 +1,801 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_tt_sm50.cu -new file mode 100644 -index 0000000..079e756 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_tt_sm50.cu -@@ -0,0 +1,801 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm50_gemm_f32_f32_f32_simt.cu b/3rdparty/cutlass/test/unit/gemm/device/sm50_gemm_f32_f32_f32_simt.cu -new file mode 100644 -index 0000000..f7a18bc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm50_gemm_f32_f32_f32_simt.cu -@@ -0,0 +1,135 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f32n_f32n_f32n_simt_f32, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f32n_f32t_f32n_simt_f32, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f32t_f32n_f32n_simt_f32, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f32t_f32t_f32n_simt_f32, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm50_gemm_f64_f64_f64_simt.cu b/3rdparty/cutlass/test/unit/gemm/device/sm50_gemm_f64_f64_f64_simt.cu -new file mode 100644 -index 0000000..421072f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm50_gemm_f64_f64_f64_simt.cu -@@ -0,0 +1,134 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f64n_f64n_f64n_simt_f64, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Gemm_f64n_f64t_f64n_simt_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f64t_f64n_f64n_simt_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f64t_f64t_f64n_simt_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm61_gemm_s8_s8_s32_simt.cu b/3rdparty/cutlass/test/unit/gemm/device/sm61_gemm_s8_s8_s32_simt.cu -new file mode 100644 -index 0000000..ba6456b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm61_gemm_s8_s8_s32_simt.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+//#if defined(CUTLASS_ARCH_MMA_SM61_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8n_s32n_simt_s32, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ int8_t, cutlass::layout::ColumnMajor, -+ int8_t, cutlass::layout::ColumnMajor, -+ int32_t, cutlass::layout::ColumnMajor, -+ int32_t>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8t_s32n_simt_s32, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ int8_t, cutlass::layout::ColumnMajor, -+ int8_t, cutlass::layout::RowMajor, -+ int32_t, cutlass::layout::ColumnMajor, -+ int32_t>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8t_s8n_s32n_simt_s32, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ int8_t, cutlass::layout::RowMajor, -+ int8_t, cutlass::layout::ColumnMajor, -+ int32_t, cutlass::layout::ColumnMajor, -+ int32_t>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8t_s8t_s32n_simt_s32, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ int8_t, cutlass::layout::RowMajor, -+ int8_t, cutlass::layout::RowMajor, -+ int32_t, cutlass::layout::ColumnMajor, -+ int32_t>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//#endif // #if defined(CUTLASS_ARCH_MMA_SM61_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f16_f16_f32_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f16_f16_f32_tensor_op_f32.cu -new file mode 100644 -index 0000000..40f7cdb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f16_f16_f32_tensor_op_f32.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+//#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#if 1 -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32_3x, 128x128x32_64x64x32) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::half_t, cutlass::layout::RowMajor, -+ cutlass::half_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+#endif -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#if 1 -+TEST(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32_3x, 128x128x32_64x64x32) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::half_t, cutlass::layout::ColumnMajor, -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32_3x, 128x128x32_64x64x32) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::half_t, cutlass::layout::ColumnMajor, -+ cutlass::half_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32_3x, 128x128x32_64x64x32) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::half_t, cutlass::layout::RowMajor, -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+#endif -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f32_f32_f32_simt.cu b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f32_f32_f32_simt.cu -new file mode 100644 -index 0000000..a7c6b52 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f32_f32_f32_simt.cu -@@ -0,0 +1,135 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32n_f32n_f32n_simt_f32, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32n_simt_f32, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32n_simt_f32, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32t_f32t_f32n_simt_f32, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f64_f64_f64_simt.cu b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f64_f64_f64_simt.cu -new file mode 100644 -index 0000000..274b30c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f64_f64_f64_simt.cu -@@ -0,0 +1,134 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64n_f64n_simt_f64, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Gemm_f64n_f64t_f64n_simt_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64n_simt_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64t_f64n_simt_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f64_f64_f64_tensor_op_f64.cu b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f64_f64_f64_tensor_op_f64.cu -new file mode 100644 -index 0000000..e53a8e8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f64_f64_f64_tensor_op_f64.cu -@@ -0,0 +1,98 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+//#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// #endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_s8_s8_s32_tensor_op.cu b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_s8_s8_s32_tensor_op.cu -new file mode 100644 -index 0000000..d53cf54 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_s8_s8_s32_tensor_op.cu -@@ -0,0 +1,94 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+//#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(DISABLED_SM80_Device_Gemm_s8n_s8n_s32n_tensor_op_s32, 128x128x32_64x64x64) { -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(DISABLED_SM80_Device_Gemm_s8n_s8t_s32n_tensor_op_s32, 128x128x32_64x64x64) { -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x32_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ int8_t, cutlass::layout::RowMajor, -+ int8_t, cutlass::layout::ColumnMajor, -+ int32_t, cutlass::layout::ColumnMajor, -+ int32_t>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(DISABLED_SM80_Device_Gemm_s8t_s8t_s32n_tensor_op_s32, 128x128x32_64x64x64) { -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu -new file mode 100644 -index 0000000..14654c7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu -@@ -0,0 +1,135 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+ -+//#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32n_tensor_op_f32, 128x128x32_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::tfloat32_t, cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32n_tensor_op_f32, 128x128x32_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::tfloat32_t, cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32n_tensor_op_f32, 128x128x32_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::tfloat32_t, cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32n_tensor_op_f32, 128x128x32_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::tfloat32_t, cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu -new file mode 100644 -index 0000000..9fbbd86 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu -@@ -0,0 +1,188 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16t_bf16t_bf16n_align8_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 8, -+ cutlass::bfloat16_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 4, -+ cutlass::bfloat16_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 2, -+ cutlass::bfloat16_t, LayoutB, 2, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16n_bf16n_bf16n_align8_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 8, -+ cutlass::bfloat16_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu -new file mode 100644 -index 0000000..d3983e4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu -@@ -0,0 +1,187 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16t_bf16t_bf16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 8, -+ cutlass::bfloat16_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 8, -+ cutlass::bfloat16_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 8, -+ cutlass::bfloat16_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16n_bf16n_bf16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 8, -+ cutlass::bfloat16_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu -new file mode 100644 -index 0000000..0ee526b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu -@@ -0,0 +1,449 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////// TT ////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelMultistage -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 4, -+ cutlass::half_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 2, -+ cutlass::half_t, LayoutB, 2, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////// TN ////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelMultistage -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 4, -+ cutlass::half_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 2, -+ cutlass::half_t, LayoutB, 2, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////// NT ////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelMultistage -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 4, -+ cutlass::half_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 2, -+ cutlass::half_t, LayoutB, 2, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////// NN ////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelMultistage -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 4, -+ cutlass::half_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 2, -+ cutlass::half_t, LayoutB, 2, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu -new file mode 100644 -index 0000000..4fea99a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu -@@ -0,0 +1,1077 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/epilogue.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32, 128x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32, 64x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32, 128x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32, 64x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f32, 128x128x32) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f32, 64x64x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f32, 128x128x32) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f32, 64x64x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f16, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f16, 128x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f16, 64x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16, 128x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16, 64x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f16, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f16, 128x128x32) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f16, 64x64x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f16, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f16, 128x128x32) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f16, 64x64x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_Epilogue, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_1,_64>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_64,_16>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_Epilogue, 128x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_128,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout,_64>,Stride,_64>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_128,_8>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_Epilogue, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout>,Stride<_64,Stride<_1,_4096>>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_8,_128>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_Epilogue, 128x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_128,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_64,_1>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_16,_64>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_Epilogue, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_1,_64>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_64,_16>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_Epilogue, 128x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_128,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout,_64>,Stride,_64>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_128,_8>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_Epilogue, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout>,Stride<_64,Stride<_1,_4096>>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_8,_128>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_Epilogue, 128x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_128,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_64,_1>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_16,_64>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu -new file mode 100644 -index 0000000..1646632 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu -@@ -0,0 +1,582 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 2x2x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 4x1x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 1x4x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 2x4x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu -new file mode 100644 -index 0000000..378315d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu -@@ -0,0 +1,582 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 2x2x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 4x1x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 1x4x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 2x4x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu -new file mode 100644 -index 0000000..c7d814b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu -@@ -0,0 +1,1018 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/epilogue.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x1x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_1,_1,_1>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x1x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_1,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_1,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_4x1x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_4,_1,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x4x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_1,_4,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x4x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_4,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_4x4x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_4,_4,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1x1x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_1,_1,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2x1x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_1,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_1,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_4x1x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_4,_1,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1x4x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_1,_4,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2x4x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_4,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_4x4x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_4,_4,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_persistent_Epilogue, 64x128x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout,Stride<_1,_64>>; -+ using TileShapeS2R = Shape<_64,_16>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_persistent_Epilogue, 128x64x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_64,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout,_64>,Stride,_64>>; -+ using TileShapeS2R = Shape<_128,_8>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_persistent_Epilogue, 64x128x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout>,Stride<_64,Stride<_1,_4096>>>; -+ using TileShapeS2R = Shape<_8,_128>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_persistent_Epilogue, 128x64x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ using TileShape_MNK = Shape<_128,_64,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout,Stride<_64,_1>>; -+ using TileShapeS2R = Shape<_16,_64>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_Epilogue, 64x128x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout,Stride<_1,_64>>; -+ using TileShapeS2R = Shape<_64,_16>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_Epilogue, 128x64x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_64,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout,_64>,Stride,_64>>; -+ using TileShapeS2R = Shape<_128,_8>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_Epilogue, 64x128x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout>,Stride<_64,Stride<_1,_4096>>>; -+ using TileShapeS2R = Shape<_8,_128>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_Epilogue, 128x64x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ using TileShape_MNK = Shape<_128,_64,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout,Stride<_64,_1>>; -+ using TileShapeS2R = Shape<_16,_64>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu -new file mode 100644 -index 0000000..b4edaf6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu -@@ -0,0 +1,86 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Redistribution and use in source and binary forms, with or without modification, are permitted -+ * provided that the following conditions are met: -+ * * Redistributions of source code must retain the above copyright notice, this list of -+ * conditions and the following disclaimer. -+ * * Redistributions in binary form must reproduce the above copyright notice, this list of -+ * conditions and the following disclaimer in the documentation and/or other materials -+ * provided with the distribution. -+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used -+ * to endorse or promote products derived from this software without specific prior written -+ * permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/collective/default_transposed_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f32t_f32n_f32n_tensor_op_gmma_f32, 64x128x32_1x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ float, LayoutA, 4, -+ float, LayoutB, 4, -+ float, -+ Shape<_64,_128,_128>, Shape<_1,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu -new file mode 100644 -index 0000000..5d30e96 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu -@@ -0,0 +1,152 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 64x128x128) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 8, -+ int8_t, LayoutB, 8, -+ int32_t, -+ Shape<_64,_128,_128>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32, 128x128x128) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_128,_128,_128>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelMultistage -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x64x128) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 4, -+ int8_t, LayoutB, 4, -+ int32_t, -+ Shape<_128,_64,_128>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu -new file mode 100644 -index 0000000..f0762a9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 64x128x128) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_64,_128,_128>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 64x128x128_1x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_64,_128,_128>, Shape<_1,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_128,_128,_128>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_1x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_128,_128,_128>, Shape<_1,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_2x1x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_128,_128,_128>, Shape<_2,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_2x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_128,_128,_128>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu -new file mode 100644 -index 0000000..e95772f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu -@@ -0,0 +1,151 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32, 64x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ tfloat32_t, LayoutA, 4, -+ tfloat32_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelMultistage -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 64x64x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::tfloat32_t, LayoutA, 2, -+ cutlass::tfloat32_t, LayoutB, 2, -+ float, -+ Shape<_64,_64,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32, 128x64x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::tfloat32_t, LayoutA, 1, -+ cutlass::tfloat32_t, LayoutB, 1, -+ float, -+ Shape<_128,_64,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu -new file mode 100644 -index 0000000..ce570a2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu -@@ -0,0 +1,185 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_tensor_op_gmma_f32, 64x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::tfloat32_t, LayoutA, 4, -+ cutlass::tfloat32_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_tf32n_tf32n_f32n_tensor_op_gmma_f32, 64x128x32) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::tfloat32_t, LayoutA, 1, -+ cutlass::tfloat32_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::tfloat32_t, LayoutA, 1, -+ cutlass::tfloat32_t, LayoutB, 1, -+ float, -+ Shape<_64,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::tfloat32_t, LayoutA, 4, -+ cutlass::tfloat32_t, LayoutB, 1, -+ float, -+ Shape<_64,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..d386a7e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_ls_l_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_ls_u_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_ls_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_f32_rs_sm80.cu -new file mode 100644 -index 0000000..07f8564 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_f32_rs_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_rs_l_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_rs_u_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_rs_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..3ad96e4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_ls_l_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_ls_u_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_ls_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_fast_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_fast_f32_rs_sm80.cu -new file mode 100644 -index 0000000..4eb8b7a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_fast_f32_rs_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_rs_l_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_rs_u_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_rs_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..a13f744 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,133 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Symm_cf64n_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Symm_cf64n_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..fa2f574 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_ls_u_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_ls_u_tensor_op_f64_gaussian, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_ls_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_ls_f64_sm80.cu -new file mode 100644 -index 0000000..3dd4edd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_ls_f64_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_ls_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_ls_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_ls_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_rs_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_rs_f64_sm80.cu -new file mode 100644 -index 0000000..af810d4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_rs_f64_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_rs_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..6cdc04d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Symm_{ElementA/B}{LayoutA/B}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_fast_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_fast_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_fast_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_fast_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_fast_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_fast_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_fast_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_fast_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_fast_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_fast_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_fast_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_rs_sm80.cu -new file mode 100644 -index 0000000..1ae9cf9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_rs_sm80.cu -@@ -0,0 +1,276 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Symm_{ElementA/B}{LayoutA/B}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_fast_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_fast_f32_align1_align1, 128x64x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_l_tensor_op_fast_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_fast_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_fast_f32_align1_align4, 128x64x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_l_tensor_op_fast_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..28be9a6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Symm_{ElementA/B}{LayoutA/B}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_fast_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_fast_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_fast_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_fast_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_fast_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_fast_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_fast_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_fast_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_fast_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_fast_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_fast_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..1feb2d6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu -@@ -0,0 +1,135 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Symm_f64n_f64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64n_tensor_op_f64_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64n_tensor_op_f64_ls_sm80.cu -new file mode 100644 -index 0000000..abb5020 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64n_tensor_op_f64_ls_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_ls_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_ls_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_ls_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_ls_u_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_ls_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64n_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64n_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..57c4f79 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64n_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_rs_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_rs_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_rs_u_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_rs_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64t_tensor_op_f64_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64t_tensor_op_f64_ls_sm80.cu -new file mode 100644 -index 0000000..6c82d76 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64t_tensor_op_f64_ls_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_ls_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_ls_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_ls_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_ls_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_ls_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64t_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64t_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..a7a44f6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64t_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_rs_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_rs_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_rs_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_rs_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64n_tensor_op_f64_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64n_tensor_op_f64_ls_sm80.cu -new file mode 100644 -index 0000000..64f4078 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64n_tensor_op_f64_ls_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_ls_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_ls_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_ls_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_ls_u_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_ls_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64n_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64n_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..21cf9fd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64n_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_rs_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_rs_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_rs_u_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_rs_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64t_tensor_op_f64_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64t_tensor_op_f64_ls_sm80.cu -new file mode 100644 -index 0000000..9fdd1a0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64t_tensor_op_f64_ls_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_ls_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64t_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64t_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..fce589d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64t_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_rs_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_rs_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_rs_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_rs_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_tf32n_f32n_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_tf32n_f32n_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..7e6e4b4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_tf32n_f32n_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Symm_{ElementA/B}{LayoutA/B}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_tf32n_f32n_tensor_op_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_tf32n_f32n_tensor_op_f32_rs_sm80.cu -new file mode 100644 -index 0000000..cb9bf2e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_tf32n_f32n_tensor_op_f32_rs_sm80.cu -@@ -0,0 +1,276 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Symm_{ElementA/B}{LayoutA/B}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_f32_align1_align1, 128x64x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_l_tensor_op_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_f32_align1_align4, 128x64x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_l_tensor_op_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_tf32t_f32t_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_tf32t_f32t_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..a80084c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_tf32t_f32t_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Symm_{ElementA/B}{LayoutA/B}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..218fcd6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32n_l_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32n_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32n_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32n_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..b559945 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32n_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32n_l_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32n_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..7090a0a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32t_l_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32t_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..0c6efb1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32t_l_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32t_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..76d19f6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syr2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syr2k_cf64n_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..cea1691 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,308 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_cf64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_cf64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_cf64, 64x32x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_cf64, 32x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_u_tensor_op_cf64, 32x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_u_tensor_op_cf64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_u_tensor_op_cf64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..3f7b03a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..b3e2e27 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,168 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64t_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..30dc4ba ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf64n_cf64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf64n_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64t_cf64n_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64t_cf64n_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..75ade1f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64t_cf64n_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,168 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64t_cf64t_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64t_cf64t_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..71e794d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64t_cf64t_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,168 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64t_cf64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64t_cf64t_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64t_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..e310ac8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f32n_f32n_l_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f32n_f32n_u_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f32t_f32n_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f32t_f32n_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..e24a150 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f32t_f32n_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,133 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f32t_f32n_l_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f32t_f32n_u_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..f7aa84d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu -@@ -0,0 +1,134 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..24f832d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,483 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..9cf9173 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,253 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..e7b165f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,273 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..e3fb6ee ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,253 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64t_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64t_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..b53b710 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,308 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_u_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..f720f88 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,253 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64t_f64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64t_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64t_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..f9292e7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64t_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,308 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_u_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_tf32n_f32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_tf32n_f32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..c6bb3b1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_tf32n_f32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_tf32n_f32n_l_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_tf32n_f32n_u_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_tf32t_f32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_tf32t_f32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..25e62fe ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_tf32t_f32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,133 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_tf32t_f32n_l_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_tf32t_f32n_u_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..dcd963b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32n_l_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32n_u_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32n_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32n_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..007faad ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32n_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32n_l_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32n_u_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..5d90211 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32t_l_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32t_u_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32t_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32t_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..85301bc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32t_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32t_l_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32t_u_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..98da67d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syrk_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syrk_cf64n_cf64t_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..3888116 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..b826f05 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu -@@ -0,0 +1,95 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf64n_cf64t_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..2c455e7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf64n_cf64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf64n_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..8f4e9f9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,541 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 64x128x32_32x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 128x64x32_64x32x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 128x128x16_64x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 64x128x16_32x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_u_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_u_tensor_op_fast_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_u_tensor_op_fast_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_u_tensor_op_fast_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_u_tensor_op_fast_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_f32t_f32t_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_f32t_f32t_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..4dbd5b0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_f32t_f32t_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,541 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 64x128x32_32x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 128x64x32_64x32x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 128x128x16_64x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 64x128x16_32x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_u_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_u_tensor_op_fast_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_u_tensor_op_fast_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_u_tensor_op_fast_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_u_tensor_op_fast_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..8fe7627 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu -@@ -0,0 +1,126 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syrk_f64n_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syrk_f64t_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_f64n_f64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_f64n_f64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..62d29af ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_f64n_f64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,237 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64n_f64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64n_f64t_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64n_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64n_f64t_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64n_f64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_f64t_f64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_f64t_f64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..0ad9dbb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_f64t_f64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,301 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_u_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_u_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_tf32n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_tf32n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..ba96ad5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_tf32n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,541 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 128x128x16_64x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 64x128x16_32x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_u_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_u_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_u_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_u_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_u_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_tf32t_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_tf32t_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..a1466d6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_tf32t_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,541 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 128x128x16_64x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 64x128x16_32x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_u_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_u_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_u_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_u_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_u_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed.h b/3rdparty/cutlass/test/unit/gemm/device/testbed.h -new file mode 100644 -index 0000000..dc21f41 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed.h -@@ -0,0 +1,600 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_utils.h" -+#include "testbed_universal.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct Testbed { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ typename Gemm::LayoutA::Stride stride_factor_A; -+ typename Gemm::LayoutB::Stride stride_factor_B; -+ typename Gemm::LayoutC::Stride stride_factor_C; -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+ // -+ // Methods -+ // -+ -+ Testbed( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ stride_factor_A(typename Gemm::LayoutA::Stride()), -+ stride_factor_B(typename Gemm::LayoutB::Stride()), -+ stride_factor_C(typename Gemm::LayoutC::Stride()), -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ Testbed( -+ typename Gemm::LayoutA::Stride stride_factor_A_, -+ typename Gemm::LayoutB::Stride stride_factor_B_, -+ typename Gemm::LayoutC::Stride stride_factor_C_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ stride_factor_A(stride_factor_A_), -+ stride_factor_B(stride_factor_B_), -+ stride_factor_C(stride_factor_C_), -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ tensor_A.resize(problem_size.mk(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mk(), stride_factor_A)); -+ tensor_B.resize(problem_size.kn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.kn(), stride_factor_B)); -+ tensor_C.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); -+ tensor_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); -+ reference_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); -+ tensor_C.host_view().at(cutlass::make_Coord(0, 0)) = typename Gemm::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ if (tensor_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ -+ if (reference_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::Gemm< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm::Operator> -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ beta, -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ if (Relu) { -+ for (int i = 0; i < problem_size.m(); ++i) { -+ for (int j = 0; j < problem_size.n(); ++j) { -+ reference_D.at(cutlass::MatrixCoord(i, j)) = -+ ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) -+ ? (typename Gemm::ElementC)0 -+ : reference_D.at(cutlass::MatrixCoord(i, j)); -+ } -+ } -+ } -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Determine if the CUDA device is sufficient to run the kernel -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ int split_k_slices = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) -+ { -+/* -+ std::cout << "\n-----------------------\n"; -+ std::cout << "problem size: " << problem_size << "\n"; -+ std::cout << "split_k_slices: " << split_k_slices << "\n"; -+ std::cout << "alpha: " << alpha << "\n"; -+ std::cout << "beta: " << beta << "\n"; -+ std::cout << "-----------------------\n\n"; -+*/ -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D.device_ref(), -+ {alpha, beta}, -+ split_k_slices -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ if (!passed) { -+ std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllGemmBasic( -+ const typename Gemm::LayoutA::Stride& stride_factor_A = typename Gemm::LayoutA::Stride(), -+ const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), -+ const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) { -+ bool passed = true; -+ -+ int const kMinimumOperandElementSize = -+ std::min( -+ int(cutlass::sizeof_bits::value), -+ int(cutlass::sizeof_bits::value)); -+ -+ int const kAlignment = cutlass::platform::is_same< -+ typename Gemm::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ // int8_t gemm alignment constraints -+ int const kAlignmentM = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentN = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentK = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ (cutlass::platform::is_same::value || -+ cutlass::platform::is_same::value) ? 4 : kAlignment; -+ -+ int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; -+ -+ int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; -+ -+ int problem_size_k[] = { -+ kAlignmentK, Gemm::ThreadblockShape::kK * (Gemm::kStages + 1) - kAlignmentK}; -+ -+ int split_k_slices[] = { -+ 1, 2, 3 -+ }; -+ -+ double problem_alpha[] = { -+ 1 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ Testbed testbed(stride_factor_A, stride_factor_B, stride_factor_C); -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int split_k : split_k_slices) { -+ -+ if (!Gemm::kSplitKSerial && split_k > 1) { -+ continue; -+ } -+ -+ if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { -+ continue; -+ } -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ passed = testbed.run( -+ problem_size, -+ split_k, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllGemm( -+ const typename Gemm::LayoutA::Stride& stride_factor_A, -+ const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), -+ const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) -+{ -+ // Test basic GEMM with non-default stride factors -+ return TestAllGemmBasic(stride_factor_A, stride_factor_B, stride_factor_C); -+} -+ -+template -+bool TestAllGemm() -+{ -+#ifdef NDEBUG -+ // Non-debug builds also test basic GEMM with default stride factors -+ if (!TestAllGemmBasic()) { -+ return false; -+ } -+#endif // NDEBUG -+ -+ // Test universal GEMM -+#if 0 -+ // Define the universal kernel -+ using UniversalKernel = cutlass::gemm::kernel::GemmUniversal< -+ typename Gemm::GemmKernel::Mma, // Mma -+ typename Gemm::GemmKernel::Epilogue, // Epilogue -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> // ThreadblockSwizzle -+ >; -+#else -+ // Define the streamk universal kernel -+ using UniversalKernel = cutlass::gemm::kernel::GemmUniversalStreamk< -+ typename Gemm::GemmKernel::Mma, // Mma -+ typename Gemm::GemmKernel::Epilogue, // Epilogue -+ cutlass::gemm::threadblock::ThreadblockSwizzleStreamK // ThreadblockSwizzle -+ >; -+#endif -+ -+ // Define the universal adaptor -+ using UniversalGemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ // Test universal GEMM -+ return TestAllGemmUniversal(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestGemmPerf(int iterations = 1) { -+ bool passed = true; -+ -+ int problem_size_m[] = { 2048 }; -+ -+ int problem_size_n[] = { 4352 }; -+ -+ int problem_size_k[] = { 4096 }; -+ -+ int split_k_slices[] = { 1 }; -+ double problem_alpha[] = { 1 }; -+ double problem_beta[] = { 0.0 }; -+ -+ Testbed testbed; -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int split_k : split_k_slices) { -+ -+ if (!Gemm::kSplitKSerial && split_k > 1) { -+ continue; -+ } -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ -+ for (int i = 0; i < iterations; i++){ -+ passed = testbed.run( -+ problem_size, -+ split_k, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ } -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_complex.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_complex.h -new file mode 100644 -index 0000000..244bc06 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_complex.h -@@ -0,0 +1,294 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedComplex : public Testbed { -+ -+ using Base = Testbed; -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ -+ // -+ // Methods -+ // -+ -+ TestbedComplex( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ Base(init_A_, init_B_, init_C_, seed_) { } -+ -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::GemmComplex( -+ problem_size, -+ alpha, -+ this->tensor_A.host_ref(), -+ Gemm::kTransformA, -+ this->tensor_B.host_ref(), -+ Gemm::kTransformB, -+ beta, -+ this->tensor_C.host_ref(), -+ this->reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ return this->compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ int split_k_slices = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ // -+ // Initialize workspace -+ // -+ -+ this->initialize(problem_size); -+ -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, -+ this->tensor_A.device_ref(), -+ this->tensor_B.device_ref(), -+ this->tensor_C.device_ref(), -+ this->tensor_D.device_ref(), -+ {alpha, beta}, -+ split_k_slices -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ if (!passed) { -+ std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllGemmComplex() { -+ bool passed = true; -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ int const kMinimumOperandElementSize = -+ std::min( -+ int(cutlass::sizeof_bits::value), -+ int(cutlass::sizeof_bits::value)); -+ -+ int const kAlignment = -+ cutlass::platform::is_same< -+ typename Gemm::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ int problem_size_m[] = { -+ kAlignment, 512 - 3*kAlignment -+ }; -+ -+ int problem_size_n[] = { -+ kAlignment, 512 - 2*kAlignment -+ }; -+ -+ int problem_size_k[] = { -+ kAlignment, 128 - kAlignment -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 3 -+ }; -+ -+ double problem_alpha[] = { -+ 1 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ TestbedComplex testbed; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int split_k : split_k_slices) { -+ -+ if (!Gemm::kSplitKSerial && split_k > 1) { -+ continue; -+ } -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ -+ passed = testbed.run( -+ problem_size, -+ split_k, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h -new file mode 100644 -index 0000000..10d5d3f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h -@@ -0,0 +1,657 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct GemmWithBroadcastReferenceOp { -+ -+ using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; -+ -+ using ElementCompute = typename OutputOp::ElementCompute; -+ using ElementZ = typename OutputOp::ElementZ; -+ using ElementT = typename OutputOp::ElementT; -+ -+ typename OutputOp::BinaryOp binary_op; -+ typename OutputOp::ElementwiseOp elementwise_op; -+ -+ GemmWithBroadcastReferenceOp() { } -+ -+ void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) { -+ -+ ElementCompute t_full = binary_op(gemm, bias); -+ T = ElementT(t_full); -+ -+ ElementCompute z_full = elementwise_op(t_full); -+ Z = ElementZ(z_full); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Fused testbed -+// -+// Y = GEMM(AB, C) -+// -+// T[i, j] = BinaryOp(Y[i, j], Broadcast[i]) -+// -+// Z[i, j] = Elementwise(T[i, j]) -+// -+ -+template < -+ typename Gemm, -+ typename ReferenceOp = GemmWithBroadcastReferenceOp -+> -+struct TestbedGemmWithBroadcast { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCOmpute = typename OutputOp::ElementCompute; -+ using ElementZ = typename OutputOp::ElementZ; -+ using ElementT = typename OutputOp::ElementT; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; // Input A -+ cutlass::HostTensor tensor_B; // Input B -+ cutlass::HostTensor tensor_C; // Input C -+ cutlass::HostTensor tensor_Broadcast; // Input Broadcast -+ -+ cutlass::HostTensor tensor_Z; -+ cutlass::HostTensor tensor_T; -+ -+ cutlass::HostTensor tensor_C_ref; -+ cutlass::HostTensor tensor_Y_ref; -+ cutlass::HostTensor tensor_Z_ref; -+ cutlass::HostTensor tensor_T_ref; -+ -+ -+ // -+ // Methods -+ // -+ -+ TestbedGemmWithBroadcast( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ tensor_A.resize(problem_size.mk()); -+ tensor_B.resize(problem_size.kn()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_Z.resize(problem_size.mn()); -+ tensor_T.resize(problem_size.mn()); -+ tensor_Broadcast.resize({ -+ problem_size.m(), -+ 1 -+ }); -+ -+ tensor_C_ref.resize(problem_size.mn()); -+ tensor_Y_ref.resize(problem_size.mn()); -+ tensor_Z_ref.resize(problem_size.mn()); -+ tensor_T_ref.resize(problem_size.mn()); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ EXPECT_TRUE(initialize_tensor(tensor_Broadcast.host_view(), init_C, seed + 2020)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); -+ tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); -+ -+ for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { -+ for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { -+ tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); -+ } -+ } -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_Broadcast.sync_device(); -+ -+ tensor_Z.sync_device(); -+ tensor_T.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementAccumulator alpha, -+ ElementAccumulator beta) { -+ -+ tensor_Z.sync_host(); -+ tensor_T.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T.host_view()), 0); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z_ref.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T_ref.host_view()), 0); -+ -+ bool passed = true; -+ float norm_diff = 0; -+ -+ if (OutputOp::kStoreZ) { -+ norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Z_ref.host_view(), tensor_Z.host_view(), float()); -+ passed = (norm_diff <= 0.1f); -+ EXPECT_LT(norm_diff, 0.1f) << " tensor_Z is incorrect"; -+ } -+ -+ if (OutputOp::kStoreT) { -+ -+ norm_diff = cutlass::reference::host::TensorNormDiff(tensor_T_ref.host_view(), tensor_T.host_view(), float()); -+ passed = (passed && (norm_diff <= 0.1f)); -+ -+ EXPECT_LT(norm_diff, 0.1f) << " tensor_T is incorrect"; -+ } -+ -+ -+ if (!passed) { -+ -+ /* -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ */ -+ -+ std::ofstream file("errors_testbed_gemm_with_broadcast.txt"); -+ -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\nZ =\n" << tensor_Z.host_view() -+ << "\nT =\n" << tensor_T.host_view() -+ << "\n\n" -+ << "\nY_ref =\n" << tensor_Y_ref.host_view() -+ << "\nZ_ref =\n" << tensor_Z_ref.host_view() -+ << "\nT_ref =\n" << tensor_T_ref.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementAccumulator alpha, -+ ElementAccumulator beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::GemmComplex< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ ElementAccumulator, typename Gemm::LayoutC, -+ ElementAccumulator, ElementAccumulator -+ >( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ Gemm::kTransformA, -+ tensor_B.host_ref(), -+ Gemm::kTransformB, -+ beta, -+ tensor_C_ref.host_ref(), -+ tensor_Y_ref.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ using ElementC = typename Gemm::ElementC; -+ -+ ReferenceOp reference_op; -+ -+ // compute tensor Z and tensor T -+ for (int m = 0; m < problem_size.m(); ++m) { -+ for (int n = 0; n < problem_size.n(); ++n) { -+ -+ ElementZ z; -+ ElementT t; -+ -+ reference_op(z, t, tensor_Y_ref.at({m, n}), tensor_Broadcast.at({m, 0})); -+ -+ tensor_Z_ref.at({m, n}) = z; -+ tensor_T_ref.at({m, n}) = t; -+ } -+ } -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementAccumulator alpha = ElementAccumulator(1), -+ ElementAccumulator beta = ElementAccumulator(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_Z.device_data(), -+ tensor_Broadcast.device_data(), -+ tensor_T.device_data(), -+ problem_size.m() * problem_size.k(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_Z.layout().stride(0), -+ 0, // This must be zero -+ tensor_T.layout().stride(0), -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = true; -+ -+ passed = this->verify(problem_size, alpha, beta); -+ -+ if (!passed) { -+ std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; -+ } -+ -+ // -+ // Profile -+ // -+ -+ #if 0 // profiling disabled for now. -+ -+ int const kWorkspaces = 100; -+ -+ cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_Broadcast(tensor_Broadcast.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_Z(tensor_Z.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_T(tensor_T.capacity() * kWorkspaces); -+ -+ cudaEvent_t events[2]; -+ for (auto & event : events) { -+ cudaError_t result = cudaEventCreate(&event); -+ if (result != cudaSuccess) { -+ EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); -+ return false; -+ break; -+ } -+ } -+ -+ int const kWarmupIterations = 5; -+ int const kProfilingIterations = 100; -+ -+ for (int i = 0; i < kWarmupIterations; ++i) { -+ status = gemm_op(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ } -+ -+ -+ cudaError_t result = cudaEventRecord(events[0]); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ for (int i = 0; i < kProfilingIterations; ++i) { -+ -+ typename Gemm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), -+ profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), -+ profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), -+ profiling_tensor_Z.get() + tensor_Z.capacity() * (i % kWorkspaces), -+ profiling_tensor_Broadcast.get() + tensor_Broadcast.capacity() * (i % kWorkspaces), -+ profiling_tensor_T.get() + tensor_T.capacity() * (i % kWorkspaces), -+ problem_size.m() * problem_size.k(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_Z.layout().stride(0), -+ 0, // This must be zero -+ tensor_T.layout().stride(0), -+ }; -+ -+ gemm_op.initialize(arguments, workspace.get()); -+ status = gemm_op(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ } -+ -+ result = cudaEventRecord(events[1]); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ float elapsed_time = 0; -+ result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ double average_time = double(elapsed_time) / double(kProfilingIterations); -+ -+ std::cout << problem_size << ": " << average_time << " ms" << std::endl; -+ -+ for (auto & event : events) { -+ cudaEventDestroy(event); -+ } -+ #endif -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Gemm, -+ typename ReferenceOp = GemmWithBroadcastReferenceOp -+> -+bool TestGemmWithBroadcast( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count, -+ double alpha = 1.0, -+ double beta = 2.0) { -+ -+ bool passed = true; -+ -+ TestbedGemmWithBroadcast testbed; -+ -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Gemm, -+ typename ReferenceOp = GemmWithBroadcastReferenceOp -+> -+bool TestAllGemmWithBroadcast() { -+ -+ int M_problems[] = {8, 136, 264, 520}; -+ int N_problems[] = {8, 136, 264, 520}; -+ int K_problems[] = {8, 136, 264, 520}; -+ double alpha_problems[] = {1.25, 2.25}; -+ double beta_problems[] = {0, 1, 2.0}; -+ -+ bool passed = true; -+ -+ for (int M : M_problems) { -+ for (int N : N_problems) { -+ for (int K : K_problems) { -+ for (double alpha : alpha_problems) { -+ for (double beta : beta_problems) { -+ -+ TestbedGemmWithBroadcast testbed; -+ -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ passed = testbed.run( -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ {M, N, K}, -+ 1, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ EXPECT_TRUE(passed) -+ << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta; -+ -+ if (!passed) { -+ -+ return passed; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h -new file mode 100644 -index 0000000..6f220b1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h -@@ -0,0 +1,589 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct GemmWithReductionReference { -+ -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::ElementCompute; -+ using ElementC = typename Gemm::ElementC; -+ using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; -+ // -+ // Data members -+ // -+ -+ BinaryOp binary_op; -+ -+ // -+ // Methods -+ // -+ -+ GemmWithReductionReference() { } -+ -+ ElementCompute operator()( -+ ElementAccumulator d_y, -+ ElementT t) { -+ -+ return binary_op(ElementCompute(d_y), ElementCompute(t)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Gemm, -+ typename ReferenceOp -+> -+struct TestbedGemmWithReduction { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor tensor_Reduction; -+ cutlass::HostTensor tensor_Tensor; -+ cutlass::HostTensor tensor_C_ref; -+ cutlass::HostTensor reference_d_Y; -+ cutlass::HostTensor reference_D; -+ cutlass::HostTensor reference_Reduction; -+ -+ // -+ // Methods -+ // -+ -+ TestbedGemmWithReduction( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ for (int m = 0; m < view.extent().row(); ++m) { -+ for (int n = 0; n < view.extent().column(); ++n) { -+ //view.at({m, n}) = Element(float(((idx ++) % 17) - 8)); -+ view.at({m, n}) = (n == 0 ? Element(m) : Element()); -+ -+ } -+ } -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ tensor_A.resize(problem_size.mk()); -+ tensor_B.resize(problem_size.kn()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ -+ tensor_Reduction.resize({ -+ problem_size.m(), -+ (problem_size.n() - 1 + Gemm::ThreadblockShape::kN) / Gemm::ThreadblockShape::kN -+ }); -+ -+ tensor_Tensor.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ reference_d_Y.resize(problem_size.mn(), false); -+ tensor_C_ref.resize(problem_size.mn(), false); -+ reference_Reduction.resize({problem_size.m(), 1}, false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ EXPECT_TRUE(initialize_tensor(tensor_Tensor.host_view(), init_C, seed + 2020)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); -+ tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); -+ -+ for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { -+ for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { -+ tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); -+ } -+ } -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ tensor_Reduction.sync_device(); -+ tensor_Tensor.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementAccumulator alpha, -+ ElementAccumulator beta) { -+ -+ tensor_Reduction.sync_host(); -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Reduction.host_view()), 0); -+ -+ bool passed = true; -+ for (int m = 0; m < tensor_Reduction.extent().row(); ++m) { -+ -+ ElementAccumulator reduced_value = ElementAccumulator(); -+ for (int j = 0; j < tensor_Reduction.extent().column(); ++j) { -+ reduced_value += tensor_Reduction.at({m, j}); -+ } -+ -+ if (reduced_value != reference_Reduction.at({m, 0})) { -+ std::cout << "Error in bias[" << m << "] - Expected: " << reference_Reduction.at({m, 0}) << ", got: " << reduced_value << std::endl; -+ passed = false; -+ break; -+ } -+ } -+ EXPECT_TRUE(passed) << "Reduction is incorect."; -+ -+ if (!cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view())) { -+ EXPECT_TRUE(false) << " mismatched reference"; -+ passed = false; -+ } -+ -+ if (!passed) { -+ -+ /* -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ */ -+ -+ std::ofstream file("testbed_universal_errors_sm70.txt"); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\nT = \n" << tensor_Tensor.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view() -+ << "\n\nReduction =\n" << tensor_Reduction.host_view() << "\n" -+ << "\nReference reduction =\n" << reference_Reduction.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementAccumulator alpha, -+ ElementAccumulator beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::GemmComplex< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ ElementAccumulator, typename Gemm::LayoutC, -+ ElementAccumulator, ElementAccumulator -+ >( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ Gemm::kTransformA, -+ tensor_B.host_ref(), -+ Gemm::kTransformB, -+ beta, -+ tensor_C_ref.host_ref(), -+ reference_d_Y.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ using ElementC = typename Gemm::ElementC; -+ -+ ReferenceOp reference_op; -+ -+ // compute backwards -+ for (int m = 0; m < problem_size.m(); ++m) { -+ ElementAccumulator reduced_value = ElementAccumulator(); -+ for (int n = 0; n < problem_size.n(); ++n) { -+ ElementAccumulator d_full = reference_op(reference_d_Y.at({m, n}), tensor_Tensor.at({m, n})); -+ reduced_value += d_full; -+ reference_D.at({m, n}) = ElementC(d_full); -+ } -+ reference_Reduction.at({m, 0}) = reduced_value; -+ } -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementAccumulator alpha = ElementAccumulator(1), -+ ElementAccumulator beta = ElementAccumulator(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_D.device_data(), -+ tensor_Reduction.device_data(), -+ tensor_Tensor.device_data(), -+ problem_size.m() * problem_size.k(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0), -+ tensor_Reduction.layout().stride(0), -+ tensor_Tensor.layout().stride(0), -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ if (!passed) { -+ std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; -+ } -+ -+ // -+ // Profile -+ // -+ -+ #if 0 // profiling disabled for now. -+ -+ int const kWorkspaces = 100; -+ -+ cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_D(tensor_D.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_Reduction(tensor_Reduction.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_Tensor(tensor_Tensor.capacity() * kWorkspaces); -+ -+ cudaEvent_t events[2]; -+ for (auto & event : events) { -+ cudaError_t result = cudaEventCreate(&event); -+ if (result != cudaSuccess) { -+ EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); -+ return false; -+ break; -+ } -+ } -+ -+ int const kWarmupIterations = 5; -+ int const kProfilingIterations = 100; -+ -+ for (int i = 0; i < kWarmupIterations; ++i) { -+ status = gemm_op(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ } -+ -+ -+ cudaError_t result = cudaEventRecord(events[0]); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ for (int i = 0; i < kProfilingIterations; ++i) { -+ -+ typename Gemm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), -+ profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), -+ profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), -+ profiling_tensor_D.get() + tensor_D.capacity() * (i % kWorkspaces), -+ profiling_tensor_Reduction.get() + tensor_Reduction.capacity() * (i % kWorkspaces), -+ profiling_tensor_Tensor.get() + tensor_Tensor.capacity() * (i % kWorkspaces), -+ problem_size.m() * problem_size.k(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0), -+ tensor_Reduction.layout().stride(0), -+ tensor_Tensor.layout().stride(0), -+ }; -+ -+ gemm_op.initialize(arguments, workspace.get()); -+ status = gemm_op(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ } -+ -+ result = cudaEventRecord(events[1]); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ float elapsed_time = 0; -+ result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ double average_time = double(elapsed_time) / double(kProfilingIterations); -+ -+ std::cout << problem_size << ": " << average_time << " ms" << std::endl; -+ -+ for (auto & event : events) { -+ cudaEventDestroy(event); -+ } -+ #endif -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestGemmWithReduction( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count = 1, -+ double alpha = 1.0, -+ double beta = 2.0) { -+ -+ bool passed = true; -+ -+ TestbedGemmWithReduction testbed; -+ -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped.h -new file mode 100644 -index 0000000..c5ee3ce ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped.h -@@ -0,0 +1,501 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedGrouped { -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using LayoutA = typename Gemm::LayoutA; -+ using LayoutB = typename Gemm::LayoutB; -+ using LayoutC = typename Gemm::LayoutC; -+ -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+ // -+ // Data members -+ // -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint32_t seed; -+ -+ int problem_count; -+ -+ std::vector problem_sizes_host; -+ cutlass::DeviceAllocation problem_sizes_device; -+ -+ std::vector offset_A; -+ std::vector offset_B; -+ std::vector offset_C; -+ std::vector offset_D; -+ -+ std::vector lda_host; -+ std::vector ldb_host; -+ std::vector ldc_host; -+ std::vector ldd_host; -+ -+ cutlass::DeviceAllocation lda; -+ cutlass::DeviceAllocation ldb; -+ cutlass::DeviceAllocation ldc; -+ cutlass::DeviceAllocation ldd; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ -+ cutlass::DeviceAllocation ptr_A; -+ cutlass::DeviceAllocation ptr_B; -+ cutlass::DeviceAllocation ptr_C; -+ cutlass::DeviceAllocation ptr_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedGrouped( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // no fill - remain zero -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize() { -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ // construct a few problems of random sizes -+ srand(seed); -+ -+ int64_t total_elements_A = 0; -+ int64_t total_elements_B = 0; -+ int64_t total_elements_C = 0; -+ int64_t total_elements_D = 0; -+ -+ -+ lda_host.resize(problem_count); -+ ldb_host.resize(problem_count); -+ ldc_host.resize(problem_count); -+ ldd_host.resize(problem_count); -+ -+ problem_sizes_host.clear(); -+ problem_sizes_host.resize(problem_count); -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ -+ cutlass::gemm::GemmCoord problem( -+ 8 * (rand() % 64) + 24, -+ 8 * (rand() % 64) + 24, -+ 8 * (rand() % 64) + 24); -+ -+ if (!i) { -+ problem = cutlass::gemm::GemmCoord(48, 16, 8); -+ } -+ -+ problem_sizes_host.at(i) = problem; -+ -+ // std::cout << "Problem[" << i << "]: " << problem << std::endl; -+ -+ lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); -+ ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); -+ ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ -+ offset_A.push_back(total_elements_A); -+ offset_B.push_back(total_elements_B); -+ offset_C.push_back(total_elements_C); -+ offset_D.push_back(total_elements_D); -+ -+ int64_t elements_A = problem.m() * problem.k(); -+ int64_t elements_B = problem.k() * problem.n(); -+ int64_t elements_C = problem.m() * problem.n(); -+ int64_t elements_D = problem.m() * problem.n(); -+ -+ total_elements_A += elements_A; -+ total_elements_B += elements_B; -+ total_elements_C += elements_C; -+ total_elements_D += elements_D; -+ -+ // Random strides between problems? -+ } -+ -+ problem_sizes_device.reset(problem_count); -+ problem_sizes_device.copy_from_host(problem_sizes_host.data()); -+ -+ lda.reset(problem_count); -+ ldb.reset(problem_count); -+ ldc.reset(problem_count); -+ ldd.reset(problem_count); -+ -+ lda.copy_from_host(lda_host.data()); -+ ldb.copy_from_host(ldb_host.data()); -+ ldc.copy_from_host(ldc_host.data()); -+ ldd.copy_from_host(ldd_host.data()); -+ -+ // -+ // Assign pointers -+ // -+ -+ block_A.reset(total_elements_A); -+ block_B.reset(total_elements_B); -+ block_C.reset(total_elements_C); -+ block_D.reset(total_elements_D); -+ -+ std::vector ptr_A_host(problem_count); -+ std::vector ptr_B_host(problem_count); -+ std::vector ptr_C_host(problem_count); -+ std::vector ptr_D_host(problem_count); -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ ptr_A_host.at(i) = block_A.get() + offset_A.at(i); -+ ptr_B_host.at(i) = block_B.get() + offset_B.at(i); -+ ptr_C_host.at(i) = block_C.get() + offset_C.at(i); -+ ptr_D_host.at(i) = block_D.get() + offset_D.at(i); -+ } -+ -+ ptr_A.reset(problem_count); -+ ptr_A.copy_from_host(ptr_A_host.data()); -+ -+ ptr_B.reset(problem_count); -+ ptr_B.copy_from_host(ptr_B_host.data()); -+ -+ ptr_C.reset(problem_count); -+ ptr_C.copy_from_host(ptr_C_host.data()); -+ -+ ptr_D.reset(problem_count); -+ ptr_D.copy_from_host(ptr_D_host.data()); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); -+ -+ LayoutA layout_A(lda_host.at(i)); -+ LayoutB layout_B(ldb_host.at(i)); -+ LayoutC layout_C(ldc_host.at(i)); -+ LayoutC layout_D(ldd_host.at(i)); -+ -+ MatrixCoord extent_A{problem.m(), problem.k()}; -+ MatrixCoord extent_B{problem.k(), problem.n()}; -+ MatrixCoord extent_C{problem.m(), problem.n()}; -+ -+ std::vector matrix_A(layout_A.capacity(extent_A)); -+ std::vector matrix_B(layout_B.capacity(extent_B)); -+ std::vector matrix_C(layout_C.capacity(extent_C)); -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ -+ initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); -+ initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); -+ initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); -+ -+ cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); -+ cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); -+ cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); -+ cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); -+ } -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ bool passed = true; -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); -+ -+ LayoutA layout_A(lda_host.at(i)); -+ LayoutB layout_B(ldb_host.at(i)); -+ LayoutC layout_C(ldc_host.at(i)); -+ LayoutC layout_D(ldd_host.at(i)); -+ -+ MatrixCoord extent_A{problem.m(), problem.k()}; -+ MatrixCoord extent_B{problem.k(), problem.n()}; -+ MatrixCoord extent_C{problem.m(), problem.n()}; -+ -+ std::vector matrix_A(layout_A.capacity(extent_A)); -+ std::vector matrix_B(layout_B.capacity(extent_B)); -+ std::vector matrix_C(layout_C.capacity(extent_C)); -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ std::vector matrix_Ref(layout_D.capacity(extent_C)); -+ -+ cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); -+ cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); -+ cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); -+ -+ cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); -+ cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); -+ cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); -+ cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); -+ cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); -+ -+ // Reference GEMM -+ cutlass::reference::host::GemmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem, -+ alpha, -+ view_A, -+ Gemm::kTransformA, -+ view_B, -+ Gemm::kTransformB, -+ beta, -+ view_C, -+ view_Ref, -+ ElementAccumulator(0) -+ ); -+ -+ // Ensure that no input or output is entirely zero -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); -+ -+ // Compare against reference -+ passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); -+ -+ if (!passed) { -+ std::ofstream file("testbed_grouped_errors.txt"); -+ -+ file -+ << "problem: " << problem << " [group: " << i << "]\n" -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << view_A -+ << "\nB =\n" << view_B -+ << "\nC =\n" << view_C -+ << "\n\nReference =\n" << view_Ref -+ << "\nComputed =\n" << view_D; -+ -+ return passed; -+ } -+ } -+ -+ return passed; -+ } -+ -+ /// Executes one test -+ bool run( -+ int problem_count, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ this->problem_count = problem_count; -+ -+ // Initialize the problem -+ initialize(); -+ -+ int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), problem_count); -+ -+ // Early exit -+ if (!threadblock_count) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; -+ } -+ return true; -+ } -+ -+ // Configure the GEMM arguments -+ typename EpilogueOutputOp::Params epilogue_op(alpha, beta); -+ -+ // Configure GEMM arguments -+ typename Gemm::Arguments args( -+ problem_sizes_device.get(), -+ problem_count, -+ threadblock_count, -+ epilogue_op, -+ ptr_A.get(), -+ ptr_B.get(), -+ ptr_C.get(), -+ ptr_D.get(), -+ lda.get(), -+ ldb.get(), -+ ldc.get(), -+ ldd.get(), -+ problem_sizes_host.data() -+ ); -+ -+ // Initialize the GEMM object -+ Gemm gemm; -+ -+ size_t workspace_size = gemm.get_workspace_size(args); -+ cutlass::DeviceAllocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm.initialize(args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // Run the GEMM object -+ status = gemm.run(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // Wait for completion -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) -+ << "Kernel execution error: " << cudaGetErrorString(result); -+ -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ // Verify correctness -+ return verify(alpha, beta); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // device -+} // gemm -+} // test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h -new file mode 100644 -index 0000000..7b212ae ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h -@@ -0,0 +1,502 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedGrouped { -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementA = typename Rank2K::ElementA; -+ using ElementB = typename Rank2K::ElementB; -+ using ElementC = typename Rank2K::ElementC; -+ using ElementAccumulator = typename Rank2K::ElementAccumulator; -+ -+ using EpilogueOutputOp = typename Rank2K::EpilogueOutputOp; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using LayoutA = typename Rank2K::LayoutA; -+ using LayoutB = typename Rank2K::LayoutB; -+ using LayoutC = typename Rank2K::LayoutC; -+ -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+ // -+ // Data members -+ // -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint32_t seed; -+ -+ int problem_count; -+ -+ std::vector problem_sizes_host; -+ cutlass::DeviceAllocation problem_sizes_device; -+ -+ std::vector offset_A; -+ std::vector offset_B; -+ std::vector offset_C; -+ std::vector offset_D; -+ -+ std::vector lda_host; -+ std::vector ldb_host; -+ std::vector ldc_host; -+ std::vector ldd_host; -+ -+ cutlass::DeviceAllocation lda; -+ cutlass::DeviceAllocation ldb; -+ cutlass::DeviceAllocation ldc; -+ cutlass::DeviceAllocation ldd; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ -+ cutlass::DeviceAllocation ptr_A; -+ cutlass::DeviceAllocation ptr_B; -+ cutlass::DeviceAllocation ptr_C; -+ cutlass::DeviceAllocation ptr_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedGrouped( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // no fill - remain zero -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize() { -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ // construct a few problems of random sizes -+ srand(seed); -+ -+ int64_t total_elements_A = 0; -+ int64_t total_elements_B = 0; -+ int64_t total_elements_C = 0; -+ int64_t total_elements_D = 0; -+ -+ -+ lda_host.resize(problem_count); -+ ldb_host.resize(problem_count); -+ ldc_host.resize(problem_count); -+ ldd_host.resize(problem_count); -+ -+ problem_sizes_host.clear(); -+ problem_sizes_host.resize(problem_count); -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ -+ auto N = 8 * (rand() % 64) + 24; -+ auto K = 8 * (rand() % 64) + 24; -+ cutlass::gemm::GemmCoord problem(N, N, K); -+ -+ if (!i) { -+ problem = cutlass::gemm::GemmCoord(16, 16, 8); -+ } -+ -+ problem_sizes_host.at(i) = problem; -+ -+ lda_host.at(i) = LayoutA::packed({problem.n(), problem.k()}).stride(0); -+ ldb_host.at(i) = LayoutB::packed({problem.n(), problem.k()}).stride(0); -+ ldc_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); -+ ldd_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); -+ -+ offset_A.push_back(total_elements_A); -+ offset_B.push_back(total_elements_B); -+ offset_C.push_back(total_elements_C); -+ offset_D.push_back(total_elements_D); -+ -+ int64_t elements_A = problem.n() * problem.k(); -+ int64_t elements_B = problem.n() * problem.k(); -+ int64_t elements_C = problem.n() * problem.n(); -+ int64_t elements_D = problem.n() * problem.n(); -+ -+ total_elements_A += elements_A; -+ total_elements_B += elements_B; -+ total_elements_C += elements_C; -+ total_elements_D += elements_D; -+ -+ // Random strides between problems? -+ } -+ -+ problem_sizes_device.reset(problem_count); -+ problem_sizes_device.copy_from_host(problem_sizes_host.data()); -+ -+ lda.reset(problem_count); -+ ldb.reset(problem_count); -+ ldc.reset(problem_count); -+ ldd.reset(problem_count); -+ -+ lda.copy_from_host(lda_host.data()); -+ ldb.copy_from_host(ldb_host.data()); -+ ldc.copy_from_host(ldc_host.data()); -+ ldd.copy_from_host(ldd_host.data()); -+ -+ // -+ // Assign pointers -+ // -+ -+ block_A.reset(total_elements_A); -+ block_B.reset(total_elements_B); -+ block_C.reset(total_elements_C); -+ block_D.reset(total_elements_D); -+ -+ std::vector ptr_A_host(problem_count); -+ std::vector ptr_B_host(problem_count); -+ std::vector ptr_C_host(problem_count); -+ std::vector ptr_D_host(problem_count); -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ ptr_A_host.at(i) = block_A.get() + offset_A.at(i); -+ ptr_B_host.at(i) = block_B.get() + offset_B.at(i); -+ ptr_C_host.at(i) = block_C.get() + offset_C.at(i); -+ ptr_D_host.at(i) = block_D.get() + offset_D.at(i); -+ } -+ -+ ptr_A.reset(problem_count); -+ ptr_A.copy_from_host(ptr_A_host.data()); -+ -+ ptr_B.reset(problem_count); -+ ptr_B.copy_from_host(ptr_B_host.data()); -+ -+ ptr_C.reset(problem_count); -+ ptr_C.copy_from_host(ptr_C_host.data()); -+ -+ ptr_D.reset(problem_count); -+ ptr_D.copy_from_host(ptr_D_host.data()); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); -+ -+ LayoutA layout_A(lda_host.at(i)); -+ LayoutB layout_B(ldb_host.at(i)); -+ LayoutC layout_C(ldc_host.at(i)); -+ LayoutC layout_D(ldd_host.at(i)); -+ -+ MatrixCoord extent_A{problem.n(), problem.k()}; -+ MatrixCoord extent_B{problem.n(), problem.k()}; -+ MatrixCoord extent_C{problem.n(), problem.n()}; -+ -+ std::vector matrix_A(layout_A.capacity(extent_A)); -+ std::vector matrix_B(layout_B.capacity(extent_B)); -+ std::vector matrix_C(layout_C.capacity(extent_C)); -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ -+ initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); -+ initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); -+ initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); -+ -+ cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); -+ cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); -+ cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); -+ cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); -+ } -+ } -+ -+ /// Verifies the result is a Rank2K -+ bool verify( -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ bool passed = true; -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); -+ -+ LayoutA layout_A(lda_host.at(i)); -+ LayoutB layout_B(ldb_host.at(i)); -+ LayoutC layout_C(ldc_host.at(i)); -+ LayoutC layout_D(ldd_host.at(i)); -+ -+ MatrixCoord extent_A{problem.n(), problem.k()}; -+ MatrixCoord extent_B{problem.n(), problem.k()}; -+ MatrixCoord extent_C{problem.n(), problem.n()}; -+ -+ std::vector matrix_A(layout_A.capacity(extent_A)); -+ std::vector matrix_B(layout_B.capacity(extent_B)); -+ std::vector matrix_C(layout_C.capacity(extent_C)); -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ std::vector matrix_Ref(layout_D.capacity(extent_C)); -+ -+ cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); -+ cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); -+ cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); -+ -+ cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); -+ cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); -+ cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); -+ cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); -+ cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); -+ -+ // Reference Rank2K -+ cutlass::reference::host::Rank2KComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem, -+ alpha, -+ view_A, -+ Rank2K::kTransformA, -+ view_B, -+ Rank2K::kTransformB, -+ beta, -+ view_C, -+ view_Ref, -+ ElementAccumulator(0), -+ Rank2K::kFillModeC, -+ Rank2K::kBlasMode -+ ); -+ -+ // Ensure that no input or output is entirely zero -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); -+ -+ // Compare against reference -+ passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); -+ -+ if (!passed) { -+ std::ofstream file("testbed_grouped_errors.txt"); -+ -+ file -+ << "problem: " << problem << " [group: " << i << "]\n" -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << view_A -+ << "\nB =\n" << view_B -+ << "\nC =\n" << view_C -+ << "\n\nReference =\n" << view_Ref -+ << "\nComputed =\n" << view_D; -+ -+ return passed; -+ } -+ } -+ -+ return passed; -+ } -+ -+ /// Executes one test -+ bool run( -+ int problem_count, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ this->problem_count = problem_count; -+ -+ // Initialize the problem -+ initialize(); -+ -+ int threadblock_count = Rank2K::sufficient(problem_sizes_host.data(), problem_count); -+ -+ // Early exit -+ if (!threadblock_count) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; -+ } -+ return true; -+ } -+ -+ // Configure the Rank2K arguments -+ typename EpilogueOutputOp::Params epilogue_op(alpha, beta); -+ -+ // Configure Rank2K arguments -+ typename Rank2K::Arguments args( -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_sizes_device.get(), -+ problem_count, -+ threadblock_count, -+ epilogue_op, -+ ptr_A.get(), -+ ptr_B.get(), -+ ptr_C.get(), -+ ptr_D.get(), -+ lda.get(), -+ ldb.get(), -+ ldc.get(), -+ ldd.get(), -+ problem_sizes_host.data() -+ ); -+ -+ // Initialize the Rank2K object -+ Rank2K rank2k; -+ -+ size_t workspace_size = rank2k.get_workspace_size(args); -+ cutlass::DeviceAllocation workspace(workspace_size); -+ -+ cutlass::Status status = rank2k.initialize(args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // Run the Rank2K object -+ status = rank2k.run(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // Wait for completion -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) -+ << "Kernel execution error: " << cudaGetErrorString(result); -+ -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ // Verify correctness -+ return verify(alpha, beta); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // device -+} // gemm -+} // test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h -new file mode 100644 -index 0000000..af588d3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h -@@ -0,0 +1,461 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K problem visitors -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/device_kernel.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Use simple problem visitor as a baseline -+template -+struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { -+ using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; -+ using Params = typename Base::Params; -+ static int const kThreadCount = ThreadCount; -+ static cutlass::FillMode const kFillModeC = FillModeC; -+ -+ struct SharedStorage {}; -+ -+ int32_t tile_count_sum; -+ SharedStorage &shared_storage; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ BaselineProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base(params_, block_idx), -+ shared_storage(shared_storage_) -+ { -+ cutlass::gemm::GemmCoord problem = this->problem_size(); -+ cutlass::gemm::GemmCoord grid = this->grid_shape(problem); -+ tile_count_sum = this->tile_count(grid); -+ } -+ -+ CUTLASS_DEVICE -+ bool next_tile() { -+ if (this->tile_idx < tile_count_sum) { -+ return true; -+ } -+ -+ do { -+ ++this->problem_idx; -+ -+ if (this->problem_idx >= this->params.problem_count) { -+ return false; -+ } -+ -+ cutlass::gemm::GemmCoord problem = this->problem_size(); -+ cutlass::gemm::GemmCoord grid = this->grid_shape(problem); -+ -+ this->problem_tile_start = tile_count_sum; -+ tile_count_sum += this->tile_count(grid); -+ -+ } while (tile_count_sum <= this->tile_idx); -+ -+ return true; -+ } -+ -+ static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count) { -+ return 0; -+ } -+ -+ static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count, -+ void* host_workspace_ptr) {} -+ -+ CUTLASS_DEVICE -+ cutlass::gemm::GemmCoord threadblock_offset(int32_t threadblock_id) const { -+ int32_t macro_id = threadblock_id / ProblemSizeHelper::OffsetHelper::kThreadblockSkewRatio; -+ int32_t macro_row = ceil(cutlass::fast_sqrt((2*macro_id) + 2.25) - 0.5) - 1; -+ int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); -+ -+ if (FillModeC == cutlass::FillMode::kUpper) { -+ cutlass::swap(macro_row, macro_col); -+ } -+ -+ int32_t row = ProblemSizeHelper::OffsetHelper::macro_row_to_row(macro_row, threadblock_id); -+ int32_t col = ProblemSizeHelper::OffsetHelper::macro_col_to_col(macro_col, threadblock_id); -+ -+ return cutlass::gemm::GemmCoord(row, col, 0); -+ } -+}; -+ -+template -+struct ProblemVisitorKernel { -+ struct SharedStorage { -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+ struct Params { -+ typename ProblemVisitor::Params problem_visitor_params; -+ int32_t* visited_problems_ptr; -+ int32_t* visited_tiles_ptr; -+ int32_t visits_per_block; -+ -+ Params(): -+ visited_problems_ptr(nullptr), -+ visited_tiles_ptr(nullptr), -+ visits_per_block(0) {} -+ -+ Params(typename ProblemVisitor::Params problem_visitor_params_, -+ int32_t* visited_problems_ptr_, -+ int32_t* visited_tiles_ptr_, -+ int32_t visits_per_block_): -+ problem_visitor_params(problem_visitor_params_), -+ visited_problems_ptr(visited_problems_ptr_), -+ visited_tiles_ptr(visited_tiles_ptr_), -+ visits_per_block(visits_per_block_) {} -+ }; -+ -+ CUTLASS_DEVICE -+ void operator()(const Params& params, SharedStorage &shared_storage) { -+ int32_t store_offset = params.visits_per_block * blockIdx.x; -+ ProblemVisitor problem_visitor(params.problem_visitor_params, -+ shared_storage.problem_visitor, -+ blockIdx.x); -+ -+ while (problem_visitor.next_tile()) { -+ cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size(); -+ int32_t problem_idx = problem_visitor.problem_index(); -+ int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); -+ -+ cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); -+ cutlass::gemm::GemmCoord tile_offset = problem_visitor.threadblock_offset(threadblock_idx); -+ -+ problem_visitor.advance(gridDim.x); -+ -+ // -+ // Early exit conditions -+ // 1) Out of range -+ // 2) Upper-triangular block in lower-triangular problem -+ // 3) Lower-triangular block in upper-triangular problem -+ // -+ -+ if (grid_shape.m() <= tile_offset.m() || -+ grid_shape.n() <= tile_offset.n()) { -+ continue; -+ } -+ -+ if (ProblemVisitor::kFillModeC == cutlass::FillMode::kLower && -+ (tile_offset.m() + 1) * ProblemVisitor::ThreadblockShape::kM <= tile_offset.n() * ProblemVisitor::ThreadblockShape::kN) { -+ continue; -+ } -+ -+ if (ProblemVisitor::kFillModeC == cutlass::FillMode::kUpper && -+ tile_offset.m() * ProblemVisitor::ThreadblockShape::kM >= (tile_offset.n() + 1) * ProblemVisitor::ThreadblockShape::kN) { -+ continue; -+ } -+ -+ if (threadIdx.x == 0) { -+ params.visited_problems_ptr[store_offset] = problem_idx; -+ params.visited_tiles_ptr[store_offset] = threadblock_idx; -+ ++store_offset; -+ } -+ } -+ } -+}; -+ -+template -+struct ProblemVisitorRunner { -+ using BaseKernel = ProblemVisitorKernel; -+ using Params = typename BaseKernel::Params; -+ -+ Params params; -+ std::vector host_problem_sizes; -+ int32_t problem_count; -+ int32_t threadblock_count; -+ int32_t visits_per_block; -+ cutlass::DeviceAllocation visited_problems; -+ cutlass::DeviceAllocation visited_tiles; -+ cutlass::DeviceAllocation device_problem_sizes; -+ cutlass::DeviceAllocation workspace; -+ std::vector host_visited_problems; -+ std::vector host_visited_tiles; -+ -+ ProblemVisitorRunner(const std::vector& host_problem_sizes_, -+ int32_t threadblock_count_): -+ host_problem_sizes(host_problem_sizes_), -+ problem_count(int32_t(host_problem_sizes_.size())), -+ threadblock_count(threadblock_count_) {} -+ -+ /// Initializes GEMM state from arguments. -+ cutlass::Status initialize() { -+ size_t workspace_bytes = ProblemVisitor::get_workspace_size( -+ host_problem_sizes.data(), -+ problem_count, -+ threadblock_count); -+ -+ workspace.reset(workspace_bytes); -+ std::vector host_workspace(workspace_bytes); -+ -+ int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); -+ -+ ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, -+ threadblock_count, host_workspace.data()); -+ -+ workspace.copy_from_host(host_workspace.data(), workspace_bytes); -+ -+ device_problem_sizes.reset(problem_count); -+ device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); -+ -+ visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; -+ int32_t total_visits = visits_per_block * threadblock_count; -+ -+ visited_problems.reset(total_visits); -+ visited_tiles.reset(total_visits); -+ host_visited_problems.resize(total_visits); -+ host_visited_tiles.resize(total_visits); -+ -+ cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); -+ params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); -+ -+ return cutlass::Status::kSuccess; -+ } -+ -+ bool verify() { -+ // Sort by problem size and then by threadblock_idx -+ std::vector indices(host_visited_problems.size()); -+ std::iota(indices.begin(), indices.end(), 0); -+ -+ std::stable_sort(indices.begin(), indices.end(), -+ [&](int32_t i1, int32_t i2) { -+ if (host_visited_problems[i1] == host_visited_problems[i2]) { -+ return host_visited_tiles[i1] < host_visited_tiles[i2]; -+ } -+ return host_visited_problems[i1] < host_visited_problems[i2]; -+ }); -+ -+ int32_t idx = 0; -+ -+ // Skip any entries that were not visited -+ while (host_visited_problems[indices[idx]] == -1) { -+ ++idx; -+ } -+ -+ // Check that each problem visited has the tiles we expect -+ for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { -+ auto problem = host_problem_sizes[problem_idx]; -+ ProblemVisitor::possibly_transpose_problem(problem); -+ int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); -+ for (int i = 0; i < problem_tiles; ++i) { -+ EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); -+ EXPECT_EQ(i, host_visited_tiles[indices[idx]]); -+ ++idx; -+ } -+ } -+ -+ return true; -+ } -+ -+ bool run(bool skip_tile_check=false, cudaStream_t stream = nullptr) { -+ cutlass::Status status = initialize(); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Initialization failed" << std::endl; -+ return false; -+ } -+ -+ dim3 grid(threadblock_count, 1, 1); -+ dim3 block(ProblemVisitor::kThreadCount, 1, 1); -+ int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params); -+ -+ cudaError_t result = cudaGetLastError(); -+ if (result != cudaSuccess) { -+ std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ visited_problems.copy_to_host(host_visited_problems.data()); -+ visited_tiles.copy_to_host(host_visited_tiles.data()); -+ -+ if (skip_tile_check) { -+ return true; -+ } -+ -+ return verify(); -+ } -+}; -+ -+template -+struct TestbedGroupedRank2KScheduler { -+ -+ using BaselinePV = BaselineProblemVisitor, -+ ThreadblockShape, -+ PrefetchTileCount, -+ ThreadCount, -+ FillModeC>; -+ -+ // -+ // Data members -+ // -+ -+ // Whether to skip checking that the tiles are visited as expected. This is useful -+ // in cases where ThreadblockShape::kM != ThreadblockShape::kN, for which the grouped -+ // Rank2K scheduler may assign out-of-bounds tiles that will cause a threadblock to -+ // exit early, but which are difficult to detect in tests without reimplementing -+ // this functionality. -+ bool skip_tile_check; -+ uint32_t seed; -+ int problem_count; -+ int threadblock_count; -+ std::vector problem_sizes_host; -+ -+ // -+ // Methods -+ // -+ -+ TestbedGroupedRank2KScheduler(bool skip_tile_check_=false, uint32_t seed_ = 3080): -+ skip_tile_check(skip_tile_check_), seed(seed_) { srand(seed); } -+ -+ /// Initializes data structures -+ void initialize(int32_t scale_factor) { -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ problem_sizes_host.clear(); -+ problem_sizes_host.resize(problem_count); -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ int n = scale_factor * (rand() % 64) + 24; -+ -+ cutlass::gemm::GemmCoord problem( -+ n, -+ n, -+ scale_factor * (rand() % 64) + 24); -+ -+ problem_sizes_host.at(i) = problem; -+ } -+ } -+ -+ template -+ void compare_visitors(const ProblemVisitorRunner& baseline_runner) { -+ using PV = cutlass::gemm::kernel::Rank2KGroupedProblemVisitor< -+ ThreadblockShape, -+ GroupScheduleMode_, -+ PrefetchTileCount, -+ ThreadCount, -+ FillModeC>; -+ ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); -+ EXPECT_TRUE(runner.run(skip_tile_check)); -+ -+ // Check that this problem visitor visits the same problems and tiles as the baseline -+ EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); -+ EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); -+ } -+ -+ template -+ void compare_visitors(const ProblemVisitorRunner& baseline_runner) { -+ // Compare the next visitor with the baseline visitor -+ compare_visitors(baseline_runner); -+ -+ // Recurse to compare the next visitors -+ compare_visitors(baseline_runner); -+ } -+ -+ /// Executes the test on all scheduler modes -+ void run(int problem_count, int threadblock_count, int scale_factor=8) { -+ -+ this->problem_count = problem_count; -+ this->threadblock_count = threadblock_count; -+ -+ // Initialize the problem -+ initialize(scale_factor); -+ -+ // Run the baseline visitor to which we will compare all other visitors -+ ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); -+ EXPECT_TRUE(baseline_runner.run(skip_tile_check)); -+ -+ compare_visitors(baseline_runner); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // device -+} // gemm -+} // test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h -new file mode 100644 -index 0000000..00d83b6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h -@@ -0,0 +1,407 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped GEMM problem visitors -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -+#include "cutlass/gemm/kernel/grouped_problem_visitor.h" -+#include "cutlass/util/device_memory.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Use simple problem visitor as a baseline -+template -+struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { -+ using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; -+ using Params = typename Base::Params; -+ static int const kThreadCount = ThreadCount; -+ -+ struct SharedStorage {}; -+ -+ int32_t tile_count_sum; -+ SharedStorage &shared_storage; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ BaselineProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base(params_, block_idx), -+ shared_storage(shared_storage_) -+ { -+ cutlass::gemm::GemmCoord problem = this->problem_size(); -+ cutlass::gemm::GemmCoord grid = this->grid_shape(problem); -+ tile_count_sum = this->tile_count(grid); -+ } -+ -+ CUTLASS_DEVICE -+ bool next_tile() { -+ if (this->tile_idx < tile_count_sum) { -+ return true; -+ } -+ -+ do { -+ ++this->problem_idx; -+ -+ if (this->problem_idx >= this->params.problem_count) { -+ return false; -+ } -+ -+ cutlass::gemm::GemmCoord problem = this->problem_size(); -+ cutlass::gemm::GemmCoord grid = this->grid_shape(problem); -+ -+ this->problem_tile_start = tile_count_sum; -+ tile_count_sum += this->tile_count(grid); -+ -+ } while (tile_count_sum <= this->tile_idx); -+ -+ return true; -+ } -+ -+ static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count) { -+ return 0; -+ } -+ -+ static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count, -+ void* host_workspace_ptr) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct ProblemVisitorKernel { -+ struct SharedStorage { -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+ struct Params { -+ typename ProblemVisitor::Params problem_visitor_params; -+ int32_t* visited_problems_ptr; -+ int32_t* visited_tiles_ptr; -+ int32_t visits_per_block; -+ -+ Params(): -+ visited_problems_ptr(nullptr), -+ visited_tiles_ptr(nullptr), -+ visits_per_block(0) {} -+ -+ Params(typename ProblemVisitor::Params problem_visitor_params_, -+ int32_t* visited_problems_ptr_, -+ int32_t* visited_tiles_ptr_, -+ int32_t visits_per_block_): -+ problem_visitor_params(problem_visitor_params_), -+ visited_problems_ptr(visited_problems_ptr_), -+ visited_tiles_ptr(visited_tiles_ptr_), -+ visits_per_block(visits_per_block_) {} -+ }; -+ -+ CUTLASS_DEVICE -+ void operator()(const Params& params, SharedStorage &shared_storage) { -+ int32_t store_offset = params.visits_per_block * blockIdx.x; -+ ProblemVisitor problem_visitor(params.problem_visitor_params, -+ shared_storage.problem_visitor, -+ blockIdx.x); -+ -+ while (problem_visitor.next_tile()) { -+ int32_t problem_idx = problem_visitor.problem_index(); -+ int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); -+ -+ if (threadIdx.x == 0) { -+ params.visited_problems_ptr[store_offset] = problem_idx; -+ params.visited_tiles_ptr[store_offset] = threadblock_idx; -+ ++store_offset; -+ } -+ problem_visitor.advance(gridDim.x); -+ } -+ } -+}; -+ -+template -+struct ProblemVisitorRunner { -+ using BaseKernel = ProblemVisitorKernel; -+ using Params = typename BaseKernel::Params; -+ -+ Params params; -+ std::vector host_problem_sizes; -+ int32_t problem_count; -+ int32_t threadblock_count; -+ int32_t visits_per_block; -+ cutlass::DeviceAllocation visited_problems; -+ cutlass::DeviceAllocation visited_tiles; -+ cutlass::DeviceAllocation device_problem_sizes; -+ cutlass::DeviceAllocation workspace; -+ std::vector host_visited_problems; -+ std::vector host_visited_tiles; -+ -+ ProblemVisitorRunner(const std::vector& host_problem_sizes_, -+ int32_t threadblock_count_): -+ host_problem_sizes(host_problem_sizes_), -+ problem_count(int32_t(host_problem_sizes_.size())), -+ threadblock_count(threadblock_count_) {} -+ -+ /// Initializes GEMM state from arguments. -+ cutlass::Status initialize() { -+ size_t workspace_bytes = ProblemVisitor::get_workspace_size( -+ host_problem_sizes.data(), -+ problem_count, -+ threadblock_count); -+ -+ workspace.reset(workspace_bytes); -+ std::vector host_workspace(workspace_bytes); -+ -+ int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); -+ -+ ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, -+ threadblock_count, host_workspace.data()); -+ -+ workspace.copy_from_host(host_workspace.data(), workspace_bytes); -+ -+ device_problem_sizes.reset(problem_count); -+ device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); -+ -+ visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; -+ int32_t total_visits = visits_per_block * threadblock_count; -+ -+ visited_problems.reset(total_visits); -+ visited_tiles.reset(total_visits); -+ host_visited_problems.resize(total_visits); -+ host_visited_tiles.resize(total_visits); -+ -+ cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); -+ params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); -+ -+ return cutlass::Status::kSuccess; -+ } -+ -+ bool verify() { -+ // Sort by problem size and then by threadblock_idx -+ std::vector indices(host_visited_problems.size()); -+ std::iota(indices.begin(), indices.end(), 0); -+ -+ std::stable_sort(indices.begin(), indices.end(), -+ [&](int32_t i1, int32_t i2) { -+ if (host_visited_problems[i1] == host_visited_problems[i2]) { -+ return host_visited_tiles[i1] < host_visited_tiles[i2]; -+ } -+ return host_visited_problems[i1] < host_visited_problems[i2]; -+ }); -+ -+ int32_t idx = 0; -+ -+ // Skip any entries that were not visited -+ while (host_visited_problems[indices[idx]] == -1) { -+ ++idx; -+ } -+ -+ // Check that each problem visited has the tiles we expect -+ for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { -+ auto problem = host_problem_sizes[problem_idx]; -+ ProblemVisitor::possibly_transpose_problem(problem); -+ int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); -+ for (int i = 0; i < problem_tiles; ++i) { -+ EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); -+ EXPECT_EQ(i, host_visited_tiles[indices[idx]]); -+ ++idx; -+ } -+ } -+ -+ return true; -+ } -+ -+ bool run(cudaStream_t stream = nullptr) { -+ cutlass::Status status = initialize(); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Initialization failed" << std::endl; -+ return false; -+ } -+ -+ dim3 grid(threadblock_count, 1, 1); -+ dim3 block(ProblemVisitor::kThreadCount, 1, 1); -+ int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params); -+ -+ cudaError_t result = cudaGetLastError(); -+ if (result != cudaSuccess) { -+ std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ visited_problems.copy_to_host(host_visited_problems.data()); -+ visited_tiles.copy_to_host(host_visited_tiles.data()); -+ -+ return verify(); -+ } -+}; -+ -+template -+struct TestbedGroupedGemmScheduler { -+ -+ using PSHelper = cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper; -+ using BaselinePV = BaselineProblemVisitor; -+ -+ // -+ // Data members -+ // -+ uint32_t seed; -+ int problem_count; -+ int threadblock_count; -+ std::vector problem_sizes_host; -+ -+ // -+ // Methods -+ // -+ -+ TestbedGroupedGemmScheduler(uint32_t seed_ = 3080): -+ seed(seed_) { srand(seed); } -+ -+ /// Initializes data structures -+ void initialize(int32_t scale_factor) { -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ problem_sizes_host.clear(); -+ problem_sizes_host.resize(problem_count); -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ -+ cutlass::gemm::GemmCoord problem( -+ scale_factor * (rand() % 64) + 24, -+ scale_factor * (rand() % 64) + 24, -+ scale_factor * (rand() % 64) + 24); -+ -+ problem_sizes_host.at(i) = problem; -+ } -+ } -+ -+ template -+ void compare_visitors(const ProblemVisitorRunner& baseline_runner) { -+ using PV = cutlass::gemm::kernel::GemmGroupedProblemVisitor< -+ ThreadblockShape, -+ GroupScheduleMode_, -+ PrefetchTileCount, -+ ThreadCount, -+ Transpose>; -+ ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); -+ EXPECT_TRUE(runner.run()); -+ -+ // Check that this problem visitor visits the same problems and tiles as the baseline -+ EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); -+ EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); -+ } -+ -+ template -+ void compare_visitors(const ProblemVisitorRunner& baseline_runner) { -+ // Compare the next visitor with the baseline visitor -+ compare_visitors(baseline_runner); -+ -+ // Recurse to compare the next visitors -+ compare_visitors(baseline_runner); -+ } -+ -+ /// Executes the test on all scheduler modes -+ void run(int problem_count, int threadblock_count, int scale_factor=8) { -+ -+ this->problem_count = problem_count; -+ this->threadblock_count = threadblock_count; -+ -+ // Initialize the problem -+ initialize(scale_factor); -+ -+ // Run the baseline visitor to which we will compare all other visitors -+ ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); -+ EXPECT_TRUE(baseline_runner.run()); -+ -+ compare_visitors(baseline_runner); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // device -+} // gemm -+} // test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_interleaved.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_interleaved.h -new file mode 100644 -index 0000000..b54a4b6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_interleaved.h -@@ -0,0 +1,347 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/host_reorder.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct InterleavedTestbed { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ InterleavedTestbed( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Waives test if CUDA device is insufficient -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementA, -+ typename Gemm::LayoutA> tensor_A(problem_size.mk()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementB, -+ typename Gemm::LayoutB> tensor_B(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementB, -+ typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementC, -+ typename Gemm::LayoutC> tensor_C(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementC, -+ typename Gemm::LayoutC> tensor_D(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementC, -+ typename Gemm::LayoutC> reference_D(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ cutlass::reorder_column( -+ tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); -+ -+ cutlass::reference::host::TensorCopy( -+ reference_D.host_view(), -+ tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B_reordered.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B_reordered.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D.device_ref(), -+ {alpha, beta} -+ }; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op.initialize(arguments); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::Gemm< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm::Operator> -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ beta, -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D.host_view(), -+ tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nB_reordered =\n" << tensor_B_reordered.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Runs a set of problem sizes -+ bool run_all() { -+ bool passed = true; -+ -+ int problem_size_m[] = { -+ InterleavedK, 256 + InterleavedK, 512 + InterleavedK -+ }; -+ -+ int problem_size_n[] = { -+ InterleavedK, 256 + InterleavedK, 512 + InterleavedK -+ }; -+ -+ int problem_size_k[] = { -+ InterleavedK, 256 + InterleavedK, 512 + InterleavedK -+ }; -+ -+ double problem_alpha[] = { -+ 1.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (double alpha : problem_alpha) { -+ for (double beta : problem_beta) { -+ -+ passed = run( -+ {m, n, k}, -+ ElementCompute(alpha), -+ ElementCompute(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_planar_complex.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_planar_complex.h -new file mode 100644 -index 0000000..a721cc8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_planar_complex.h -@@ -0,0 +1,326 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/gemm_planar_complex.h" -+#include "cutlass/util/host_tensor_planar_complex.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TestbedPlanarComplex { -+public: -+ -+ using ElementA = typename Gemm::ElementA; -+ using LayoutA = typename Gemm::LayoutA; -+ using ElementB = typename Gemm::ElementB; -+ using LayoutB = typename Gemm::LayoutB; -+ using ElementC = typename Gemm::ElementC; -+ using LayoutC = typename Gemm::LayoutC; -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::HostTensorPlanarComplex tensor_A; -+ cutlass::HostTensorPlanarComplex tensor_B; -+ cutlass::HostTensorPlanarComplex tensor_C; -+ cutlass::HostTensorPlanarComplex tensor_D; -+ cutlass::HostTensorPlanarComplex tensor_D_ref; -+ -+ // -+ // Methods -+ // -+ -+ TestbedPlanarComplex(cutlass::gemm::GemmCoord const & problem_size): problem_size(problem_size) { -+ -+ tensor_A.reset({problem_size.m(), problem_size.k()}); -+ tensor_B.reset({problem_size.k(), problem_size.n()}); -+ tensor_C.reset({problem_size.m(), problem_size.n()}); -+ tensor_D.reset({problem_size.m(), problem_size.n()}); -+ tensor_D_ref.reset({problem_size.m(), problem_size.n()}, false); -+ } -+ -+ void initialize() { -+ -+ uint64_t seed = 1073; -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_A.host_view(), seed, scope_max, scope_min, 0); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_B.host_view(), seed * 2019, scope_max, scope_min, 0); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_C.host_view(), seed * 2020, scope_max, scope_min, 0); -+ -+ cutlass::reference::host::TensorFill(tensor_D.host_view(), cutlass::complex()); -+ cutlass::reference::host::TensorFill(tensor_D_ref.host_view(), cutlass::complex()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ bool run( -+ cutlass::complex alpha = {1, 0}, -+ cutlass::complex beta = {0, 0}) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ initialize(); -+ -+ int batch_count = 1; -+ -+ ElementA *ptr_A = tensor_A.device_data(); -+ ElementB *ptr_B = tensor_B.device_data(); -+ ElementC *ptr_C = tensor_C.device_data(); -+ ElementC *ptr_D = tensor_D.device_data(); -+ -+ typename LayoutA::Stride::Index lda = tensor_A.layout().stride(0); -+ typename LayoutB::Stride::Index ldb = tensor_B.layout().stride(0); -+ typename LayoutC::Stride::Index ldc = tensor_C.layout().stride(0); -+ typename LayoutC::Stride::Index ldd = tensor_D.layout().stride(0); -+ -+ int64_t imag_stride_A = tensor_A.imaginary_stride(); -+ int64_t imag_stride_B = tensor_B.imaginary_stride(); -+ int64_t imag_stride_C = tensor_C.imaginary_stride(); -+ int64_t imag_stride_D = tensor_D.imaginary_stride(); -+ -+ // -+ // Launch device kernel -+ // -+ -+ Gemm gemm_op; -+ -+ typename Gemm::Arguments args{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ ptr_A, -+ ptr_A + imag_stride_A, -+ ptr_B, -+ ptr_B + imag_stride_B, -+ ptr_C, -+ ptr_C + imag_stride_C, -+ ptr_D, -+ ptr_D + imag_stride_D, -+ lda, -+ lda, -+ ldb, -+ ldb, -+ ldc, -+ ldc, -+ ldd, -+ ldd -+ }; -+ -+ cutlass::Status status = gemm_op(args); -+ -+ EXPECT_EQ(status, cutlass::Status::kSuccess); -+ -+ cudaError_t error = cudaDeviceSynchronize(); -+ -+ tensor_D.sync_host(); -+ -+ // -+ // Compute reference -+ // -+ -+ cutlass::reference::host::GemmPlanarComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator -+ >( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ Gemm::kTransformA, -+ tensor_B.host_ref(), -+ Gemm::kTransformB, -+ beta, -+ tensor_C.host_ref(), -+ tensor_D_ref.host_ref() -+ ); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D.host_view(), -+ tensor_D_ref.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("gemm_planar_complex.txt"); -+ -+ output -+ << "A:\n" << tensor_A.host_view() << "\n" -+ << "B:\n" << tensor_B.host_view() << "\n" -+ << "C:\n" << tensor_C.host_view() << "\n" -+ << "Reference:\n" -+ << tensor_D_ref.host_view() << "\n" -+ << "Computed:\n" -+ << tensor_D.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+template -+bool TestOneGemmPlanarComplex(cutlass::gemm::GemmCoord problem_size) { -+ -+ TestbedPlanarComplex testbed(problem_size); -+ -+ return testbed.run(); -+} -+ -+template -+bool TestAllGemmPlanarComplex() { -+ -+ int M[] = { -+ 16, 64, 72, 144, 264, 520, -+ }; -+ -+ int N[] = { -+ 16, 64, 72, 144, 248, 264, 520 -+ }; -+ -+ int K[] = { -+ 8, 64, 72, 96, 264, 520 -+ }; -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ cutlass::complex alpha_values[] = { -+ {ElementCompute(1.25), ElementCompute(-0.5)} -+ }; -+ -+ cutlass::complex beta_values[] = { -+ {ElementCompute(-2.25), ElementCompute(1.5)} -+ }; -+ -+ for (int m : M) { -+ for (int n : N) { -+ for (int k : K) { -+ -+ test::gemm::device::TestbedPlanarComplex testbed({m, n, k}); -+ -+ for (auto const &alpha : alpha_values) { -+ for (auto const &beta : beta_values) { -+ -+ bool passed = testbed.run(alpha, beta); -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h -new file mode 100644 -index 0000000..29f3989 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h -@@ -0,0 +1,641 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Rank 2k update interface -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/rank_2k_complex.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedRank2KUniversal { -+ -+ using ElementA = typename Rank2K::ElementA; -+ using ElementB = typename Rank2K::ElementB; -+ using ElementC = typename Rank2K::ElementC; -+ using ElementAccumulator = typename Rank2K::ElementAccumulator; -+ using ElementCompute = typename Rank2K::Rank2Kkernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedRank2KUniversal( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ -+ EXPECT_TRUE(false) << "Input distribution not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_symmetric_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillSymmetricRandomUniform( -+ view, seed, Rank2K::kFillModeC, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillSymmetricRandomGaussian( -+ view, seed, Rank2K::kFillModeC, 0, 0.5, mantissa_in_bits); -+ } -+ else { -+ -+ EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the Rank2K workspace -+ // -+ -+ tensor_A.resize(problem_size.mk()); -+ tensor_B.resize(problem_size.mk()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); -+ EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Rank2K::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Rank2K::ElementB(1); -+ tensor_C.host_view().at({0, 0}) = typename Rank2K::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ if (tensor_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ -+ if (reference_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); -+ -+ bool passed = l2_norm < cutlass::MantissaInBits::error; -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a Rank2K -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ cutlass::reference::host::Rank2KComplex< -+ typename Rank2K::ElementA, typename Rank2K::LayoutA, -+ typename Rank2K::ElementB, typename Rank2K::LayoutB, -+ typename Rank2K::ElementC, typename Rank2K::LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ Rank2K::kTransformA, -+ tensor_B.host_ref(), -+ Rank2K::kTransformB, -+ beta, -+ tensor_C.host_ref(), -+ reference_D.host_ref(), -+ ElementAccumulator(0), -+ Rank2K::kFillModeC, -+ Rank2K::kBlasMode -+ ); -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Rank2K::Rank2Kkernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 -+ std::cout << "[TestbedRank2KUniversal::run()] problem(m, n, k): " << problem_size -+ << " alpha: " << ElementCompute(alpha) -+ << " beta: " << ElementCompute(beta) << std::endl; -+#endif -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the Rank2K operator -+ // -+ -+ typename Rank2K::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_D.device_data(), -+ problem_size.n() * problem_size.k(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0) -+ }; -+ -+ Rank2K rank2k_op; -+ -+ size_t workspace_size = Rank2K::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the Rank2K -+ // -+ -+ status = rank2k_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ //if (true) { -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Rank2k_device_" -+ << "fill_mode_c_" -+ << (Rank2K::kFillModeC == cutlass::FillMode::kLower ? "lower_" : -+ (Rank2K::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) -+ << "mnk_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Rank2K::ThreadblockShape::kM << "x" -+ << Rank2K::ThreadblockShape::kN << "x" -+ << Rank2K::ThreadblockShape::kK << "_" -+ << Rank2K::WarpShape::kM << "x" -+ << Rank2K::WarpShape::kN << "x" -+ << Rank2K::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n" -+ << "\nD reference:\n" << reference_D.host_view() << "\n" -+ << "\nD computed:\n" << tensor_D.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestRank2kUniversal( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count, -+ double alpha = 1.0, -+ double beta = 2.0) { -+ -+ bool passed = true; -+ -+ TestbedRank2KUniversal testbed; -+ -+ using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ return passed; -+} -+ -+template -+bool TestAllRank2KUniversal() { -+ bool passed = true; -+ -+ -+ int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); -+ -+ int const kAlignment = cutlass::platform::is_same< -+ typename Rank2K::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ // int8_t gemm alignment constraints -+ int const kAlignmentM = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentN = kAlignmentM; -+ -+ int const kAlignmentK = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value -+ ? 4 : kAlignment; -+ -+ cutlass::gemm::GemmUniversalMode modes[] = { -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ }; -+ -+ int problem_size_n[] = { -+ kAlignmentN, 512 - 2*kAlignmentN -+ }; -+ -+ int problem_size_k[] = { -+ kAlignmentK, -+ Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, -+ Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK -+ }; -+ -+ int batch_counts[] = { // may be interpretted as batch count or split-K slices -+ 1 // Just running one batch for now (removing 2, 3, 5, 7) -+ }; -+ -+ double problem_alpha[] = { -+ 1.0, 3.25 -+ }; -+ -+ double problem_beta[] = { -+ 0.0, 2.15 -+ }; -+ -+ using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; -+ -+ for (cutlass::gemm::GemmUniversalMode mode : modes) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int batch_count : batch_counts) { -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ if (mode == cutlass::gemm::GemmUniversalMode::kGemm || -+ mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { -+ -+ // skip very small K problems -+ //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { -+ // continue; -+ //} -+ } -+ -+ cutlass::gemm::GemmCoord problem_size(n, n, k); -+ -+ TestbedRank2KUniversal testbed; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+template -+bool TestAllRank2KHermitianUniversal() { -+ bool passed = true; -+ -+ using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; -+ using ElementAccumulator = typename Rank2K::ElementAccumulator; -+ -+ int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); -+ -+ int const kAlignment = cutlass::platform::is_same< -+ typename Rank2K::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ // int8_t gemm alignment constraints -+ int const kAlignmentM = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentN = kAlignmentM; -+ -+ int const kAlignmentK = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value -+ ? 4 : kAlignment; -+ -+ cutlass::gemm::GemmUniversalMode modes[] = { -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ }; -+ -+ int problem_size_n[] = { -+ kAlignmentN, 512 - 2*kAlignmentN -+ }; -+ -+ int problem_size_k[] = { -+ kAlignmentK, -+ Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, -+ Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK -+ }; -+ -+ int batch_counts[] = { // may be interpretted as batch count or split-K slices -+ 1 // Just running one batch for now (removing 2, 3, 5, 7) -+ }; -+ -+ /* Complex alpha for HER2K */ -+ ElementAccumulator problem_alpha[] = { -+ {1.0}, -+ {1.25, 3.25}, -+ {-0.25, -2.25} -+ }; -+ -+ ElementAccumulator problem_beta[] = { -+ 0.0, -2.25 -+ }; -+ -+ for (cutlass::gemm::GemmUniversalMode mode : modes) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int batch_count : batch_counts) { -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ if (mode == cutlass::gemm::GemmUniversalMode::kGemm || -+ mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { -+ -+ // skip very small K problems -+ //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { -+ // continue; -+ //} -+ } -+ -+ cutlass::gemm::GemmCoord problem_size(n, n, k); -+ -+ TestbedRank2KUniversal testbed; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ alpha, -+ beta -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h -new file mode 100644 -index 0000000..7c403ad ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h -@@ -0,0 +1,511 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Rank 2k update interface -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedRank2KUniversal { -+ -+ using ElementA = typename RankK::ElementA; -+ using ElementC = typename RankK::ElementC; -+ using ElementAccumulator = typename RankK::ElementAccumulator; -+ using ElementCompute = typename RankK::RankKkernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedRank2KUniversal( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ -+ EXPECT_TRUE(false) << "Input distribution not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_symmetric_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillSymmetricRandomUniform( -+ view, seed, RankK::kFillModeC, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillSymmetricRandomGaussian( -+ view, seed, RankK::kFillModeC, 0, 0.5, mantissa_in_bits); -+ } -+ else { -+ -+ EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the RankK workspace -+ // -+ -+ tensor_A.resize(problem_size.mk()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); -+ EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename RankK::ElementA(1); -+ tensor_C.host_view().at({0, 0}) = typename RankK::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ if (tensor_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ -+ if (reference_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); -+ -+ bool passed = l2_norm < cutlass::MantissaInBits::error; -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a RankK -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ cutlass::reference::host::Rank2KComplex< -+ typename RankK::ElementA, typename RankK::LayoutA, -+ typename RankK::ElementC, typename RankK::LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ RankK::kTransformA, -+ beta, -+ tensor_C.host_ref(), -+ reference_D.host_ref(), -+ ElementAccumulator(0), -+ RankK::kFillModeC, -+ RankK::kBlasMode -+ ); -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename RankK::RankKkernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 -+ std::cout << "[TestbedRankKUniversal::run()] problem(m, n, k): " << problem_size -+ << " alpha: " << ElementCompute(alpha) -+ << " beta: " << ElementCompute(beta) << std::endl; -+#endif -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the RankK operator -+ // -+ -+ typename RankK::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_A.device_data(), -+ tensor_C.device_data(), -+ tensor_D.device_data(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0) -+ }; -+ -+ RankK rank2k_op; -+ -+ size_t workspace_size = RankK::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the RankK -+ // -+ -+ status = rank2k_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ //if (true) { -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_RankK_device_" -+ << "fill_mode_c_" -+ << (RankK::kFillModeC == cutlass::FillMode::kLower ? "lower_" : -+ (RankK::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) -+ << "mnk_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << RankK::ThreadblockShape::kM << "x" -+ << RankK::ThreadblockShape::kN << "x" -+ << RankK::ThreadblockShape::kK << "_" -+ << RankK::WarpShape::kM << "x" -+ << RankK::WarpShape::kN << "x" -+ << RankK::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n" -+ << "\nD reference:\n" << reference_D.host_view() << "\n" -+ << "\nD computed:\n" << tensor_D.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestRank2kUniversal( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count, -+ double alpha = 1.0, -+ double beta = 2.0) { -+ -+ bool passed = true; -+ -+ TestbedRank2KUniversal testbed; -+ -+ using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ return passed; -+} -+ -+template -+bool TestAllRankKUniversal() { -+ bool passed = true; -+ -+ -+ int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); -+ int const kAlignmentN = 128 / kMinimumOperandElementSize; -+ int const kAlignmentK = 128 / kMinimumOperandElementSize; -+ -+ cutlass::gemm::GemmUniversalMode modes[] = { -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ }; -+ -+ int problem_size_n[] = { -+ kAlignmentN, 512 - 2*kAlignmentN -+ }; -+ -+ int problem_size_k[] = { -+ kAlignmentK, -+ RankK::ThreadblockShape::kK * RankK::kStages - kAlignmentK, -+ RankK::ThreadblockShape::kK * RankK::kStages * 3 - kAlignmentK -+ }; -+ -+ int batch_counts[] = { // may be interpretted as batch count or split-K slices -+ 1 // Just running one batch for now (removing 2, 3, 5, 7) -+ }; -+ -+ double problem_alpha[] = { -+ 1.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ -+ using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; -+ -+ for (cutlass::gemm::GemmUniversalMode mode : modes) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int batch_count : batch_counts) { -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ if (mode == cutlass::gemm::GemmUniversalMode::kGemm || -+ mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { -+ } -+ -+ cutlass::gemm::GemmCoord problem_size(n, n, k); -+ -+ TestbedRank2KUniversal testbed; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_sanity.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_sanity.h -new file mode 100644 -index 0000000..73c0c5c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_sanity.h -@@ -0,0 +1,238 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/core_io.h" -+ -+#include "testbed.h" -+ -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// List of Gemm internal paramters this testbed supports user verification -+// -+enum class ParameterID { -+ -+ // Threadblock-level parameters -+ kSmemASize, -+ kSmemBSize, -+ -+ // Warp-level parameters -+ kWarpFragmentASize, -+ kWarpFragmentBSize, -+ kWarpFragmentCSize, -+ kInvalid -+}; -+ -+struct Reference { -+ ParameterID parameter_id; -+ -+ union { -+ int value; -+ -+ struct { -+ int m, n, k; -+ } gemm_shape; -+ -+ struct { -+ int row, column; -+ } matrix_shape; -+ }; -+ -+ std::string error_msg; -+ -+ Reference( -+ ParameterID parameter_id_, -+ int value_=-1, -+ std::string const &error_msg_="") : parameter_id(parameter_id_), value(value_), error_msg(error_msg_) {} -+}; -+ -+ -+template -+struct TestbedSanity { -+ -+ // -+ // Type definitions (All Gemm types top down) -+ // -+ -+ // Unpacking Gemm types in the following order -+ // Kernel-level > Threadblock-level > Warp-level > Instruction-level -+ -+ // kernel-level cutlass Gemm -+ using GemmKernel = typename Gemm::GemmKernel; -+ -+ // -+ // Threadblock-level gemm types -+ // -+ using MmaThreadBlock = typename GemmKernel::Mma; -+ -+ // Threadblock-level gemm shape covering one stage -+ using ThreadblockShape = typename MmaThreadBlock::Shape; -+ -+ // Shared memory size covering all stages -+ using SmemShapeA = typename MmaThreadBlock::Base::SharedStorage::ShapeA; -+ using SmemPaddingA = typename MmaThreadBlock::Policy::SmemPaddingA; -+ using SmemShapeB = typename MmaThreadBlock::Base::SharedStorage::ShapeB; -+ using SmemPaddingB = typename MmaThreadBlock::Policy::SmemPaddingB; -+ -+ -+ /// Number of stages -+ static int const kStages = MmaThreadBlock::Base::kStages; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = MmaThreadBlock::kWarpGemmIterations; -+ -+ -+ // -+ // Warp-level gemm types -+ // -+ -+ // Warp-level gemm operator -+ using MmaWarp = typename MmaThreadBlock::Operator; -+ -+ // Warp-level gemm shape covering all kgroups -+ using WarpShape = typename MmaWarp::Shape; -+ -+ // Warp-level framents holding operands A & B operand and destination C -+ using WarpFragmentA = typename MmaWarp::FragmentA; -+ using WarpFragmentB = typename MmaWarp::FragmentB; -+ using WarpFragmentC = typename MmaWarp::FragmentC; -+ -+ // -+ // Instruction-level gemm types -+ // -+ -+ // Instruction-level gemm operator -+ using MmaInstruction = typename MmaWarp::Policy::Operator; -+ -+ // Instruction shape -+ using InstructionShape = typename MmaInstruction::Shape; -+ -+ // Instruction-level framents holding operands A & B operand and destination C -+ using InstructionFragmentA = typename MmaInstruction::FragmentA; -+ using InstructionFragmentB = typename MmaInstruction::FragmentB; -+ using InstructionFragmentC = typename MmaInstruction::FragmentC; -+ -+ // -+ // Testbed types -+ // -+ -+ // Vector of values holding user provided reference -+ using ReferenceVector = std::vector; -+ -+ // -+ // Data members -+ // -+ ReferenceVector references; -+ -+ // -+ // Methods -+ // -+ -+ TestbedSanity(ReferenceVector const &references_ = ReferenceVector()) : references(references_){ } -+ -+ // verify all parameter in ReferenceVector -+ bool verify() { -+ for(auto ref : references) -+ verify_parameter(ref); -+ return true; -+ } -+ -+ // verify parameter of type Reference -+ void verify_parameter(Reference const& ref) { -+ switch(ref.parameter_id) { -+ case ParameterID::kWarpFragmentASize : EXPECT_TRUE(WarpFragmentA::kElements == ref.value) << *this; break; -+ case ParameterID::kWarpFragmentBSize : EXPECT_TRUE(WarpFragmentB::kElements == ref.value) << *this; break; -+ case ParameterID::kWarpFragmentCSize : EXPECT_TRUE(WarpFragmentC::kElements == ref.value) << *this; break; -+ } -+ } -+ -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+// Overload output operators for TesbedSanity -+/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+template -+std::ostream & operator<<(std::ostream &out, TestbedSanity const &test) { -+ -+ -+ out << "Gemm internal parameters" << std::endl -+ << " Threadblock-level parameters:" << std::endl -+ << " ThreadblockShape = " << typename TestbedSanity::ThreadblockShape() << std::endl -+ << " kStages = " << TestbedSanity::kStages << std::endl -+ << " kWarpGemmIterations = "<< TestbedSanity::kWarpGemmIterations << std::endl -+ <<" Shared memory sizes:" << std::endl -+ <<" SmemPaddingA = " << typename TestbedSanity::SmemPaddingA() << std::endl -+ <<" SmemPaddingB = " << typename TestbedSanity::SmemPaddingB() << std::endl -+ <<" SmemShapeA = " << typename TestbedSanity::SmemShapeA() << std::endl -+ <<" SmemShapeB = " << typename TestbedSanity::SmemShapeB() << std::endl -+ <<" Warp-level parameters" << std::endl -+ <<" WarpShape = " << typename TestbedSanity::WarpShape() << std::endl -+ <<" Fragment sizes:" << std::endl -+ <<" WarpFragmentA::kElements = " << TestbedSanity::WarpFragmentA::kElements << std::endl -+ <<" WarpFragmentB::kElements = " << TestbedSanity::WarpFragmentB::kElements << std::endl -+ <<" WarpFragmentC::kElements = " << TestbedSanity::WarpFragmentC::kElements << std::endl -+ <<" Instruction-level parameters" << std::endl -+ <<" InstructionShape = " << typename TestbedSanity::InstructionShape() << std::endl -+ <<" Fragment sizes:" << std::endl -+ <<" InstructionFragmentA::kElements = " << TestbedSanity::InstructionFragmentA::kElements << std::endl -+ <<" InstructionFragmentB::kElements = " << TestbedSanity::InstructionFragmentB::kElements << std::endl -+ <<" InstructionFragmentC::kElements = " << TestbedSanity::InstructionFragmentC::kElements << std::endl; -+ -+ return out; -+} -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_sparse.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_sparse.h -new file mode 100644 -index 0000000..56f3e5e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_sparse.h -@@ -0,0 +1,488 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+ Testbed for sparse operations not to be released for CUDA 11.0 GA. Expected release is 11.1. -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/util/host_uncompress.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct SparseTestbed { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ static int const kSparse = Gemm::GemmKernel::kSparse; -+ static int const kMetaSizeInBits = Gemm::GemmKernel::kMetaSizeInBits; -+ static int const kMaxID2 = Gemm::GemmKernel::kMaxID2; -+ static int const kElementsPerElementE = Gemm::GemmKernel::kElementsPerElementE; -+ -+ using ElementE = typename Gemm::GemmKernel::ElementE; -+ using LayoutE = cutlass::layout::RowMajor; -+ using ReorderedLayoutE = typename Gemm::GemmKernel::LayoutE; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_E; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_A_uncompressed; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ cutlass::HostTensor tensor_E; -+ cutlass::HostTensor tensor_E_reordered; -+ -+ // -+ // Methods -+ // -+ -+ SparseTestbed( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080) -+ : init_A(init_A_), -+ init_B(init_B_), -+ init_C(init_C_), -+ init_E(init_E_), -+ seed(seed_) {} -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ tensor_A.resize(cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); -+ tensor_A_uncompressed.resize(problem_size.mk()); -+ tensor_B.resize(problem_size.kn()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ tensor_E.resize(cutlass::make_Coord( -+ problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); -+ tensor_E_reordered.resize(cutlass::make_Coord( -+ problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ if (init_E == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomSparseMeta( -+ tensor_E.host_view(), seed, kMetaSizeInBits); -+ } else if (init_E == cutlass::Distribution::Identity) { -+ uint32_t content = (kMaxID2 == 1) ? 0x44444444 : 0x4444; -+ cutlass::reference::host::TensorFill(tensor_E.host_view(), -+ (ElementE)(content)); -+ } else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false); -+ } -+ -+ cutlass::reorder_meta(tensor_E_reordered.host_ref(), tensor_E.host_ref(), -+ {problem_size.m(), problem_size.n(), -+ problem_size.k() / kSparse / kElementsPerElementE}); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); -+ tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ tensor_E_reordered.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ if (tensor_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ -+ if (reference_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\nE =\n" << tensor_E.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), -+ tensor_E.host_ref(), problem_size.m(), problem_size.k()); -+ -+ cutlass::reference::host::Gemm< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, -+ ElementCompute, -+ ElementAccumulator, typename Gemm::Operator> -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, -+ alpha, -+ tensor_A_uncompressed.host_ref(), -+ tensor_B.host_ref(), -+ beta, -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ int split_k_slices = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D.device_ref(), -+ tensor_E_reordered.device_ref(), -+ {alpha, beta}, -+ split_k_slices -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ // This failure is likely due to insufficient device capabilities. Waive the test. -+ if (status != cutlass::Status::kSuccess) { -+ return true; -+ } -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ if (!passed) { -+ std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllSparseGemm() { -+ bool passed = true; -+ -+ int const kMinimumOperandElementSize = -+ std::min( -+ int(cutlass::sizeof_bits::value), -+ int(cutlass::sizeof_bits::value)); -+ -+ // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) -+ // because of the reordering of operand E -+ int const kAlignmentM = std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), -+ kMinimumOperandElementSize); -+ -+ int const kAlignmentN = 128 / kMinimumOperandElementSize; -+ -+ int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; -+ -+ int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; -+ -+ int problem_size_k[] = {Gemm::ThreadblockShape::kK, -+ Gemm::ThreadblockShape::kK * (Gemm::kStages + 1)}; -+ -+ int split_k_slices[] = { -+ 1, 2, 3 -+ }; -+ -+ double problem_alpha[] = { -+ 1 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ SparseTestbed testbed; -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int split_k : split_k_slices) { -+ -+ if (!Gemm::kSplitKSerial && split_k > 1) { -+ continue; -+ } -+ -+ if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { -+ continue; -+ } -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ -+ passed = testbed.run( -+ problem_size, -+ split_k, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_splitk.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_splitk.h -new file mode 100644 -index 0000000..73dda7e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_splitk.h -@@ -0,0 +1,218 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "testbed.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedSplitK : public Testbed { -+ -+ using Base = Testbed; -+ -+ using ElementCompute = typename Base::ElementCompute; -+ -+ // -+ // Methods -+ // -+ -+ TestbedSplitK( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ Base(init_A_, init_B_, init_C_, seed_) { } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ int split_k_slices, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, -+ this->tensor_A.device_ref(), -+ this->tensor_B.device_ref(), -+ this->tensor_C.device_ref(), -+ this->tensor_D.device_ref(), -+ {alpha, beta}, -+ split_k_slices -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Verify -+ // -+ -+ return this->verify(problem_size, alpha, beta); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllGemmSplitK() { -+ bool passed = true; -+ -+ cutlass::gemm::GemmCoord problem_sizes[] = { -+ {8, 8, 2048}, -+ {8, 8, 2056}, -+ {264, 72, 520}, -+ {264, 520, 120}, -+ {264, 520, 264} -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 4, 5, 7 -+ }; -+ -+ double problem_alpha[] = { -+ 0.5 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ using Testbed = TestbedSplitK; -+ using ElementCompute = typename Testbed::ElementCompute; -+ -+ Testbed testbed; -+ -+ for (auto problem_size : problem_sizes) { -+ for (int split_k_count : split_k_slices) { -+ for (double alpha : problem_alpha) { -+ for (double beta : problem_beta) { -+ -+ passed = testbed.run( -+ problem_size, -+ split_k_count, -+ ElementCompute(alpha), -+ ElementCompute(beta) -+ ); -+ -+ if (!passed) { -+ std::cout << "Failed on size " << problem_size << " with split_k_count " << split_k_count << std::endl; -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ EXPECT_TRUE(passed); -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_symm_universal.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_symm_universal.h -new file mode 100644 -index 0000000..1050a2e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_symm_universal.h -@@ -0,0 +1,592 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Symm update interface -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedSymmUniversal { -+ -+ using ElementA = typename Symm::ElementA; -+ using ElementB = typename Symm::ElementB; -+ using ElementC = typename Symm::ElementC; -+ using ElementAccumulator = typename Symm::ElementAccumulator; -+ using ElementCompute = typename Symm::SymmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedSymmUniversal( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ -+ EXPECT_TRUE(false) << "Input distribution not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_symmetric_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillSymmetricRandomUniform( -+ view, seed, Symm::kFillModeA, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillSymmetricRandomGaussian( -+ view, seed, Symm::kFillModeA, 0, 0.5, mantissa_in_bits); -+ } -+ else { -+ -+ EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the Symm workspace -+ // -+ -+ if (Symm::kSideModeA == cutlass::SideMode::kLeft) { -+ tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); -+ } -+ else if (Symm::kSideModeA == cutlass::SideMode::kRight) { -+ tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); -+ } -+ -+ tensor_B.resize(problem_size.mn()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Symm::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Symm::ElementB(1); -+ tensor_C.host_view().at({0, 0}) = typename Symm::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ if (tensor_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ -+ if (reference_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); -+ -+ bool passed = l2_norm < cutlass::MantissaInBits::error; -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a Symm -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ -+ using HostReference = typename cutlass::platform::conditional< -+ (cutlass::platform::is_same -+ >::value || -+ cutlass::platform::is_same -+ >::value -+ ), -+ cutlass::reference::host::SymmComplex< -+ typename Symm::ElementA, typename Symm::LayoutA, -+ Symm::kSideModeA, Symm::kFillModeA, -+ typename Symm::ElementB, typename Symm::LayoutB, -+ typename Symm::ElementC, typename Symm::LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ Symm::kBlasMode>, -+ cutlass::reference::host::Symm< -+ typename Symm::ElementA, typename Symm::LayoutA, -+ Symm::kSideModeA, Symm::kFillModeA, -+ typename Symm::ElementB, typename Symm::LayoutB, -+ typename Symm::ElementC, typename Symm::LayoutC, -+ ElementCompute, -+ ElementAccumulator> -+ >::type; -+ -+ -+ HostReference reference_symm; -+ -+ reference_symm( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ beta, -+ tensor_C.host_ref(), -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Symm::SymmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 -+ std::cout << "[TestbedSymmUniversal::run()] problem(m, n, k): " << problem_size -+ << " alpha: " << ElementCompute(alpha) -+ << " beta: " << ElementCompute(beta) << std::endl; -+#endif -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the Symm operator -+ // -+ -+ int batch_stride_A; -+ if (Symm::kSideModeA == cutlass::SideMode::kLeft) -+ batch_stride_A = problem_size.m()*problem_size.m(); -+ if (Symm::kSideModeA == cutlass::SideMode::kRight) -+ batch_stride_A = problem_size.n()*problem_size.n(); -+ -+ typename Symm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_D.device_data(), -+ batch_stride_A, -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0) -+ }; -+ -+ Symm symm_op; -+ -+ size_t workspace_size = Symm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = symm_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the Symm -+ // -+ -+ status = symm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ //if (true) { -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_" -+ << (Symm::kBlasMode == cutlass::BlasMode::kSymmetric ? "symm_" : "hemm_" ) -+ << "device_" -+ << "fill_mode_a_" -+ << (Symm::kSideModeA == cutlass::SideMode::kLeft ? "leftside_" : -+ (Symm::kSideModeA == cutlass::SideMode::kRight ? "rightside_" : "invalid_")) -+ << (Symm::kFillModeA == cutlass::FillMode::kLower ? "lower_" : -+ (Symm::kFillModeA == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) -+ << "mnk_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Symm::ThreadblockShape::kM << "x" -+ << Symm::ThreadblockShape::kN << "x" -+ << Symm::ThreadblockShape::kK << "_" -+ << Symm::WarpShape::kM << "x" -+ << Symm::WarpShape::kN << "x" -+ << Symm::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "alpha: " << ElementCompute(alpha) << "\n" -+ << "beta: " << ElementCompute(beta) << "\n" -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n" -+ << "\nD reference:\n" << reference_D.host_view() << "\n" -+ << "\nD computed:\n" << tensor_D.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestsymmUniversal( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count, -+ double alpha = 1.0, -+ double beta = 2.0) { -+ -+ bool passed = true; -+ -+ TestbedSymmUniversal testbed; -+ -+ using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ return passed; -+} -+ -+template -+bool TestAllSymmUniversal() { -+ bool passed = true; -+ -+ -+ int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); -+ -+ int const kAlignment = cutlass::platform::is_same< -+ typename Symm::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ // int8_t gemm alignment constraints -+ int const kAlignmentM = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentN = kAlignmentM; -+ -+ int const kAlignmentK = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value -+ ? 4 : kAlignment; -+ -+ cutlass::gemm::GemmUniversalMode modes[] = { -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ }; -+ -+ int problem_size_m[] = { -+ kAlignmentK, -+ Symm::ThreadblockShape::kK * Symm::kStages - kAlignmentK, -+ Symm::ThreadblockShape::kK * Symm::kStages * 3 - kAlignmentK -+ }; -+ -+ int problem_size_n[] = { -+ kAlignmentN, 512 - 2*kAlignmentN -+ }; -+ -+ int batch_counts[] = { // may be interpretted as batch count or split-K slices -+ 1 // Just running one batch for now (removing 2, 3, 5, 7) -+ }; -+ -+ double problem_alpha[] = { -+ 1.0, 3.0 -+ }; -+ -+ double problem_beta[] = { -+ 0, 2.0 -+ }; -+ -+ -+ using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; -+ -+ for (cutlass::gemm::GemmUniversalMode mode : modes) { -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int batch_count : batch_counts) { -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ int k = 0; -+ if (Symm::kSideModeA == cutlass::SideMode::kLeft) -+ k = m; -+ else if (Symm::kSideModeA == cutlass::SideMode::kRight) -+ k = n; -+ -+ if (mode == cutlass::gemm::GemmUniversalMode::kGemm || -+ mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { -+ -+ #if 0 -+ // skip very small K problems -+ if (k / batch_count < 2 * Symm::ThreadblockShape::kK) { -+ continue; -+ } -+ #endif -+ } -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ -+ TestbedSymmUniversal testbed; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_trmm_universal.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_trmm_universal.h -new file mode 100644 -index 0000000..db40eff ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_trmm_universal.h -@@ -0,0 +1,609 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/trmm_complex.h" -+#include "cutlass/core_io.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedTrmmUniversal { -+ -+ using ElementA = typename Trmm::ElementA; -+ using ElementB = typename Trmm::ElementB; -+ using ElementC = typename Trmm::ElementC; -+ using ElementAccumulator = typename Trmm::ElementAccumulator; -+ using ElementCompute = typename Trmm::TrmmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_D; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedTrmmUniversal( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_D_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_D(init_D_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_symmetric_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillSymmetricRandomUniform( -+ view, seed, Trmm::kFillMode, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillSymmetricRandomGaussian( -+ view, seed, Trmm::kFillMode, 0, 0.5, mantissa_in_bits); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Helper to initialize a tensor view (pad diagonal fill with zeros for up to alignment on wrong side of diagonal) -+ template -+ bool initialize_pad_diagonal_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int alignment) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillPadDiagonalRandomUniform( -+ view, seed, Trmm::kFillMode, scope_max, scope_min, 0, alignment); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ EXPECT_TRUE(false) << "Gaussian distribution for pad diagonal not implemented"; -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the TRMM workspace -+ // -+ -+ if (Trmm::kSideMode == cutlass::SideMode::kLeft) { -+ tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); -+ } -+ else if (Trmm::kSideMode == cutlass::SideMode::kRight) { -+ tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); -+ } -+ -+ tensor_B.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ -+ //EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2017)); -+ //EXPECT_TRUE(initialize_pad_diagonal_tensor(tensor_A.host_view(), init_A, seed + 2017, Trmm::kAlignmentA)); -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2017, cutlass::MantissaInBits::bits)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2019, cutlass::MantissaInBits::bits)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Trmm::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Trmm::ElementB(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_D.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ -+ if (tensor_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ -+ if (reference_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); -+ -+ bool passed = l2_norm < cutlass::MantissaInBits::error; -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a TRMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha) { -+ -+ // -+ // Verify -+ // -+ -+ using HostReference = typename cutlass::platform::conditional< -+ (cutlass::platform::is_same -+ >::value || -+ cutlass::platform::is_same -+ >::value -+ ), -+ cutlass::reference::host::TrmmComplex< -+ typename Trmm::ElementA, typename Trmm::LayoutA, -+ Trmm::kTransformA, -+ Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, -+ typename Trmm::ElementB, typename Trmm::LayoutB, -+ Trmm::kTransformB, -+ typename Trmm::ElementC, typename Trmm::LayoutC, -+ ElementCompute, -+ ElementAccumulator>, -+ cutlass::reference::host::Trmm< -+ typename Trmm::ElementA, typename Trmm::LayoutA, -+ Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, -+ typename Trmm::ElementB, typename Trmm::LayoutB, -+ typename Trmm::ElementC, typename Trmm::LayoutC, -+ ElementCompute, -+ ElementAccumulator> -+ >::type; -+ -+ -+ HostReference reference_trmm; -+ -+ reference_trmm( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ return compare_reference(problem_size, alpha); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Trmm::TrmmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementCompute alpha = ElementCompute(1)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 -+ std::cout << "[TestbedTrmmUniversal::run()] problem(m, n, k): " << problem_size -+ << " alpha: " << ElementCompute(alpha) << std::endl; -+#endif -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the TRMM operator -+ // -+ -+ int batch_stride_A; -+ if (Trmm::kSideMode == cutlass::SideMode::kLeft) -+ batch_stride_A = problem_size.m()*problem_size.m(); -+ if (Trmm::kSideMode == cutlass::SideMode::kRight) -+ batch_stride_A = problem_size.n()*problem_size.n(); -+ -+ typename Trmm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha}, -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_D.device_data(), -+ batch_stride_A, -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_D.layout().stride(0) -+ }; -+ -+ Trmm trmm_op; -+ -+ size_t workspace_size = Trmm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = trmm_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the TRMM -+ // -+ -+ status = trmm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ bool passed = this->verify(problem_size, alpha); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Trmm_device_" -+ << "fill_mode_" -+ << (Trmm::kFillMode == cutlass::FillMode::kLower ? "lower_" : -+ (Trmm::kFillMode == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) -+ << "side_mode_" -+ << (Trmm::kSideMode == cutlass::SideMode::kLeft ? "left_" : -+ (Trmm::kSideMode == cutlass::SideMode::kRight ? "right_" : "invalid_")) -+ << "mnk_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Trmm::ThreadblockShape::kM << "x" -+ << Trmm::ThreadblockShape::kN << "x" -+ << Trmm::ThreadblockShape::kK << "_" -+ << Trmm::WarpShape::kM << "x" -+ << Trmm::WarpShape::kN << "x" -+ << Trmm::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nD reference:\n" << reference_D.host_view() << "\n" -+ << "\nD computed:\n" << tensor_D.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestTrmmUniversal( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count, -+ double alpha = 1.0) { -+ -+ bool passed = true; -+ -+ TestbedTrmmUniversal testbed; -+ -+ using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha) -+ ); -+ -+ return passed; -+} -+ -+template -+bool TestAllTrmmUniversal() { -+ bool passed = true; -+ -+ int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); -+ -+ int const kAlignment = cutlass::platform::is_same< -+ typename Trmm::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ // int8_t gemm alignment constraints -+ int const kAlignmentM = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentN = kAlignmentM; -+ -+ int const kAlignmentK = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value -+ ? 4 : kAlignment; -+ -+ cutlass::gemm::GemmUniversalMode modes[] = { -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ }; -+ -+ int problem_size_m[] = { -+ kAlignmentK, -+ Trmm::ThreadblockShape::kK * Trmm::kStages - kAlignmentK, -+ Trmm::ThreadblockShape::kK * Trmm::kStages * 3 - kAlignmentK -+ }; -+ -+ int problem_size_n[] = { -+ kAlignmentN, 512 - 2*kAlignmentN -+ }; -+ -+ int batch_counts[] = { // may be interpretted as batch count or split-K slices -+ 1 // Just running one batch for now (removing 2, 3, 5, 7) -+ }; -+ -+ double problem_alpha[] = { -+ 1.0, 2.0 -+ }; -+ -+ using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; -+ -+ for (cutlass::gemm::GemmUniversalMode mode : modes) { -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int batch_count : batch_counts) { -+ for (auto alpha : problem_alpha) { -+ -+ int k = 0; -+ if (Trmm::kSideMode == cutlass::SideMode::kLeft) -+ k = m; -+ else if (Trmm::kSideMode == cutlass::SideMode::kRight) -+ k = n; -+ -+ if (mode == cutlass::gemm::GemmUniversalMode::kGemm || -+ mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { -+ -+#if 0 -+ // skip very small K problems -+ if (k / batch_count < 2 * Trmm::ThreadblockShape::kK) { -+ continue; -+ } -+#endif -+ } -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ -+ TestbedTrmmUniversal testbed; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_universal.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_universal.h -new file mode 100644 -index 0000000..615e9c5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_universal.h -@@ -0,0 +1,547 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedUniversal { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedUniversal( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ tensor_A.resize(problem_size.mk()); -+ tensor_B.resize(problem_size.kn()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ cutlass::Coord<2> origin(0); -+ tensor_A.host_view().at(origin) = typename Gemm::ElementA(1); -+ tensor_B.host_view().at(origin) = typename Gemm::ElementB(1); -+ tensor_C.host_view().at(origin) = typename Gemm::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed) << " mismatched reference"; -+ -+ if (!passed) { -+ -+ /* -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ */ -+ -+ std::ofstream file("testbed_universal_errors.txt"); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::GemmComplex< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ Gemm::kTransformA, -+ tensor_B.host_ref(), -+ Gemm::kTransformB, -+ beta, -+ tensor_C.host_ref(), -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ if (Relu) { -+ for (int i = 0; i < problem_size.m(); ++i) { -+ for (int j = 0; j < problem_size.n(); ++j) { -+ reference_D.at(cutlass::MatrixCoord(i, j)) = -+ ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) -+ ? (typename Gemm::ElementC)0 -+ : reference_D.at(cutlass::MatrixCoord(i, j)); -+ } -+ } -+ } -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) -+ { -+/* -+ std::cout << "\n-----------------------\n"; -+ std::cout << "mode: " << (int) mode << "\n"; -+ std::cout << "problem size: " << problem_size << "\n"; -+ std::cout << "batch_count: " << batch_count << "\n"; -+ std::cout << "alpha: " << alpha << "\n"; -+ std::cout << "beta: " << beta << "\n"; -+ std::cout << "-----------------------\n\n"; -+*/ -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_D.device_data(), -+ problem_size.m() * problem_size.k(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0) -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ if (!passed) { -+ std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestGemmUniversal( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count, -+ double alpha = 1.0, -+ double beta = 2.0) { -+ -+ bool passed = true; -+ -+ TestbedUniversal testbed; -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ return passed; -+} -+ -+template -+bool TestAllGemmUniversal() { -+ bool passed = true; -+ -+ -+ int const kMinimumOperandElementSize = -+ std::min( -+ int(cutlass::sizeof_bits::value), -+ int(cutlass::sizeof_bits::value)); -+ -+ int const kAlignment = cutlass::platform::is_same< -+ typename Gemm::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ // int8_t gemm alignment constraints -+ int const kAlignmentM = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentN = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentK = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ (cutlass::platform::is_same::value || -+ cutlass::platform::is_same::value) ? 4 : kAlignment; -+ -+ -+ -+ cutlass::gemm::GemmUniversalMode modes[] = { -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ }; -+ -+ int problem_size_m[] = { -+ kAlignmentM, 512 - 3*kAlignmentM -+ }; -+ -+ int problem_size_n[] = { -+ kAlignmentN, 512 - 2*kAlignmentN -+ }; -+ -+ int problem_size_k[] = { -+ kAlignmentK, -+ Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, -+ Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK -+ }; -+ -+ int batch_counts[] = { // may be interpretted as batch count or split-K slices -+ 1, 2, 3, 5, 7 -+ }; -+ -+ double problem_alpha[] = { -+ 1 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ for (cutlass::gemm::GemmUniversalMode mode : modes) { -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int batch_count : batch_counts) { -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ if (mode == cutlass::gemm::GemmUniversalMode::kGemm || -+ mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { -+ -+ // skip very small K problems -+ if (k / batch_count < 2 * Gemm::ThreadblockShape::kK) { -+ continue; -+ } -+ } -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ -+ TestbedUniversal testbed; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /* -+ // large problem with high coverage -+ for (int split_k_slices = 1; split_k_slices <= 3; ++split_k_slices) { -+ TestbedUniversal testbed; -+ -+ cutlass::gemm::GemmCoord problem_size(72, 56, 8192); -+ -+ passed = testbed.run( -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, -+ split_k_slices, -+ cutlass::from_real(1.0), -+ cutlass::from_real(2.0) -+ ); -+ -+ if (!passed) { -+ break; -+ } -+ } -+ */ -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_utils.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_utils.h -new file mode 100644 -index 0000000..e47ecda ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_utils.h -@@ -0,0 +1,53 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+inline char const *to_string(cutlass::Status status) { -+ -+ switch (status) { -+ case cutlass::Status::kSuccess: return "kSuccess"; -+ case cutlass::Status::kErrorMisalignedOperand: return "kErrorMisalignedOperand"; -+ case cutlass::Status::kErrorInvalidLayout: return "kErrorInvalidLayout"; -+ case cutlass::Status::kErrorInvalidProblem: return "kErrorInvalidProblem"; -+ case cutlass::Status::kErrorNotSupported: return "kErrorNotSupported"; -+ case cutlass::Status::kErrorWorkspaceNull: return "kErrorWorkspaceNull"; -+ case cutlass::Status::kErrorInternal: return "kErrorInternal"; -+ case cutlass::Status::kInvalid: return "kInvalid"; -+ default: break; -+ } -+ return "invalid"; -+} -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_cf32n_cf32n_cf32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf32n_cf32n_cf32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..1e31402 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf32n_cf32n_cf32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,305 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_l_nu_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_u_nu_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_rs_u_nu_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_l_un_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_u_un_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_rs_u_un_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_cf32n_cf32n_cf32t_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf32n_cf32n_cf32t_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..8dc41a4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf32n_cf32n_cf32t_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,305 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_l_nu_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_u_nu_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_rs_u_nu_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_l_un_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_u_un_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_rs_u_un_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..437bed5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Trmm_cf64n_cf64n_cf64t_ls_u_nu_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Trmm_cf64h_cf64n_cf64t_ls_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64n_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64n_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..b26a8d2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64n_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64n_cf64n_cf64t_ls_u_nu_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64h_cf64n_cf64t_ls_u_nu_tensor_op_f64_gaussian, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64n_cf64n_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64n_cf64n_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..2db3d2c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64n_cf64n_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,301 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64n_cf64n_cf64t_ls_l_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64n_cf64n_cf64t_ls_u_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64h_cf64n_cf64t_ls_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64n_cf64n_cf64t_ls_l_un_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64n_cf64n_cf64t_ls_u_un_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64h_cf64n_cf64t_ls_u_un_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..d8ad244 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,500 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Trmm_{ElementA}{LayoutA}_{ElementB}{LayoutB}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _{DiagType}_tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_l_nu_tensor_op_fast_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_l_nu_tensor_op_fast_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_l_nu_tensor_op_fast_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_l_nu_tensor_op_fast_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_l_nu_tensor_op_fast_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_l_nu_tensor_op_fast_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_rs_sm80.cu -new file mode 100644 -index 0000000..a9ed921 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_rs_sm80.cu -@@ -0,0 +1,252 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Trmm_{ElementA}{LayoutA}_{ElementB}{LayoutB}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _{DiagType}_tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_rs_u_nu_tensor_op_fast_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_rs_u_nu_tensor_op_fast_f32_align1_align1, 128x64x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_rs_l_nu_tensor_op_fast_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kLower, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_rs_u_nu_tensor_op_fast_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_rs_u_nu_tensor_op_fast_f32_align1_align4, 128x64x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_rs_l_nu_tensor_op_fast_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kLower, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f32t_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32t_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..56c6396 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32t_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,449 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align1, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, cutlass::FillMode::kLower, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_u_nu_tensor_op_fast_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_u_nu_tensor_op_fast_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_u_nu_tensor_op_fast_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_u_nu_tensor_op_fast_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_u_nu_tensor_op_fast_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f32t_f32n_f32t_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32t_f32n_f32t_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..9217ebd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32t_f32n_f32t_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,458 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Trmm_{ElementA}{LayoutA}_{ElementB}{LayoutB}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _{DiagType}_tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_l_un_tensor_op_fast_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_l_un_tensor_op_fast_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_l_un_tensor_op_fast_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_l_un_tensor_op_fast_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_l_nu_tensor_op_fast_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_u_un_tensor_op_fast_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..5339bc5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64n_f64t_tensor_op_f64_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64n_f64t_tensor_op_f64_ls_sm80.cu -new file mode 100644 -index 0000000..0dd9064 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64n_f64t_tensor_op_f64_ls_sm80.cu -@@ -0,0 +1,414 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_l_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_l_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_l_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_l_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_l_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_u_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_u_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_u_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_u_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64n_f64t_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64n_f64t_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..f00e50c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64n_f64t_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,415 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_u_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_u_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_u_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_u_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64t_f64t_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64t_f64t_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..98f2a57 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64t_f64t_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,415 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_l_un_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_l_un_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_l_un_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_l_un_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_l_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_u_un_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_u_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_u_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_u_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f64t_f64t_f64n_tensor_op_f64_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64t_f64t_f64n_tensor_op_f64_ls_sm80.cu -new file mode 100644 -index 0000000..bb4443d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64t_f64t_f64n_tensor_op_f64_ls_sm80.cu -@@ -0,0 +1,414 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_l_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_l_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_l_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_l_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_l_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_u_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_u_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_u_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_u_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f64t_f64t_f64n_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64t_f64t_f64n_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..dd07d78 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64t_f64t_f64n_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,415 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_u_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_u_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_u_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_u_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32n_tf32t_f32t_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32n_tf32t_f32t_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..97106d8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32n_tf32t_f32t_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,500 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Trmm_{ElementA}{LayoutA}_{ElementB}{LayoutB}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _{DiagType}_tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_l_nu_tensor_op_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_l_nu_tensor_op_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_l_nu_tensor_op_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_l_nu_tensor_op_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_l_nu_tensor_op_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_l_nu_tensor_op_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_u_nu_tensor_op_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_u_nu_tensor_op_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_u_nu_tensor_op_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32n_tf32t_f32t_tensor_op_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32n_tf32t_f32t_tensor_op_f32_rs_sm80.cu -new file mode 100644 -index 0000000..723af64 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32n_tf32t_f32t_tensor_op_f32_rs_sm80.cu -@@ -0,0 +1,252 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Trmm_{ElementA}{LayoutA}_{ElementB}{LayoutB}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _{DiagType}_tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_rs_u_nu_tensor_op_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_rs_u_nu_tensor_op_f32_align1_align1, 128x64x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_rs_l_nu_tensor_op_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kLower, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_rs_u_nu_tensor_op_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_rs_u_nu_tensor_op_f32_align1_align4, 128x64x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_rs_l_nu_tensor_op_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kLower, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32t_tf32n_f32n_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32t_tf32n_f32n_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..ebb427b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32t_tf32n_f32n_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,449 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_l_nu_tensor_op_f32_align1_align1, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, cutlass::FillMode::kLower, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_l_nu_tensor_op_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_l_nu_tensor_op_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_l_nu_tensor_op_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_l_nu_tensor_op_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_u_nu_tensor_op_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_u_nu_tensor_op_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_u_nu_tensor_op_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_u_nu_tensor_op_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_u_nu_tensor_op_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32t_tf32n_f32t_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32t_tf32n_f32t_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..ba3f7f3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32t_tf32n_f32t_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,458 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Trmm_{ElementA}{LayoutA}_{ElementB}{LayoutB}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _{DiagType}_tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_l_un_tensor_op_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_l_un_tensor_op_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_l_un_tensor_op_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_l_un_tensor_op_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_l_nu_tensor_op_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_un_tensor_op_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/kernel/batched_gemv.cu b/3rdparty/cutlass/test/unit/gemm/kernel/batched_gemv.cu -new file mode 100755 -index 0000000..4e06485 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/kernel/batched_gemv.cu -@@ -0,0 +1,1082 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include "testbed_gemv.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcr_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcr_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcr_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcr_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x27x4096_1x8x1x64_1x1x1x64_rcr_alpha_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 27, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 1>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size, -0.5f); -+} -+ -+TEST(SM50_batched_gemv, 1x64x27x4096_1x8x1x64_1x1x1x64_rcr_alpha_beta_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 27, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 1>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size, 4.5f, -0.5f); -+} -+ -+TEST(SM50_batched_gemv, 1x64x24x4096_1x8x4x64_1x1x4x64_rcr_alpha_beta_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 24, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size, cutlass::half_t(4.5f), cutlass::half_t(-0.5f)); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcr_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcr_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcr_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcr_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcr_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcr_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcr_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcr_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcr_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcr_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcr_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcr_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+///////////// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_crc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_crc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_crc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_crc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_crc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_crc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_crc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_crc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_crc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_crc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_crc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_crc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_crc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_crc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_crc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_crc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x27x4096_1x8x1x64_1x1x1x64_crc_alpha_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 27, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 1>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size, -0.5f); -+} -+ -+TEST(SM50_batched_gemv, 1x64x27x4096_1x8x1x64_1x1x1x64_crc_alpha_beta_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 27, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 1>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size, 4.5f, -0.5f); -+} -+ -+TEST(SM50_batched_gemv, 1x64x24x4096_1x8x4x64_1x1x4x64_crc_alpha_beta_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 24, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size, cutlass::half_t(4.5f), cutlass::half_t(-0.5f)); -+} -+ -+///////////// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x27x4096_1x8x1x64_1x1x1x64_rcc_alpha_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 27, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 1>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size, -0.5f); -+} -+ -+TEST(SM50_batched_gemv, 1x64x27x4096_1x8x1x64_1x1x1x64_rcc_alpha_beta_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 27, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 1>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size, 4.5f, -0.5f); -+} -+ -+TEST(SM50_batched_gemv, 1x64x24x4096_1x8x4x64_1x1x4x64_rcc_alpha_beta_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 24, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size, cutlass::half_t(4.5f), cutlass::half_t(-0.5f)); -+} -diff --git a/3rdparty/cutlass/test/unit/gemm/kernel/testbed_gemv.h b/3rdparty/cutlass/test/unit/gemm/kernel/testbed_gemv.h -new file mode 100755 -index 0000000..dc551ef ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/kernel/testbed_gemv.h -@@ -0,0 +1,358 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "cutlass/gemm/kernel/default_gemv.h" -+#include "cutlass/gemm/kernel/gemv_batched_strided.h" -+ -+namespace test { -+namespace gemm { -+namespace kernel { -+ -+template -+void batched_gemv_kernel_test(cutlass::gemm::BatchedGemmCoord problem_size, -+ ElementCD_ alpha = ElementCD_(1), -+ ElementCD_ beta = ElementCD_(0), -+ bool perf_test = false, -+ int perf_test_iter = 1) -+{ -+ using ThreadBlockShape = ThreadBlockShape_; -+ using ThreadShape = ThreadShape_; -+ using ElementA = ElementAB_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementAB_; -+ using LayoutB = LayoutB_; -+ using ElementAccumulator = ElementCD_; -+ using ElementCD = ElementCD_; -+ using LayoutCD = LayoutCD_; -+ -+ using GemvKernel = cutlass::gemm::kernel::DefaultGemv; -+ -+ using ThreadBlockGemv = typename GemvKernel::ThreadBlockGemv; -+ using ThreadBlockSwizzle = typename GemvKernel::ThreadBlockSwizzle; -+ -+ if (DEBUG) -+ { -+ problem_size = cutlass::gemm::BatchedGemmCoord( -+ problem_size.m(), problem_size.n(), problem_size.k(), 1); -+ } -+ -+ // Create host tensors that will be the backing store for the batches -+ // Note that no device memory is initially allocated -+ cutlass::HostTensor matrix_A({problem_size.m(), problem_size.k()}, false); -+ cutlass::HostTensor matrix_B({problem_size.k(), problem_size.n()}, false); -+ cutlass::HostTensor matrix_C_computed({problem_size.m(), problem_size.n()}, false); -+ cutlass::HostTensor matrix_C_reference({problem_size.m(), problem_size.n()}, false); -+ -+ // Reserve memory for the batch of tensors -+ matrix_A.reserve(problem_size.m()*problem_size.k()*problem_size.batch()); -+ matrix_B.reserve(problem_size.n()*problem_size.k()*problem_size.batch()); -+ matrix_C_computed.reserve(problem_size.m()*problem_size.n()*problem_size.batch()); -+ matrix_C_reference.reserve(problem_size.m()*problem_size.n()*problem_size.batch(), false); -+ -+ // Fill eatch tensor batch -+ const int seed = 9876; -+ for (int b = 0; b < problem_size.batch(); b++) -+ { -+ if(DEBUG) -+ { -+ cutlass::reference::host::BlockFillSequential( -+ matrix_A.host_data_ptr_offset(b*matrix_A.capacity()), matrix_A.capacity()); -+ cutlass::reference::host::BlockFillSequential( -+ matrix_B.host_data_ptr_offset(b*matrix_B.capacity()), matrix_B.capacity()); -+ } -+ else -+ { -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(b*matrix_A.capacity()), -+ seed + 1660, -+ 8, -+ -8, -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(b*matrix_B.capacity()), -+ seed + 1880, -+ 8, -+ -8, -+ 0 -+ ); -+ } -+ -+ cutlass::reference::host::TensorFill(matrix_C_computed.host_view(b*matrix_C_computed.capacity())); -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view(b*matrix_C_reference.capacity())); -+ } -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ matrix_C_computed.sync_device(); -+ -+ ThreadBlockSwizzle swizzle; -+ -+ cutlass::gemm::BatchedGemmCoord tiled_size{ThreadBlockShape::kM, -+ ThreadBlockShape::kN, -+ problem_size.k(), // no split-k -+ DEBUG ? 1 : THREAD_B }; -+ -+ cutlass::gemm::BatchedGemmCoord tiled_shape = swizzle.get_tiled_shape(problem_size, tiled_size); -+ -+ #if 0 -+ printf("tiled_size = %d %d %d %d\n", tiled_size.m(), tiled_size.n(), tiled_size.k(), tiled_size.batch()); -+ printf("tiled_shape = %d %d %d %d\n", tiled_shape.m(), tiled_shape.n(), tiled_shape.k(), tiled_shape.batch()); -+ #endif -+ -+ // No split-k -+ EXPECT_EQ(tiled_size.k(), problem_size.k()); -+ -+ dim3 grid = swizzle.get_grid_shape(tiled_shape); -+ dim3 block(tiled_size.n() / ThreadShape::kN, tiled_size.batch(), tiled_size.k() / problem_size.k()); -+ -+ // Some sanity checks -+ EXPECT_TRUE( block.x*block.y*block.z <= 1024 ); -+ EXPECT_TRUE( block.x <= 1024 ); -+ EXPECT_TRUE( block.y <= 1024 ); -+ EXPECT_TRUE( block.z <= 64 ); -+ -+ #if 0 -+ printf("grid dim = %d, %d, %d\n", grid.x, grid.y, grid.z); -+ printf("block dim = %d, %d, %d\n", block.x, block.y, block.z); -+ #endif -+ -+ cudaError_t result; -+ cudaEvent_t start_event, end_event; -+ -+ for (int iter = 0; iter < (perf_test ? (perf_test_iter+1) : 1); ++iter) -+ { -+ if (perf_test && iter == 1) -+ { -+ result = cudaEventCreate(&start_event); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ result = cudaEventCreate(&end_event); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ result = cudaEventRecord(start_event); -+ EXPECT_EQ(result, cudaSuccess); -+ } -+ -+ if (beta == ElementCD(0)) -+ { -+ if (alpha == ElementCD(1)) -+ { -+ cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( -+ problem_size, -+ matrix_A.device_ref(), -+ matrix_A.capacity(), -+ matrix_B.device_ref(), -+ matrix_B.capacity(), -+ matrix_C_computed.device_ref(), -+ matrix_C_computed.capacity() -+ ); -+ } -+ else -+ { -+ cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( -+ problem_size, -+ alpha, -+ matrix_A.device_ref(), -+ matrix_A.capacity(), -+ matrix_B.device_ref(), -+ matrix_B.capacity(), -+ matrix_C_computed.device_ref(), -+ matrix_C_computed.capacity() -+ ); -+ } -+ } -+ else -+ { -+ cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( -+ problem_size, -+ alpha, -+ beta, -+ matrix_A.device_ref(), -+ matrix_A.capacity(), -+ matrix_B.device_ref(), -+ matrix_B.capacity(), -+ matrix_C_computed.device_ref(), -+ matrix_C_computed.capacity(), -+ matrix_C_computed.device_ref(), -+ matrix_C_computed.capacity() -+ ); -+ } -+ -+ if (iter == 0) -+ { -+ result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); -+ } -+ } -+ -+ if (perf_test) -+ { -+ result = cudaEventRecord(end_event); -+ EXPECT_EQ(result, cudaSuccess); -+ } -+ -+ result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); -+ -+ if (perf_test) -+ { -+ float ms; -+ result = cudaEventElapsedTime(&ms, start_event, end_event); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ double flops = (double(problem_size.m()) * -+ double(problem_size.n()) * -+ double(problem_size.k()) * -+ double(problem_size.batch()) * 2); // 2 for MAC -+ -+ double read_bytes = double(problem_size.batch()) * (sizeof(ElementA)*double(problem_size.m())*double(problem_size.k()) + -+ sizeof(ElementB)*double(problem_size.k())*double(problem_size.n())); -+ -+ double write_bytes = double(problem_size.batch()) * (sizeof(ElementCD)*double(problem_size.m())*double(problem_size.n())); -+ -+ double avg_runtime = double(ms) / perf_test_iter; -+ double gflops_per_sec = flops / 1.0e6 / avg_runtime; -+ double read_bandwidth = read_bytes / 1.0e6 / avg_runtime; -+ double write_bandwidth = write_bytes / 1.0e6 / avg_runtime; -+ -+ std::cout << "\n\nProblem size: " -+ << problem_size.m() -+ << " x " << problem_size.n() -+ << " x " << problem_size.k() -+ << " x " << problem_size.batch() -+ << std::endl; -+ -+ std::cout << " GFLOPs: " << gflops_per_sec << std::endl; -+ std::cout << "BW (R/W): " << read_bandwidth << " / " << write_bandwidth << " GB/sec" << std::endl; -+ std::cout << " Runtime: " << avg_runtime << " ms" << std::endl; -+ } -+ else -+ { -+ matrix_C_computed.sync_host(); -+ -+ // Compute the batched gemms -+ for (int b = 0; b < problem_size.batch(); b++) -+ { -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size.mnk(), alpha, -+ matrix_A.host_ref(b * matrix_A.capacity()), -+ matrix_B.host_ref(b * matrix_B.capacity()), beta, -+ matrix_C_reference.host_ref(b * matrix_C_computed.capacity())); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed.host_view(b * matrix_C_computed.capacity()), -+ matrix_C_reference.host_view(b * matrix_C_reference.capacity())); -+ -+ EXPECT_TRUE(passed) -+ //<< "A:\n" << matrix_A.host_view() << "\n" -+ //<< "B:\n" << matrix_B.host_view() << "\n" -+ << "Batch: " << b << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view(b * matrix_C_reference.capacity()) -+ << "\n" -+ << "Computed:\n" -+ << matrix_C_computed.host_view(b * matrix_C_computed.capacity()) -+ << "\n"; -+ } -+ } -+} -+ -+template -+void batched_gemv_kernel_perf_test(cutlass::gemm::BatchedGemmCoord problem_size, -+ ElementCD_ alpha = ElementCD_(1), -+ ElementCD_ beta = ElementCD_(0), -+ int iter = 50) -+{ -+ batched_gemv_kernel_test(problem_size, alpha, beta, true, iter); -+} -+ -+} // namespace threadblock -+} // namespace kernel -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm50.cu b/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm50.cu -new file mode 100644 -index 0000000..1ac6ea5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm50.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Sgemm_thread, col_row_3x4x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<3, 4, 2>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_row_4x4x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<4, 4, 2>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, row_col_4x4x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<4, 4, 2>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_row_4x5x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<4, 5, 3>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_row) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, row_col) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_col) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, row_row) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Dgemm_thread, col_row) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Dgemm_thread, row_col) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm60.cu b/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm60.cu -new file mode 100644 -index 0000000..23099b2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm60.cu -@@ -0,0 +1,499 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Compute capability SM60 -+// -+ -+TEST(SM60_Hgemm_thread, col_row_col_1x1x16) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<1, 1, 16>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_row_1x1x16) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<1, 1, 16>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_col_1x3x8) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<1, 3, 8>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_row_7x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_row_7x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_row_7x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_row_7x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_row_7x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_row_7x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_row_7x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_row_7x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_col_16x3x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_col_16x3x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_col_16x3x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_col_16x3x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_col_16x3x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_col_16x3x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_col_16x3x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_col_16x3x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_row_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_col_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_row_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+}TEST(SM60_Hgemm_thread, row_col_col_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_row_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_col_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_row_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_col_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_row_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_col_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_row_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_col_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_row_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_col_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_row_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_col_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm61.cu b/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm61.cu -new file mode 100644 -index 0000000..68f9110 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm61.cu -@@ -0,0 +1,87 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Compute capability SM61 -+// -+ -+TEST(SM61_Igemm_thread, col_row_1x1x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int32_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM61_Igemm_thread, col_row_2x3x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 3, 4>, -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int32_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM61_Igemm_thread, col_row_8x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int32_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/thread/host/gemm_sm60_host.cu b/3rdparty/cutlass/test/unit/gemm/thread/host/gemm_sm60_host.cu -new file mode 100644 -index 0000000..5b1b5da ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/thread/host/gemm_sm60_host.cu -@@ -0,0 +1,176 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "testbed_host.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Compute capability SM60 -+// -+ -+TEST(SM60_host_Hgemm_thread, col_row_col_1x1x16) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<1, 1, 16>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, row_col_row_1x1x16) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<1, 1, 16>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, row_row_row_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, row_row_col_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, row_col_row_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, row_col_col_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, col_row_row_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, col_row_col_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, col_col_row_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, col_col_col_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/thread/host/testbed_host.h b/3rdparty/cutlass/test/unit/gemm/thread/host/testbed_host.h -new file mode 100644 -index 0000000..bd78947 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/thread/host/testbed_host.h -@@ -0,0 +1,232 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#pragma once -+ -+#include "cutlass/gemm/thread/mma.h" -+#include "cutlass/layout/vector.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace test { -+namespace gemm { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Thread-level matrix multiply-accumulate -+template -+void kernel( -+ typename Mma::ElementC *D, -+ typename Mma::ElementA const *A, -+ typename Mma::ElementB const *B, -+ typename Mma::ElementC const *C) { -+ -+ auto ptr_D = reinterpret_cast *>(D); -+ auto ptr_A = reinterpret_cast const *>(A); -+ auto ptr_B = reinterpret_cast const *>(B); -+ auto ptr_C = reinterpret_cast const *>(C); -+ -+ Mma mma; -+ -+ auto a = *ptr_A; -+ auto b = *ptr_B; -+ auto c = *ptr_C; -+ -+ using Btype = typename Mma::ElementB; -+ cutlass::Array d; -+ -+ mma(d, a, b, c); -+ -+ *ptr_D = d; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC -+> -+struct Testbed { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = cutlass::gemm::thread::Mma< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC -+ >; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed() { -+ -+ tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK), false); -+ tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN), false); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ /// Runs the test -+ bool run() { -+ -+ // -+ // initialize device memory -+ // -+ -+ cutlass::reference::host::detail::RandomUniformFunc< ElementA > tfill_rand_func( -+ 0, // seed -+ 10, // max -+ 0, // min -+ 0); // bits after decimal -+ -+ cutlass::reference::host::detail::TensorFillRandomUniformFunc< ElementA, LayoutA > tfill_rand( -+ tensor_A.host_view(), -+ tfill_rand_func); -+ -+ for (auto i=0; i< Shape::kM; i++) -+ for (auto j=0; j< Shape::kK; j++) -+ tfill_rand(cutlass::make_Coord(i,j)); -+ -+ cutlass::reference::host::BlockFillSequential( -+ tensor_B.host_data(), -+ tensor_B.capacity(), -+ ElementB(1), -+ ElementB(2) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ -+ // Host side call -+ kernel( -+ tensor_D_computed.host_data(), -+ tensor_A.host_data(), -+ tensor_B.host_data(), -+ tensor_C.host_data()); -+ -+ // -+ // Reference implementation -+ // -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ {Shape::kM, Shape::kN, Shape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ ElementC(0), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed) -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ -+ -+ return passed; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/thread/testbed.h b/3rdparty/cutlass/test/unit/gemm/thread/testbed.h -new file mode 100644 -index 0000000..c5ad60f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/thread/testbed.h -@@ -0,0 +1,236 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#pragma once -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace test { -+namespace gemm { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Thread-level matrix multiply-accumulate -+template -+__global__ void kernel( -+ typename Mma::ElementC *D, -+ typename Mma::ElementA const *A, -+ typename Mma::ElementB const *B, -+ typename Mma::ElementC const *C) { -+ -+ auto ptr_D = reinterpret_cast *>(D); -+ auto ptr_A = reinterpret_cast const *>(A); -+ auto ptr_B = reinterpret_cast const *>(B); -+ auto ptr_C = reinterpret_cast const *>(C); -+ -+ Mma mma; -+ -+ auto a = *ptr_A; -+ auto b = *ptr_B; -+ auto c = *ptr_C; -+ -+ cutlass::Array d; -+ -+ mma(d, a, b, c); -+ -+ *ptr_D = d; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC -+> -+struct Testbed { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = cutlass::gemm::thread::Mma< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC -+ >; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed() { -+ -+ tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); -+ tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ /// Runs the test -+ bool run() { -+ -+ // -+ // initialize device memory -+ // -+ -+ cutlass::reference::host::BlockFillSequential( -+ tensor_A.host_data(), -+ tensor_A.capacity() -+ ); -+ -+ cutlass::reference::host::BlockFillSequential( -+ tensor_B.host_data(), -+ tensor_B.capacity(), -+ ElementB(1), -+ ElementB(2) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ -+ // launch kernel -+ kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( -+ tensor_D_computed.device_data(), -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data()); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ -+ //tensor_D_reference.fill(tensor_C.host_view()); -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ {Shape::kM, Shape::kN, Shape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ ElementC(0), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed) -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/batched_gemv.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/batched_gemv.cu -new file mode 100644 -index 0000000..28b49f4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/batched_gemv.cu -@@ -0,0 +1,646 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for threadblock level GEMV -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "cutlass/gemm/threadblock/gemv.h" -+#include "cutlass/gemm/threadblock/default_gemv_core.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void batched_gemv_threadblock_test_kernel( -+ cutlass::gemm::GemmCoord problem_size, -+ LongIndex stride_a, -+ LongIndex stride_b, -+ LongIndex stride_c, -+ RefA ref_A, -+ RefB ref_B, -+ RefC ref_C -+ ) { -+ -+ typename Gemv::IteratorA::TensorCoord threadblock_offset_A(0, 0); -+ typename Gemv::IteratorB::TensorCoord threadblock_offset_B(0, 0); -+ typename Gemv::IteratorB::TensorCoord threadblock_offset_C(0, 0); -+ -+ // Move to the right batches for these threads -+ ref_A.add_pointer_offset(threadIdx.y * stride_a); -+ ref_B.add_pointer_offset(threadIdx.y * stride_b); -+ ref_C.add_pointer_offset(threadIdx.y * stride_c); -+ -+ // Construct iterators to A and B operands -+ typename Gemv::IteratorA::Params params_A(ref_A.layout()); -+ typename Gemv::IteratorA iterator_A(params_A, ref_A.data(), { problem_size.m(), problem_size.k() }, 0, threadblock_offset_A); -+ typename Gemv::IteratorB::Params params_B(ref_B.layout()); -+ typename Gemv::IteratorB iterator_B(params_B, ref_B.data(), { problem_size.k(), problem_size.n() }, threadIdx.x, threadblock_offset_B); -+ -+ Gemv gemv; -+ -+ typename Gemv::FragmentC accum; -+ accum.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ gemv(problem_size, accum, iterator_A, iterator_B, accum); -+ -+ // IteratorC is PitchLinear<> assumes n() contiguous -+ typename Gemv::IteratorC::Params params_C(ref_C.layout()); -+ typename Gemv::IteratorC iterator_C(params_C, ref_C.data(), { problem_size.m(), problem_size.n() }, threadIdx.x, threadblock_offset_C); -+ iterator_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void batched_gemv_threadblock_test(cutlass::gemm::GemmCoord problem_size, int num_batch) -+{ -+ using Shape = Shape_; -+ using ElementA = ElementAB_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementAB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using ThreadShape = cutlass::gemm::GemmShape<1, THREAD_N, THREAD_K>; -+ -+ using Core = typename cutlass::gemm::threadblock::DefaultGemvCore< -+ Shape, -+ ThreadShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC -+ >; -+ -+ if (DEBUG) -+ { -+ num_batch = 1; -+ } -+ -+ using Mma = cutlass::gemm::threadblock::Gemv; -+ -+ // Create host tensors that will be the backing store for the batches -+ // Note that no device memory is initially allocated -+ cutlass::HostTensor matrix_A({problem_size.m(), problem_size.k()}, false); -+ cutlass::HostTensor matrix_B({problem_size.k(), problem_size.n()}, false); -+ cutlass::HostTensor matrix_C_computed({problem_size.m(), problem_size.n()}, false); -+ cutlass::HostTensor matrix_C_reference({problem_size.m(), problem_size.n()}, false); -+ -+ // Reserve memory for the batch of tensors -+ matrix_A.reserve(problem_size.m()*problem_size.k()*num_batch); -+ matrix_B.reserve(problem_size.n()*problem_size.k()*num_batch); -+ matrix_C_computed.reserve(problem_size.m()*problem_size.n()*num_batch); -+ matrix_C_reference.reserve(problem_size.m()*problem_size.n()*num_batch, false); -+ -+ // Fill eatch tensor batch -+ const int seed = 6834; -+ for (int b = 0; b < num_batch; b++) -+ { -+ if(DEBUG) -+ { -+ cutlass::reference::host::BlockFillSequential( -+ matrix_A.host_data_ptr_offset(b*matrix_A.capacity()), matrix_A.capacity()); -+ cutlass::reference::host::BlockFillSequential( -+ matrix_B.host_data_ptr_offset(b*matrix_B.capacity()), matrix_B.capacity()); -+ } -+ else -+ { -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(b*matrix_A.capacity()), -+ seed + 1660, -+ 8, -+ -8, -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(b*matrix_B.capacity()), -+ seed + 1880, -+ 8, -+ -8, -+ 0 -+ ); -+ } -+ -+ cutlass::reference::host::TensorFill(matrix_C_computed.host_view(b*matrix_C_computed.capacity())); -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view(b*matrix_C_reference.capacity())); -+ } -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ matrix_C_computed.sync_device(); -+ -+ dim3 grid(1, 1); // only 1 CTA is used -+ dim3 block(Shape::kN / THREAD_N, num_batch, 1); -+ -+ #if 0 -+ printf("block dim = %d x %d\n", block.x, block.y); -+ #endif -+ -+ // Some sanity checks -+ EXPECT_TRUE( problem_size.n() % THREAD_N == 0 ); -+ EXPECT_TRUE( block.x*block.y <= MAX_THREADS_PER_BLOCK ); -+ -+ test::gemm::threadblock::batched_gemv_threadblock_test_kernel<<< grid, block >>>( -+ problem_size, -+ matrix_A.capacity(), -+ matrix_B.capacity(), -+ matrix_C_computed.capacity(), -+ matrix_A.device_ref(), -+ matrix_B.device_ref(), -+ matrix_C_computed.device_ref() -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); -+ -+ matrix_C_computed.sync_host(); -+ -+ // Compute the batched gemms -+ for (int b = 0; b < num_batch; b++) -+ { -+ -+ cutlass::reference::host::Gemm reference_gemm; -+ -+ reference_gemm( -+ problem_size.mnk(), -+ ElementC(1), -+ matrix_A.host_ref(b*matrix_A.capacity()), -+ matrix_B.host_ref(b*matrix_B.capacity()), -+ ElementC(0), -+ matrix_C_reference.host_ref(b*matrix_C_computed.capacity()) -+ ); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed.host_view(b*matrix_C_computed.capacity()), -+ matrix_C_reference.host_view(b*matrix_C_reference.capacity())); -+ -+ EXPECT_TRUE(passed) -+ //<< "A:\n" << matrix_A.host_view() << "\n" -+ //<< "B:\n" << matrix_B.host_view() << "\n" -+ << "Batch: " << b << "\n" -+ << "Reference:\n" << matrix_C_reference.host_view(b*matrix_C_reference.capacity()) << "\n" -+ << "Computed:\n" << matrix_C_computed.host_view(b*matrix_C_computed.capacity()) << "\n"; -+ } -+} -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// A: ColumnMajor -+// B: RowMajor -+// C: ColumnMajor -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp32_fp32_2N_2K) { -+ -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 2; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 5x1x128x128_crc_fp32_fp32_4N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 128, 128); -+ const int num_batch = 5; -+ const int THREAD_N = 4; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 128, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_fp32_fp32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp16_fp32_2N_2K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 2; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp16_fp32_2N_8K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 8; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_fp16_fp32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_i8_i32_2N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 128, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_i8_i32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+// A: RowMajor -+// B: ColumnMajor -+// C: RowMajor -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp32_fp32_2N_2K) { -+ -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 2; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 5x1x128x128_rcr_fp32_fp32_4N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 128, 128); -+ const int num_batch = 5; -+ const int THREAD_N = 4; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 128, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_fp32_fp32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp16_fp32_2N_2K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 2; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp16_fp32_2N_8K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 8; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_fp16_fp32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_i8_i32_2N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 128, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_i8_i32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+// A: RowMajor -+// B: ColumnMajor -+// C: ColumnMajor -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp32_fp32_2N_2K) { -+ -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 2; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 5x1x128x128_rcc_fp32_fp32_4N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 128, 128); -+ const int num_batch = 5; -+ const int THREAD_N = 4; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 128, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_fp32_fp32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp16_fp32_2N_2K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 2; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp16_fp32_2N_8K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 8; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_fp16_fp32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_i8_i32_2N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 128, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_i8_i32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/epilogue_workspace.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/epilogue_workspace.cu -new file mode 100644 -index 0000000..7e08723 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/epilogue_workspace.cu -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/epilogue/epilogue_workspace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel computes accumulator data and stores it out -+template -+__global__ void kernel_epilogue_workspace(typename Epilogue::Params params) { -+ -+ __shared__ typename Epilogue::SharedStorage shared_storage; -+ -+ int warp_id = threadIdx.y; -+ int lane_id = threadIdx.x; -+ -+ Epilogue epilogue(params, shared_storage, warp_id, lane_id); -+ -+ // -+ // Initialize accumulator tile -+ // -+ typename Epilogue::FragmentC accum; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Epilogue::FragmentC::kElements; ++i) { -+ accum[i] = Element(warp_id * blockDim.x + lane_id); -+ } -+ -+ // -+ // Efficient epilogue -+ // -+ -+ cutlass::GemmCoord tb_tile_coord{blockIdx.x, blockIdx.y, 0}; -+ -+ cutlass::GemmCoord problem_size = -+ tb_tile_coord * -+ cutlass::GemmCoord{Epilogue::Shape::kM, Epilogue::Shape::kN, 1}; -+ -+ // Store accumulators -+ epilogue( -+ problem_size, -+ tb_tile_coord, -+ accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_epilogue_workspace, tensor_op_128x128_64x64) { -+ -+ // -+ // Define an instance of the epilogue and see if it works -+ // -+ static int const kWarpCount = 4; -+ static int const kWarpSize = 32; -+ -+ using Shape = cutlass::MatrixShape<128, 128>; -+ using FragmentC = cutlass::Array; -+ -+ using Epilogue = cutlass::gemm::threadblock::EpilogueWorkspace< -+ Shape, -+ kWarpCount, -+ FragmentC -+ >; -+ -+ typename Epilogue::Params params( -+ -+ ); -+ -+ // Launch the kernel -+ dim3 grid(1,1); -+ dim3 block(kWarpSize, kWarpCount); -+ -+ test::gemm::threadblock::kernel_epilogue_workspace<<< grid, block >>>( -+ params -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << "Kernel launch error - " << cudaGetErrorString(result); -+ -+ // -+ // -+ // -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage.cu -new file mode 100644 -index 0000000..8025637 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage.cu -@@ -0,0 +1,3835 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit tests for threadblock-level GEMM -+*/ -+ -+#include "mma_multistage_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x64x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x64x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x128x64_32x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x128x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_256x256x384_128x128x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_512x256x384_256x128x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x64x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x64x32_64x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x128x32_32x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_256x256x384_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_512x256x768_256x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x64x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x64x32_64x32x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x128x32_32x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x128x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_256x256x192_128x128x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_512x256x384_256x128x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x64x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x64x16_64x32x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x128x16_32x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x128x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_256x256x192_128x128x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_512x256x384_256x128x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_32x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x64_32x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x384_128x128x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x768_256x128x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x32_32x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x32_64x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x32_32x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x384_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x768_256x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x32_32x32x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x32_64x32x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x32_32x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x192_128x128x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x192_256x128x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x16_32x32x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x16_64x32x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x16_32x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x192_128x128x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x192_256x128x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_64x64x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_32x32x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x128_64x32x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x128_32x64x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x128_64x64x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x768_128x128x128_64x64x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x768_256x128x128_64x64x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_64x64x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_32x32x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x64_64x32x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x64_32x64x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x768_128x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x768_256x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x256_64x64x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x256_32x32x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x256_64x32x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x256x256_32x64x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x256x256_64x64x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x1536_128x256x256_64x64x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x1536_256x256x256_64x64x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_64x64x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_32x32x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x128_64x32x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x256x128_32x64x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x256x128_64x64x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x1536_128x256x128_64x64x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x1536_256x256x128_64x64x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x1024_64x64x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x1024_32x32x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x1024_64x32x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x1024x1024_32x64x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x1024x1024_64x64x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x6144_128x1024x1024_64x64x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 6144); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x6144_256x1024x1024_64x64x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 6144); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x512_64x64x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x512_32x32x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x512_64x32x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x512_32x64x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x512_64x64x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x6144_128x128x512_64x64x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 6144); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x6144_256x128x512_64x64x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 6144); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x64x16_32x64x16_8x8x4_3stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 16); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x128x16_32x64x16_8x8x4_3stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ tensor_op_64x128x64_32x64x64_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ tensor_op_128x128x64_64x64x64_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ multicta_256x256x384_128x128x64_64x64x64_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ multicta_512x256x384_256x128x64_64x64x64_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ tensor_op_64x128x128_32x64x128_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ tensor_op_128x128x128_64x64x128_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ multicta_256x256x768_128x128x128_64x64x128_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ multicta_512x256x1536_256x128x128_64x64x128_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise_f64, -+ tensor_op_32x32x16_16x16x16_8x8x4_4stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+TEST(SM80_gemm_threadblock_crosswise_f64, -+ tensor_op_64x64x16_32x32x16_8x8x4_4stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+TEST(SM80_gemm_threadblock_crosswise_f64, -+ tensor_op_64x128x16_32x64x16_8x8x4_4stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+TEST(SM80_gemm_threadblock_crosswise_f64, -+ tensor_op_128x64x16_64x32x16_8x8x4_4stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+TEST(SM80_gemm_threadblock_crosswise_f64, -+ tensor_op_128x128x16_32x64x16_8x8x4_3stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_slicedk.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_slicedk.cu -new file mode 100644 -index 0000000..7418732 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_slicedk.cu -@@ -0,0 +1,111 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit tests for CTA-level GEMM specifically for sliced-k kernels (SM_61 and SM_75) -+*/ -+ -+#include "mma_multistage_testbed_slicedk.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Tensor Op GEMM for SM_80 -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous_sliced, tensor_op_128x64x256_tb128x64x64_warp64x64x32_16x8x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM80_gemm_threadblock_crosswise_sliced, tensor_op_128x64x256_tb128x64x64_warp64x64x32_16x8x16) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse.cu -new file mode 100644 -index 0000000..4bb98cd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse.cu -@@ -0,0 +1,2703 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit tests for threadblock-level GEMM -+*/ -+ -+#include "mma_multistage_sparse_testbed.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x64x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x64x64_32x32x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x64x64_64x32x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x128x64_32x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ multicta_256x256x768_128x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ multicta_512x256x768_256x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_32x32x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x64_64x32x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x64_32x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x768_128x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_512x256x768_256x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x64x128_64x64x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x64x128_64x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x128x128_32x64x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x128x128_64x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ multicta_256x256x768_128x128x128_64x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_64x64x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_32x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x128_64x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x128_32x64x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x128_64x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x768_128x128x128_64x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x64x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x64x32_32x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x64x32_64x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x128x32_32x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ multicta_256x256x384_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ multicta_512x256x384_256x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x32_32x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x32_64x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x32_32x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x384_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_512x256x384_256x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x64x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x64x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x128x64_32x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x128x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ multicta_256x256x384_128x128x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_32x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x64_32x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x384_128x128x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_64x64x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_32x32x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x128_64x32x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x128_32x64x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x128_64x64x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x1536_128x128x128_64x64x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_512x256x1536_256x128x128_64x64x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x256_64x64x256_16x8x64_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x256_32x32x256_16x8x64_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x256_64x32x256_16x8x64_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x256_32x64x256_16x8x64_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x256_64x32x256_16x8x64_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x1536_128x128x256_64x32x256_16x8x64_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x256_64x64x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x256_32x32x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x256_64x32x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x256_32x64x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x256_64x64x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x3072_128x128x256_64x64x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 3072); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_512x256x3072_256x128x256_64x64x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 3072); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x512_64x64x512_16x8x128_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x512_32x32x512_16x8x128_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x512_64x32x512_16x8x128_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x512_32x64x512_16x8x128_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x512_64x32x512_16x8x128_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x3072_128x128x512_64x32x512_16x8x128_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 3072); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h -new file mode 100644 -index 0000000..6e14745 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h -@@ -0,0 +1,438 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit testbed for kernel-level GEMM -+*/ -+ -+#pragma once -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/core_io.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/util/host_uncompress.h" -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_multistage_mma_sparse(cutlass::gemm::GemmCoord problem_size, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Mma::ElementC *ptr_C, -+ typename Mma::LayoutC::Stride::Index ldc, -+ typename Mma::IteratorE::Params params_E, -+ typename Mma::IteratorE::TensorRef ref_E) { -+ // Shared storage needed by threadblock-scoped matrix multiply- -+ // Dynamic shared memory base pointer -+ extern __shared__ int GemmSharedStorageBase[]; -+ -+ // Declare pointer to dynamic shared memory. -+ typename Mma::SharedStorage *shared_storage = -+ reinterpret_cast(GemmSharedStorageBase); -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), -+ 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k() / Mma::kSparse}; -+ -+ cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN}; -+ -+ cutlass::MatrixCoord tb_offset_E{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k() / Mma::kSparse}; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A(params_A, ref_A.data(), -+ {problem_size.m(), problem_size.k() / Mma::kSparse}, -+ tb_thread_id, tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B(params_B, ref_B.data(), -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ typename Mma::IteratorE iterator_E( -+ params_E, ref_E.data(), -+ {problem_size.m(), -+ problem_size.k() / Mma::kSparse / Mma::kElementsPerElementE}, -+ tb_thread_id, tb_offset_E); -+ -+ int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); -+ -+ typename Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_E, accum); -+ -+ // Output results -+ typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); -+ -+ iterator_C.add_tile_offset( -+ {(tb_tile_offset.m() * Mma::WarpCount::kM) + -+ (warp_id % Mma::WarpCount::kM), -+ (tb_tile_offset.n() * Mma::WarpCount::kN) + -+ (warp_id / Mma::WarpCount::kM)}); -+ -+ iterator_C.store(accum); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Threadblock-level matrix multiply-accumulate -+ typename MmaCore_> -+struct SparseTestbed { -+ /// Threadblock-level GEMM implementation -+ using MmaCore = MmaCore_; -+ using ThreadblockShape = typename MmaCore::Shape; -+ using WarpShape = typename MmaCore::WarpShape; -+ using InstructionShape = typename MmaCore::InstructionShape; -+ using ElementA = typename MmaCore::ElementA; -+ using LayoutA = typename MmaCore::LayoutA; -+ using ElementB = typename MmaCore::ElementB; -+ using LayoutB = typename MmaCore::LayoutB; -+ using ElementC = typename MmaCore::ElementC; -+ using LayoutC = typename MmaCore::LayoutC; -+ using ElementE = typename MmaCore::ElementE; -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using ThreadMapE = typename MmaCore::IteratorThreadMapE; -+ using AccessTypeA = cutlass::Array; -+ using AccessTypeB = cutlass::Array; -+ using AccessTypeE = cutlass::Array; -+ static int const Stages = MmaCore::kStages; -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ MmaCore::kCacheOpA; -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ MmaCore::kCacheOpB; -+ static cutlass::arch::CacheOperation::Kind const CacheOpE = -+ MmaCore::kCacheOpE; -+ -+ static int const Sparse = MmaCore::kSparse; -+ static int const MetaSizeInBits = MmaCore::kMetaSizeInBits; -+ static int const MaxID2 = MmaCore::kMaxID2; -+ -+ using LayoutE = cutlass::layout::RowMajor; -+ using ReorderedLayoutE = typename MmaCore::GmemLayoutE; -+ -+ static int const ElementsPerElementE = MmaCore::kElementsPerElementE; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define iterators over tiles from the E operand -+ using IteratorE = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementE, ReorderedLayoutE, 1, ThreadMapE, AccessTypeE>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using Mma = cutlass::gemm::threadblock::SparseMmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, -+ LayoutC, IteratorE, typename MmaCore::SmemIteratorE, CacheOpE, -+ typename MmaCore::MmaPolicy, Stages>; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor matrix_A; -+ cutlass::HostTensor matrix_A_uncompressed; -+ cutlass::HostTensor matrix_B; -+ cutlass::HostTensor matrix_C_computed; -+ cutlass::HostTensor matrix_C_reference; -+ cutlass::HostTensor matrix_E; -+ cutlass::HostTensor matrix_E_reordered; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha, beta; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ SparseTestbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) -+ : problem_size(m, n, k), alpha(alpha_), beta(beta_) { -+ matrix_A.reset(cutlass::make_Coord(m, k / Sparse)); -+ matrix_A_uncompressed.reset(cutlass::make_Coord(m, k)); -+ matrix_B.reset(cutlass::make_Coord(k, n)); -+ matrix_C_computed.reset(cutlass::make_Coord(m, n)); -+ matrix_C_reference.reset(cutlass::make_Coord(m, n), false); -+ matrix_E.reset(cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); -+ matrix_E_reordered.reset( -+ cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ dim3 grid, dim3 block, -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { -+ -+ // Waive the test -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), -+ matrix_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), -+ matrix_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); -+ -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); -+ -+ if (init_E == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomSparseMeta( -+ matrix_E.host_view(), seed, MetaSizeInBits); -+ } else if (init_E == cutlass::Distribution::Identity) { -+ uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; -+ cutlass::reference::host::TensorFill(matrix_E.host_view(), -+ (ElementE)(content)); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reorder_meta(matrix_E_reordered.host_ref(), matrix_E.host_ref(), -+ {problem_size.m(), problem_size.n(), -+ problem_size.k() / Sparse / ElementsPerElementE}); -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ matrix_C_computed.sync_device(); -+ matrix_E_reordered.sync_device(); -+ -+ typename IteratorA::Params params_A(matrix_A.layout()); -+ typename IteratorB::Params params_B(matrix_B.layout()); -+ typename IteratorE::Params params_E(matrix_E_reordered.layout()); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename Mma::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute( -+ test::gemm::threadblock::kernel_multistage_mma_sparse, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); -+ -+ if (result != cudaSuccess) { -+ return true; -+ } -+ -+ result = cudaFuncSetAttribute( -+ test::gemm::threadblock::kernel_multistage_mma_sparse, -+ cudaFuncAttributePreferredSharedMemoryCarveout, 100); -+ -+ if (result != cudaSuccess) { -+ return true; -+ } -+ } -+ -+ test::gemm::threadblock::kernel_multistage_mma_sparse -+ <<>>( -+ problem_size, params_A, matrix_A.device_ref(), params_B, -+ matrix_B.device_ref(), matrix_C_computed.device_data(), -+ matrix_C_computed.layout().stride(0), params_E, -+ matrix_E_reordered.device_ref()); -+ -+ // -+ // Check error code -+ // -+ -+ result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) -+ << " kernel error: " << cudaGetErrorString(result); -+ -+ matrix_C_computed.sync_host(); -+ -+ cutlass::uncompress(matrix_A_uncompressed.host_ref(), matrix_A.host_ref(), -+ matrix_E.host_ref(), problem_size.m(), -+ problem_size.k()); -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm(problem_size, ElementC(alpha), -+ matrix_A_uncompressed.host_view(), matrix_B.host_view(), -+ ElementC(beta), matrix_C_reference.host_view()); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed.host_view(), matrix_C_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ -+ std::cout -+ << __FILE__ << ":" << __LINE__ << " " -+ << "A:\n" << matrix_A.host_view() << "\n" -+ << "B:\n" << matrix_B.host_view() << "\n" -+ << "E:\n" << matrix_E.host_view() << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view() << "\n" -+ << "Computed:\n" -+ << matrix_C_computed.host_view() << "\n"; -+ } -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); -+ -+ return passed; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h -new file mode 100644 -index 0000000..1e859b6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h -@@ -0,0 +1,374 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit testbed for kernel-level GEMM -+*/ -+ -+#pragma once -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/core_io.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Mma::ElementC *ptr_C, -+ typename Mma::LayoutC::Stride::Index ldc) { -+ // Shared storage needed by threadblock-scoped matrix multiply-accumulate -+ -+ // Dynamic shared memory base pointer -+ extern __shared__ int GemmSharedStorageBase[]; -+ -+ // Declare pointer to dynamic shared memory. -+ typename Mma::SharedStorage *shared_storage = -+ reinterpret_cast(GemmSharedStorageBase); -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), -+ 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k()}; -+ -+ cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN}; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A(params_A, ref_A.data(), -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B(params_B, ref_B.data(), -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); -+ -+ typename Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); -+ -+ // Output results -+ typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); -+ -+ iterator_C.add_tile_offset( -+ {(tb_tile_offset.m() * Mma::WarpCount::kM) + -+ (warp_id % Mma::WarpCount::kM), -+ (tb_tile_offset.n() * Mma::WarpCount::kN) + -+ (warp_id / Mma::WarpCount::kM)}); -+ -+ iterator_C.store(accum); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Threadblock-level matrix multiply-accumulate -+ typename MmaCore_> -+struct Testbed { -+ /// Threadblock-level GEMM implementation -+ using MmaCore = MmaCore_; -+ using ThreadblockShape = typename MmaCore::Shape; -+ using WarpShape = typename MmaCore::WarpShape; -+ using InstructionShape = typename MmaCore::InstructionShape; -+ using ElementA = typename MmaCore::ElementA; -+ using LayoutA = typename MmaCore::LayoutA; -+ using ElementB = typename MmaCore::ElementB; -+ using LayoutB = typename MmaCore::LayoutB; -+ using ElementC = typename MmaCore::ElementC; -+ using LayoutC = typename MmaCore::LayoutC; -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeA = cutlass::Array; -+ using AccessTypeB = cutlass::Array; -+ static int const Stages = MmaCore::kStages; -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ MmaCore::kCacheOpA; -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ MmaCore::kCacheOpB; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using Mma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, -+ LayoutC, typename MmaCore::MmaPolicy, Stages>; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor matrix_A; -+ cutlass::HostTensor matrix_B; -+ cutlass::HostTensor matrix_C_computed; -+ cutlass::HostTensor matrix_C_reference; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha, beta; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) -+ : problem_size(m, n, k), alpha(alpha_), beta(beta_) { -+ matrix_A.reset(cutlass::make_Coord(m, k)); -+ matrix_B.reset(cutlass::make_Coord(k, n)); -+ matrix_C_computed.reset(cutlass::make_Coord(m, n)); -+ matrix_C_reference.reset(cutlass::make_Coord(m, n), false); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ dim3 grid, dim3 block, -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), -+ matrix_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), -+ matrix_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); -+ -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ matrix_C_computed.sync_device(); -+ -+ typename IteratorA::Params params_A(matrix_A.layout()); -+ typename IteratorB::Params params_B(matrix_B.layout()); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename Mma::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute( -+ test::gemm::threadblock::kernel_multistage_mma, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); -+ -+ if (result != cudaSuccess) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ result = cudaFuncSetAttribute( -+ test::gemm::threadblock::kernel_multistage_mma, -+ cudaFuncAttributePreferredSharedMemoryCarveout, 100); -+ -+ if (result != cudaSuccess) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ } -+ -+ test::gemm::threadblock::kernel_multistage_mma -+ <<>>( -+ problem_size, params_A, matrix_A.device_ref(), params_B, -+ matrix_B.device_ref(), matrix_C_computed.device_data(), -+ matrix_C_computed.layout().stride(0)); -+ -+ // -+ // Check error code -+ // -+ -+ result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) -+ << " kernel error: " << cudaGetErrorString(result); -+ -+ matrix_C_computed.sync_host(); -+ -+ cutlass::reference::host::Gemm reference_gemm; -+ -+ reference_gemm( -+ problem_size, ElementC(alpha), matrix_A.host_view(), -+ matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed.host_view(), matrix_C_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cout -+ << __FILE__ << ":" << __LINE__ << " " -+ << "A:\n" << matrix_A.host_view() << "\n" -+ << "B:\n" << matrix_B.host_view() << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view() << "\n" -+ << "Computed:\n" -+ << matrix_C_computed.host_view() << "\n"; -+ } -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); -+ -+ return passed; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h -new file mode 100644 -index 0000000..a47a300 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h -@@ -0,0 +1,389 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit testbed for kernel-level GEMM -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/platform/platform.h" -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Mma::ElementC **ptr_C, -+ typename Mma::LayoutC::Stride::Index ldc) { -+ // Shared storage needed by threadblock-scoped matrix multiply-accumulate -+ -+ // Dynamic shared memory base pointer -+ extern __shared__ int GemmSharedStorageBase[]; -+ -+ // Declare pointer to dynamic shared memory. -+ typename Mma::SharedStorage *shared_storage = -+ reinterpret_cast(GemmSharedStorageBase); -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), -+ 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k()}; -+ -+ cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN}; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A(params_A, ref_A.data(), -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B(params_B, ref_B.data(), -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); -+ int lane_id = threadIdx.x; -+ -+ int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); -+ -+ typename Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); -+ -+ // Output results -+ typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); -+ -+ int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); -+ iterator_C.add_tile_offset( -+ {(tb_tile_offset.m() * Mma::WarpCount::kM) + -+ (warp_idx_mn % Mma::WarpCount::kM), -+ (tb_tile_offset.n() * Mma::WarpCount::kN) + -+ (warp_idx_mn / Mma::WarpCount::kM)}); -+ -+ iterator_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Threadblock-level matrix multiply-accumulate -+ typename MmaCore_> -+struct Testbed { -+ /// Threadblock-level GEMM implementation -+ using MmaCore = MmaCore_; -+ using ThreadblockShape = typename MmaCore::Shape; -+ using WarpShape = typename MmaCore::WarpShape; -+ using InstructionShape = typename MmaCore::InstructionShape; -+ using ElementA = typename MmaCore::ElementA; -+ using LayoutA = typename MmaCore::LayoutA; -+ using ElementB = typename MmaCore::ElementB; -+ using LayoutB = typename MmaCore::LayoutB; -+ using ElementC = typename MmaCore::ElementC; -+ using LayoutC = typename MmaCore::LayoutC; -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeA = cutlass::Array; -+ using AccessTypeB = cutlass::Array; -+ static int const Stages = MmaCore::kStages; -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ MmaCore::kCacheOpA; -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ MmaCore::kCacheOpB; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using Mma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, CacheOpA, -+ IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, LayoutC, -+ typename MmaCore::MmaPolicy, Stages>; -+ -+ static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor matrix_A; -+ cutlass::HostTensor matrix_B; -+ cutlass::HostTensor matrix_C_computed[kPartitionsK]; -+ cutlass::HostTensor matrix_C_reference; -+ cutlass::HostTensor matrix_C_pointers; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha, beta; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) -+ : problem_size(m, n, k), alpha(alpha_), beta(beta_) { -+ matrix_A.reset(cutlass::make_Coord(m, k)); -+ matrix_B.reset(cutlass::make_Coord(k, n)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); -+ -+ matrix_C_reference.reset(cutlass::make_Coord(m, n), false); -+ matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); -+ } -+ -+ /// Runs the test -+ bool run( -+ dim3 grid, dim3 block, -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), -+ matrix_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), -+ matrix_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); -+ -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_computed[k].sync_device(); -+ -+ typename IteratorA::Params params_A(matrix_A.layout()); -+ typename IteratorB::Params params_B(matrix_B.layout()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); -+ -+ matrix_C_pointers.sync_device(); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename Mma::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute( -+ test::gemm::threadblock::kernel_multistage_mma, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); -+ -+ EXPECT_EQ(result, cudaSuccess) -+ << " cudaFuncSetAttribute " -+ "cudaFuncAttributeMaxDynamicSharedMemorySize error: " -+ << cudaGetErrorString(result); -+ -+ result = cudaFuncSetAttribute( -+ test::gemm::threadblock::kernel_multistage_mma, -+ cudaFuncAttributePreferredSharedMemoryCarveout, 100); -+ -+ EXPECT_EQ(result, cudaSuccess) -+ << " cudaFuncSetAttribute " -+ "cudaFuncAttributePreferredSharedMemoryCarveout error: " -+ << cudaGetErrorString(result); -+ } -+ -+ test::gemm::threadblock::kernel_multistage_mma<<>>( -+ problem_size, params_A, matrix_A.device_ref(), params_B, -+ matrix_B.device_ref(), matrix_C_pointers.device_data(), -+ matrix_C_computed[0].layout().stride(0)); -+ -+ // -+ // Check error code -+ // -+ -+ result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) -+ << " kernel error: " << cudaGetErrorString(result); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_computed[k].sync_host(); -+ -+ // TODO: this is temporary. it will be removed after slicing can de -+ // reduction -+ // -+ // Reduce matrix_C_computed -+ // -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 1; k < kPartitionsK; k++) { -+ CUTLASS_PRAGMA_UNROLL -+ for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ -+ CUTLASS_PRAGMA_UNROLL -+ for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ -+ matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); -+ } -+ } -+ } -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, ElementC(alpha), matrix_A.host_view(), -+ matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("mma_multistage_testbed_errors.txt"); -+ -+ output -+ << "A:\n" << matrix_A.host_view() << "\n" -+ << "B:\n" << matrix_B.host_view() << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view() << "\n" -+ << "Computed:\n" -+ << matrix_C_computed[0].host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_simt.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_simt.cu -new file mode 100644 -index 0000000..506ca09 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_simt.cu -@@ -0,0 +1,1022 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "mma_pipelined_testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// sgemm_NT -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_sgemm, sgemm_nt_32x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ float, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ float, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ float, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass, -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_sgemm, sgemm_nt_64x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ float, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ float, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ float, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_sgemm, sgemm_nt_32x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ float, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ float, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ float, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_sgemm, sgemm_nt_64x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ float, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ float, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ float, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 16); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_sgemm, sgemm_nt_128x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ float, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ float, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ float, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// dgemm_NN -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_dgemm, dgemm_nt_32x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ double, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ double, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ double, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_dgemm, dgemm_nt_64x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ double, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ double, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ double, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_dgemm, dgemm_nt_32x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ double, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ double, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ double, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_dgemm, dgemm_nt_64x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ double, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ double, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ double, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 16); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_dgemm, dgemm_nt_128x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ double, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ double, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ double, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// igemm_NN -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_igemm, igemm_nt_32x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ int, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_igemm, igemm_nt_64x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ int, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_igemm, igemm_nt_32x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ int, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_igemm, igemm_nt_64x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ int, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 16); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_igemm, igemm_nt_128x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ int, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// hgemm_NN -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_hgemm, hgemm_nt_32x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ cutlass::half_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ cutlass::half_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ cutlass::half_t, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_hgemm, hgemm_nt_64x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ cutlass::half_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ cutlass::half_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ cutlass::half_t, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_hgemm, hgemm_nt_32x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ cutlass::half_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ cutlass::half_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ cutlass::half_t, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_hgemm, hgemm_nt_64x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ cutlass::half_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ cutlass::half_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ cutlass::half_t, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 16); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_hgemm, hgemm_nt_128x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ cutlass::half_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ cutlass::half_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ cutlass::half_t, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// igemm_NT DP4A -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_igemm, igemm_int8_nt_64x64x16_64x64x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_64x64x32_64x64x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 32>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 4096); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_64x64x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_128x64x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_128x128x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 128, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_256x128x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<256, 256, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<128, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_128x256x64_64x64x16) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 256, 64>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 64>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 256, 64); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_256x128x64_64x64x16) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<256, 128, 64>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 64>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 128, 64); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_64x64x16_64x64x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_64x64x32_64x64x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 32>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 4096); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_64x64x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+TEST(SM61_igemm, igemm_int8_tn_128x64x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_128x128x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 128, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_256x128x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<256, 256, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<128, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_128x256x64_64x64x16) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 256, 64>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 64>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 256, 64); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_256x128x64_64x64x16) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<256, 128, 64>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 64>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 128, 64); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nn_64x64x16_64x64x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_slicedk.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_slicedk.cu -new file mode 100644 -index 0000000..af1d61d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_slicedk.cu -@@ -0,0 +1,186 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit tests for CTA-level GEMM specifically for sliced-k kernels (SM_61 and SM_75) -+*/ -+ -+#include "mma_pipelined_testbed_slicedk.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// igemm_NT DP4A -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_igemm_sliced_k, igemm_int8_nt_32x32x128_32x32x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 32, 128>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 32, 32>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2>; // Stages, -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 128); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm_sliced_k_big, igemm_int8_nt_32x32x128_32x32x4_bigk) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 32, 128>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 32, 32>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2>; // Stages, -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 1024); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+ -+TEST(SM61_igemm_sliced_k, igemm_int8_nt_32x64x128_32x32x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 64, 128>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 32, 64>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2>; // Stages, -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 256); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Tensor Op GEMM for SM_75 -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous_sliced, tensor_op_64x64x256_tb64x64x64_warp64x32x32_16x8x8) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpMultiplyAdd>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_crosswise_sliced, tensor_op_64x64x256_tb64x64x64_warp64x32x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpMultiplyAdd>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm70.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm70.cu -new file mode 100644 -index 0000000..3374263 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm70.cu -@@ -0,0 +1,498 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "mma_pipelined_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_64x64x32_64x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_128x128x32_64x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_64x64x32_32x32x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_128x64x32_64x32x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_128x64x64_64x32x64_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using OperatorShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, OperatorShape, ElementA, LayoutA, ElementB, -+ LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_64x128x32_32x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_256x128x32_32x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_64x64x32_64x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_128x128x32_64x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_256x128x32_64x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_64x64x32_32x32x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_128x64x32_64x32x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_128x64x64_64x32x64_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using OperatorShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, OperatorShape, ElementA, LayoutA, ElementB, -+ LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_64x128x32_32x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm75.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm75.cu -new file mode 100644 -index 0000000..3f17387 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm75.cu -@@ -0,0 +1,2129 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for threadblock-level GEMM -+*/ -+ -+#include "mma_pipelined_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous, tensor_op_64x64x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous, tensor_op_128x64x32_64x32x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous, tensor_op_64x128x32_32x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous, tensor_op_128x128x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous, -+ multicta_256x256x96_128x128x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous, -+ multicta_512x256x384_256x128x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x32x32_16x16x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x64x32_16x32x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x32x32_32x16x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x32_32x32x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x64x32_64x32x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x128x32_32x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x128x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_256x256x96_128x128x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_512x256x384_256x128x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_32x32x64_16x16x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_64x32x64_32x16x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_32x64x64_16x32x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_64x64x64_32x32x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_128x64x64_64x32x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore component -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_64x128x64_32x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_128x128x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, -+ multicta_256x256x192_128x128x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 192); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, -+ multicta_512x256x768_256x128x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x32x64_16x16x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x32x64_32x16x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x64x64_16x32x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x64_32x32x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x64x64_64x32x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore component -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x128x64_32x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x128x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_256x256x192_128x128x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 192); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_512x256x768_256x128x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x32x128_16x16x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x32x128_32x16x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x64x128_16x32x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x128_32x32x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x64x128_64x32x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore component -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x128x128_32x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x128x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_256x256x384_128x128x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_512x256x1536_256x128x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 1536); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_32x32x128_16x16x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_64x32x128_32x16x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_32x64x128_16x32x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_64x64x128_32x32x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_128x64x128_64x32x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore component -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_64x128x128_32x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_128x128x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, -+ multicta_256x256x384_128x128x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, -+ multicta_512x256x1536_256x128x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 1536); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x32x512_16x16x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x32x512_32x16x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x64x512_16x32x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x512_32x32x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x64x512_64x32x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore component -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x128x512_32x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x128x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_256x256x1536_128x128x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 1536); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_512x256x6144_256x128x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 6144); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm80.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm80.cu -new file mode 100644 -index 0000000..6e91c01 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm80.cu -@@ -0,0 +1,569 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for threadblock-level GEMM -+*/ -+ -+#include "mma_pipelined_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, tensor_op_64x64x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, tensor_op_128x64x16_64x32x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, tensor_op_64x128x16_32x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, tensor_op_128x128x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_256x256x96_128x128x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_512x256x192_256x128x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x64x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_32x32x16_16x16x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_32x64x16_16x32x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x32x16_32x16x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x64x16_32x32x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_128x64x16_64x32x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x128x16_32x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_128x128x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 48); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x48_128x128x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 48); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x192_256x128x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h -new file mode 100644 -index 0000000..6f36b53 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h -@@ -0,0 +1,355 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit testbed for kernel-level GEMM -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/platform/platform.h" -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Mma::ElementC *ptr_C, -+ typename Mma::LayoutC::Stride::Index ldc) { -+ // Shared storage needed by threadblock-scoped matrix multiply-accumulate -+ __shared__ typename Mma::SharedStorage shared_storage; -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), -+ 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k()}; -+ -+ cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN}; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A(params_A, ref_A.data(), -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B(params_B, ref_B.data(), -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ int warp_id = threadIdx.y; -+ int lane_id = threadIdx.x; -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); -+ -+ typename Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); -+ -+ // Output results -+ typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); -+ -+ iterator_C.add_tile_offset( -+ {(tb_tile_offset.m() * Mma::WarpCount::kM) + -+ (warp_id % Mma::WarpCount::kM), -+ (tb_tile_offset.n() * Mma::WarpCount::kN) + -+ (warp_id / Mma::WarpCount::kM)}); -+ -+ iterator_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Threadblock-level matrix multiply-accumulate -+ typename MmaCore_, -+ /// Number of stages -+ int Stages = 2> -+struct Testbed { -+ /// Threadblock-level GEMM implementation -+ using MmaCore = MmaCore_; -+ using ThreadblockShape = typename MmaCore::Shape; -+ using WarpShape = typename MmaCore::WarpShape; -+ using InstructionShape = typename MmaCore::InstructionShape; -+ using ElementA = typename MmaCore::ElementA; -+ using LayoutA = typename MmaCore::LayoutA; -+ using ElementB = typename MmaCore::ElementB; -+ using LayoutB = typename MmaCore::LayoutB; -+ using ElementC = typename MmaCore::ElementC; -+ using LayoutC = typename MmaCore::LayoutC; -+ static const int kStages = Stages; -+ -+ // Define iterators over tiles from the A operand -+ static const bool use_idp4a = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value; -+ -+ static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; -+ static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; -+ -+ using IteratorA = typename cutlass::platform::conditional< use_idp4a, -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , -+ -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> -+ >::type; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = typename cutlass::platform::conditional< use_idp4a, -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , -+ -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> -+ >::type; -+ -+ // Define MmaPipeline Single Stage -+ using MmaPipelineSingleStage = cutlass::gemm::threadblock::MmaSingleStage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, -+ typename MmaCore::MmaPolicy>; -+ -+ // Define MmaPipeline Two Stages -+ using MmaPipelineTwoStages = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, -+ typename MmaCore::MmaPolicy>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply (Select between Single vs. Two stages) -+ using Mma = typename cutlass::platform::conditional<(kStages==1), MmaPipelineSingleStage, MmaPipelineTwoStages>::type; -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor matrix_A; -+ cutlass::HostTensor matrix_B; -+ cutlass::HostTensor matrix_C_computed; -+ cutlass::HostTensor matrix_C_reference; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha, beta; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed(int m, int n, int k, float alpha_, float beta_) -+ : problem_size(m, n, k), alpha(alpha_), beta(beta_) { -+ matrix_A.reset(cutlass::make_Coord(m, k)); -+ matrix_B.reset(cutlass::make_Coord(k, n)); -+ matrix_C_computed.reset(cutlass::make_Coord(m, n)); -+ matrix_C_reference.reset(cutlass::make_Coord(m, n), false); -+ } -+ -+ bool sufficient() { -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ dim3 grid, dim3 block, -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), -+ matrix_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), -+ matrix_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); -+ -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ matrix_C_computed.sync_device(); -+ -+ typename IteratorA::Params params_A(matrix_A.layout()); -+ typename IteratorB::Params params_B(matrix_B.layout()); -+ -+ test::gemm::threadblock::kernel_mma<<>>( -+ problem_size, params_A, matrix_A.device_ref(), params_B, -+ matrix_B.device_ref(), matrix_C_computed.device_data(), -+ matrix_C_computed.layout().stride(0)); -+ -+ // -+ // Check error code -+ // -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) -+ << " kernel error: " << cudaGetErrorString(result) << " on device " << GetCudaDevice(); -+ -+ matrix_C_computed.sync_host(); -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, ElementC(alpha), matrix_A.host_view(), -+ matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed.host_view(), matrix_C_reference.host_view()); -+ -+ EXPECT_TRUE(passed) << "Failed on device " << GetCudaDevice(); -+ -+ if (!passed) { -+ std::ofstream output("mma_pipelined_testbed_errors.txt"); -+ -+ output -+ << "A:\n" << matrix_A.host_view() << "\n" -+ << "B:\n" << matrix_B.host_view() << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view() << "\n" -+ << "Computed:\n" -+ << matrix_C_computed.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h -new file mode 100644 -index 0000000..9e8d351 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h -@@ -0,0 +1,372 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit testbed for kernel-level GEMM -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/platform/platform.h" -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Mma::ElementC **ptr_C, -+ typename Mma::LayoutC::Stride::Index ldc) { -+ // Shared storage needed by threadblock-scoped matrix multiply-accumulate -+ __shared__ typename Mma::SharedStorage shared_storage; -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), -+ 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k()}; -+ -+ cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN}; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A(params_A, ref_A.data(), -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B(params_B, ref_B.data(), -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ int warp_id = threadIdx.y; -+ int lane_id = threadIdx.x; -+ -+ int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); -+ -+ typename Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); -+ -+ // Output results -+ typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); -+ -+ -+ int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); -+ iterator_C.add_tile_offset( -+ {(tb_tile_offset.m() * Mma::WarpCount::kM) + -+ (warp_idx_mn % Mma::WarpCount::kM), -+ (tb_tile_offset.n() * Mma::WarpCount::kN) + -+ (warp_idx_mn / Mma::WarpCount::kM)}); -+ -+ iterator_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Threadblock-level matrix multiply-accumulate -+ typename MmaCore_> -+struct Testbed { -+ /// Threadblock-level GEMM implementation -+ using MmaCore = MmaCore_; -+ using ThreadblockShape = typename MmaCore::Shape; -+ using WarpShape = typename MmaCore::WarpShape; -+ using InstructionShape = typename MmaCore::InstructionShape; -+ using ElementA = typename MmaCore::ElementA; -+ using LayoutA = typename MmaCore::LayoutA; -+ using ElementB = typename MmaCore::ElementB; -+ using LayoutB = typename MmaCore::LayoutB; -+ using ElementC = typename MmaCore::ElementC; -+ using LayoutC = typename MmaCore::LayoutC; -+ -+ // Define iterators over tiles from the A operand -+ static const bool use_idp4a = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value; -+ -+ static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; -+ static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; -+ -+ using IteratorA = typename cutlass::platform::conditional< use_idp4a, -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , -+ -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> -+ >::type; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = typename cutlass::platform::conditional< use_idp4a, -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , -+ -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> -+ >::type; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using Mma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, -+ typename MmaCore::MmaPolicy>; -+ -+ static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor matrix_A; -+ cutlass::HostTensor matrix_B; -+ cutlass::HostTensor matrix_C_computed[kPartitionsK]; -+ cutlass::HostTensor matrix_C_reference; -+ cutlass::HostTensor matrix_C_pointers; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha, beta; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed(int m, int n, int k, float alpha_, float beta_) -+ : problem_size(m, n, k), alpha(alpha_), beta(beta_) { -+ matrix_A.reset(cutlass::make_Coord(m, k)); -+ matrix_B.reset(cutlass::make_Coord(k, n)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); -+ -+ matrix_C_reference.reset(cutlass::make_Coord(m, n), false); -+ matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); -+ } -+ -+ /// Runs the test -+ bool run( -+ dim3 grid, dim3 block, -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), -+ matrix_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), -+ matrix_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); -+ -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_computed[k].sync_device(); -+ -+ typename IteratorA::Params params_A(matrix_A.layout()); -+ typename IteratorB::Params params_B(matrix_B.layout()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); -+ -+ matrix_C_pointers.sync_device(); -+ -+ test::gemm::threadblock::kernel_mma<<>>( -+ problem_size, params_A, matrix_A.device_ref(), params_B, -+ matrix_B.device_ref(), matrix_C_pointers.device_data(), -+ matrix_C_computed[0].layout().stride(0)); -+ -+ // -+ // Check error code -+ // -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) -+ << " kernel error: " << cudaGetErrorString(result); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_computed[k].sync_host(); -+ -+ // TODO: this is temporary. it will be removed after slicing can de -+ // reduction -+ // -+ // Reduce matrix_C_computed -+ // -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 1; k < kPartitionsK; k++) { -+ CUTLASS_PRAGMA_UNROLL -+ for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ -+ CUTLASS_PRAGMA_UNROLL -+ for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ -+ matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); -+ } -+ } -+ } -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, ElementC(alpha), matrix_A.host_view(), -+ matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("mma_pipelined_testbed_errors.txt"); -+ -+ output -+ << "A:\n" << matrix_A.host_view() << "\n" -+ << "B:\n" << matrix_B.host_view() << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view() << "\n" -+ << "Computed:\n" -+ << matrix_C_computed[0].host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_wmma_sm70.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_wmma_sm70.cu -new file mode 100644 -index 0000000..28dc1c8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_wmma_sm70.cu -@@ -0,0 +1,766 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include "mma_pipelined_testbed.h" -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+ -+/// All tests use double-buffered (kStages=2) mma pipeline for the gemm mainloop -+/// Test name format: SM[arch]_gemm_threadblock_wmma_tensor_op_[alayout]_[blayout]_[clayout]_[dtype].[threadblock_shape]_[warp_shape] -+ -+//////////////// [START] Verifying all layouts {N,T}x{N,T}=>{N,T} for WMMA 16x16x16 [START] ////////////////////// -+ -+/////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 (wmma native size 16x16x16) -+//////////////////////////////////////////////////////////// -+ -+// tests for {N,T}x{N,T}=>{T} -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m16n16k16.f16.f16 (wmma native size 16x16x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_col_row_row_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_col_row_row_f16, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.row.m16n16k16.f16.f16 (wmma native size 16x16x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_row_row_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_row_row_f16, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.col.m16n16k16.f16.f16 (wmma native size 16x16x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_col_col_row_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_col_col_row_f16, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+// tests for {N,T}x{N,T}=>{N} -+/////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 (wmma native size 16x16x16) -+//////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_col_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m16n16k16.f16.f16 (wmma native size 16x16x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_col_row_col_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.row.m16n16k16.f16.f16 (wmma native size 16x16x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_row_col_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.col.m16n16k16.f16.f16 (wmma native size 16x16x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_col_col_col_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////// [END] Verifying all layouts {N,T}x{N,T}=>{N,T} for WMMA 16x16x16 [END] ////////////////////// -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f16, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f16, multicta_256x256x96_128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m32n8k16.f16.f16 (wmma native size 32x8x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f16, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m8n32k16.f16.f16 (wmma native size 8x32x16) -+////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f16, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+////////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m16n16k16.f32.f32 (wmma native size 16x16x16) -+////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f32, multicta_256x256x96_128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m32n8k16.f32.f32 (wmma native size 32x8x16) -+//////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m8n32k16.f32.f32 (wmma native size 8x32x16) -+///////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu -new file mode 100644 -index 0000000..12fae1f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu -@@ -0,0 +1,337 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED -+#include "mma_pipelined_testbed.h" -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+ -+/// All tests use double-buffered (kStages=2) mma pipeline for the gemm mainloop -+/// Test name format: SM[arch]_gemm_threadblock_wmma_tensor_op_[alayout]_[blayout]_[clayout]_[atype].[threadblock_shape]_[warp_shape]_[instruction_shape] -+ -+///////////////////////////////////////////////////////////////////////// -+/// Integer (s8 and u8) WMMA threadblock level tests ///// -+///////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED) -+TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_row_s8, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_row_s8, 64x64x64_64x64x64_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_col_row_row_s8, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_col_row_row_s8, 64x64x64_64x64x64_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+#endif //CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////// -+/// SUBBYTE (s4 and b1) WMMA threadblock level tests //// -+/////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED) -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_row_s4, 64x64x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_col_s4, 64x64x64_64x64x64_8x8x32) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_row_b1, 64x64x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_col_b1, 64x64x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+ -+#endif //CUTLASS_ARCH_WMMA_SM75_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_planar_complex_sm80.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_planar_complex_sm80.cu -new file mode 100644 -index 0000000..0b6dc11 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_planar_complex_sm80.cu -@@ -0,0 +1,79 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for threadblock-level GEMM -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h" -+ -+#include "mma_planar_complex_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_planar_complex_congruous, tensor_op_64x64x32_64x64x32_16x8x16_3stage) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 8); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using Mma = typename cutlass::gemm::threadblock::DefaultMmaPlanarComplexMultistage< -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages>::ThreadblockMma; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, Mma::WarpCount::kCount, 1); -+ -+ test::gemm::threadblock::TestbedPlanarComplex(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h -new file mode 100644 -index 0000000..b33abdb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h -@@ -0,0 +1,352 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit testbed for kernel-level GEMM -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor_planar_complex.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/gemm_planar_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_mma_planar_complex( -+ cutlass::gemm::GemmCoord problem_size, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::Element *ptr_A, -+ int64_t imaginary_stride_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::Element *ptr_B, -+ int64_t imaginary_stride_B, -+ typename Mma::ElementC *ptr_C, -+ typename Mma::LayoutC::Stride::Index ldc, int64_t imaginary_stride_C) { -+ -+ // Shared storage needed by threadblock-scoped matrix multiply-accumulate -+ __shared__ typename Mma::SharedStorage shared_storage; -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), -+ 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k()}; -+ -+ cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN}; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ // Construct iterators to A operand -+ typename Mma::IteratorA iterator_A_real(params_A, ptr_A, -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, tb_offset_A); -+ -+ typename Mma::IteratorA iterator_A_imag(params_A, ptr_A + imaginary_stride_A, -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, tb_offset_A); -+ -+ // Construct iterators to B operand -+ typename Mma::IteratorB iterator_B_real(params_B, ptr_B, -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ typename Mma::IteratorB iterator_B_imag(params_B, ptr_B + imaginary_stride_B, -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ int warp_id = threadIdx.y; -+ int lane_id = threadIdx.x; -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); -+ -+ typename Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag, accum); -+ -+ // Output results -+ typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); -+ -+ iterator_C.add_tile_offset( -+ {(tb_tile_offset.m() * Mma::WarpCount::kM) + -+ (warp_id % Mma::WarpCount::kM), -+ (tb_tile_offset.n() * Mma::WarpCount::kN) + -+ (warp_id / Mma::WarpCount::kM)}); -+ -+ iterator_C.store(accum.real); -+ -+ iterator_C.store_with_pointer_offset(accum.imag, imaginary_stride_C); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Threadblock-level matrix multiply-accumulate -+ typename Mma_> -+struct TestbedPlanarComplex { -+ -+ using Mma = Mma_; -+ using ThreadblockShape = typename Mma::Shape; -+ using IteratorA = typename Mma::IteratorA; -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using IteratorB = typename Mma::IteratorB; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Mma::ElementC; -+ using ElementAccumulator = typename Mma::ElementC; -+ using LayoutC = typename Mma::LayoutC; -+ using ThreadMapA = typename Mma::IteratorA::ThreadMap; -+ using ThreadMapB = typename Mma::IteratorB::ThreadMap; -+ using AccessTypeA = cutlass::Array; -+ using AccessTypeB = cutlass::Array; -+ static int const Stages = Mma::kStages; -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ Mma::kCacheOpA; -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ Mma::kCacheOpB; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensorPlanarComplex matrix_A; -+ cutlass::HostTensorPlanarComplex matrix_B; -+ cutlass::HostTensorPlanarComplex matrix_C_computed; -+ cutlass::HostTensorPlanarComplex matrix_C_reference; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ TestbedPlanarComplex(int m, int n, int k) -+ : problem_size(m, n, k) { -+ -+ matrix_A.reset(cutlass::make_Coord(m, k)); -+ matrix_B.reset(cutlass::make_Coord(k, n)); -+ matrix_C_computed.reset(cutlass::make_Coord(m, n)); -+ matrix_C_reference.reset(cutlass::make_Coord(m, n), false); -+ } -+ -+ /// Runs the test -+ bool run( -+ dim3 grid, dim3 block, -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(), seed, scope_max, scope_min, 0); -+ -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ -+ for (int i = 0; i < matrix_A.capacity() * 2; ++i) { -+ matrix_A.host_data()[i] = cutlass::half_t(float(i % 5) - 2); -+ } -+ /* -+ cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), -+ matrix_A.capacity() * 2); -+ */ -+ } else if (init_A == cutlass::Distribution::Identity) { -+ //cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ -+ -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), -+ matrix_B.capacity() * 2); -+ -+ for (int i = 0; i < matrix_B.capacity() * 2; ++i) { -+ matrix_B.host_data()[i] = cutlass::half_t(float((i + 3) % 5) - 2); -+ } -+ -+ -+ } else if (init_B == cutlass::Distribution::Identity) { -+ -+ //cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); -+ -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ matrix_C_computed.sync_device(); -+ -+ typename IteratorA::Params params_A(matrix_A.layout()); -+ typename IteratorB::Params params_B(matrix_B.layout()); -+ -+ test::gemm::threadblock::kernel_mma_planar_complex<<>>( -+ problem_size, -+ params_A, -+ matrix_A.device_data(), -+ matrix_A.imaginary_stride(), -+ params_B, -+ matrix_B.device_data(), -+ matrix_B.imaginary_stride(), -+ matrix_C_computed.device_data(), -+ matrix_C_computed.layout().stride(0), -+ matrix_C_computed.imaginary_stride() -+ ); -+ -+ -+ // -+ // Check error code -+ // -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) -+ << " kernel error: " << cudaGetErrorString(result); -+ -+ matrix_C_computed.sync_host(); -+ -+ cutlass::reference::host::GemmPlanarComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator -+ >( -+ problem_size, -+ cutlass::complex(ElementAccumulator(1)), -+ matrix_A.host_ref(), -+ Mma::kTransformA, -+ matrix_B.host_ref(), -+ Mma::kTransformB, -+ cutlass::complex(ElementAccumulator(0)), -+ matrix_C_reference.host_ref(), -+ matrix_C_reference.host_ref() -+ ); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed.host_view(), -+ matrix_C_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("mma_pipelined_testbed_errors.txt"); -+ -+ output -+ << "A:\n" << matrix_A.host_view() << "\n" -+ << "B:\n" << matrix_B.host_view() << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view() << "\n" -+ << "Computed:\n" -+ << matrix_C_computed.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_singlestage_wmma_sm70.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_singlestage_wmma_sm70.cu -new file mode 100644 -index 0000000..06a3ebb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_singlestage_wmma_sm70.cu -@@ -0,0 +1,417 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include "mma_pipelined_testbed.h" -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+ -+/// All tests use single staged (kStages=1) mma pipeline for the gemm mainloop -+/// Test name format: SM[arch]_gemm_threadblock_singlestage_wmma_[alayout]_[blayout]_[clayout]_[dtype].[threadblock_shape]_[warp_shape] -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+/// WMMA Floating point (f16 accumulation) - Single stage - Threadblock level tests //// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f16, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f16, multicta_256x256x96_128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m32n8k16.f16.f16 (wmma native size 32x8x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f16, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m8n32k16.f16.f16 (wmma native size 8x32x16) -+////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f16, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+/// WMMA Floating point (f32 accumulation) - Single stage - Threadblock level tests //// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+////////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m16n16k16.f32.f32 (wmma native size 16x16x16) -+////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f32, multicta_256x256x96_128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m32n8k16.f32.f32 (wmma native size 32x8x16) -+//////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m8n32k16.f32.f32 (wmma native size 8x32x16) -+///////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu -new file mode 100644 -index 0000000..1b24ebe ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu -@@ -0,0 +1,337 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED -+#include "mma_pipelined_testbed.h" -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+ -+/// All tests use single staged (kStages=1) mma pipeline for the gemm mainloop -+/// Test name format: SM[arch]_gemm_threadblock_singlestage_wmma_tensor_op_[alayout]_[blayout]_[clayout]_[atype].[threadblock_shape]_[warp_shape]_[instruction_shape] -+ -+///////////////////////////////////////////////////////////////////////// -+/// Integer (s8 and u8) WMMA threadblock level tests //// -+///////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED) -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_s8, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_s8, 64x64x64_64x64x64_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_col_row_row_s8, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_col_row_row_s8, 64x64x64_64x64x64_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+#endif //CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////// -+/// SUBBYTE (s4 and b1) WMMA threadblock level tests //// -+/////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED) -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_s4, 64x64x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_col_s4, 64x64x64_64x64x64_8x8x32) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_b1, 64x64x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_col_b1, 64x64x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+ -+#endif //CUTLASS_ARCH_WMMA_SM75_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_complex_sm80.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_complex_sm80.cu -new file mode 100644 -index 0000000..28410ad ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_complex_sm80.cu -@@ -0,0 +1,698 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// complex * complex => complex -+// Input data type: complex -+// Math instruction: mma.sync.aligned.m8n8k4.f64.f64.f64.f64 -+// Output data type: complex -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 8x8x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 16x16x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 16x32x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x16x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x4_8x8x4_nh) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x4_8x8x4_ct) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 8x8x4_8x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 16x16x4_8x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// complex * complex => complex -+// Input data type: complex -+// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+// Output data type: complex -+// Shared memory layout: Congrous -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x8_16x8x8_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 8> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x16_16x8x8_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 16> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x32x8_16x8x8_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<16, 32, 8> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x16x8_16x16x8_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 16, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 16, 8> >() -+ .run(); -+} -+ -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_nh) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_ct) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() -+ .run(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// complex * complex => complex -+// Input data type: complex -+// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+// Output data type: complex -+// Shared memory layout: Crosswise -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x8_16x8x8_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 8> >() -+ .run(); -+} -+ -+// TEST FAILS crosswise complex TN mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 test fails for k = 2*8 = 16 -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x16_16x8x8_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 16> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x64x8_16x8x8_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 64, 8> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 64x32x8_16x8x8_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<64, 32, 8> >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x8_8x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x8_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_complex_sm90.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_complex_sm90.cu -new file mode 100644 -index 0000000..38bdfa6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_complex_sm90.cu -@@ -0,0 +1,334 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM with Hopper FP64 -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x8x4_16x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x16x4_16x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x32x4_16x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x16x4_16x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x4_16x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x4_16x8x4_nh) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x4_16x8x4_ct) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x8x4_16x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x16x4_16x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x16_16x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 64x64x4_16x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex().run(); -+} -+ -+#endif // if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_gaussian_complex_sm80.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_gaussian_complex_sm80.cu -new file mode 100644 -index 0000000..e6f71ce ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_gaussian_complex_sm80.cu -@@ -0,0 +1,287 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 8x8x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 16x16x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 16x32x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 32x16x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 32x32x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 32x32x4_8x8x4_nh) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 32x32x4_8x8x4_ct) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 16x16x4_8x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm50.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm50.cu -new file mode 100644 -index 0000000..5a9e2e2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm50.cu -@@ -0,0 +1,654 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// NT SMEM layout -+TEST(SM50_warp_gemm_f32_col_row_col, 32x16x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+// TN SMEM layout -+TEST(SM50_warp_gemm_f32_row_col_col, 32x16x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+// TT SMEM layout -+TEST(SM50_warp_gemm_f32_row_row_col, 32x16x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+// NN SMEM layout -+TEST(SM50_warp_gemm_f32_col_col_col, 32x16x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// NT SMEM layout -+TEST(SM50_warp_gemm_f32_col_row_row, 16x32x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+// TN SMEM layout -+TEST(SM50_warp_gemm_f32_row_col_row, 16x32x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// NT SMEM layout -+TEST(SM50_warp_gemm_f32_col_row_col, 32x16x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f32_col_row_row, 32x16x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+// TN SMEM layout -+TEST(SM50_warp_gemm_f32_row_col_col, 32x16x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f32_row_col_row, 32x16x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// NT SMEM layout -+TEST(SM50_warp_gemm_f32_col_row_col, 32x64x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f32_col_row_row, 32x64x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+// TN SMEM layout -+TEST(SM50_warp_gemm_f32_row_col_col, 32x64x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f32_row_col_row, 32x64x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_complex_f32_col_row_col, 64x32x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using complex_f32_t = cutlass::complex; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ complex_f32_t, -+ cutlass::layout::ColumnMajor, -+ complex_f32_t, -+ cutlass::layout::RowMajor, -+ complex_f32_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_complex_f32_col_row_row, 64x32x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using complex_f32_t = cutlass::complex; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ complex_f32_t, -+ cutlass::layout::ColumnMajor, -+ complex_f32_t, -+ cutlass::layout::RowMajor, -+ complex_f32_t, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_f64_col_row_col, 8x4x1_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<8, 4, 8>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f64_col_row_row, 8x4x1_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<8, 4, 8>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_f64_col_row_col, 32x16x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f64_col_row_row, 32x16x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_f64_col_row_col, 64x32x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f64_col_row_row, 64x32x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_complex_f64_col_row_col, 32x16x1_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using complex_f64_t = cutlass::complex; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 16, 8>, -+ complex_f64_t, -+ cutlass::layout::ColumnMajor, -+ complex_f64_t, -+ cutlass::layout::RowMajor, -+ complex_f64_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_complex_f64_col_row_row, 32x16x1_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using complex_f64_t = cutlass::complex; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 16, 8>, -+ complex_f64_t, -+ cutlass::layout::ColumnMajor, -+ complex_f64_t, -+ cutlass::layout::RowMajor, -+ complex_f64_t, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_quaternion_f32_col_row_col, 16x8x8_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using quaternion_f32_t = cutlass::Quaternion; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ quaternion_f32_t, -+ cutlass::layout::ColumnMajor, -+ quaternion_f32_t, -+ cutlass::layout::RowMajor, -+ quaternion_f32_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_quaternion_f32_col_row_row, 16x8x8_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using quaternion_f32_t = cutlass::Quaternion; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ quaternion_f32_t, -+ cutlass::layout::ColumnMajor, -+ quaternion_f32_t, -+ cutlass::layout::RowMajor, -+ quaternion_f32_t, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm60.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm60.cu -new file mode 100644 -index 0000000..03ba3ea ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm60.cu -@@ -0,0 +1,140 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM60_warp_gemm_f16_col_row, 8x4x1_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<8, 4, 8>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM60_warp_gemm_f16_col_row, 16x8x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM60_warp_gemm_f16_col_row, 32x16x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM60_warp_gemm_f16_col_row, 64x16x1_8x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<8, 8, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm61.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm61.cu -new file mode 100644 -index 0000000..c042b5b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm61.cu -@@ -0,0 +1,198 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM61_warp_gemm_int8_col_row, col_row_8x4x8_1x1x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<8, 4, 8>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM61_warp_gemm_int8_col_row, col_row_8x4x4_1x1x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<8, 4, 8>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM61_warp_gemm_int8_col_row, col_row_16x4x4_2x1x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 1, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 4, 4>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM61_warp_gemm_int8_col_row, col_row_16x4x4_2x2x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM61_warp_gemm_int8_col_row, col_row_32x16x4_4x4x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 16>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+TEST(SM61_warp_gemm_int8_col_row, col_row_128x64x4_16x16x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<16, 16, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<128, 64, 4>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM61_warp_gemm_int8_col_row, col_row_64x64x4_4x4x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm70.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm70.cu -new file mode 100644 -index 0000000..6785ddb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm70.cu -@@ -0,0 +1,295 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_warp_gemm_tensor_op_congruous, 128x128x16_64x64x16_16x16x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_gemm_tensor_op_congruous, 128x64x4_64x64x4_16x16x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_gemm_tensor_op_congruous, 128x128x4_32x32x4_16x16x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_gemm_tensor_op_crosswise, 64x64x32_64x64x32_16x16x4) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::RowMajor, -+ ElementB, -+ cutlass::layout::ColumnMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_warp_gemm_volta_tensor_op_canonical_f32_row_col, 64x64x16_64x64x4_8x8x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::RowMajor, -+ ElementB, -+ cutlass::layout::ColumnMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+TEST(SM70_warp_gemm_volta_tensor_op_canonical_f32_col_row, 64x64x16_64x64x4_8x8x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm75.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm75.cu -new file mode 100644 -index 0000000..43f185d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm75.cu -@@ -0,0 +1,860 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_congruous_f16, 128x128x8_32x128x8_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_congruous_f16, 128x128x32_64x64x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_congruous_f16, 128x128x32_32x32x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x32_64x64x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x32_64x32x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x32_32x32x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x32_32x16x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x32_16x16x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x64_64x64x64_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x64_64x32x64_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x64_32x32x64_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x64_32x16x64_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x64_16x16x64_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i8, 128x128x64_64x64x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i8, 128x128x64_64x32x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i8, 128x128x64_32x32x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i8, 128x128x64_32x16x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i8, 128x128x64_16x16x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i8, 128x128x64_64x64x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i8, 128x128x64_64x32x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i8, 128x128x64_32x32x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i8, 128x128x64_32x16x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i8, 128x128x64_16x16x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i4, 128x128x128_64x64x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i4, 128x128x128_64x32x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i4, 128x128x128_32x32x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i4, 128x128x128_32x16x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i4, 128x128x128_16x16x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i4, 128x128x128_64x64x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i4, 128x128x128_64x32x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i4, 128x128x128_32x32x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i4, 128x128x128_32x16x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i4, 128x128x128_16x16x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_b1, 128x128x512_64x64x512_8x8x128) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, -+ cutlass::arch::OpXorPopc>() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_b1, 128x128x512_64x32x512_8x8x128) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, -+ cutlass::arch::OpXorPopc>() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_b1, 128x128x512_32x32x512_8x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, -+ cutlass::arch::OpXorPopc>() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_b1, 128x128x512_32x16x512_8x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, -+ cutlass::arch::OpXorPopc>() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_b1, 128x128x512_16x16x512_8x8x128) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, -+ cutlass::arch::OpXorPopc>() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm80.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm80.cu -new file mode 100644 -index 0000000..54a0248 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm80.cu -@@ -0,0 +1,1865 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_64x64x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_64x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_32x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_32x16x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_16x16x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_64x64x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_64x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_32x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_32x16x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_16x16x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_64x64x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_64x32x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_32x32x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_32x16x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_16x16x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_64x64x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_64x32x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_32x32x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_32x16x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_16x16x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f16, 128x128x32_64x64x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f16, 128x128x32_32x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f16, 128x128x64_64x64x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f16, 128x128x64_32x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_tf32, 128x128x16_64x64x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_tf32, 128x128x16_32x32x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_tf32, 128x128x32_64x64x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_tf32, 128x128x32_32x32x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_tn, tf32_round_128x128x32_64x64x32_16x8x8) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = float; -+ using ElementC = float; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::TransformTestbed >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_tensor_op_nt, tf32_round_128x128x32_64x64x32_16x8x8) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = float; -+ using ElementC = float; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::TransformTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_16x16x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_32x16x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_32x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_64x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_64x64x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_16x16x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_32x16x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_32x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_64x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_64x64x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_64x64x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_64x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_32x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_32x16x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_16x16x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_64x64x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_64x32x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_32x32x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_32x16x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_16x16x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_64x64x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_64x32x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_32x32x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_32x16x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_16x16x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_64x64x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_64x32x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_32x32x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_32x16x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_16x16x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_64x64x512_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_64x32x512_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_32x32x512_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_32x16x512_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_16x16x512_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_64x64x1024_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_64x32x1024_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_32x32x1024_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_32x16x1024_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_16x16x1024_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f64, 16x16x4_16x16x4_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f64, 32x16x4_32x16x4_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f64, 32x32x4_32x32x4_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f64, 32x64x4_32x64x4_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f64, 16x16x16_16x16x16_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f64, 32x32x16_32x32x16_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f64, 64x32x16_64x32x16_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f64, 32x64x16_32x64x16_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_16x16x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_32x16x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_32x32x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_64x32x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_64x64x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_canonical_f64_row_col, 32x32x8_64x32x8_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_canonical_f64_col_row, 32x32x8_64x32x8_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_canonical_tf32_row_col, 32x32x8_64x32x8_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_canonical_tf32_col_row, 32x32x8_64x32x8_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm90.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm90.cu -new file mode 100644 -index 0000000..f417a41 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm90.cu -@@ -0,0 +1,206 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM with Hopper FP64 -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+TEST(SM90_warp_gemm_tensor_op_congruous_f64, 16x16x4_16x16x4_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_congruous_f64, 32x16x4_32x16x4_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_congruous_f64, 32x32x4_32x32x4_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_congruous_f64, 32x64x4_32x64x4_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 16x16x16_16x16x16_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 32x32x16_32x32x16_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 64x32x16_64x32x16_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 32x64x16_32x64x16_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sparse_sm80.cu -new file mode 100644 -index 0000000..af87ee7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sparse_sm80.cu -@@ -0,0 +1,1107 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_sparse_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_64x64x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_64x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_32x64x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_32x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_32x16x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x64x128_64x32x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 64x128x128_32x64x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 64x64x128_32x32x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 64x32x128_32x16x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_64x64x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_64x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_32x64x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_32x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x64x128_64x32x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 64x128x128_32x64x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 64x64x128_32x32x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_64x64x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_64x32x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_32x64x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_32x32x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_32x16x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x64x256_64x32x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 64x128x256_32x64x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 64x64x256_32x32x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 64x32x256_32x16x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_64x64x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_64x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_32x64x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_32x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_32x16x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x64x256_64x32x256_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 64x128x64_32x64x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 64x64x64_32x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 64x32x64_32x16x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x128x32_64x64x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x128x32_64x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x128x32_32x64x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x128x32_32x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x64x64_64x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 64x128x64_32x64x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 64x64x64_32x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_64x64x256_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_64x32x256_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_32x64x256_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_32x32x256_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_32x16x256_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x64x512_64x32x512_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 64x128x512_32x64x512_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 64x64x512_32x32x512_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 64x32x512_32x16x512_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/testbed.h b/3rdparty/cutlass/test/unit/gemm/warp/testbed.h -new file mode 100644 -index 0000000..3487aa0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/testbed.h -@@ -0,0 +1,1554 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/subbyte_reference.h" -+#include "cutlass/platform/platform.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/util/host_uncompress.h" -+ -+namespace test { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test kernel -+template -+__global__ void kernel( -+ typename Mma::ElementC *output_C, -+ typename Mma::ElementA const *input_A, -+ typename Mma::ElementB const *input_B, -+ typename Mma::ElementC const *input_C, -+ int iterations = 1) { -+ -+ // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. -+ __shared__ cutlass::AlignedBuffer< -+ typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; -+ -+ __shared__ cutlass::AlignedBuffer< -+ typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; -+ -+ if (threadIdx.x == 0) { -+ typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_A.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_A, i) = -+ cutlass::ReferenceFactory::type>::get(input_A, i); -+ } -+ -+ typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_B.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_B, i) = -+ cutlass::ReferenceFactory::type>::get(input_B, i); -+ } -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Construct warp-level matrix product -+ // -+ -+ using FragmentA = typename Mma::FragmentA; -+ using FragmentB = typename Mma::FragmentB; -+ using FragmentC = typename Mma::FragmentC; -+ -+ typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); -+ typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); -+ typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); -+ -+ typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); -+ -+ typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); -+ -+ FragmentA frag_A; -+ FragmentB frag_B; -+ -+ FragmentC accum; -+ -+ Mma mma; -+ -+ accum.clear(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < ThreadblockShape::kK; -+ k += Mma::Policy::MmaShape::kK) { -+ iter_A.load(frag_A); -+ iter_B.load(frag_B); -+ -+ ++iter_A; -+ ++iter_B; -+ -+ mma(accum, frag_A, frag_B, accum); -+ } -+ } -+ -+ typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); -+ -+ iter_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Warp-level matrix multiply-accumulate -+ typename Mma_, -+ /// Size of threadblock-scoped shape used to store SMEM -+ typename ThreadblockShape_, -+ /// The inner product operation performed by GEMM -+ typename Operator_ = cutlass::arch::OpMultiplyAdd -+> -+struct Testbed { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = Mma_; -+ using ThreadblockShape = ThreadblockShape_; -+ using Operator = Operator_; -+ -+ using Shape = typename Mma::Shape; -+ using ElementA = typename Mma::ElementA; -+ using LayoutA = typename Mma::LayoutA; -+ using ElementB = typename Mma::ElementB; -+ using LayoutB = typename Mma::LayoutB; -+ using ElementC = typename Mma::ElementC; -+ using LayoutC = typename Mma::LayoutC; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed() { -+ -+ tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); -+ tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.major == 9) { -+ // NVIDIA Hopper drops support for several data types -+ if ( -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8) { -+ -+ return false; -+ } -+ } -+ -+ return true; -+ } -+ -+ -+ /// Runs the test -+ bool run( -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ -+ cutlass::reference::host::BlockFillRandomUniform(tensor_A.host_data(), -+ tensor_A.capacity(), seed, scope_max, scope_min, 0); -+ -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), -+ tensor_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ -+ cutlass::reference::host::BlockFillRandomUniform(tensor_B.host_data(), -+ tensor_B.capacity(), seed, scope_max, scope_min, 0); -+ -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), -+ tensor_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ -+ // launch kernel -+ kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( -+ tensor_D_computed.device_data(), -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data()); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ {Shape::kM, Shape::kN, ThreadblockShape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ ElementC(0), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ -+ cutlass::TensorView tensor_A_physical( -+ tensor_A.host_data(), -+ tensor_A.stride()[0], -+ tensor_A.extent()); -+ -+ cutlass::TensorView tensor_B_physical( -+ tensor_B.host_data(), -+ tensor_B.stride()[0], -+ tensor_B.extent()); -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "A(physical - stride: " << tensor_A.stride()[0] -+ << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "B(physical - stride: " << tensor_B.stride()[0] -+ << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; -+ -+ std::cout -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Warp-level matrix multiply-accumulate -+ typename Mma_, -+ /// Size of threadblock-scoped shape used to store SMEM -+ typename ThreadblockShape_ -+> -+struct TestbedComplex { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = Mma_; -+ using ThreadblockShape = ThreadblockShape_; -+ -+ using Shape = typename Mma::Shape; -+ using ElementA = typename Mma::ElementA; -+ using LayoutA = typename Mma::LayoutA; -+ using ElementB = typename Mma::ElementB; -+ using LayoutB = typename Mma::LayoutB; -+ using ElementC = typename Mma::ElementC; -+ using LayoutC = typename Mma::LayoutC; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ TestbedComplex() { -+ -+ tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); -+ tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.major == 9) { -+ // NVIDIA Hopper drops support for several data types -+ if ( -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8) { -+ -+ return false; -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), -+ seed, 8, -8, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), -+ tensor_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), -+ seed + 16, 8, -8, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), -+ tensor_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ -+ // launch kernel -+ kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( -+ tensor_D_computed.device_data(), -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data()); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ -+ cutlass::reference::host::GemmComplex( -+ {Shape::kM, Shape::kN, ThreadblockShape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ Mma::kTransformA, -+ tensor_B.host_ref(), -+ Mma::kTransformB, -+ ElementC(0), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ -+ cutlass::TensorView tensor_A_physical( -+ tensor_A.host_data(), -+ tensor_A.stride()[0], -+ tensor_A.extent()); -+ -+ cutlass::TensorView tensor_B_physical( -+ tensor_B.host_data(), -+ tensor_B.stride()[0], -+ tensor_B.extent()); -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; -+ -+ std::cout -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test kernel -+template -+__global__ void kernel_transform( -+ typename Mma::ElementC *output_C, -+ typename Mma::ElementA const *input_A, -+ typename Mma::ElementB const *input_B, -+ typename Mma::ElementC const *input_C, -+ int iterations = 1) { -+ -+ // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. -+ __shared__ cutlass::AlignedBuffer< -+ typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; -+ -+ __shared__ cutlass::AlignedBuffer< -+ typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; -+ -+ if (threadIdx.x == 0) { -+ typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_A.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_A, i) = -+ cutlass::ReferenceFactory::type>::get(input_A, i); -+ } -+ -+ typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_B.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_B, i) = -+ cutlass::ReferenceFactory::type>::get(input_B, i); -+ } -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Construct warp-level matrix product -+ // -+ -+ using FragmentA = typename Mma::FragmentA; -+ using FragmentB = typename Mma::FragmentB; -+ using FragmentC = typename Mma::FragmentC; -+ -+ using TransformedFragmentA = typename Mma::TransformedFragmentA; -+ using TransformedFragmentB = typename Mma::TransformedFragmentB; -+ -+ typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); -+ typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); -+ typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); -+ -+ typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); -+ -+ typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); -+ -+ FragmentA loaded_frag_A; -+ FragmentB loaded_frag_B; -+ TransformedFragmentA transformed_frag_A; -+ TransformedFragmentB transformed_frag_B; -+ -+ FragmentC accum; -+ -+ Mma mma; -+ -+ accum.clear(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < ThreadblockShape::kK; -+ k += Mma::Policy::MmaShape::kK) { -+ iter_A.load(loaded_frag_A); -+ iter_B.load(loaded_frag_B); -+ -+ ++iter_A; -+ ++iter_B; -+ -+ mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, -+ loaded_frag_B); -+ -+ mma(accum, transformed_frag_A, transformed_frag_B, accum); -+ } -+ } -+ -+ typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); -+ -+ iter_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Warp-level matrix multiply-accumulate -+ typename Mma_, -+ /// Size of threadblock-scoped shape used to store SMEM -+ typename ThreadblockShape_, -+ /// The innter product operation performed by GEMM -+ typename Operator_ = cutlass::arch::OpMultiplyAdd -+> -+struct TransformTestbed { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = Mma_; -+ using ThreadblockShape = ThreadblockShape_; -+ using Operator = Operator_; -+ -+ using Shape = typename Mma::Shape; -+ using ElementA = typename Mma::ElementA; -+ using LayoutA = typename Mma::LayoutA; -+ using ElementB = typename Mma::ElementB; -+ using LayoutB = typename Mma::LayoutB; -+ using ElementC = typename Mma::ElementC; -+ using LayoutC = typename Mma::LayoutC; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ TransformTestbed() { -+ -+ tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); -+ tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.major == 9) { -+ // NVIDIA Hopper drops support for several data types -+ if ( -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8) { -+ -+ return false; -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), -+ tensor_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), -+ tensor_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ -+ // launch kernel -+ kernel_transform<<>>( -+ tensor_D_computed.device_data(), tensor_A.device_data(), -+ tensor_B.device_data(), tensor_C.device_data()); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ {Shape::kM, Shape::kN, ThreadblockShape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ ElementC(0), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ -+ cutlass::TensorView tensor_A_physical( -+ tensor_A.host_data(), -+ tensor_A.stride()[0], -+ tensor_A.extent()); -+ -+ cutlass::TensorView tensor_B_physical( -+ tensor_B.host_data(), -+ tensor_B.stride()[0], -+ tensor_B.extent()); -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; -+ -+ std::cout -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Warp-level matrix multiply-accumulate -+ typename Mma_, -+ /// Size of threadblock-scoped shape used to store SMEM -+ typename ThreadblockShape_ -+> -+struct TransformedTestbedComplex { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = Mma_; -+ using ThreadblockShape = ThreadblockShape_; -+ -+ using Shape = typename Mma::Shape; -+ using ElementA = typename Mma::ElementA; -+ using LayoutA = typename Mma::LayoutA; -+ using ElementB = typename Mma::ElementB; -+ using LayoutB = typename Mma::LayoutB; -+ using ElementC = typename Mma::ElementC; -+ using LayoutC = typename Mma::LayoutC; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ TransformedTestbedComplex() { -+ -+ tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); -+ tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.major == 9) { -+ // NVIDIA Hopper drops support for several data types -+ if ( -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8) { -+ -+ return false; -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), -+ seed, 8, -8, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), -+ tensor_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), -+ seed + 16, 8, -8, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), -+ tensor_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ -+ // launch kernel -+ kernel_transform<<< dim3(1, 1), dim3(32, 1, 1) >>>( -+ tensor_D_computed.device_data(), -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data()); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ -+ cutlass::reference::host::GemmComplex( -+ {Shape::kM, Shape::kN, ThreadblockShape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ Mma::kTransformA, -+ tensor_B.host_ref(), -+ Mma::kTransformB, -+ ElementC(0), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ -+ cutlass::TensorView tensor_A_physical( -+ tensor_A.host_data(), -+ tensor_A.stride()[0], -+ tensor_A.extent()); -+ -+ cutlass::TensorView tensor_B_physical( -+ tensor_B.host_data(), -+ tensor_B.stride()[0], -+ tensor_B.extent()); -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; -+ -+ std::cout -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test kernel -+template -+__global__ void sparse_kernel( -+ typename Mma::ElementC *output_C, -+ typename Mma::ElementA const *input_A, -+ typename Mma::ElementB const *input_B, -+ typename Mma::ElementC const *input_C, -+ typename Mma::ElementE const *input_E, -+ int iterations = 1) { -+ -+ // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. -+ __shared__ cutlass::AlignedBuffer -+ smem_buffer_A; -+ -+ __shared__ cutlass::AlignedBuffer< -+ typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; -+ -+ __shared__ cutlass::AlignedBuffer< -+ typename Mma::ElementE, Mma::Shape::kM * Mma::Shape::kK / -+ Mma::kSparse / Mma::kElementsPerElementE> -+ smem_buffer_E; -+ -+ __syncthreads(); -+ -+ if (threadIdx.x == 0) { -+ typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_A.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_A, i) = -+ cutlass::ReferenceFactory::type>::get(input_A, i); -+ } -+ -+ typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_B.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_B, i) = -+ cutlass::ReferenceFactory::type>::get(input_B, i); -+ } -+ -+ typename Mma::ElementE *smem_ptr_E = smem_buffer_E.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_E.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_E, i) = -+ cutlass::ReferenceFactory::type>::get(input_E, i); -+ } -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Construct warp-level matrix product -+ // -+ -+ using FragmentA = typename Mma::FragmentA; -+ using FragmentB = typename Mma::FragmentB; -+ using FragmentC = typename Mma::FragmentC; -+ using FragmentE = typename Mma::FragmentE; -+ -+ typename Mma::LayoutA layout_A = Mma::LayoutA::packed( -+ {ThreadblockShape::kM, ThreadblockShape::kK / Mma::kSparse}); -+ typename Mma::LayoutB layout_B = -+ Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); -+ typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); -+ typename Mma::LayoutE layout_E = -+ Mma::LayoutE::packed({Mma::Shape::kM * Mma::kInterleaved, -+ Mma::Shape::kK / Mma::kSparse / -+ Mma::kElementsPerElementE / Mma::kInterleaved}); -+ -+ typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); -+ -+ typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); -+ -+ typename Mma::IteratorE iter_E({smem_buffer_E.data(), layout_E}, cutlass::arch::LaneId()); -+ -+ FragmentA frag_A; -+ FragmentB frag_B; -+ -+ FragmentC accum; -+ -+ FragmentE frag_E; -+ -+ Mma mma; -+ -+ accum.clear(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < ThreadblockShape::kK; -+ k += Mma::Policy::MmaShape::kK) { -+ iter_A.load(frag_A); -+ iter_B.load(frag_B); -+ iter_E.load(frag_E); -+ -+ ++iter_A; -+ ++iter_B; -+ ++iter_E; -+ -+ mma(accum, frag_A, frag_B, accum, frag_E); -+ } -+ } -+ -+ typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); -+ -+ iter_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Warp-level matrix multiply-accumulate -+ typename Mma_, -+ /// Size of threadblock-scoped shape used to store SMEM -+ typename ThreadblockShape_, -+ /// The innter product operation performed by GEMM -+ typename Operator_ = cutlass::arch::OpMultiplyAdd -+> -+struct SparseTestbed { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = Mma_; -+ using ThreadblockShape = ThreadblockShape_; -+ using Operator = Operator_; -+ -+ using Shape = typename Mma::Shape; -+ using ElementA = typename Mma::ElementA; -+ using LayoutA = typename Mma::LayoutA; -+ using ElementB = typename Mma::ElementB; -+ using LayoutB = typename Mma::LayoutB; -+ using ElementC = typename Mma::ElementC; -+ using LayoutC = typename Mma::LayoutC; -+ -+ static int const Sparse = Mma::kSparse; -+ static int const MetaSizeInBits = Mma::kMetaSizeInBits; -+ static int const MaxID2 = Mma::kMaxID2; -+ static int const Interleaved = Mma::kInterleaved; -+ -+ using ElementE = typename Mma::ElementE; -+ -+ static int const ElementsPerElementE = Mma::kElementsPerElementE; -+ -+ using LayoutE = cutlass::layout::RowMajor; -+ using ReorderedLayoutE = -+ cutlass::layout::ColumnMajorInterleaved; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_A_uncompressed; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ cutlass::HostTensor tensor_E; -+ cutlass::HostTensor tensor_E_reordered; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ SparseTestbed() { -+ -+ tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, -+ ThreadblockShape::kK / Sparse)); -+ tensor_A_uncompressed.reset( -+ cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); -+ tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ tensor_E.reset(cutlass::make_Coord( -+ Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); -+ tensor_E_reordered.reset(cutlass::make_Coord( -+ Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.major == 9) { -+ // NVIDIA Hopper drops support for several data types -+ if ( -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8) { -+ -+ return false; -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { -+ -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), -+ tensor_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), -+ tensor_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ if (init_E == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomSparseMeta( -+ tensor_E.host_view(), seed, MetaSizeInBits); -+ } else if (init_E == cutlass::Distribution::Identity) { -+ uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; -+ cutlass::reference::host::TensorFill(tensor_E.host_view(), -+ (ElementE)(content)); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reorder_meta( -+ tensor_E_reordered.host_ref(), tensor_E.host_ref(), -+ {Shape::kM, Shape::kN, Shape::kK / Sparse / ElementsPerElementE}); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ tensor_E_reordered.sync_device(); -+ -+ // launch kernel -+ sparse_kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( -+ tensor_D_computed.device_data(), -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_E_reordered.device_data()); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), -+ tensor_E.host_ref(), Shape::kM, Shape::kK); -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ {Shape::kM, Shape::kN, ThreadblockShape::kK}, -+ ElementC(1), -+ tensor_A_uncompressed.host_ref(), -+ tensor_B.host_ref(), -+ ElementC(0), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout << "A:\n" << tensor_A.host_view() << "\n\n"; -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout << "B:\n" << tensor_B.host_view() << "\n\n"; -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout << "E:\n" << tensor_E.host_view() << "\n\n"; -+ -+ std::cout -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm70.cu b/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm70.cu -new file mode 100644 -index 0000000..f2d6762 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm70.cu -@@ -0,0 +1,688 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for warp-level wmma gemm -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED) -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_wmma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+/// Test name format: SM[arch]_warp_wmma_[alayout]_[blayout]_[clayout]_[dtype].[threadblock_shape]_[warp_shape] -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////// f16 accumulation point wmma.mma ////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////// [START] Verifying all layouts {N,T}x{N,T}=>{N,T} for WMMA 16x16x16 [START] ////////////////////// -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+ -+// 4 tests for {N,T}x{N,T}=>{T} -+TEST(SM70_warp_wmma_row_col_row_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_col_row_row_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.row.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_row_row_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_col_col_row_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+// 4 tests for {N,T}x{N,T}=>{N} -+TEST(SM70_warp_wmma_row_col_col_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_col_row_col_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.row.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_row_col_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_col_col_col_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+/////////// [END] Verifying all layouts {N,T}x{N,T}=>{N,T} for WMMA 16x16x16 [END] /////////////////////////// -+ -+ -+ -+TEST(SM70_warp_wmma_row_col_row_f16, 64x64x16_64x64x16_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+TEST(SM70_warp_wmma_row_col_row_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+TEST(SM70_warp_wmma_row_col_row_f16, 64x64x32_64x32x32_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_wmma_row_col_row_f16, 64x64x32_32x64x32_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_wmma_row_col_row_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_wmma_row_col_row_f16, 128x128x16_64x64x16_16x16x16) { -+ // Even though the test launches 128x128x16 CTA tile this test only verfies one warp -+ // , i.e., warp_0 of size 64x64x16 out of the four warps required to cover the CTA -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m32n8k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_col_row_f16, 32x8x16_32x8x16_32x8x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m8n32k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_col_row_f16, 8x32x16_8x32x16_32x8x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m8n32k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_col_row_row_f16, 8x32x16_8x32x16_8x32x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_wmma_col_row_row_f16, 32x8x16_32x8x16_32x8x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////// f32 accumulation point wmma.mma ////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m16n16k16.f32.f32 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_col_row_f32, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_wmma_row_col_row_f32, 64x64x16_64x64x16_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_wmma_row_col_row_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+TEST(SM70_warp_wmma_row_col_row_f32, 128x128x16_64x64x16_16x16x16) { -+ // Even though the test launches 128x128x16 CTA tile this test only verfies one warp -+ // , i.e., warp_0 of size 64x64x16 out of the four warps required to cover the CTA -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+///////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m32n8k16.f32.f32 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_col_row_f32, 32x8x16_32x8x16_32x8x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+///////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m8n32k16.f32.f32 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_col_row_f32, 8x32x16_8x32x16_8x32x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm72.cu b/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm72.cu -new file mode 100644 -index 0000000..8b56220 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm72.cu -@@ -0,0 +1,185 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED) -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_wmma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////// Integer wmma.mma //////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+// TODO: FIXME SM75 should SM72, but the compilation breaks as SM72 shows up and runs on VOLTA -+TEST(SM75_warp_wmma_row_col_s8, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM75_warp_wmma_row_col_s8, 32x8x16_32x8x16_32x8x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM75_warp_wmma_row_col_s8, 8x32x16_8x32x16_8x32x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM75_warp_wmma_row_col_u8, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = uint8_t; -+ using ElementB = uint8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM75_warp_wmma_row_col_u8, 32x8x16_32x8x16_32x8x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using ElementA = uint8_t; -+ using ElementB = uint8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM75_warp_wmma_row_col_u8, 8x32x16_8x32x16_8x32x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using ElementA = uint8_t; -+ using ElementB = uint8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+#endif //CUTLASS_ARCH_WMMA_SM72_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm75.cu b/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm75.cu -new file mode 100644 -index 0000000..ebc0f3b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm75.cu -@@ -0,0 +1,170 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_wmma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////// SUBBYTE wmma.mma //////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_warp_wmma_row_col_s4, 64x64x32_8x8x32_8x8x32) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+ -+} -+ -+TEST(SM75_warp_wmma_row_col_s4, 64x64x32_64x64x32_8x8x32) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+ -+} -+ -+TEST(SM75_warp_wmma_row_col_s4, 64x64x64_8x8x64_8x8x32) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<8, 8, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+ -+} -+ -+TEST(SM75_warp_wmma_row_col_b1, 64x64x128_8x8x128_8x8x128) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using ElementA = cutlass::uint1b_t; -+ using ElementB = cutlass::uint1b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, cutlass::arch::OpXorPopc>().run(); -+ -+} -+ -+TEST(SM75_warp_wmma_row_col_b1, 64x64x128_64x64x128_8x8x128) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using ElementA = cutlass::uint1b_t; -+ using ElementB = cutlass::uint1b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, cutlass::arch::OpXorPopc>().run(); -+ -+} -+#endif //CUTLASS_ARCH_WMMA_SM75_ENABLED -diff --git a/3rdparty/cutlass/test/unit/layout/matrix.cu b/3rdparty/cutlass/test/unit/layout/matrix.cu -new file mode 100644 -index 0000000..c603ced ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/layout/matrix.cu -@@ -0,0 +1,151 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+* -+**************************************************************************************************/ -+/*! \file -+\brief unit tests for matrix layout -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_coord.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+namespace test { -+namespace layout { -+ void test_row_major_layout(int row_size, int column_size, int ldm) { -+ cutlass::layout::RowMajor row_major(ldm); -+ -+ // test pointer offset -+ for (int row_idx = 0; row_idx < row_size; row_idx++) { -+ for (int column_idx = 0; column_idx < column_size; column_idx++) { -+ cutlass::MatrixCoord matrix_coord(row_idx, column_idx); -+ auto ptr_offset = row_major(matrix_coord); -+ decltype(ptr_offset) reference_offset = row_idx * ldm + column_idx; -+ EXPECT_EQ(ptr_offset, reference_offset); -+ } -+ } -+ -+ // test stride -+ EXPECT_EQ(row_major.stride()[0], ldm); -+ -+ // test capacity -+ auto capacity = row_major.capacity(cutlass::MatrixCoord(row_size, column_size)); -+ decltype(capacity) reference_capacity = row_size * ldm; -+ EXPECT_EQ(capacity, reference_capacity); -+ -+ // test packed -+ auto packed = row_major.packed(cutlass::MatrixCoord(row_size, column_size)); -+ // the packed matrix's stride is the same with column size -+ EXPECT_EQ(packed.stride()[0], column_size); -+ } -+ -+ void test_column_major_layout(int row_size, int column_size, int ldm) { -+ cutlass::layout::ColumnMajor column_major(ldm); -+ -+ // test pointer offset -+ for (int row_idx = 0; row_idx < row_size; row_idx++) { -+ for (int column_idx = 0; column_idx < column_size; column_idx++) { -+ cutlass::MatrixCoord matrix_coord(row_idx, column_idx); -+ auto ptr_offset = column_major(matrix_coord); -+ decltype(ptr_offset) reference_offset = row_idx + column_idx * ldm; -+ EXPECT_EQ(ptr_offset, reference_offset); -+ } -+ } -+ -+ // test stride -+ EXPECT_EQ(column_major.stride()[0], ldm); -+ -+ // test capacity -+ auto capacity = column_major.capacity(cutlass::MatrixCoord(row_size, column_size)); -+ decltype(capacity) reference_capacity = column_size * ldm; -+ EXPECT_EQ(capacity, reference_capacity); -+ -+ // test packed -+ auto packed = column_major.packed(cutlass::MatrixCoord(row_size, column_size)); -+ // the packed matrix's stride is the same with row size -+ EXPECT_EQ(packed.stride()[0], row_size); -+ } -+ -+} // namespace layout -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Layout_Matrix, row_major_32_53) { -+ int const row_size = 32; -+ int const column_size = 53; -+ int const ldm = 55; -+ test::layout::test_row_major_layout(row_size, column_size, ldm); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Layout_Matrix, column_major_32_53) { -+ int const row_size = 32; -+ int const column_size = 53; -+ int const ldm = 55; -+ test::layout::test_column_major_layout(row_size, column_size, ldm); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Layout_Matrix, general_matrix) { -+ -+ int M = 16; -+ int N = 16; -+ int interleave = 4; -+ -+ cutlass::layout::GeneralMatrix::TensorCoord extent = {M, N}; -+ -+ cutlass::layout::GeneralMatrix layout = -+ cutlass::layout::GeneralMatrix::packed( -+ extent, cutlass::layout::Matrix::kColumnMajor, interleave); -+ -+ cutlass::HostTensor tensor(extent); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ tensor.host_data(m * N + n) = m * N + n; -+ } -+ } -+ -+ cutlass::TensorView canonical({tensor.host_data(), layout}, extent); -+ -+ // Uncomment this to view -+ // -+ //std::cout << canonical << std::endl; -+ // -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/layout/tensor.cu b/3rdparty/cutlass/test/unit/layout/tensor.cu -new file mode 100644 -index 0000000..253f0c0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/layout/tensor.cu -@@ -0,0 +1,153 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+* -+**************************************************************************************************/ -+/*! \file -+\brief unit tests for tensor layout -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+namespace test { -+namespace layout { -+ void test_NHWC_layout(int n_size, int h_size, int w_size, int c_size) { -+ int ldc = c_size + 1; -+ int ldw = ldc * (w_size + 2); -+ int ldh = ldw * (h_size + 3); -+ -+ cutlass::layout::TensorNHWC::Stride tensor_stride({ ldc, ldw, ldh }); -+ -+ cutlass::layout::TensorNHWC tensor_nhwc(tensor_stride); -+ -+ // test pointer offset -+ for (int n_idx = 0; n_idx < n_size; n_idx++) { -+ for (int h_idx = 0; h_idx < h_size; h_idx++) { -+ for (int w_idx = 0; w_idx < w_size; w_idx++) { -+ for (int c_idx = 0; c_idx < c_size; c_idx++) { -+ cutlass::Tensor4DCoord tensor_coord(n_idx, h_idx, w_idx, c_idx); -+ auto ptr_offset = tensor_nhwc(tensor_coord); -+ decltype(ptr_offset) reference_offset = c_idx + -+ w_idx * ldc + -+ h_idx * ldw + -+ n_idx * ldh; -+ EXPECT_EQ(ptr_offset, reference_offset); -+ } -+ } -+ } -+ } -+ -+ // test stride -+ auto stride = tensor_nhwc.stride(); -+ EXPECT_EQ(stride, tensor_stride); -+ -+ // test capacity -+ auto capacity = tensor_nhwc.capacity(cutlass::Tensor4DCoord(n_size, h_size, w_size, c_size)); -+ decltype(capacity) referece_capacity = ldh * n_size; -+ EXPECT_EQ(capacity, referece_capacity); -+ -+ // test packed -+ auto packed_tensor_layout = tensor_nhwc.packed(cutlass::Tensor4DCoord(n_size, h_size, w_size, c_size)); -+ auto packed_stride = packed_tensor_layout.stride(); -+ EXPECT_EQ(packed_stride, cutlass::layout::TensorNHWC::Stride({ c_size, w_size * c_size, h_size * w_size * c_size })); -+ } -+ -+ -+ void test_NCHW_layout(int n_size, int c_size, int h_size, int w_size) { -+ int ldw = w_size + 1; -+ int ldh = ldw * (h_size + 2); -+ int ldc = ldh * (c_size + 1); -+ -+ cutlass::layout::TensorNCHW::Stride tensor_stride({ ldw, ldh, ldc }); -+ -+ cutlass::layout::TensorNCHW tensor_nchw(tensor_stride); -+ -+ // test pointer offset -+ for (int n_idx = 0; n_idx < n_size; n_idx++) { -+ for (int c_idx = 0; c_idx < c_size; c_idx++) { -+ for (int h_idx = 0; h_idx < w_size; h_idx++) { -+ for (int w_idx = 0; w_idx < c_size; w_idx++) { -+ // tensor4DCoord is always created in nhwc order -+ cutlass::Tensor4DCoord tensor_coord(n_idx, h_idx, w_idx, c_idx); -+ auto ptr_offset = tensor_nchw(tensor_coord); -+ decltype(ptr_offset) reference_offset = w_idx + -+ h_idx * ldw + -+ c_idx * ldh + -+ n_idx * ldc; -+ EXPECT_EQ(ptr_offset, reference_offset); -+ } -+ } -+ } -+ } -+ -+ // test stride -+ auto stride = tensor_nchw.stride(); -+ EXPECT_EQ(stride, tensor_stride); -+ -+ // test capacity -+ auto capacity = tensor_nchw.capacity(cutlass::Tensor4DCoord(n_size, h_size, w_size, c_size)); -+ decltype(capacity) referece_capacity = ldc * n_size; -+ EXPECT_EQ(capacity, referece_capacity); -+ -+ // test packed -+ auto packed_tensor_layout = tensor_nchw.packed(cutlass::Tensor4DCoord(n_size, h_size, w_size, c_size)); -+ auto packed_stride = packed_tensor_layout.stride(); -+ EXPECT_EQ(packed_stride, cutlass::layout::TensorNHWC::Stride({ w_size, w_size * h_size, w_size * h_size * c_size })); -+ } -+} // namespace layout -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Layout_Tensor, NHWC_32_12_10_14) { -+ int n_size = 32; -+ int h_size = 12; -+ int w_size = 10; -+ int c_size = 14; -+ test::layout::test_NHWC_layout(n_size, h_size, w_size, c_size); -+ -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Layout_Tensor, NCHW_32_12_10_14) { -+ int n_size = 32; -+ int c_size = 12; -+ int h_size = 10; -+ int w_size = 14; -+ test::layout::test_NCHW_layout(n_size, c_size, h_size, w_size); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/layout/tensor_nhwc.cu b/3rdparty/cutlass/test/unit/layout/tensor_nhwc.cu -new file mode 100644 -index 0000000..e0f6b5b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/layout/tensor_nhwc.cu -@@ -0,0 +1,214 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+* -+**************************************************************************************************/ -+/*! \file -+\brief unit tests for NHWC tensor layout -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/util/device_memory.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+namespace test { -+namespace layout { -+ -+ void test_nhwc_layout(int n_size, int h_size, int w_size, int c_size) { -+ int ldc = c_size + 1; -+ int ldw = ldc * (w_size + 2); -+ int ldh = ldw * (h_size + 3); -+ -+ typedef cutlass::layout::TensorNHWC Tensor; -+ -+ Tensor::Stride tensor_stride({ ldc, ldw, ldh }); -+ Tensor tensor_nhw_packed_c(tensor_stride); -+ -+ // test pointer offset -+ for (int n_idx = 0; n_idx < n_size; n_idx++) { -+ for (int p_idx = 0; p_idx < h_size; p_idx++) { -+ for (int q_idx = 0; q_idx < w_size; q_idx++) { -+ for (int c_idx = 0; c_idx < c_size; c_idx++) { -+ cutlass::Tensor4DCoord tensor_coord(n_idx, p_idx, q_idx, c_idx); -+ auto ptr_offset = tensor_nhw_packed_c(tensor_coord); -+ decltype(ptr_offset) reference_offset = c_idx + -+ q_idx * ldc + -+ p_idx * ldw + -+ n_idx * ldh; -+ EXPECT_EQ(ptr_offset, reference_offset); -+ } -+ } -+ } -+ } -+ -+ // test stride -+ auto stride = tensor_nhw_packed_c.stride(); -+ EXPECT_EQ(stride, tensor_stride); -+ -+ // test capacity -+ auto capacity = tensor_nhw_packed_c.capacity( -+ cutlass::Tensor4DCoord(n_size, h_size, w_size, c_size)); -+ decltype(capacity) referece_capacity = ldh * n_size; -+ EXPECT_EQ(capacity, referece_capacity); -+ -+ } -+ -+ __global__ void test_nhwc_inverse( -+ int *output, int n_size, int h_size, int w_size, int c_size) { -+ int ldc = c_size; -+ int ldw = ldc * w_size; -+ int ldh = ldw * h_size; -+ -+ typedef cutlass::layout::TensorNHWC Tensor; -+ -+ Tensor::Stride tensor_stride({ ldc, ldw, ldh }); -+ Tensor tensor_nhw_packed_c(tensor_stride); -+ -+ for (int n_idx = 0; n_idx < n_size; n_idx++) { -+ for (int p_idx = 0; p_idx < h_size; p_idx++) { -+ for (int q_idx = 0; q_idx < w_size; q_idx++) { -+ cutlass::Tensor4DCoord tensor_coord(n_idx, p_idx, q_idx, threadIdx.x); -+ int ptr_offset = tensor_nhw_packed_c(tensor_coord); -+ cutlass::Tensor4DCoord inv_coord = tensor_nhw_packed_c.inverse(ptr_offset); -+ output[ptr_offset] = tensor_nhw_packed_c(inv_coord); -+ } -+ } -+ } -+ } -+ -+ class TestTensorNHWC { -+ public: -+ -+ // -+ // Data members -+ // -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ TestTensorNHWC() { -+ -+ } -+ -+ /// Runs the test -+ void run(int n_size, int h_size, int w_size, int c_size) { -+ -+ size_t size = n_size * h_size * w_size * c_size; -+ -+ /// Device memory containing output -+ cutlass::device_memory::allocation< int > output(size); -+ int *output_host = (int *)malloc(sizeof(int) * size); -+ -+ dim3 grid(1,1); -+ dim3 block(c_size, 1, 1); -+ -+ test::layout::test_nhwc_inverse<<< grid, block >>>(output.get(), -+ n_size, h_size, w_size, c_size); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ // -+ // Verify output -+ // -+ -+ cutlass::device_memory::copy_to_host(output_host, output.get(), size); -+ -+ result = cudaGetLastError(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ for (int n_idx = 0; n_idx < n_size; n_idx++) { -+ for (int p_idx = 0; p_idx < h_size; p_idx++) { -+ for (int q_idx = 0; q_idx < w_size; q_idx++) { -+ for (int c_idx = 0; c_idx < c_size; c_idx++) { -+ int reference_offset = c_idx + -+ q_idx * c_size + -+ p_idx * (c_size * w_size) + -+ n_idx * (c_size * w_size * h_size); -+ EXPECT_EQ(output_host[reference_offset], reference_offset); -+ } -+ } -+ } -+ } -+ } -+}; -+ -+ -+} // namespace layout -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Layout_TensorNHWC, NHWC_1_16_8_32) { -+ int n_size = 1; -+ int h_size = 16; -+ int w_size = 8; -+ int c_size = 32; -+ test::layout::test_nhwc_layout(n_size, h_size, w_size, c_size); -+ test::layout::TestTensorNHWC test_nhwc; -+ test_nhwc.run(n_size, h_size, w_size, c_size); -+ -+} -+ -+TEST(Layout_TensorNHWC, NHWC_2_16_8_32) { -+ int n_size = 2; -+ int h_size = 16; -+ int w_size = 8; -+ int c_size = 32; -+ test::layout::test_nhwc_layout(n_size, h_size, w_size, c_size); -+ test::layout::TestTensorNHWC test_nhwc; -+ test_nhwc.run(n_size, h_size, w_size, c_size); -+} -+ -+TEST(Layout_TensorNHWC, NHWC_2_16_8_128) { -+ int n_size = 2; -+ int h_size = 16; -+ int w_size = 8; -+ int c_size = 128; -+ test::layout::test_nhwc_layout(n_size, h_size, w_size, c_size); -+ test::layout::TestTensorNHWC test_nhwc; -+ test_nhwc.run(n_size, h_size, w_size, c_size); -+ -+} -+ -+TEST(Layout_TensorNHWC, NHWC_4_8_16_128) { -+ int n_size = 4; -+ int h_size = 8; -+ int w_size = 16; -+ int c_size = 128; -+ test::layout::test_nhwc_layout(n_size, h_size, w_size, c_size); -+ test::layout::TestTensorNHWC test_nhwc; -+ test_nhwc.run(n_size, h_size, w_size, c_size); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h b/3rdparty/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h -new file mode 100644 -index 0000000..94f3c78 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h -@@ -0,0 +1,43 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace nvrtc { -+ -+extern char const *kCutlassHeaders[]; -+extern char const *kCutlassHeaderNames[]; -+extern size_t const kCutlassHeaderCount; -+} // namespace nvrtc -+} // namespace cutlass -diff --git a/3rdparty/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h b/3rdparty/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h -new file mode 100644 -index 0000000..c2d9cde ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h -@@ -0,0 +1,76 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+ -+namespace test { -+namespace nvrtc { -+namespace kernel { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Thread-level matrix multiply-accumulate -+template -+__global__ void testbed_kernel( -+ typename Mma::ElementC *D, -+ typename Mma::ElementA const *A, -+ typename Mma::ElementB const *B, -+ typename Mma::ElementC const *C) { -+ -+ auto ptr_D = reinterpret_cast *>(D); -+ auto ptr_A = reinterpret_cast const *>(A); -+ auto ptr_B = reinterpret_cast const *>(B); -+ auto ptr_C = reinterpret_cast const *>(C); -+ -+ Mma mma; -+ -+ auto a = *ptr_A; -+ auto b = *ptr_B; -+ auto c = *ptr_C; -+ -+ cutlass::Array d; -+ -+ mma(d, a, b, c); -+ -+ *ptr_D = d; -+} -+ -+} -+} -+} -+} -+ -diff --git a/3rdparty/cutlass/test/unit/nvrtc/stdlib/assert.h b/3rdparty/cutlass/test/unit/nvrtc/stdlib/assert.h -new file mode 100644 -index 0000000..e69de29 -diff --git a/3rdparty/cutlass/test/unit/nvrtc/stdlib/stdint.h b/3rdparty/cutlass/test/unit/nvrtc/stdlib/stdint.h -new file mode 100644 -index 0000000..f6033de ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/nvrtc/stdlib/stdint.h -@@ -0,0 +1,129 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+typedef char int8_t; -+typedef unsigned char uint8_t; -+typedef short int16_t; -+typedef unsigned short uint16_t; -+typedef int int32_t; -+typedef unsigned int uint32_t; -+typedef long long int int64_t; -+typedef unsigned long long int uint64_t; -+ -+#if defined __x86_64__ && !defined __ILP32__ -+# define __WORDSIZE 64 -+#else -+# define __WORDSIZE 32 -+#endif -+ -+ -+/* Small types. */ -+ -+/* Signed. */ -+typedef signed char int_least8_t; -+typedef short int int_least16_t; -+typedef int int_least32_t; -+#if __WORDSIZE == 64 -+typedef long int int_least64_t; -+#else -+__extension__ -+typedef long long int int_least64_t; -+#endif -+ -+/* Unsigned. */ -+typedef unsigned char uint_least8_t; -+typedef unsigned short int uint_least16_t; -+typedef unsigned int uint_least32_t; -+#if __WORDSIZE == 64 -+typedef unsigned long int uint_least64_t; -+#else -+__extension__ -+typedef unsigned long long int uint_least64_t; -+#endif -+ -+ -+/* Fast types. */ -+ -+/* Signed. */ -+typedef signed char int_fast8_t; -+#if __WORDSIZE == 64 -+typedef long int int_fast16_t; -+typedef long int int_fast32_t; -+typedef long int int_fast64_t; -+#else -+typedef int int_fast16_t; -+typedef int int_fast32_t; -+__extension__ -+typedef long long int int_fast64_t; -+#endif -+ -+/* Unsigned. */ -+typedef unsigned char uint_fast8_t; -+#if __WORDSIZE == 64 -+typedef unsigned long int uint_fast16_t; -+typedef unsigned long int uint_fast32_t; -+typedef unsigned long int uint_fast64_t; -+#else -+typedef unsigned int uint_fast16_t; -+typedef unsigned int uint_fast32_t; -+__extension__ -+typedef unsigned long long int uint_fast64_t; -+#endif -+ -+/* Types for `void *' pointers. */ -+#if __WORDSIZE == 64 -+# ifndef __intptr_t_defined -+typedef long int intptr_t; -+# define __intptr_t_defined -+# endif -+typedef unsigned long int uintptr_t; -+#else -+# ifndef __intptr_t_defined -+typedef int intptr_t; -+# define __intptr_t_defined -+# endif -+typedef unsigned int uintptr_t; -+#endif -+ -+ -+/* Largest integral types. */ -+#if __WORDSIZE == 64 -+typedef long int intmax_t; -+typedef unsigned long int uintmax_t; -+#else -+__extension__ -+typedef long long int intmax_t; -+__extension__ -+typedef unsigned long long int uintmax_t; -+#endif -+ -diff --git a/3rdparty/cutlass/test/unit/nvrtc/thread/gemm_nvrtc.cu b/3rdparty/cutlass/test/unit/nvrtc/thread/gemm_nvrtc.cu -new file mode 100644 -index 0000000..8b9b8bb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/nvrtc/thread/gemm_nvrtc.cu -@@ -0,0 +1,203 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "testbed.h" -+ -+#if 0 -+int main() { -+ nvrtc::thread::Testbed< -+ cutlass::gemm::GemmShape<3, 4, 2>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run("cutlass::gemm::thread::Mma, float, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor, float, cutlass::layout::ColumnMajor >"); -+ return 0; -+} -+#endif -+ -+TEST(SM50_Sgemm_thread_nvrtc, DISABLED_col_row_3x4x2) { -+ -+ test::nvrtc::thread::Testbed< -+ cutlass::gemm::GemmShape<3, 4, 2>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run("cutlass::gemm::thread::Mma, float, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor, float, cutlass::layout::ColumnMajor >"); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#if 0 -+TEST(SM50_Sgemm_thread, col_row_3x4x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<3, 4, 2>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_row_4x4x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<4, 4, 2>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, row_col_4x4x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<4, 4, 2>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_row_4x5x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<4, 5, 3>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_row) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, row_col) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_col) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, row_row) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Dgemm_thread, col_row) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Dgemm_thread, row_col) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+#endif -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/nvrtc/thread/testbed.h b/3rdparty/cutlass/test/unit/nvrtc/thread/testbed.h -new file mode 100644 -index 0000000..378be81 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/nvrtc/thread/testbed.h -@@ -0,0 +1,323 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/gemm/thread/mma.h" -+#include "../kernel/thread/testbed_kernel.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include -+#include -+#include "../cutlass/nvrtc/environment.h" -+#include -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace nvrtc { -+namespace thread { -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC -+> -+struct Testbed { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = cutlass::gemm::thread::Mma< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC -+ >; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed() { -+ -+ tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); -+ tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ static inline bool check_nvrtc_error(nvrtcResult error) { -+ if (error != NVRTC_SUCCESS) { -+ std::cerr << "failed to compile "; -+ return false; -+ } -+ return true; -+ } -+ -+ /// Runs the test -+ bool run(std::string const &gemm_traits) { -+ -+ // -+ // initialize device memory -+ // -+ -+ cutlass::reference::host::BlockFillSequential( -+ tensor_A.host_data(), -+ tensor_A.capacity() -+ ); -+ -+ cutlass::reference::host::BlockFillSequential( -+ tensor_B.host_data(), -+ tensor_B.capacity(), -+ ElementB(1), -+ ElementB(2) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ -+#if 0 -+ // launch kernel -+ cutlass::gemm::kernel::testbed_kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( -+ tensor_D_computed.device_data(), -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data()); -+ -+#else -+ // Instantiate gemm_kernel -+ nvrtcResult result_nvrtc; -+ nvrtcProgram program; -+ static char const *src = -+ "#include \"cutlass/gemm/thread/mma.h\"\n" -+ "#include \"cutlass/gemm/gemm.h\"\n" -+ "#include \"cutlass/layout/matrix.h\"\n" -+ "#include \"unit/nvrtc/kernel/thread/testbed_kernel.h\"\n" -+ ; -+ -+ std::string type_name; -+#if 0 -+ // TODO Ideally we'd use nvrtcGetTypeName to determine the type, but it cannot resolve enum symbol names -+ // As altername solution we might want to implement to_string() to get the traits string. -+ nvrtcGetTypeName(&type_name); -+#else -+ type_name = gemm_traits; -+#endif -+ -+ result_nvrtc = nvrtcCreateProgram(&program, -+ src, -+ NULL, -+ (int)cutlass::nvrtc::kCutlassHeaderCount, -+ cutlass::nvrtc::kCutlassHeaders, -+ cutlass::nvrtc::kCutlassHeaderNames); -+ check_nvrtc_error(result_nvrtc); -+ -+ std::string gemm_kernel_instantiation = -+ "test::nvrtc::kernel::thread::testbed_kernel< " + type_name + " >"; -+ nvrtcAddNameExpression(program, gemm_kernel_instantiation.c_str()); -+ -+ const char *opts[] = {"--gpu-architecture=compute_75", -+ "--std=c++11", -+ "--include-path=/usr/local/cuda-10.1/include"}; -+ -+ result_nvrtc = nvrtcCompileProgram(program, 3, opts); -+ if (result_nvrtc != NVRTC_SUCCESS) { -+ size_t logSize; -+ nvrtcGetProgramLogSize(program, &logSize); -+ std::vector log(logSize); -+ nvrtcGetProgramLog(program, log.data()); -+ std::cout << "Compile log:" << std::endl << log.data() << std::endl; -+ } -+ if (!check_nvrtc_error(result_nvrtc)) { -+ assert(0); -+ } -+ -+ // The lowered name is the name of the template instantiation in the generated PTX code. -+ char const *gemm_kernel_lowered_name; -+ nvrtcGetLoweredName(program, gemm_kernel_instantiation.c_str(), &gemm_kernel_lowered_name); -+ if (!check_nvrtc_error(result_nvrtc)) { -+ assert(0); -+ } -+ -+ // Query the size of the genereated PTX so that we can allocate storage and retrieve it afterwards -+ size_t ptx_size; -+ result_nvrtc = nvrtcGetPTXSize(program, &ptx_size); -+ if (!check_nvrtc_error(result_nvrtc)) { -+ assert(0); -+ } -+ -+ std::vector ptx(ptx_size); -+ result_nvrtc = nvrtcGetPTX(program, ptx.data()); -+ if (!check_nvrtc_error(result_nvrtc)) { -+ assert(0); -+ } -+ -+ // we do not need the nvrtc program anymore -+ //nvrtcDestroyProgram(&program); -+ -+ CUmodule module; -+ CUresult result_cuda; -+ result_cuda = cuModuleLoadDataEx(&module, ptx.data(), 0, 0, 0); -+ if (result_cuda != CUDA_SUCCESS) { -+ assert(0); -+ } -+ -+ CUfunction kernel; -+ result_cuda = cuModuleGetFunction(&kernel, module, gemm_kernel_lowered_name); -+ if (result_cuda != CUDA_SUCCESS) { -+ assert(0); -+ } -+ -+ void* d_a = (void*)tensor_A.device_data(); -+ void* d_b = (void*)tensor_B.device_data(); -+ void* d_c = (void*)tensor_C.device_data(); -+ void* d_d = (void*)tensor_D_computed.device_data(); -+ void* args[] = { &d_d, &d_a, &d_b, &d_c }; -+ -+ // CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra -+ result_cuda = cuLaunchKernel(kernel, 1, 1, 1, 1, 1, 1, 0, 0 /*cudaStreamDefault*/, args, 0); -+ if (result_cuda != CUDA_SUCCESS) { -+ assert(0); -+ } else { -+} -+#endif -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cout << "CUDA ERROR: " << cudaGetErrorString(result); -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ -+ //tensor_D_reference.fill(tensor_C.host_view()); -+ -+ cutlass::reference::host::Gemm reference_gemm; -+ -+ reference_gemm( -+ {Shape::kM, Shape::kN, Shape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ ElementC(0), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ if(!passed) std::cout -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ -+ std::cout << "passed " << passed << std::endl; -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace nvrtc -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/pipeline/pipeline_async.cu b/3rdparty/cutlass/test/unit/pipeline/pipeline_async.cu -new file mode 100644 -index 0000000..d2adad6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/pipeline/pipeline_async.cu -@@ -0,0 +1,468 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit test for the PipelineAsync class -+*/ -+ -+#define KERNEL_DBG_TRACE false -+ -+#include "../common/cutlass_unit_test.h" -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cutlass/util/print_error.hpp" -+#include "cutlass/util/GPU_Clock.hpp" -+ -+#include "testbed.h" -+#include "cutlass/pipeline.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cute/arch/cluster_sm90.hpp" -+ -+using namespace cute; -+ -+//////////////////// KERNEL ///////////////////////// -+ -+template -+struct SharedStorage -+{ -+ typename cutlass::PipelineAsync::SharedStorage storage; -+}; -+ -+// Goal of this kernel is to complete deadlock-free -+// Simple 1 producer warp, one consumer warp scenario -+template -+__global__ static -+void pipeline_async_basic_device(uint32_t const num_iterations) -+{ -+ -+ extern __shared__ char shared_memory[]; -+ using MainloopPipeline = typename cutlass::PipelineAsync; -+ using PipelineState = typename cutlass::PipelineState; -+ -+ using SharedStorage = SharedStorage; -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ -+ -+ auto cta_layout = Layout{}; // (m,n) -> cta_id -+ -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ int lane_predicate = cute::elect_one_sync(); -+ dim3 block_id_in_cluster = cute::block_id_in_cluster(); -+ auto cluster_shape = ClusterShape{}; -+ -+ // This example showcases 2 producer 1 consumer example -+ typename MainloopPipeline::Params params; -+ params.producer_arv_count = 2; -+ params.consumer_arv_count = 1; -+ MainloopPipeline pipeline(shared_storage.storage, params); -+ -+ // Ensure All CTAs in Cluster have completed init before issuing commits -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+ __syncthreads(); -+ -+ if (lane_predicate) { -+ // Producer Warps -+ if (warp_idx==0 || warp_idx==1) { -+ -+ int prologue_iterations = min(NumStages, num_iterations); -+ for ( int i = 0; i < prologue_iterations; ++i) { -+ // Can also specify stage to commit directly -+ pipeline.producer_commit(i); -+ } -+ -+ int mainloop_iterations = num_iterations - prologue_iterations; -+ -+ // Only the mainloop needs a PipelineState because this is where we start "waiting" (acquiring) -+ PipelineState smem_pipe_write; -+ -+ for ( ; mainloop_iterations > 0; --mainloop_iterations) { -+ pipeline.producer_acquire(smem_pipe_write); -+ pipeline.producer_commit(smem_pipe_write); -+ ++smem_pipe_write; -+ } -+ } -+ else { -+ PipelineState smem_pipe_read; -+ for (int iter=0 ; iter < num_iterations; ++iter) { -+ pipeline.consumer_wait(smem_pipe_read); -+ pipeline.consumer_release(smem_pipe_read.index()); -+ ++smem_pipe_read; -+ } -+ } -+ } -+ -+ // To make sure remote SMEM doesn't get destroyed -+ cute::cluster_arrive(); -+ cute::cluster_wait(); -+} -+///////////////////////////////////////////////////// -+ -+template -+struct PipelineTest { -+ -+ // -+ // Data members -+ // -+ static constexpr uint32_t Stages = Stages_; -+ static constexpr uint32_t kBlockSize = 96; -+ using ClusterShape = ClusterShape_; -+ -+ // -+ // Methods -+ // -+ -+ // Ctor -+ PipelineTest() = default; -+ -+ -+ // Run CuTe GEMM kernel -+ cudaError_t run(uint32_t const kNumIters, -+ cudaStream_t stream = nullptr) { -+ -+ // Pipeline (multistage pipeline) -+ auto num_stages = Int{}; -+ -+ auto cluster_shape = Shape, Int, _1>{}; -+ -+ // -+ // Configure and launch -+ // -+ int iterations = 2; -+ cudaError_t result; -+ -+ for (int iter = 0; iter < iterations; ++iter) { -+ -+ // Define the tiled MMA layout (static, 4warps) -+ using MainloopPipeline = typename cutlass::PipelineAsync; -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ -+ result = cudaFuncSetAttribute( -+ pipeline_async_basic_device, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ // Launch a single Cluster, with 128 thread per CTA -+ dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimBlock(kBlockSize,1,1); -+ -+ const void* kernel = (const void*)pipeline_async_basic_device; -+ int iters = kNumIters; -+ void* kernel_params[] = {reinterpret_cast(&iters)}; -+ cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); -+ -+ } // profiling loop ends -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; -+ return result; -+ } -+ -+ return cudaSuccess; -+ } -+ -+}; -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+TEST(SM90_Verify_PipelineAsync, Cluster1x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x1_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x1_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x2_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x2_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x2_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage3) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 3; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage4) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 4; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage6) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 6; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage8) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 8; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage9) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 9; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage11) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 11; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async.cu b/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async.cu -new file mode 100644 -index 0000000..90e0ca3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async.cu -@@ -0,0 +1,469 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit test for the PipelineTmaAsync class -+*/ -+ -+ -+#define KERNEL_DBG_TRACE false -+ -+#include "../common/cutlass_unit_test.h" -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cutlass/util/print_error.hpp" -+#include "cutlass/util/GPU_Clock.hpp" -+ -+#include "testbed.h" -+#include "cutlass/pipeline.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cute/arch/cluster_sm90.hpp" -+ -+using namespace cute; -+ -+//////////////////// KERNEL ///////////////////////// -+ -+template -+struct SharedStorage -+{ -+ typename cutlass::PipelineTmaAsync::SharedStorage storage; -+}; -+ -+// Goal of this kernel is to complete deadlock-free -+template -+__global__ static -+void pipeline_device(uint32_t const NumIterations) -+{ -+ -+ extern __shared__ char shared_memory[]; -+ using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmma; -+ using MainloopPipeline = cutlass::PipelineTmaAsync; -+ using PipelineState = cutlass::PipelineState; -+ -+ using SharedStorage = SharedStorage; -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ -+ auto cta_layout = Layout{}; // (m,n) -> cta_id -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ int warp_group_thread_idx = threadIdx.x % 128; -+ dim3 block_id_in_cluster = cute::block_id_in_cluster(); -+ -+ auto cluster_shape = ClusterShape{}; -+ -+ // #Producers = #RowsInCluster + #ColsInCluster - 1 -+ uint32_t const NumProducers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; -+ uint32_t const TmaTransactionBytes = sizeof(uint32_t) * NumProducers; -+ uint32_t const per_cta_bytes = sizeof(uint32_t); -+ -+ // mbarrier.init -+ typename MainloopPipeline::Params params; -+ params.transaction_bytes = TmaTransactionBytes; -+ params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; -+ params.is_leader = warp_group_thread_idx == 0; -+ params.num_consumers = 128; -+ -+ MainloopPipeline pipeline(shared_storage.storage, params); -+ -+ __syncthreads(); -+ -+ // Ensure All CTAs in Cluster have completed init before issuing commits -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+ -+ // Total number of gemm_k_iterations -+ auto mma_k_iterations = NumIterations; -+ auto tma_k_iterations = NumIterations; -+ -+ PipelineState smem_pipe_read; -+ // For the DMA (prologue) - we start with an opposite phase - since we skip all waits -+ // i.e., we know that the buffer is indeed empty -+ PipelineState smem_pipe_write = cutlass::make_producer_start_state(); -+ PipelineState smem_pipe_release; -+ int K_TILE_MMAS = 1; -+ -+ int lane_predicate = cute::elect_one_sync(); -+ int k_pipe_tma_prologue = min(NumStages, tma_k_iterations); -+ -+ // DMA Prologue (Loads) -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < k_pipe_tma_prologue; ++i) { -+ pipeline.producer_acquire(smem_pipe_write); -+ // cp.async.bulk.tensor would typically happen here -+ pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); -+ ++smem_pipe_write; -+ } -+ tma_k_iterations -= k_pipe_tma_prologue; -+ -+ // MMA Prologue (Compute) - modeling inflight MMAs -+ for (int iter = 0; iter < K_TILE_MMAS; ++iter) -+ { -+ pipeline.consumer_wait(smem_pipe_read); -+ warpgroup_arrive(); -+ // GMMA would typically happen here -+ -+ ++smem_pipe_read; -+ } -+ -+ mma_k_iterations -= K_TILE_MMAS; -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int iter = 0; iter < mma_k_iterations; ++iter) -+ { -+ pipeline.consumer_wait(smem_pipe_read); -+ -+ warpgroup_arrive(); -+ // GMMA would typically happen here -+ -+ pipeline.consumer_release(smem_pipe_release); -+ -+ if (lane_predicate && (warp_idx == 0) && (tma_k_iterations > 0)) { -+ pipeline.producer_acquire(smem_pipe_write); -+ // cp.async.bulk.tensor would typically happen here -+ pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); -+ ++smem_pipe_write; -+ --tma_k_iterations; -+ } -+ -+ // next read stage -+ ++smem_pipe_read; -+ ++smem_pipe_release; -+ } -+ -+ // To make sure remote SMEM doesn't get destoryed -+ cute::cluster_arrive(); -+ cute::cluster_wait(); -+} -+///////////////////////////////////////////////////// -+ -+/// Device NT GMMA + TMA specialized -+template -+struct PipelineTest { -+ -+ // -+ // Data members -+ // -+ static constexpr uint32_t Stages = Stages_; -+ static constexpr uint32_t kBlockSize = 128; -+ using ClusterShape = ClusterShape_; -+ -+ // -+ // Methods -+ // -+ -+ // Ctor -+ PipelineTest(){}; -+ -+ -+ // Run CuTe GEMM kernel -+ cudaError_t run(uint32_t const kNumIters, -+ cudaStream_t stream = 0) { -+ -+ float elapsed_ms = 0.0f; -+ // Pipeline (multistage pipeline) -+ auto num_stages = Int{}; -+ -+ auto cluster_shape = Shape, Int, _1>{}; -+ -+ // -+ // Configure and launch -+ // -+ int iterations = 1; -+ cudaEvent_t events[2]; -+ cudaError_t result; -+ -+ for (cudaEvent_t & event : events) { -+ result = cudaEventCreate(&event); -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to create event."; -+ return result; -+ } -+ } -+ -+ result = cudaEventRecord(events[0]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to record start event."; -+ return result; -+ } -+ -+ for (int iter = 0; iter < iterations; ++iter) { -+ -+ // Define the tiled MMA layout (static, 4warps) -+ using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmma; -+ using MainloopPipeline = typename cutlass::PipelineTmaAsync; -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ -+ result = cudaFuncSetAttribute( -+ pipeline_device, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ // Launch a single Cluster, with 128 thread per CTA -+ dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimBlock(kBlockSize,1,1); -+ -+ const void* kernel = (const void*)pipeline_device; -+ int iters = kNumIters; -+ void* kernel_params[] = {reinterpret_cast(&iters)}; -+ cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); -+ -+ } // profiling loop ends -+ -+ result = cudaEventRecord(events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to record stop event."; -+ return result; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; -+ return result; -+ } -+ -+ result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to create event."; -+ return result; -+ } -+ -+ for (cudaEvent_t & event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ return cudaSuccess; -+ } -+}; -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x1_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x1_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x2_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x2_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster4x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster4x4_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x2_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster4x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster4x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster4x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster4x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu b/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu -new file mode 100644 -index 0000000..f0d6a79 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu -@@ -0,0 +1,525 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit test for the PipelineTmaAsync class as it would be used in a Warp specialized loop -+*/ -+ -+#define KERNEL_DBG_TRACE false -+ -+#include "../common/cutlass_unit_test.h" -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/print_error.hpp" -+#include "cutlass/util/GPU_Clock.hpp" -+ -+#include "testbed.h" -+#include "cutlass/pipeline.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cutlass/arch/reg_reconfig.h" -+ -+ -+using namespace cute; -+using namespace cutlass; -+ -+//////////////////// KERNEL ///////////////////////// -+ -+template -+struct SharedStorage -+{ -+ typename cutlass::PipelineTmaAsync::SharedStorage storage ; -+}; -+ -+struct KernelParams -+{ -+ uint32_t num_iterations; -+ int* data_ptr; -+}; -+ -+// Goal of this kernel is to complete deadlock-free -+template -+__launch_bounds__(384, 1) -+__global__ static -+void pipeline_device(KernelParams const kernel_params) -+{ -+ extern __shared__ char shared_memory[]; -+ using MainloopPipeline = typename cutlass::PipelineTmaAsync; -+ using PipelineState = typename cutlass::PipelineState; -+ -+ using SharedStorage = SharedStorage; -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ -+ auto cta_layout = Layout{}; // (m,n) -> cta_id -+ int warp_group_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); -+ int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); -+ int warp_group_thread_idx = threadIdx.x % 128; -+ dim3 block_id_in_cluster = cute::block_id_in_cluster(); -+ -+ auto cluster_shape = ClusterShape{}; -+ -+ // #Producers = #RowsInCluster + #ColsInCluster - 1 -+ uint32_t const NumProducers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; -+ uint32_t const TmaTransactionBytes = static_cast(sizeof(uint32_t) * NumProducers); -+ uint32_t const per_cta_bytes = sizeof(uint32_t); -+ -+ // mbarrier.init -+ typename MainloopPipeline::Params params; -+ params.transaction_bytes = TmaTransactionBytes; -+ if (warp_group_idx == 0) { -+ params.role = MainloopPipeline::ThreadCategory::Producer; -+ } -+ else { -+ params.role = MainloopPipeline::ThreadCategory::Consumer; -+ } -+ params.is_leader = warp_group_thread_idx == 0; -+ params.num_consumers = 128; -+ -+ MainloopPipeline pipeline(shared_storage.storage, params); -+ -+ __syncthreads(); -+ -+ // Ensure All CTAs in Cluster have completed init before issuing commits -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+ -+ -+ // Producer WarpGroup -+ if (warp_group_idx == 0) { -+ cutlass::arch::warpgroup_reg_alloc<232>(); -+ -+ int lane_predicate = cute::elect_one_sync(); -+ if (warp_idx_in_warpgroup == 0 && lane_predicate) { -+ -+ int tma_k_prologue = min(Stages, kernel_params.num_iterations); -+ -+ // Simulating Prologue TMA Loads -+ // For the DMA (prologue) - we start with an opposite phase - since we skip all waits -+ // i.e., we know that the buffer is indeed empty -+ PipelineState smem_pipe_write = make_producer_start_state(); -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < tma_k_prologue; ++i) { -+ pipeline.producer_acquire(smem_pipe_write); -+ // Simulating cp.async.bulk.tensor behavior -+ pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); -+ ++smem_pipe_write; -+ } -+ int tma_k_iter = kernel_params.num_iterations - tma_k_prologue; -+ -+ // Simulating Mainloop TMA Loads -+ CUTE_NO_UNROLL -+ for ( ; tma_k_iter > 0; --tma_k_iter) { -+ -+ pipeline.producer_acquire(smem_pipe_write); -+ -+ // Simulating cp.async.bulk.tensor behavior -+ pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); -+ -+ // Advance write stage -+ ++smem_pipe_write; -+ } -+ -+ // Tail Loop -+ // Handles the case where we never enter the mainloop -+ PipelineState tail = tma_k_prologue == Stages ? smem_pipe_write : PipelineState{}; -+ for ( int i = 0; i < tma_k_prologue; ++i) { -+ pipeline.producer_acquire(tail); -+ ++tail; -+ } -+ } -+ // Consumer WarpGroup -+ } else if(warp_group_idx == 1) { -+ cutlass::arch::warpgroup_reg_alloc<232>(); -+ -+ PipelineState smem_pipe_read; -+ PipelineState smem_pipe_release; -+ -+ // simulates accumulators + extra reg. pressure -+ int arr[168]; -+ -+ // Init Shared Memory read stages & PhaseBit -+ static constexpr uint32_t K_PIPE_MMAS = 1; -+ static_assert( K_PIPE_MMAS < Stages, "ERROR : Too many MMAs in flight"); -+ -+ // Total number of gemm iterations -+ auto gemm_k_iterations = kernel_params.num_iterations; -+ -+ // Simulating Prologue MMAs -+ int mma_k_prologue = min(K_PIPE_MMAS, gemm_k_iterations); -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < mma_k_prologue; ++iter) { -+ pipeline.consumer_wait(smem_pipe_read); -+ -+ warpgroup_arrive(); -+ // GMMA would typically happen here -+ -+ ++smem_pipe_read; -+ } -+ gemm_k_iterations -= mma_k_prologue; -+ -+ // Simulating Mainloop MMAs -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; gemm_k_iterations > 0; --gemm_k_iterations) { -+ -+ /// Wait on the smem_pipe_read stage / phase -+ pipeline.consumer_wait(smem_pipe_read); -+ -+ warpgroup_arrive(); -+ // GMMA would typically happen here -+ -+ // Dummy op - which will never happen -+ // But simulates high register usage. -+ CUTE_UNROLL -+ for(int i = 0; i < 168; ++i){ -+ if (threadIdx.x > 256){ -+ arr[i] += kernel_params.data_ptr[i]; -+ } -+ } -+ -+ pipeline.consumer_release(smem_pipe_release); -+ -+ // Advance stages -+ ++smem_pipe_read; -+ ++smem_pipe_release; -+ } -+ -+ // Dummy op - which will never happen -+ CUTE_UNROLL -+ for(int i = 0; i < 168; ++i){ -+ if (threadIdx.x > 256){ -+ kernel_params.data_ptr[i] = arr[i]; -+ } -+ } -+ -+ // Tail Loop -+ for (int i = 0; i < K_PIPE_MMAS; ++i){ -+ pipeline.consumer_release(smem_pipe_release); -+ ++smem_pipe_release; -+ } -+ -+ // Warp-Group #2 -+ } else { -+ cutlass::arch::warpgroup_reg_dealloc<40>(); -+ } -+} -+///////////////////////////////////////////////////// -+ -+/// Device NT GMMA + TMA specialized -+template -+struct PipelineTest { -+ -+ // -+ // Data members -+ // -+ static constexpr uint32_t Stages = Stages_; -+ static constexpr uint32_t kBlockSize = 128 * 3; -+ using ClusterShape = ClusterShape_; -+ -+ // -+ // Methods -+ // -+ -+ // Ctor -+ PipelineTest(){}; -+ -+ // Run CuTe GEMM kernel -+ cudaError_t run(uint32_t const kNumIters, -+ cudaStream_t stream = 0) { -+ -+ float elapsed_ms = 0.0f; -+ // Pipeline (multistage pipeline) -+ auto num_stages = Int{}; -+ auto cluster_shape = Shape, Int, _1>{}; -+ -+ // -+ // Configure and launch -+ // -+ int iterations = 1; -+ cudaEvent_t events[2]; -+ cudaError_t result; -+ -+ for (cudaEvent_t & event : events) { -+ result = cudaEventCreate(&event); -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to create event."; -+ return result; -+ } -+ } -+ -+ result = cudaEventRecord(events[0]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to record start event."; -+ return result; -+ } -+ -+ for (int iter = 0; iter < iterations; ++iter) { -+ -+ using MainloopPipeline = typename cutlass::PipelineTmaAsync; -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ -+ result = cudaFuncSetAttribute( -+ pipeline_device, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ // Launch a single Cluster, with kBlockSize threads per CTA -+ dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimBlock(kBlockSize,1,1); -+ -+ const void* kernel = (const void*)pipeline_device; -+ KernelParams params{kNumIters, nullptr}; -+ void* kernel_params[] = {reinterpret_cast(¶ms)}; -+ cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); -+ -+ } -+ -+ result = cudaEventRecord(events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to record stop event."; -+ return result; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; -+ return result; -+ } -+ -+ result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to create event."; -+ return result; -+ } -+ -+ for (cudaEvent_t & event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ return cudaSuccess; -+ } -+}; -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x1_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x1_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x2_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu b/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu -new file mode 100644 -index 0000000..4b6a3b1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu -@@ -0,0 +1,585 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit test for the PipelineTmaAsync class used in a WarpSpecialized Persistent loop -+*/ -+ -+#define KERNEL_DBG_TRACE false -+ -+#include "../common/cutlass_unit_test.h" -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/print_error.hpp" -+#include "cutlass/util/GPU_Clock.hpp" -+ -+#include "testbed.h" -+#include "cutlass/pipeline.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cutlass/arch/reg_reconfig.h" -+ -+ -+using namespace cute; -+using namespace cutlass; -+ -+//////////////////// KERNEL ///////////////////////// -+ -+template -+struct SharedStorage -+{ -+ typename cutlass::PipelineTmaAsync::SharedStorage pipeline_storage; -+ typename PingPongBarrier::SharedStorage pingpong_storage; -+}; -+ -+template -+struct CollectiveSimulation { -+ using MainloopPipeline = typename cutlass::PipelineTmaAsync; -+ using PipelineState = typename cutlass::PipelineState; -+ -+ CUTLASS_DEVICE -+ static void -+ dma_wg_simulation(MainloopPipeline pipeline, PipelineState tile_start_state_pipe, -+ uint32_t const num_iterations) { -+ uint32_t const per_cta_bytes = sizeof(uint32_t); -+ int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); -+ int lane_predicate = cute::elect_one_sync(); -+ if (warp_idx_in_warpgroup==0 && lane_predicate) { -+ -+ int tma_k_prologue = min(Stages, num_iterations); -+ -+ // Simulating Prologue TMA Loads -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < tma_k_prologue; ++i) { -+ pipeline.producer_acquire(tile_start_state_pipe); -+ // Simulating cp.async.bulk.tensor behavior -+ pipeline.producer_commit(tile_start_state_pipe.index(), per_cta_bytes); -+ ++tile_start_state_pipe; -+ } -+ int tma_k_iter = num_iterations - tma_k_prologue; -+ -+ PipelineState wr_pipe = tile_start_state_pipe; -+ // Simulating Mainloop TMA Loads -+ CUTE_NO_UNROLL -+ for ( ; tma_k_iter > 0; --tma_k_iter){ -+ -+ pipeline.producer_acquire(wr_pipe); -+ -+ // Simulating cp.async.bulk.tensor behavior -+ pipeline.producer_commit(wr_pipe.index(), per_cta_bytes); -+ -+ // Advance write stage -+ ++wr_pipe; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ static void -+ math_wg_simulation(MainloopPipeline pipeline, PipelineState tile_start_state_pipe, -+ uint32_t const num_iterations, int* data_ptr) { -+ PipelineState rd_pipe = tile_start_state_pipe; -+ PipelineState release_pipe = rd_pipe; -+ -+ // simulates accumulators + extra reg. pressure -+ int arr[168]; -+ -+ // Init Shared Memory read stages & PhaseBit -+ static constexpr uint32_t K_PIPE_MMAS = 1; -+ static_assert( K_PIPE_MMAS < Stages, "ERROR : Too many MMAs in flight"); -+ -+ // Total number of gemm iterations -+ auto gemm_k_iterations = num_iterations; -+ -+ // Simulating Prologue MMAs -+ int mma_k_prologue = min(K_PIPE_MMAS, gemm_k_iterations); -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < mma_k_prologue; ++iter) { -+ pipeline.consumer_wait(rd_pipe); -+ -+ warpgroup_arrive(); -+ // GMMA would typically happen here -+ -+ ++rd_pipe; -+ } -+ gemm_k_iterations -= mma_k_prologue; -+ -+ // Simulating Mainloop MMAs -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; gemm_k_iterations > 0; --gemm_k_iterations) { -+ -+ /// Wait on the rd_pipe stage / phase -+ pipeline.consumer_wait(rd_pipe); -+ -+ warpgroup_arrive(); -+ // GMMA would typically happen here -+ -+ // Dummy op - which will never happen -+ // But simulates high register usage. -+ CUTE_UNROLL -+ for(int i = 0; i < 168; ++i){ -+ if (threadIdx.x > 384){ -+ arr[i] += data_ptr[i]; -+ } -+ } -+ -+ pipeline.consumer_release(release_pipe); -+ -+ // Advance stages -+ ++rd_pipe; -+ ++release_pipe; -+ } -+ -+ // Dummy op - which will never happen -+ CUTE_UNROLL -+ for(int i = 0; i < 168; ++i){ -+ if (threadIdx.x > 384){ -+ data_ptr[i] = arr[i]; -+ } -+ } -+ -+ // Tail Loop -+ for (int i = 0; i < K_PIPE_MMAS; ++i){ -+ pipeline.consumer_release(release_pipe); -+ ++release_pipe; -+ } -+ -+ } -+}; -+ -+struct KernelParams -+{ -+ uint32_t num_iterations; -+ int tiles_per_cluster; -+ int* data_ptr; -+}; -+ -+// Goal of this kernel is to complete deadlock-free -+template -+__launch_bounds__(384, 1) -+__global__ static -+void pipeline_device(KernelParams params) -+{ -+ extern __shared__ char shared_memory[]; -+ using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized; -+ using MainloopPipeline = typename cutlass::PipelineTmaAsync; -+ using PipelineState = typename cutlass::PipelineState; -+ -+ /* One for Mainloop and one for Epilogue */ -+ constexpr int StagesPerMathWarpGroup = 2; -+ constexpr int MathWarpGroupCountPersistent = 2; -+ using PingPongBarrier = typename cutlass::OrderedSequenceBarrier; -+ -+ using SharedStorage = SharedStorage; -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ -+ auto cta_layout = Layout{}; // (m,n) -> cta_id -+ int warp_group_idx = __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarpGroup, 0); -+ int warp_group_thread_idx = threadIdx.x % NumThreadsPerWarpGroup; -+ dim3 block_id_in_cluster = cute::block_id_in_cluster(); -+ -+ auto cluster_shape = ClusterShape{}; -+ -+ // #Producers = #RowsInCluster + #ColsInCluster - 1 -+ uint32_t const NumProducers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; -+ uint32_t const TmaTransactionBytes = static_cast(sizeof(uint32_t) * NumProducers); -+ -+ // mbarrier.init -+ typename MainloopPipeline::Params pipeline_params; -+ pipeline_params.transaction_bytes = TmaTransactionBytes; -+ if (warp_group_idx == 0) { -+ pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; -+ } -+ else { -+ pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; -+ } -+ pipeline_params.is_leader = warp_group_thread_idx == 0; -+ pipeline_params.num_consumers = NumThreadsPerWarpGroup; -+ -+ MainloopPipeline pipeline(shared_storage.pipeline_storage, pipeline_params); -+ PipelineState tile_start_state_pipe; -+ -+ int tiles_per_cluster = params.tiles_per_cluster; -+ -+ /* Offset pipeline start state for Math WG 2 */ -+ if (warp_group_idx == 2) { -+ // Update pipeline state for next persistent tile -+ tile_start_state_pipe.advance(params.num_iterations); -+ tiles_per_cluster--; -+ } -+ -+ typename PingPongBarrier::Params pingpong_params; -+ pingpong_params.group_id = warp_group_idx - 1; // Since DMA Warp Group Idx 0 will not participate -+ pingpong_params.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group -+ PingPongBarrier math_wg_barrier(shared_storage.pingpong_storage, pingpong_params); -+ -+ __syncthreads(); -+ -+ // Ensure All CTAs in Cluster have completed init before issuing commits -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+ -+ // Producer/DMA WarpGroup -+ if (warp_group_idx == 0) { -+ cutlass::arch::warpgroup_reg_dealloc<40>(); -+ // For the DMA (prologue) - we start with an opposite phase - since we skip all waits -+ // i.e., we know that the buffer is indeed empty -+ PipelineState tile_prologue_state_pipe = make_producer_start_state(); -+ while (tiles_per_cluster > 0) { -+ CollectiveSimulation::dma_wg_simulation(pipeline, tile_prologue_state_pipe, params.num_iterations); -+ // Update pipeline state for next persistent tile -+ tile_prologue_state_pipe.advance(params.num_iterations); -+ tiles_per_cluster--; -+ } -+ } -+ // Math WarpGropups -+ if(warp_group_idx == 1 || warp_group_idx == 2) { -+ cutlass::arch::warpgroup_reg_alloc<232>(); -+ while (tiles_per_cluster > 0) { -+ // MMA -+ math_wg_barrier.wait(); -+ CollectiveSimulation::math_wg_simulation(pipeline, tile_start_state_pipe, params.num_iterations, params.data_ptr); -+ math_wg_barrier.arrive(); -+ // Epilogue -+ math_wg_barrier.wait(); -+ // Simulates long running stage -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) -+ __nanosleep(100000); -+ #endif -+ math_wg_barrier.arrive(); -+ // Update pipeline state for next persistent tile -+ tile_start_state_pipe.advance(params.num_iterations * 2); -+ tiles_per_cluster -= 2; -+ } -+ } -+ -+ // Makes sure remote SMEM doesn't get destroyed -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+} -+///////////////////////////////////////////////////// -+ -+/// Device NT GMMA + TMA specialized -+template -+struct PipelineTest { -+ -+ // -+ // Data members -+ // -+ static constexpr uint32_t Stages = Stages_; -+ static constexpr uint32_t kBlockSize = 128 * 3; -+ using ClusterShape = ClusterShape_; -+ -+ // -+ // Methods -+ // -+ -+ // Run CuTe GEMM kernel -+ cudaError_t run(uint32_t const kNumIters, -+ cudaStream_t stream = 0) { -+ -+ float elapsed_ms = 0.0f; -+ // Pipeline (multistage pipeline) -+ auto num_stages = Int{}; -+ auto cluster_shape = Shape, Int, _1>{}; -+ -+ // -+ // Configure and launch -+ // -+ int iterations = 1; -+ cudaEvent_t events[2]; -+ cudaError_t result; -+ -+ for (cudaEvent_t & event : events) { -+ result = cudaEventCreate(&event); -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to create event."; -+ return result; -+ } -+ } -+ -+ result = cudaEventRecord(events[0]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to record start event."; -+ return result; -+ } -+ -+ for (int iter = 0; iter < iterations; ++iter) { -+ -+ using MainloopPipeline = typename cutlass::PipelineTmaAsync; -+ -+ constexpr int StagesPerMathWarpGroup = 2; -+ constexpr int MathWarpGroupCountPersistent = 2; -+ int smem_size = int(sizeof(SharedStorage>)); -+ -+ result = cudaFuncSetAttribute( -+ pipeline_device, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ // Launch a single Cluster, with kBlockSize threads per CTA -+ dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimBlock(kBlockSize,1,1); -+ -+ int tiles_per_cluster = (kNumIters % 10) + 1; -+ printf("Persistent version: Tiles per Cluster = %d\n", tiles_per_cluster); -+ -+ const void* kernel = (const void*)pipeline_device; -+ KernelParams params{kNumIters, tiles_per_cluster, nullptr}; -+ void *kernel_params[] = {¶ms}; -+ cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); -+ -+ } -+ -+ result = cudaEventRecord(events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to record stop event."; -+ return result; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; -+ return result; -+ } -+ -+ result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to create event."; -+ return result; -+ } -+ -+ for (cudaEvent_t & event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ return cudaSuccess; -+ } -+}; -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x1_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x1_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x2_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/pipeline/sequence_barrier.cu b/3rdparty/cutlass/test/unit/pipeline/sequence_barrier.cu -new file mode 100644 -index 0000000..f426ca0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/pipeline/sequence_barrier.cu -@@ -0,0 +1,226 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit test for the OrderedSequenceBarrier class -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cutlass/util/print_error.hpp" -+#include "cutlass/util/GPU_Clock.hpp" -+ -+#include "testbed.h" -+#include "cutlass/pipeline.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cute/arch/cluster_sm90.hpp" -+ -+using namespace cute; -+ -+//////////////////// KERNEL ///////////////////////// -+ -+template -+struct SharedStorage -+{ -+ typename OrderedSequencer::SharedStorage storage; -+}; -+ -+// Goal of this kernel is to complete deadlock-free -+template -+__global__ static -+void ordered_sequence_device(uint32_t const num_iterations) -+{ -+ -+ extern __shared__ char shared_memory[]; -+ using SequenceBarrier = typename cutlass::OrderedSequenceBarrier; -+ using SmemStorage = SharedStorage; -+ -+ SmemStorage& shared_storage = *reinterpret_cast(shared_memory); -+ -+ int group_idx = threadIdx.x / ThreadsPerGroup; -+ -+ typename SequenceBarrier::Params params; -+ params.group_id = group_idx; // sequence ID -+ params.group_size = ThreadsPerGroup; // Number of threads / participants in a group -+ -+ SequenceBarrier barrier(shared_storage.storage, params); -+ -+ // Ensure All CTAs in Cluster have completed init before issuing commits -+ __syncthreads(); -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int i = 0; i < num_iterations; ++i){ -+ -+ barrier.wait(); -+ // STAGE 1 CODE... -+ #ifndef NDEBUG -+ int thread_idx_in_group = threadIdx.x % ThreadsPerGroup; -+ if (thread_idx_in_group == 0) { -+ printf("STAGE 0 : Group_IDX : %d, id = %d, iter = %d, tidx = %d\n", group_idx, params.id, i, threadIdx.x); -+ } -+ #endif -+ // Simulates long running stage -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) -+ __nanosleep(100000); -+ #endif -+ barrier.arrive(); -+ -+ barrier.wait(); -+ // STAGE 2 CODE... -+ #ifndef NDEBUG -+ if (thread_idx_in_group == 0) { -+ printf("STAGE 1 : Group_IDX : %d, id = %d, iter = %d, tidx = %d\n", group_idx, params.id, i, threadIdx.x); -+ } -+ #endif -+ // Simulates long running stage -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) -+ __nanosleep(100000); -+ #endif -+ barrier.arrive(); -+ } -+ -+ // To make sure remote SMEM doesn't get destroyed -+ cute::cluster_arrive(); -+ cute::cluster_wait(); -+} -+///////////////////////////////////////////////////// -+ -+template -+struct PipelineTest { -+ -+ // -+ // Data members -+ // -+ static constexpr uint32_t ThreadsPerGroup = 128; -+ static constexpr uint32_t BlockSize = GroupCount_ * ThreadsPerGroup; -+ static constexpr uint32_t Stages = Stages_; -+ static constexpr uint32_t GroupCount = GroupCount_; -+ using SequenceBarrier = typename cutlass::OrderedSequenceBarrier; -+ using SmemStorage = SharedStorage; -+ -+ // -+ // Methods -+ // -+ -+ // Run CuTe GEMM kernel -+ cudaError_t run(uint32_t const kNumIters, -+ cudaStream_t stream = nullptr) { -+ -+ // Pipeline (multistage pipeline) -+ auto cluster_shape = Shape<_1, _1, _1>{}; -+ -+ // -+ // Configure and launch -+ // -+ int iterations = 1; -+ cudaError_t result; -+ -+ for (int iter = 0; iter < iterations; ++iter) { -+ -+ int smem_size = int(sizeof(SmemStorage)); -+ -+ result = cudaFuncSetAttribute( -+ ordered_sequence_device, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ // Launch a single Cluster, with 128 thread per CTA -+ dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); -+ dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimBlock(BlockSize,1,1); -+ -+ const void* kernel = (const void*)ordered_sequence_device; -+ int iters = kNumIters; -+ void* kernel_params[] = {reinterpret_cast(&iters)}; -+ cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); -+ -+ } // profiling loop ends -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; -+ return result; -+ } -+ -+ return cudaSuccess; -+ } -+}; -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+TEST(SM90_Verify_OrderedSequence, Depth_2_Length_2) { -+ Options options; -+ static constexpr uint32_t GroupCount = 2; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_OrderedSequence, Depth_2_Length_3) { -+ Options options; -+ static constexpr uint32_t GroupCount = 3; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_OrderedSequence, Depth_2_Length_4) { -+ Options options; -+ static constexpr uint32_t GroupCount = 4; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_OrderedSequence, Depth_2_Length_5) { -+ Options options; -+ static constexpr uint32_t GroupCount = 5; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/pipeline/testbed.h b/3rdparty/cutlass/test/unit/pipeline/testbed.h -new file mode 100644 -index 0000000..b809e74 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/pipeline/testbed.h -@@ -0,0 +1,145 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Common Testbed file shared by Pipeline unit tests -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/util/command_line.h" -+#include "../common/cutlass_unit_test.h" -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+ #define CUTLASS_UNIT_TEST_PIPELINE true -+#else -+ #define CUTLASS_UNIT_TEST_PIPELINE false -+#endif -+ -+// Command line test options -+struct Options { -+ // -+ // Data Members -+ // -+ bool help; -+ bool verification_enabled; -+ int SM_count; -+ int clock_MHz; -+ -+ // -+ // Methods -+ // -+ Options(): -+ help(false), -+ verification_enabled(true), -+ SM_count(116), -+ clock_MHz(1477) -+ { } -+ -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("verification-enabled", verification_enabled, true); -+ cmd.get_cmd_line_argument("sm-count", SM_count, 116); -+ cmd.get_cmd_line_argument("clock", clock_MHz, 1477); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --verification-enabled= Enable/Disable verification\n" -+ << " --sm-count= Number of SMs on the chip\n" -+ << " --clock= Locked clock value in Mhz\n"; -+ -+ return out; -+ } -+}; -+ -+// -+// Testbed -+// -+ -+template -+struct Testbed { -+private: -+ // Commandline options -+ Options options; -+ -+ void run_test(uint32_t const kNumIters) { -+ -+ // Run CuTe Gemm -+ Pipeline pipeline; -+ -+ cudaError_t result = pipeline.run(kNumIters); -+ -+ CUTE_CHECK_LAST(); -+ } -+ -+ -+public: -+ Testbed(Options const &options_) : options(options_) { -+ int device_id = 0; -+ cudaDeviceProp device_prop; -+ CUTE_CHECK_ERROR(cudaSetDevice(device_id)); -+ CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); -+ -+ if (device_prop.major < 1) { -+ fprintf(stderr, "Device does not support CUDA.\n"); -+ exit(1); -+ } -+ } -+ -+ /// Run verification Gemm problem sizes -+ bool verification() { -+ -+ std::array kNumIters; -+ -+ for (int i = 0; i < kNumIters.size(); ++i) { -+ kNumIters[i] = (rand() % 1000) + 1; -+ } -+ -+ for (int n : kNumIters) { -+ std::cout << "Stages = " << Pipeline::Stages << " kNumIters = " << n << "\n"; -+ run_test(n); -+ } -+ -+ return true; -+ } -+}; -diff --git a/3rdparty/cutlass/test/unit/reduction/device/tensor_reduce_contiguous.cu b/3rdparty/cutlass/test/unit/reduction/device/tensor_reduce_contiguous.cu -new file mode 100644 -index 0000000..c582eb5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/reduction/device/tensor_reduce_contiguous.cu -@@ -0,0 +1,476 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for TensorReduce family of device-wide operators -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+#include "cutlass/reduction/device/tensor_reduce.h" -+ -+#include "cutlass/functional.h" -+#include "cutlass/layout/tensor.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This reduces the C dimension, transforming an NHWC tensor into NHWC with C=1. -+template -+bool TestAllReduction_NHWC_reduce_c(ElementCompute reduction_identity = ElementCompute()) { -+ -+ using Layout = typename TensorReduction::Layout; -+ using ElementOutput = typename TensorReduction::ElementOutput; -+ using ElementSource = typename TensorReduction::ElementSource; -+ -+ int const kV = TensorReduction::kVectorLength; -+ -+ int const N_indices[] = {3, 13}; -+ int const H_indices[] = {5, 17}; -+ int const W_indices[] = {7, 19}; -+ int const C_indices[] = {2049, 2048, 2047, 384, 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1}; -+ -+ for (int N : N_indices) { -+ for (int H : H_indices) { -+ for (int W : W_indices) { -+ for (int Cx : C_indices) { -+ -+ int C = Cx * kV; -+ -+ cutlass::HostTensor src_tensor({N, H, W, C}); -+ cutlass::HostTensor dst_tensor({N, H, W, 1}); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ src_tensor.host_view(), 17, 10, -10, 0); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ // Execute a tensor reduction over rank 3 (the 'C' dimension is reduced; NHWC => NHW) -+ TensorReduction reduction(src_tensor.extent(), 3); -+ -+ cutlass::DeviceAllocation device_workspace(reduction.workspace_size()); -+ -+ cutlass::Status status = reduction.reduce( -+ dst_tensor.device_ref(), -+ src_tensor.device_ref(), -+ device_workspace.get(), -+ reduction_identity -+ ); -+ -+ EXPECT_EQ(status, cutlass::Status::kSuccess); -+ EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); -+ -+ dst_tensor.sync_host(); -+ -+ typename TensorReduction::ReductionOp reduction_op; -+ -+ // -+ // Reference check -+ // -+ for (int n = 0; n < src_tensor.extent().n(); ++n) { -+ for (int h = 0; h < src_tensor.extent().h(); ++h) { -+ for (int w = 0; w < src_tensor.extent().w(); ++w) { -+ -+ ElementCompute c_accum = reduction_identity; -+ -+ for (int c = 0; c < src_tensor.extent().c(); ++c) { -+ c_accum = reduction_op(c_accum, ElementCompute(src_tensor.at({n, h, w, c}))); -+ } -+ -+ ElementCompute got = ElementCompute(dst_tensor.at({n, h, w, 0})); -+ -+ bool equal = (c_accum == got); -+ -+ EXPECT_TRUE(equal); -+ if (!equal) { -+ -+ std::cerr -+ << "Error at location (" << n << ", " << h << ", " << w << ", 0)" << std::endl; -+ -+ std::cerr -+ << " expected: " << c_accum << std::endl -+ << " got: " << got << std::endl; -+ -+ std::cerr -+ << "Problem: " << src_tensor.extent() << " -> " -+ << dst_tensor.extent() << std::endl; -+ -+ std::cerr -+ << " Grid: " << reduction.reduction_strided.grid_shape -+ << "\n Block: " << reduction.reduction_strided.threadblock_shape << std::endl -+ << " FInal: " << reduction.reduction_strided.grid_final -+ << "\n Block: " << reduction.reduction_strided.threadblock_final << "\n"; -+ -+ return false; -+ } -+ -+ } //w -+ } // h -+ } // n -+ -+ // -+ // Next problem -+ // -+ -+ } // C -+ } // W -+ } // H -+ } // N -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x1) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 1; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x1_f16x1) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = cutlass::half_t; -+ using ElementCompute = float; -+ int const kV = 1; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x2) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 2; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x2_f16x2) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = cutlass::half_t; -+ using ElementCompute = float; -+ int const kV = 2; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x4) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 4; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x4_f16x4) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = cutlass::half_t; -+ using ElementCompute = float; -+ int const kV = 4; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_maximum_c_f32x4) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 4; -+ -+ // Define the functor -+ using Functor = cutlass::maximum; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( -std::numeric_limits::max() )); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_minimum_c_f32x4) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 4; -+ -+ // Define the functor -+ using Functor = cutlass::minimum; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( std::numeric_limits::max() )); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_ANY_c_s32) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = int; -+ using ElementSource = int; -+ using ElementCompute = int; -+ int const kV = 1; -+ -+ // Define the functor -+ using Functor = cutlass::logical_or; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( ElementCompute(0) )); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_ALL_c_s32) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = int; -+ using ElementSource = int; -+ using ElementCompute = int; -+ int const kV = 1; -+ -+ // Define the functor -+ using Functor = cutlass::logical_and; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( ElementCompute(1) )); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_ANY_c_f32) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 1; -+ -+ // Define the functor -+ using Functor = cutlass::logical_or; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( ElementCompute(0) )); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_ALL_c_f32) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 1; -+ -+ // Define the functor -+ using Functor = cutlass::logical_and; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( ElementCompute(1) )); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/reduction/device/tensor_reduce_strided.cu b/3rdparty/cutlass/test/unit/reduction/device/tensor_reduce_strided.cu -new file mode 100644 -index 0000000..7e9ccc3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/reduction/device/tensor_reduce_strided.cu -@@ -0,0 +1,523 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for TensorReduce family of device-wide operators -+*/ -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+#include "cutlass/reduction/device/tensor_reduce.h" -+ -+#include "cutlass/functional.h" -+#include "cutlass/layout/tensor.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This reduces the W dimension, transforming an NHWC tensor into NHWC with W=1. -+template < -+ typename TensorReduction, -+ typename ElementCompute = typename TensorReduction::ElementCompute -+> -+bool TestAllReduction_NHWC_reduce_w(ElementCompute reduction_identity = ElementCompute()) { -+ -+ using Layout = typename TensorReduction::Layout; -+ using ElementOutput = typename TensorReduction::ElementOutput; -+ using ElementSource = typename TensorReduction::ElementSource; -+ -+ int const kV = TensorReduction::kVectorLength; -+ -+ int const N_indices[] = {1, 2, 5, 10}; -+ int const H_indices[] = {1, 3, 9 }; -+ int const W_indices[] = {1, 5, 19, 40, 224}; -+ int const C_indices[] = { -+ kV, -+ 2 * kV, -+ 5 * kV, -+ 9 * kV, -+ 17 * kV, -+ 39 * kV, -+ 257 * kV, -+ kV * 760 -+ }; -+ -+ using Element = int; -+ -+ for (int N : N_indices) { -+ for (int H : H_indices) { -+ for (int W : W_indices) { -+ for (int C : C_indices) { -+ -+ cutlass::HostTensor src_tensor({N, H, W, C}); -+ cutlass::HostTensor dst_tensor({N, H, 1, C}); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ src_tensor.host_view(), 17, 10, -10, 0); -+ -+ cutlass::reference::host::BlockFillSequential( -+ dst_tensor.host_data(), dst_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ // Execute a tensor reduction over rank 2 (the 'W' dimension is reduced; NHWC => NHC) -+ TensorReduction reduction(src_tensor.extent(), 2); -+ -+ cutlass::DeviceAllocation device_workspace(reduction.workspace_size()); -+ -+ cutlass::Status status = reduction.reduce( -+ dst_tensor.device_ref(), -+ src_tensor.device_ref(), -+ device_workspace.get(), -+ reduction_identity -+ ); -+ -+ EXPECT_EQ(status, cutlass::Status::kSuccess); -+ EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); -+ // Reference check -+ dst_tensor.sync_host(); -+ -+ typename TensorReduction::ReductionOp reduction_op; -+ -+ for (int n = 0; n < src_tensor.extent().n(); ++n) { -+ for (int h = 0; h < src_tensor.extent().h(); ++h) { -+ for (int c = 0; c < src_tensor.extent().c(); ++c) { -+ -+ ElementCompute w_accum = reduction_identity; -+ -+ for (int w = 0; w < src_tensor.extent().w(); ++w) { -+ w_accum = reduction_op(w_accum, ElementCompute(src_tensor.at({n, h, w, c}))); -+ } -+ -+ ElementCompute got = ElementCompute(dst_tensor.at({n, h, 0, c})); -+ -+ bool equal = (w_accum == got); -+ -+ EXPECT_TRUE(equal); -+ if (!equal) { -+ -+ std::cerr -+ << "Error at location (" << n << ", " << h << ", 0, " << c << ")" << std::endl; -+ -+ std::cerr -+ << " expected: " << w_accum << std::endl -+ << " got: " << got << std::endl; -+ -+ std::cerr -+ << "Problem: " << src_tensor.extent() << " -> " -+ << dst_tensor.extent() << std::endl; -+ -+ std::cerr -+ << " Grid: " << reduction.reduction_strided.grid_shape -+ << "\n Block: " << reduction.reduction_strided.threadblock_shape << std::endl -+ << " Final: " << reduction.reduction_strided.grid_final -+ << "\n Block: " << reduction.reduction_strided.threadblock_final << "\n"; -+ -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_reduce_w_f32x8_f16x8) { -+ -+ int const kV = 8; -+ using ElementOutput = float; -+ using ElementSource = cutlass::half_t; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_reduce_w_f32x2_f16x2) { -+ -+ int const kV = 2; -+ using ElementOutput = float; -+ using ElementSource = cutlass::half_t; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_reduce_w_f32x1_f16x1) { -+ -+ int const kV = 1; -+ using ElementOutput = float; -+ using ElementSource = cutlass::half_t; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_reduce_w_s32x4) { -+ -+ int const kV = 4; -+ using Element = int; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ Element, -+ Element, -+ Layout, -+ Functor, -+ kV, -+ Element -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_reduce_w_cf32) { -+ -+ int const kV = 1; -+ using ElementOutput = cutlass::complex; -+ using ElementSource = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_maximum_w_cf32) { -+ -+ int const kV = 1; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::maximum; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w( -std::numeric_limits::max() )); -+} -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_minimum_w_cf32) { -+ -+ int const kV = 1; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::minimum; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(std::numeric_limits::max())); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_XOR_w_u32) { -+ -+ int const kV = 1; -+ using ElementOutput = int; -+ using ElementSource = int; -+ using ElementCompute = int; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::bit_xor; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_AND_w_s32) { -+ -+ int const kV = 1; -+ using ElementOutput = unsigned; -+ using ElementSource = unsigned; -+ using ElementCompute = unsigned; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::bit_and; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(0xffffffff)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_OR_w_u32) { -+ -+ int const kV = 1; -+ using ElementOutput = int; -+ using ElementSource = int; -+ using ElementCompute = int; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::bit_or; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_ANY_w_s32) { -+ -+ int const kV = 1; -+ using ElementOutput = int; -+ using ElementSource = int; -+ using ElementCompute = int; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::logical_or; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(ElementCompute(0))); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_ALL_w_s32) { -+ -+ int const kV = 1; -+ using ElementOutput = int; -+ using ElementSource = int; -+ using ElementCompute = int; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::logical_and; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(ElementCompute(1))); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_ANY_w_f32) { -+ -+ int const kV = 1; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::logical_or; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(ElementCompute(0))); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_ALL_w_f32) { -+ -+ int const kV = 1; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::logical_and; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(ElementCompute(1))); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/reduction/kernel/reduce_splitk.cu b/3rdparty/cutlass/test/unit/reduction/kernel/reduce_splitk.cu -new file mode 100644 -index 0000000..6a990f4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/reduction/kernel/reduce_splitk.cu -@@ -0,0 +1,389 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/reduction/kernel/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace reduction { -+ -+template -+__global__ void kernel_reduce_splitk(typename ReductionKernel::Params params) { -+ -+ __shared__ typename ReductionKernel::SharedStorage shared_storage; -+ -+ ReductionKernel reduction_op; -+ -+ reduction_op(params, shared_storage); -+} -+ -+template -+class ReduceSplitKTestbed { -+public: -+ -+ using ElementAccumulator = typename ReductionKernel::ElementAccumulator; -+ using ElementWorkspace = typename ReductionKernel::ElementWorkspace; -+ using ElementOutput = typename ReductionKernel::ElementOutput; -+ using Layout = cutlass::layout::RowMajor; -+ -+public: -+ -+ cutlass::Distribution::Kind distribution_workspace; -+ cutlass::Distribution::Kind distribution_source; -+ uint64_t seed; -+ -+public: -+ -+ /// Ctor -+ ReduceSplitKTestbed( -+ cutlass::Distribution::Kind distribution_workspace = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind distribution_source = cutlass::Distribution::Uniform, -+ uint64_t seed = 2019 -+ ): -+ distribution_workspace(distribution_workspace), -+ distribution_source(distribution_source), -+ seed(seed) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor(cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ cutlass::reference::host::TensorFillRandomUniform(view, seed, 8, -8, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, -1); -+ } else if (dist_kind == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(view); -+ } else if (dist_kind == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(view.data(), -+ view.capacity()); -+ } else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Runs a single problem size -+ bool run( -+ cutlass::MatrixCoord problem_size, -+ int partitions, -+ ElementAccumulator alpha = 1, -+ ElementAccumulator beta = 0) { -+ -+ cutlass::HostTensor workspace({ -+ problem_size.row() * partitions, -+ problem_size.column() -+ }); -+ -+ cutlass::HostTensor source(problem_size); -+ cutlass::HostTensor destination(problem_size); -+ cutlass::HostTensor destination_reference(problem_size, false); -+ -+ // -+ // Initialize -+ // -+ initialize_tensor(workspace.host_view(), distribution_workspace, seed); -+ initialize_tensor(source.host_view(), distribution_source, seed + 23); -+ -+ cutlass::reference::host::TensorFill(destination.host_view()); -+ -+ workspace.sync_device(); -+ source.sync_device(); -+ destination.sync_device(); -+ -+ // -+ // Launch reduction kernel -+ // -+ -+ dim3 block = ReductionKernel::block_shape(); -+ dim3 grid = ReductionKernel::grid_shape(problem_size); -+ -+ typename ReductionKernel::Params params( -+ problem_size, -+ partitions, -+ problem_size.row() * problem_size.column(), -+ workspace.device_ref(), -+ destination.device_ref(), -+ source.device_ref(), -+ {alpha, beta} -+ ); -+ -+ test::reduction::kernel_reduce_splitk<<< grid, block >>>(params); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) -+ << "CUDA error: " << cudaGetErrorString(result); -+ -+ destination.sync_host(); -+ -+ // -+ // Compute reference -+ // -+ -+ for (int m = 0; m < problem_size.row(); ++m) { -+ for (int n = 0; n < problem_size.column(); ++n) { -+ -+ ElementAccumulator accum = 0; -+ -+ for (int k = 0; k < partitions; ++k) { -+ accum += ElementAccumulator(workspace.at({m + k * problem_size.row(), n})); -+ } -+ -+ ElementAccumulator c = ElementAccumulator(source.at({m, n})); -+ -+ destination_reference.at({m, n}) = ElementOutput(accum * alpha + beta * c); -+ } -+ } -+ -+ // -+ // Compare -+ // -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(destination.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(destination_reference.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ destination.host_view(), destination_reference.host_view()); -+ -+ EXPECT_TRUE(passed) -+ << "Workspace =\n" << workspace.host_view() << "\n\n" -+ << "\n" -+ << "Reference =\n" << destination_reference.host_view() << "\n\n" -+ << "Computed =\n" << destination.host_view() << "\n"; -+ -+ return passed; -+ } -+ -+ /// Runs through a variety of test cases -+ bool run_all() { -+ -+ cutlass::MatrixCoord problem_sizes[] = { -+ {8, 8}, -+ {136, 72}, -+ {248, 232}, -+ }; -+ -+ int partition_counts[] = { -+ 1,3,4,5,11 -+ }; -+ -+ bool passed = false; -+ -+ for (cutlass::MatrixCoord problem : problem_sizes) { -+ for (int partitions : partition_counts) { -+ passed = run(problem, partitions); -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ -+ return passed; -+ } -+}; -+ -+} // namespace reduction -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Strictly F32 data -+// -+TEST(Reduction_ReduceSplitK, f32_f32_f32_1_1x32) { -+ -+ using ElementWorkspace = float; -+ using ElementAccumulator = float; -+ using ElementOutput = float; -+ int const kN = 1; -+ using Shape = cutlass::MatrixShape<1, 32>; -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kN, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ ElementWorkspace, -+ kN -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ Shape, -+ OutputOp, -+ ReductionOp -+ >; -+ -+ test::reduction::ReduceSplitKTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Vectorized access -+// -+TEST(Reduction_ReduceSplitK, f32_f32_f32_2_4x64) { -+ -+ using ElementWorkspace = float; -+ using ElementAccumulator = float; -+ using ElementOutput = float; -+ int const kN = 2; -+ using Shape = cutlass::MatrixShape<4, 64>; -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kN, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ ElementWorkspace, -+ kN -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ Shape, -+ OutputOp, -+ ReductionOp -+ >; -+ -+ test::reduction::ReduceSplitKTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Vectorized access -+// -+TEST(Reduction_ReduceSplitK, f32_f32_f16_2_4x64) { -+ -+ using ElementWorkspace = float; -+ using ElementAccumulator = float; -+ using ElementOutput = cutlass::half_t; -+ int const kN = 2; -+ using Shape = cutlass::MatrixShape<4, 64>; -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kN, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ ElementWorkspace, -+ kN -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ Shape, -+ OutputOp, -+ ReductionOp -+ >; -+ -+ test::reduction::ReduceSplitKTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Vectorized access -+// -+TEST(Reduction_ReduceSplitK, f32_f32_f16_8_4x64) { -+ -+ using ElementWorkspace = float; -+ using ElementAccumulator = float; -+ using ElementOutput = cutlass::half_t; -+ int const kN = 8; -+ using Shape = cutlass::MatrixShape<4, 64>; -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kN, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ ElementWorkspace, -+ kN -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ Shape, -+ OutputOp, -+ ReductionOp -+ >; -+ -+ test::reduction::ReduceSplitKTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h b/3rdparty/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h -new file mode 100644 -index 0000000..78c720a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h -@@ -0,0 +1,45 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level Reduction -+*/ -+ -+#pragma once -+ -+#include "cutlass/reduction/thread/reduce.h" -+ -+#include "cutlass/layout/vector.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -diff --git a/3rdparty/cutlass/test/unit/reduction/thread/reduction_thread.cu b/3rdparty/cutlass/test/unit/reduction/thread/reduction_thread.cu -new file mode 100644 -index 0000000..be92fea ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/reduction/thread/reduction_thread.cu -@@ -0,0 +1,100 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level Reduction -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+ -+TEST(Reduce_thread_device, Reduce_half_t_1) { -+ -+ test::reduction::thread::Testbed_reduce_device< -+ cutlass::half_t, -+ 1 -+ >().run(); -+} -+ -+TEST(Reduce_thread_device, Reduce_half_t_16) { -+ -+ test::reduction::thread::Testbed_reduce_device< -+ cutlass::half_t, -+ 16 -+ >().run(); -+} -+ -+TEST(Reduce_thread_device, Reduce_half_t_31) { -+ -+ test::reduction::thread::Testbed_reduce_device< -+ cutlass::half_t, -+ 31 -+ >().run(); -+} -+ -+ -+TEST(Reduce_thread_host, Reduce_float_1) { -+ -+ test::reduction::thread::Testbed_reduce_host< -+ float, -+ 1 -+ >().run(); -+} -+ -+TEST(Reduce_thread_host, Reduce_float_16) { -+ -+ test::reduction::thread::Testbed_reduce_host< -+ float, -+ 16 -+ >().run(); -+ -+} -+ -+TEST(Reduce_thread_host, Reduce_half_t_1) { -+ -+ test::reduction::thread::Testbed_reduce_host< -+ cutlass::half_t, -+ 1 -+ >().run(); -+} -+ -+TEST(Reduce_thread_host, Reduce_half_t_16) { -+ -+ test::reduction::thread::Testbed_reduce_host< -+ cutlass::half_t, -+ 16 -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/reduction/thread/testbed.h b/3rdparty/cutlass/test/unit/reduction/thread/testbed.h -new file mode 100644 -index 0000000..e0e38ed ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/reduction/thread/testbed.h -@@ -0,0 +1,242 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level Reduction -+*/ -+ -+#pragma once -+ -+#include "cutlass/reduction/thread/reduce.h" -+ -+#include "cutlass/layout/vector.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+namespace test { -+namespace reduction { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the reduction -+template < -+ /// Data type of elements -+ typename Element, -+ /// Number of elements -+ int N -+> -+struct Testbed_reduce_host { -+ -+ /// Thread-level reduction operator -+ using Reduce = cutlass::reduction::thread::Reduce< -+ cutlass::plus, -+ cutlass::Array -+ >; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::Array tensor_in; -+ cutlass::Array reduced_tensor_computed; -+ cutlass::Array reduced_tensor_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed_reduce_host() { -+ tensor_in.clear(); -+ reduced_tensor_computed.clear(); -+ reduced_tensor_reference.clear(); -+ } -+ -+ /// Runs the test -+ bool run() { -+ -+ // -+ // initialize memory -+ // -+ -+ for(int i = 0; i < N; i++) -+ tensor_in.at(i) = Element(i); -+ -+ -+ Reduce reduce; -+ -+ cutlass::Array *out_ptr = &reduced_tensor_computed; -+ out_ptr[0] = reduce(tensor_in); -+ -+ // -+ // Reference implementation -+ // -+ Element e(0); -+ for (int i = 0; i < N; i++) -+ e = e + Element(i); -+ -+ reduced_tensor_reference.at(0) = e; -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = reduced_tensor_reference[0] == reduced_tensor_computed[0]; -+ -+ EXPECT_TRUE(passed) -+ << "Expected = " << float(reduced_tensor_reference.at(0)) << "\n\n" -+ << "Actual = " << float(reduced_tensor_computed.at(0)) << "\n\n" -+ << std::endl; -+ -+ return passed; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Thread-level reduction kernel -+template -+__global__ void kernel_reduce(Element const *array_in, Element *result) { -+ -+ /// Thread-level reduction operator -+ using Reduce = cutlass::reduction::thread::Reduce< -+ cutlass::plus, -+ cutlass::Array -+ >; -+ -+ Reduce reduce; -+ -+ auto ptr_in = reinterpret_cast const *>(array_in); -+ auto result_ptr = reinterpret_cast *>(result); -+ auto in = *ptr_in; -+ result_ptr[0] = reduce(in); -+} -+ -+ -+/// Structure to compute the reduction -+template < -+ /// Data type of elements -+ typename Element, -+ /// Number of elements -+ int N -+> -+struct Testbed_reduce_device { -+ -+ using Layout = cutlass::layout::PackedVectorLayout; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_in; -+ cutlass::HostTensor reduced_tensor_computed; -+ cutlass::HostTensor reduced_tensor_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed_reduce_device() { -+ -+ tensor_in.reset(cutlass::make_Coord(N), true); -+ reduced_tensor_computed.reset(cutlass::make_Coord(1), true); -+ reduced_tensor_reference.reset(cutlass::make_Coord(1), true); -+ } -+ -+ -+ /// Runs the test -+ bool run() { -+ -+ // -+ // initialize memory -+ // -+ -+ cutlass::reference::host::TensorFill( -+ tensor_in.host_view(), -+ Element(1) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ reduced_tensor_computed.host_view(), -+ Element(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ reduced_tensor_reference.host_view(), -+ Element(N) -+ ); -+ -+ tensor_in.sync_device(); -+ reduced_tensor_computed.sync_device(); -+ reduced_tensor_reference.sync_device(); -+ -+ /// call the kernel -+ kernel_reduce<<< dim3(1, 1), dim3(1, 1, 1) >>> ( -+ tensor_in.device_data(), -+ reduced_tensor_computed.device_data() -+ ); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ // Copy back results -+ reduced_tensor_computed.sync_host(); -+ -+ // Verify equivalence -+ bool passed = cutlass::reference::host::TensorEquals( -+ reduced_tensor_computed.host_view(), -+ reduced_tensor_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed) -+ << "Expected = " << reduced_tensor_reference.host_view() << "\n\n" -+ << "Actual = " << reduced_tensor_computed.host_view() << "\n\n" -+ << std::endl; -+ -+ return passed; -+ } -+}; -+ -+} // namespace thread -+} // namespace reduction -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/transform/threadblock/predicated_tile_iterator.cu b/3rdparty/cutlass/test/unit/transform/threadblock/predicated_tile_iterator.cu -new file mode 100644 -index 0000000..e30986b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/transform/threadblock/predicated_tile_iterator.cu -@@ -0,0 +1,798 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests cutlass::transform::threadblock::PredicatedTileIterator -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+ -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace transform { -+namespace threadblock { -+namespace kernel { -+ -+/// Copy with an iterator -+template -+__global__ void copy( -+ typename Iterator::Params dst_params, -+ typename Iterator::Element *dst_pointer, -+ typename Iterator::Params src_params, -+ typename Iterator::Element *src_pointer, -+ cutlass::Coord<2> extent) { -+ -+ Iterator dst_iterator(dst_params, dst_pointer, extent, threadIdx.x); -+ Iterator src_iterator(src_params, src_pointer, extent, threadIdx.x); -+ -+ int iterations = (extent[1] + Iterator::Shape::kStrided - 1) / Iterator::Shape::kStrided; -+ -+ typename Iterator::Fragment frag; -+ -+ for(int i = 0; i < frag.size(); i++) -+ frag[i] = 0; -+ -+ src_iterator.load(frag); -+ dst_iterator.store(frag); -+ -+ ++dst_iterator; -+ ++src_iterator; -+ -+ for (; iterations > 1; --iterations) { -+ -+ src_iterator.load(frag); -+ dst_iterator.store(frag); -+ -+ ++dst_iterator; -+ ++src_iterator; -+ } -+} -+ -+} // namespace kernel -+} // namespace threadblock -+} // namespace transform -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<64, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator< -+ Shape, Element, Layout, 1, ThreadMap -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(57, 35); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(64, 35); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]; ++s) { -+ for (int c = 0; c < alloc_extent[0]; ++c) { -+ -+ Element expected = Element(0); -+ -+ if (c < copy_extent[0] && s < copy_extent[1]) { -+ expected = src_tensor.at({c, s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({c, s}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_128x4) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<128, 4>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap, false -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(128, 4); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(128, 4); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]; ++s) { -+ for (int c = 0; c < alloc_extent[0]; ++c) { -+ -+ Element expected = Element(0); -+ -+ if (c < copy_extent[0] && s < copy_extent[1]) { -+ expected = src_tensor.at({c, s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({c, s}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_128x64) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<128, 64>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(128, 64); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(128, 64); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]; ++s) { -+ for (int c = 0; c < alloc_extent[0]; ++c) { -+ -+ Element expected = Element(0); -+ -+ if (c < copy_extent[0] && s < copy_extent[1]) { -+ expected = src_tensor.at({c, s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({c, s}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_64x64) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<64, 64>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(64, 64); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(64, 64); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]; ++s) { -+ for (int c = 0; c < alloc_extent[0]; ++c) { -+ -+ Element expected = Element(0); -+ -+ if (c < copy_extent[0] && s < copy_extent[1]) { -+ expected = src_tensor.at({c, s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({c, s}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_64x8) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<64, 8>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(32, 8); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(64, 8); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]; ++s) { -+ for (int c = 0; c < alloc_extent[0]; ++c) { -+ -+ Element expected = Element(0); -+ -+ if (c < copy_extent[0] && s < copy_extent[1]) { -+ expected = src_tensor.at({c, s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({c, s}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_64x32_transpose4x4) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<64, 8>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap, true -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(64, 32); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(64, 32); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::TensorFillRandomUniform(src_tensor.host_view(), seed, 8, -8, 0); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]/4; ++s) { -+ for (int c = 0; c < alloc_extent[0]/4; ++c) { -+ for (int s1 = 0; s1 < 4; s1++){ -+ for(int c1 = 0; c1 < 4; c1++){ -+ Element expected = Element(0); -+ -+ int l_c = c * 4 + c1; -+ int l_s = s * 4 + s1; -+ -+ int l_tc = c * 4 + s1; -+ int l_ts = s * 4 + c1; -+ -+ if (l_c < copy_extent[0] && l_s < copy_extent[1]) { -+ expected = src_tensor.at({l_c, l_s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({l_tc, l_ts}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_64x29_transpose4x4) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<64, 8>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap, true -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(64, 29); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(64, 29); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::TensorFillRandomUniform(src_tensor.host_view(), seed, 8, -8, 0); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]/4; ++s) { -+ for (int c = 0; c < alloc_extent[0]/4; ++c) { -+ for (int s1 = 0; s1 < 4; s1++){ -+ for(int c1 = 0; c1 < 4; c1++){ -+ Element expected = Element(0); -+ -+ int l_c = c * 4 + c1; -+ int l_s = s * 4 + s1; -+ -+ int l_tc = c * 4 + s1; -+ int l_ts = s * 4 + c1; -+ -+ if (l_c < copy_extent[0] && l_s < copy_extent[1]) { -+ expected = src_tensor.at({l_c, l_s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({l_tc, l_ts}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_120x4_transpose4x4) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<128, 4>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap, true -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(120, 4); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(120, 4); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::TensorFillRandomUniform(src_tensor.host_view(), seed, 8, -8, 0); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]/4; ++s) { -+ for (int c = 0; c < alloc_extent[0]/4; ++c) { -+ for (int s1 = 0; s1 < 4; s1++){ -+ for(int c1 = 0; c1 < 4; c1++){ -+ Element expected = Element(0); -+ -+ int l_c = c * 4 + c1; -+ int l_s = s * 4 + s1; -+ -+ int l_tc = c * 4 + s1; -+ int l_ts = s * 4 + c1; -+ -+ if (l_c < copy_extent[0] && l_s < copy_extent[1]) { -+ expected = src_tensor.at({l_c, l_s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({l_tc, l_ts}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_48x29_transpose4x4) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<64, 8>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap, true -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(48, 29); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(48, 29); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::TensorFillRandomUniform(src_tensor.host_view(), seed, 8, -8, 0); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]/4; ++s) { -+ for (int c = 0; c < alloc_extent[0]/4; ++c) { -+ for (int s1 = 0; s1 < 4; s1++){ -+ for(int c1 = 0; c1 < 4; c1++){ -+ Element expected = Element(0); -+ -+ int l_c = c * 4 + c1; -+ int l_s = s * 4 + s1; -+ -+ int l_tc = c * 4 + s1; -+ int l_ts = s * 4 + c1; -+ -+ if (l_c < copy_extent[0] && l_s < copy_extent[1]) { -+ expected = src_tensor.at({l_c, l_s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({l_tc, l_ts}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ if (!equal) { -+ return; -+ } -+ } -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/transform/threadblock/regular_tile_iterator_tensor_op.cu b/3rdparty/cutlass/test/unit/transform/threadblock/regular_tile_iterator_tensor_op.cu -new file mode 100644 -index 0000000..c5ad3e9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/transform/threadblock/regular_tile_iterator_tensor_op.cu -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/core_io.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+/// -+template -+__global__ void kernel_gemm_threadblock_tensor_op_multiplicand_store( -+ typename Iterator::TensorRef ref_output, -+ typename Iterator::Element *input) { -+ -+ // Construct fragment -+ typename Iterator::Fragment frag; -+ -+ frag.clear(); -+ -+ // each thread loads a fragment -+ using AccessType = cutlass::Array; -+ -+ int const kElementsPerAccess = Iterator::ThreadMap::kElementsPerAccess; -+ int stride = Iterator::Shape::kContiguous; -+ -+ int warp_id = (threadIdx.x / 32); -+ int lane_id = (threadIdx.x % 32); -+ -+ input += (lane_id % 8) * kElementsPerAccess + (lane_id / 8) * stride; -+ -+ input += (warp_id * Iterator::Shape::kStrided / Iterator::ThreadMap::Detail::kWarpCount) * stride; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Iterator::ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Iterator::ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < Iterator::ThreadMap::kElementsPerAccess; ++v) { -+ frag[v + Iterator::ThreadMap::kElementsPerAccess * (c + s * Iterator::ThreadMap::Iterations::kContiguous)] = -+ input[v + c * 64 + s * Iterator::ThreadMap::Delta::kStrided * stride]; -+ } -+ } -+ } -+ -+ // Use iterator to store results -+ Iterator iter(ref_output, threadIdx.x); -+ iter.store(frag); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Simple test environment -+template < -+ typename Shape_, -+ int WarpCount -+> -+class MultiplicandTileIteratorTestbed { -+public: -+ -+ // -+ // Define iterator -+ // -+ -+ using Shape = Shape_; -+ using Element = cutlass::half_t; -+ using Layout = cutlass::layout::TensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ static int const kAdvanceRank = 1; -+ static int const kThreads = 32 * WarpCount; -+ -+ using ThreadMap = cutlass::transform::PitchLinearWarpRakedThreadMap< -+ Shape, -+ kThreads, -+ cutlass::layout::PitchLinearShape<8, 4>, -+ 128 / cutlass::sizeof_bits::value -+ >; -+ -+ using Iterator = cutlass::transform::threadblock::RegularTileIterator< -+ Shape, Element, Layout, kAdvanceRank, ThreadMap -+ >; -+ -+public: -+ -+ // -+ // Members -+ // -+ -+ cutlass::HostTensor destination_tensor; -+ cutlass::HostTensor source_tensor; -+ -+ -+public: -+ -+ MultiplicandTileIteratorTestbed(): -+ destination_tensor({Shape::kContiguous, Shape::kStrided}), -+ source_tensor({Shape::kContiguous, Shape::kStrided}) { -+ -+ } -+ -+ bool run() { -+ -+ cutlass::reference::host::BlockFillSequential( -+ source_tensor.host_data(), -+ source_tensor.capacity() -+ ); -+ -+ cutlass::reference::host::BlockFillSequential( -+ destination_tensor.host_data(), -+ destination_tensor.capacity(), -+ Element(0), -+ Element(0) -+ ); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ destination_tensor.sync_device(); -+ source_tensor.sync_device(); -+ -+ test::gemm::threadblock::kernel_gemm_threadblock_tensor_op_multiplicand_store<<< -+ grid, block -+ >>>( -+ destination_tensor.device_ref(), -+ source_tensor.device_data() -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA ERROR: " << cudaGetErrorString(result); -+ -+ destination_tensor.sync_host(); -+ -+ // -+ // Verify -+ // -+ -+ // Verify that its contents match the destination -+ int errors = 0; -+ for (int s = 0; s < Shape::kStrided; ++s) { -+ for (int c = 0; c < Shape::kContiguous; ++c) { -+ -+ if (errors >= 10) { -+ break; -+ } -+ -+ Element expected = source_tensor.at({c, s}); -+ Element got = destination_tensor.at({c, s}); -+ -+ bool passed = (expected == got); -+ if (!passed) { -+ ++errors; -+ } -+ } -+ } -+ -+ EXPECT_EQ(errors, 0) -+ << source_tensor.host_view() << "\n\n" << destination_tensor.host_view() << std::endl; -+ -+ return !errors; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 64x8_w1) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<64, 8>, 1>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 64x16_w1) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<64, 16>, 1>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 64x16_w2) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<64, 16>, 2>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 128x8_w1) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<128, 8>, 1>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 64x32_w4) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<64, 32>, 4>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 128x32_w1) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<128, 32>, 1>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 128x32_w4) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<128, 32>, 4>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 256x32_w4) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<256, 32>, 4>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 256x32_w8) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<256, 32>, 8>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/util/cutlass_test_levels.cu b/3rdparty/cutlass/test/unit/util/cutlass_test_levels.cu -new file mode 100644 -index 0000000..3879783 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/util/cutlass_test_levels.cu -@@ -0,0 +1,77 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "../common/cutlass_unit_test.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_CUTLASS_TEST, level_not_specified) { -+ -+ EXPECT_TRUE(true); -+} -+ -+TEST(SM80_CUTLASS_TEST, level_not_specified) { -+ -+ EXPECT_TRUE(true); -+} -+ -+CUTLASS_TEST_L0(SM75_CUTLASS_TEST, level0, { -+ -+ EXPECT_TRUE(true); -+}) -+ -+CUTLASS_TEST_L1(SM75_CUTLASS_TEST, level1, { -+ -+ EXPECT_TRUE(true); -+}) -+ -+CUTLASS_TEST_L2(SM75_CUTLASS_TEST, level2, { -+ -+ EXPECT_TRUE(true); -+}) -+ -+CUTLASS_TEST_L0(SM80_CUTLASS_TEST, level0, { -+ -+ EXPECT_TRUE(true); -+}) -+ -+CUTLASS_TEST_L1(SM80_CUTLASS_TEST, level1, { -+ -+ EXPECT_TRUE(true); -+}) -+ -+CUTLASS_TEST_L2(SM80_CUTLASS_TEST, level2, { -+ -+ EXPECT_TRUE(true); -+}) -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/util/tensor_reduce.cu b/3rdparty/cutlass/test/unit/util/tensor_reduce.cu -new file mode 100644 -index 0000000..c71d080 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/util/tensor_reduce.cu -@@ -0,0 +1,244 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+ -+#include "cutlass/util/reference/device/tensor_reduce.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/host_tensor.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorReduce, norm_rowmajor_f32) { -+ -+ int const kM = 129; -+ int const kN = 91; -+ -+ cutlass::HostTensor tensor({kM, kN}); -+ -+ for (int m = 0; m < kM; ++m) { -+ for (int n = 0; n < kN; ++n) { -+ -+ float x = float(((m * kN + m + 7) % 8) - 4); -+ -+ tensor.at({m, n}) = x; -+ } -+ } -+ -+ tensor.sync_device(); -+ -+ double device_norm = cutlass::reference::device::TensorNorm(tensor.device_view(), double()); -+ double host_norm = cutlass::reference::host::TensorNorm(tensor.host_view(), double()); -+ -+ EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001); -+} -+ -+TEST(TensorReduce, norm_nhwc_f32) { -+ -+ int const kN = 19; -+ int const kH = 18; -+ int const kW = 17; -+ int const kC = 16; -+ -+ cutlass::HostTensor tensor({kN, kH, kW, kC}); -+ -+ int idx = 0; -+ -+ double computed_norm = double(); -+ -+ for (int n = 0; n < kN; ++n) { -+ for (int h = 0; h < kH; ++h) { -+ for (int w = 0; w < kW; ++w) { -+ for (int c = 0; c < kC; ++c, ++idx) { -+ -+ float x = float(((idx + 7) % 8) - 4); -+ -+ computed_norm += double(x) * double(x); -+ -+ tensor.at({n, h, w, c}) = x; -+ } -+ } -+ } -+ } -+ -+ computed_norm = std::sqrt(computed_norm); -+ -+ tensor.sync_device(); -+ -+ double device_norm = cutlass::reference::device::TensorNorm(tensor.device_view(), double()); -+ double host_norm = cutlass::reference::host::TensorNorm(tensor.host_view(), double()); -+ -+ EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001 && std::abs(computed_norm - host_norm) < 0.001) -+ << "computed norm: " << computed_norm << "\n" -+ << " host norm: " << host_norm << "\n" -+ << "device norm: " << device_norm << "\n"; -+} -+ -+TEST(TensorReduce, norm_nhwc_f16) { -+ -+ int const kN = 69; -+ int const kH = 68; -+ int const kW = 67; -+ int const kC = 66; -+ -+ cutlass::HostTensor tensor({kN, kH, kW, kC}); -+ -+ int idx = 0; -+ -+ double computed_norm = double(); -+ -+ for (int n = 0; n < kN; ++n) { -+ for (int h = 0; h < kH; ++h) { -+ for (int w = 0; w < kW; ++w) { -+ for (int c = 0; c < kC; ++c, ++idx) { -+ -+ float x = float(((idx + 7) % 8) - 4); -+ computed_norm += double(x) * double(x); -+ -+ tensor.at({n, h, w, c}) = cutlass::half_t(x); -+ } -+ } -+ } -+ } -+ -+ computed_norm = std::sqrt(computed_norm); -+ -+ tensor.sync_device(); -+ -+ double device_norm = cutlass::reference::device::TensorNorm(tensor.device_view(), double()); -+ double host_norm = cutlass::reference::host::TensorNorm(tensor.host_view(), double()); -+ -+ EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001 && std::abs(computed_norm - host_norm) < 0.001) -+ << "computed norm: " << computed_norm << "\n" -+ << " host norm: " << host_norm << "\n" -+ << "device norm: " << device_norm << "\n"; -+} -+ -+TEST(TensorReduce, norm_diff_nhwc_f32) { -+ -+ int const kN = 59; -+ int const kH = 24; -+ int const kW = 57; -+ int const kC = 78; -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ cutlass::HostTensor tensor_A({kN, kH, kW, kC}); -+ cutlass::HostTensor tensor_B({kN, kH, kW, kC}); -+ -+ -+ int idx = 0; -+ -+ double sum_sq_diff = 0; -+ -+ for (int n = 0; n < kN; ++n) { -+ for (int h = 0; h < kH; ++h) { -+ for (int w = 0; w < kW; ++w) { -+ for (int c = 0; c < kC; ++c, ++idx) { -+ -+ float a = float(((idx * 5 + 7) % 8) - 4); -+ float b = float(((idx * 3 + 7) % 8) - 4); -+ -+ sum_sq_diff += double(a - b) * double(a - b); -+ -+ tensor_A.at({n, h, w, c}) = a; -+ tensor_B.at({n, h, w, c}) = b; -+ } -+ } -+ } -+ } -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ -+ double device_norm = cutlass::reference::device::TensorNormDiff( -+ tensor_A.device_view(), tensor_B.device_view(), double()); -+ -+ double host_norm = std::sqrt(sum_sq_diff); -+ -+ EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001f) -+ << " host norm: " << host_norm << "\n" -+ << "device norm: " << device_norm; -+} -+ -+ -+TEST(TensorReduce, norm_diff_nhwc_f16) { -+ -+ int const kN = 59; -+ int const kH = 24; -+ int const kW = 57; -+ int const kC = 78; -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ cutlass::HostTensor tensor_A({kN, kH, kW, kC}); -+ cutlass::HostTensor tensor_B({kN, kH, kW, kC}); -+ -+ int idx = 0; -+ -+ double sum_sq_diff = 0; -+ -+ for (int n = 0; n < kN; ++n) { -+ for (int h = 0; h < kH; ++h) { -+ for (int w = 0; w < kW; ++w) { -+ for (int c = 0; c < kC; ++c, ++idx) { -+ -+ float a = float(((idx * 5 + 7) % 8) - 4); -+ float b = float(((idx * 3 + 7) % 8) - 4); -+ -+ sum_sq_diff += double(a - b) * double(a - b); -+ -+ tensor_A.at({n, h, w, c}) = cutlass::half_t(a); -+ tensor_B.at({n, h, w, c}) = cutlass::half_t(b); -+ } -+ } -+ } -+ } -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ -+ double device_norm = cutlass::reference::device::TensorNormDiff( -+ tensor_A.device_view(), tensor_B.device_view(), double()); -+ -+ double host_norm = std::sqrt(sum_sq_diff); -+ -+ EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001f) -+ << " host norm: " << host_norm << "\n" -+ << "device norm: " << device_norm; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/arch_mappings.h b/3rdparty/cutlass/tools/library/include/cutlass/library/arch_mappings.h -new file mode 100644 -index 0000000..0d6790e ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/arch_mappings.h -@@ -0,0 +1,110 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ -+ \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. -+ -+ Generally, -+ -+ description - compile-time constant parameters used to instantiate an operation -+ -+ configuration - runtime parameters with computationally expensive initialization -+ -+ arguments - runtime parameters that may be passed to an initialized operation with low -+ computational overhead -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/mma.h" -+#include "cutlass/arch/arch.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct ArchMap; -+ -+template <> struct ArchMap { -+ static int const kMin = 50; -+ static int const kMax = 1024; -+}; -+ -+template <> struct ArchMap { -+ static int const kMin = 60; -+ static int const kMax = 1024; -+}; -+ -+template <> struct ArchMap { -+ static int const kMin = 61; -+ static int const kMax = 1024; -+}; -+ -+template <> struct ArchMap { -+ static int const kMin = 70; -+ static int const kMax = 1024; -+}; -+ -+template <> struct ArchMap { -+ static int const kMin = 70; -+ static int const kMax = 75; -+}; -+ -+template struct ArchMap { -+ static int const kMin = 75; -+ static int const kMax = 1024; -+}; -+ -+template struct ArchMap { -+ static int const kMin = 80; -+ static int const kMax = 1024; -+}; -+ -+template struct ArchMap { -+ static int const kMin = 86; -+ static int const kMax = 1024; -+}; -+ -+template struct ArchMap { -+ static int const kMin = 90; -+ static int const kMax = 1024; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/handle.h b/3rdparty/cutlass/tools/library/include/cutlass/library/handle.h -new file mode 100644 -index 0000000..8125989 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/handle.h -@@ -0,0 +1,355 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief BLAS-like handle used to launch operations on the CUDA device. -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/library/library.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Handle object -+class Handle { -+private: -+ -+ /// Host workspace -+ static int const kHostWorkspaceSize = (4 << 10); -+ -+ /// Provider of operations -+ Provider provider_; -+ -+ /// CUDA device properties -+ cudaDeviceProp device_; -+ -+ /// CUDA stream -+ cudaStream_t stream_; -+ -+ /// Device workspace -+ void *workspace_; -+ -+ /// Size of device workspace in bytes -+ size_t workspace_size_; -+ -+ /// Indicates whether scalars are host or device pointers -+ ScalarPointerMode scalar_pointer_mode_; -+ -+ /// Pointer to the most recently executed operation -+ Operation const *last_operation_; -+ -+public: -+ -+ /// Constructor -+ Handle(cudaStream_t stream = nullptr, size_t workspace_size = (4<<20)); -+ -+ /// Destructor -+ ~Handle(); -+ -+ /// Move constructor -+ Handle(Handle && handle); -+ -+ /// Move assignment operator -+ Handle &operator=(Handle && handle); -+ -+ // -+ // Persistent state accessors -+ // -+ -+ /// Returns compute capability of the selected device -+ int compute_capability() const; -+ -+ /// Sets the current CUDA stream -+ void set_stream(cudaStream_t stream); -+ -+ /// Gets the current CUDA stream -+ cudaStream_t get_stream() const; -+ -+ /// Gets the current provider -+ Provider get_provider() const; -+ -+ /// Sets the provider of operations -+ void set_provider(Provider provider); -+ -+ /// Gets the device workspace size -+ size_t get_workspace_size() const; -+ -+ /// Gets a pointer to the device workspace allocation in Global Memory -+ void *get_workspace() const; -+ -+ /// Sets the size of device workspace, invalidating calls to get_device_workspace() -+ void set_workspace_size(size_t bytes); -+ -+ /// Gets the scalar pointer mode -+ ScalarPointerMode get_scalar_pointer_mode() const; -+ -+ /// Sets the scalar pointer mode -+ void set_scalar_pointer_mode(ScalarPointerMode mode); -+ -+ /// Gets the most recently executed operation -+ Operation const *get_last_operation() const; -+ -+ // -+ // Computations -+ // -+ -+ /// Executes a GEMM computation: D <= alpha * A*B + beta * C -+ Status gemm( -+ -+ int M, /// GEMM M dimension -+ int N, /// GEMM N dimension -+ int K, /// GEMM K dimension -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices -+ -+ void const * ptr_A, /// Pointer to A matrix in Global Memory -+ int64_t lda, /// Leading dimension of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices -+ -+ void const * ptr_B, /// Pointer to B matrix in Global Memory -+ int64_t ldb, /// Leading dimension of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrices -+ -+ void const * ptr_C, /// Pointer to C matrix -+ int64_t ldc, /// Leading dimension of C matrix -+ -+ void * ptr_D, /// Pointer to D matrix -+ int64_t ldd /// Leading dimension of D matrix -+ ); -+ -+ /// Executes a GEMM computation: D <= alpha * A*B + beta * C. -+ // -+ // Supports batched-strided, batched array or split-K serial or split-K parallel. -+ // -+ Status gemm_universal( -+ -+ GemmUniversalMode mode, /// indicates the mode in which the kUniversal GEMM is launched -+ -+ int M, /// GEMM M dimension -+ int N, /// GEMM N dimension -+ int K, /// GEMM K dimension -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices -+ -+ void const * ptr_A, /// Pointer to A matrix in Global Memory -+ int64_t lda, /// Leading dimension of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices -+ -+ void const * ptr_B, /// Pointer to B matrix in Global Memory -+ int64_t ldb, /// Leading dimension of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrices -+ -+ void const * ptr_C, /// Pointer to C matrix -+ int64_t ldc, /// Leading dimension of C matrix -+ -+ void * ptr_D, /// Pointer to D matrix -+ int64_t ldd, /// Leading dimension of D matrix -+ -+ int batch_count = 1, /// Batch count or number of split-K slices -+ -+ int64_t batch_stride_A = 0, /// Batch stride of A operand -+ int64_t batch_stride_B = 0, /// Batch stride of B operand -+ int64_t batch_stride_C = 0, /// Batch stride of C operand -+ int64_t batch_stride_D = 0 /// Batch stride of D operand -+ ); -+ -+ /// Planar complex GEMM -+ /// -+ /// Note, all data types are the real-valued base types used by the planar-complex GEMM kernel. -+ /// -+ Status gemm_planar_complex( -+ -+ int M, /// GEMM M dimension -+ int N, /// GEMM N dimension -+ int K, /// GEMM K dimension -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix -+ -+ void const * ptr_A_real, /// Pointer to real part of A matrix -+ void const * ptr_A_imag, /// Pointer to imaginary part of A matrix -+ int64_t lda_real, /// Leading dimension of real part of A matrix -+ int64_t lda_imag, /// Leading dimension of imaginary part of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix -+ -+ void const * ptr_B_real, /// Pointer to real part of B matrix -+ void const * ptr_B_imag, /// Pointer to imaginary part of B matrix -+ int64_t ldb_real, /// Leading dimension of real part of B matrix -+ int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrix -+ -+ void const * ptr_C_real, /// Pointer to real part of C matrix -+ void const * ptr_C_imag, /// Pointer to imaginary part of C matrix -+ int64_t ldc_real, /// Leading dimension of real part of C matrix -+ int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix -+ -+ void * ptr_D_real, /// Pointer to real part of D matrix -+ void * ptr_D_imag, /// Pointer to imaginary part of D matrix -+ int64_t ldd_real, /// Leading dimension of real part of D matrix -+ int64_t ldd_imag, /// Leading dimension of imaginary part of D matrix -+ -+ int batch_count = 1, /// Number of batched GEMMs to execute -+ -+ int64_t batch_stride_A_real = 0, -+ int64_t batch_stride_A_imag = 0, -+ -+ int64_t batch_stride_B_real = 0, -+ int64_t batch_stride_B_imag = 0, -+ -+ int64_t batch_stride_C_real = 0, -+ int64_t batch_stride_C_imag = 0, -+ -+ int64_t batch_stride_D_real = 0, -+ int64_t batch_stride_D_imag = 0 -+ ); -+ -+ /// Planar complex GEMM loading pointers from arrays in global memory -+ Status gemm_planar_complex_array( -+ -+ int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) -+ int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) -+ int expected_K, /// Expected GEMM K dimension -+ int batch_count, /// Number of independent GEMM computations to execute -+ -+ int const *M, /// Array containing the GEMM M dimension for each batch index -+ int const *N, /// Array containing the GEMM N dimension for each batch index -+ int const *K, /// Array containing the GEMM K dimension for each batch index -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix -+ -+ void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices -+ void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices -+ -+ int64_t lda_real, /// Leading dimension of real part of A matrix -+ int64_t lda_imag, /// Leading dimension of imaginary part of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix -+ -+ void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices -+ void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices -+ -+ int64_t ldb_real, /// Leading dimension of real part of B matrix -+ int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrix -+ -+ void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices -+ void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices -+ -+ int64_t ldc_real, /// Leading dimension of real part of C matrix -+ int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix -+ -+ void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices -+ void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices -+ -+ int64_t ldd_real, /// Leading dimension of real part of D matrix -+ int64_t ldd_imag /// Leading dimension of imaginary part of D matrix -+ ); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Unique pointer storing the handle -+using HandlePtr = std::unique_ptr; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Finds conv2d operation instances with Conv2d::ElementC = Reduction::ElementWorkspace -+Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation); -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Finds gemm operation instances with ElementC = Reduction::ElementWorkspace -+Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation); -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/library.h b/3rdparty/cutlass/tools/library/include/cutlass/library/library.h -new file mode 100644 -index 0000000..6bb3f79 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/library.h -@@ -0,0 +1,1537 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ -+ \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. -+ -+ Generally, -+ -+ description - compile-time constant parameters used to instantiate an operation -+ -+ configuration - runtime parameters with computationally expensive initialization -+ -+ arguments - runtime parameters that may be passed to an initialized operation with low -+ computational overhead -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Layout type identifier -+enum class LayoutTypeID { -+ kUnknown, -+ kColumnMajor, -+ kRowMajor, -+ kColumnMajorInterleavedK2, -+ kRowMajorInterleavedK2, -+ kColumnMajorInterleavedK4, -+ kRowMajorInterleavedK4, -+ kColumnMajorInterleavedK16, -+ kRowMajorInterleavedK16, -+ kColumnMajorInterleavedK32, -+ kRowMajorInterleavedK32, -+ kColumnMajorInterleavedK64, -+ kRowMajorInterleavedK64, -+ kTensorNCHW, -+ kTensorNCDHW, -+ kTensorNHWC, -+ kTensorNDHWC, -+ kTensorNC32HW32, -+ kTensorC32RSK32, -+ kTensorNC64HW64, -+ kTensorC64RSK64, -+ kInvalid -+}; -+ -+/// Numeric data type -+enum class NumericTypeID { -+ kUnknown, -+ kVoid, -+ kB1, -+ kU2, -+ kU4, -+ kU8, -+ kU16, -+ kU32, -+ kU64, -+ kS2, -+ kS4, -+ kS8, -+ kS16, -+ kS32, -+ kS64, -+ kF16, -+ kBF16, -+ kTF32, -+ kF32, -+ kF64, -+ kCF16, -+ kCBF16, -+ kCF32, -+ kCTF32, -+ kCF64, -+ kCS2, -+ kCS4, -+ kCS8, -+ kCS16, -+ kCS32, -+ kCS64, -+ kCU2, -+ kCU4, -+ kCU8, -+ kCU16, -+ kCU32, -+ kCU64, -+ kInvalid -+}; -+ -+/// Enumerated type describing a transformation on a complex value. -+enum class ComplexTransform { -+ kNone, -+ kConjugate, -+ kInvalid -+}; -+ -+/// Providers -+enum class Provider { -+ kNone, -+ kCUTLASS, -+ kReferenceHost, -+ kReferenceDevice, -+ kCUBLAS, -+ kCUDNN, -+ kInvalid -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Enumeration indicating the kind of operation -+enum class OperationKind { -+ kGemm, -+ kRankK, -+ kRank2K, -+ kTrmm, -+ kSymm, -+ kConv2d, -+ kConv3d, -+ kEqGemm, -+ kSparseGemm, -+ kReduction, -+ kInvalid -+}; -+ -+/// Enumeration indicating whether scalars are in host or device memory -+enum class ScalarPointerMode { -+ kHost, -+ kDevice, -+ kInvalid -+}; -+ -+/// Describes how reductions are performed across threadblocks -+enum class SplitKMode { -+ kNone, -+ kSerial, -+ kParallel, -+ kParallelSerial, -+ kInvalid -+}; -+ -+/// Indicates the classificaition of the math instruction -+enum class OpcodeClassID { -+ kSimt, -+ kTensorOp, -+ kWmmaTensorOp, -+ kSparseTensorOp, -+ kInvalid -+}; -+ -+enum class MathOperationID { -+ kAdd, -+ kMultiplyAdd, -+ kMultiplyAddSaturate, -+ kMultiplyAddFastBF16, -+ kMultiplyAddFastF16, -+ kMultiplyAddFastF32, -+ kMultiplyAddComplex, -+ kMultiplyAddComplexFastF32, -+ kMultiplyAddGaussianComplex, -+ kXorPopc, -+ kInvalid -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Enumeration indicating what kind of GEMM operation to perform -+enum class GemmKind { -+ kGemm, -+ kSparse, -+ kUniversal, -+ kPlanarComplex, -+ kPlanarComplexArray, -+ kGrouped, -+ kInvalid -+}; -+ -+/// Mode of Universal GEMM -+using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; -+ -+/// Enumeration indicating what kind of RankK update operation to perform -+enum class RankKKind { -+ kUniversal, -+ kInvalid -+}; -+ -+/// Enumeration indicating what kind of TRMM operation to perform -+enum class TrmmKind { -+ kUniversal, -+ kInvalid -+}; -+ -+/// Enumeration indicating what kind of SYMM/HEMM operation to perform -+enum class SymmKind { -+ kUniversal, -+ kInvalid -+}; -+ -+/// Enumeration indicating what kind of Conv2d operation to perform -+enum class ConvKind { -+ kUnknown, -+ kFprop, -+ kDgrad, -+ kWgrad, -+ kInvalid -+}; -+ -+enum class ConvModeID { -+ kCrossCorrelation, -+ kConvolution, -+ kInvalid -+}; -+ -+// Iterator algorithm enum in order of general performance-efficiency -+enum class IteratorAlgorithmID { -+ kNone, -+ kAnalytic, -+ kOptimized, -+ kFixedChannels, -+ kFewChannels, -+ kInvalid -+}; -+ -+ -+enum class EpilogueKind { -+ kUnknown, -+ kConversion, -+ kLinearCombination, -+ kLinearCombinationClamp, -+ kLinearCombinationPlanarComplex, -+ kLinearCombinationRelu, -+ kLinearCombinationSigmoid, -+ kInvalid -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct MathInstructionDescription { -+ -+ /// Shape of the target math instruction -+ cutlass::gemm::GemmCoord instruction_shape; -+ -+ /// Describes the data type of the internal accumulator -+ NumericTypeID element_accumulator; -+ -+ /// Classification of math instruction -+ OpcodeClassID opcode_class; -+ -+ /// Type of math operation performed -+ MathOperationID math_operation; -+ -+ // -+ // Methods -+ // -+ -+ MathInstructionDescription( -+ cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), -+ NumericTypeID element_accumulator = NumericTypeID::kInvalid, -+ OpcodeClassID opcode_class = OpcodeClassID::kInvalid, -+ MathOperationID math_operation = MathOperationID::kMultiplyAdd -+ ): -+ instruction_shape(instruction_shape), -+ element_accumulator(element_accumulator), -+ opcode_class(opcode_class), -+ math_operation(math_operation) {} -+ -+ // Equality operator -+ inline -+ bool operator==(MathInstructionDescription const& rhs) const{ -+ return ( -+ (instruction_shape == rhs.instruction_shape) && -+ (element_accumulator == rhs.element_accumulator) && -+ (opcode_class == rhs.opcode_class) && -+ (math_operation == rhs.math_operation)); -+ } -+ -+ // Inequality operator -+ inline -+ bool operator!=(MathInstructionDescription const& rhs) const { -+ return !(*this == rhs); -+ } -+ -+}; -+ -+/// Structure describing the tiled structure of a GEMM-like computation -+struct TileDescription { -+ -+ /// Describes the shape of a threadblock (in elements) -+ cutlass::gemm::GemmCoord threadblock_shape; -+ -+ /// Describes the number of pipeline stages in the threadblock-scoped mainloop -+ int threadblock_stages; -+ -+ /// Number of warps in each logical dimension -+ cutlass::gemm::GemmCoord warp_count; -+ -+ /// Core math instruction -+ MathInstructionDescription math_instruction; -+ -+ /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. -+ int minimum_compute_capability; -+ -+ /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. -+ int maximum_compute_capability; -+ -+ /// Describes the shape of a cluster (in blocks) -+ cutlass::gemm::GemmCoord cluster_shape; -+ -+ // -+ // Methods -+ // -+ -+ TileDescription( -+ cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(), -+ int threadblock_stages = 0, -+ cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), -+ MathInstructionDescription math_instruction = MathInstructionDescription(), -+ int minimum_compute_capability = 0, -+ int maximum_compute_capability = 0, -+ cutlass::gemm::GemmCoord cluster_shape = cutlass::gemm::GemmCoord(1,1,1) -+ ): -+ threadblock_shape(threadblock_shape), -+ threadblock_stages(threadblock_stages), -+ warp_count(warp_count), -+ math_instruction(math_instruction), -+ minimum_compute_capability(minimum_compute_capability), -+ maximum_compute_capability(maximum_compute_capability), -+ cluster_shape(cluster_shape) { } -+ -+ // Equality operator -+ inline -+ bool operator==(TileDescription const& rhs) const{ -+ return ( -+ (threadblock_shape == rhs.threadblock_shape) && -+ (threadblock_stages == rhs.threadblock_stages) && -+ (warp_count == rhs.warp_count) && -+ (math_instruction == rhs.math_instruction) && -+ (minimum_compute_capability == rhs.minimum_compute_capability) && -+ (maximum_compute_capability == rhs.maximum_compute_capability)); -+ } -+ -+ // Inequality operator -+ inline -+ bool operator!=(TileDescription const& rhs) const { -+ return !(*this == rhs); -+ } -+}; -+ -+/// High-level description of an operation -+struct OperationDescription { -+ -+ /// Unique identifier describing the operation -+ char const * name; -+ -+ /// Operation provider -+ Provider provider; -+ -+ /// Kind of operation -+ OperationKind kind; -+ -+ /// Describes the tiled structure of a GEMM-like computation -+ TileDescription tile_description; -+ -+ // -+ // Methods -+ // -+ OperationDescription( -+ char const * name = "unknown", -+ Provider Provider = Provider::kInvalid, -+ OperationKind kind = OperationKind::kInvalid, -+ TileDescription const & tile_description = TileDescription() -+ ): -+ name(name), kind(kind), tile_description(tile_description) { } -+}; -+ -+/// Structure describing the properties of a tensor -+struct TensorDescription { -+ -+ /// Numeric type of an individual element -+ NumericTypeID element; -+ -+ /// Enumerant identifying the layout function for the tensor -+ LayoutTypeID layout; -+ -+ /// Alignment restriction on pointers, strides, and extents -+ int alignment; -+ -+ /// log2() of the maximum extent of each dimension -+ int log_extent_range; -+ -+ /// log2() of the maximum value each relevant stride may have -+ int log_stride_range; -+ -+ // -+ // Methods -+ // -+ -+ TensorDescription( -+ NumericTypeID element = NumericTypeID::kInvalid, -+ LayoutTypeID layout = LayoutTypeID::kInvalid, -+ int alignment = 1, -+ int log_extent_range = 24, -+ int log_stride_range = 24 -+ ): -+ element(element), -+ layout(layout), -+ alignment(alignment), -+ log_extent_range(log_extent_range), -+ log_stride_range(log_stride_range) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Description of all GEMM computations -+struct GemmDescription : public OperationDescription { -+ -+ /// Indicates the kind of GEMM performed -+ GemmKind gemm_kind; -+ -+ /// Describes the A operand -+ TensorDescription A; -+ -+ /// Describes the B operand -+ TensorDescription B; -+ -+ /// Describes the source and destination matrices -+ TensorDescription C; -+ -+ /// Describes the sparse meta matrices -+ TensorDescription E; -+ -+ /// Describes the data type of the scalars passed to the epilogue -+ NumericTypeID element_epilogue; -+ -+ /// Describes the structure of parallel reductions -+ SplitKMode split_k_mode; -+ -+ /// Transformation on A operand -+ ComplexTransform transform_A; -+ -+ /// Transformation on B operand -+ ComplexTransform transform_B; -+ -+ // -+ // Methods -+ // -+ -+ GemmDescription( -+ GemmKind gemm_kind = GemmKind::kGemm, -+ TensorDescription const &A = TensorDescription(), -+ TensorDescription const &B = TensorDescription(), -+ TensorDescription const &C = TensorDescription(), -+ NumericTypeID element_epilogue = NumericTypeID::kInvalid, -+ SplitKMode split_k_mode = SplitKMode::kNone, -+ ComplexTransform transform_A = ComplexTransform::kNone, -+ ComplexTransform transform_B = ComplexTransform::kNone -+ ): -+ gemm_kind(gemm_kind), -+ A(A), -+ B(B), -+ C(C), -+ element_epilogue(element_epilogue), -+ split_k_mode(split_k_mode), -+ transform_A(transform_A), -+ transform_B(transform_B) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Desciprion for structured sparse GEMMs. -+struct SparseGemmDescription : public GemmDescription { -+ -+ /// Description structure for structured sparse GEMM -+ SparseGemmDescription( -+ GemmKind gemm_kind = GemmKind::kGemm, -+ TensorDescription const &A = TensorDescription(), -+ TensorDescription const &B = TensorDescription(), -+ TensorDescription const &C = TensorDescription(), -+ TensorDescription const &E = TensorDescription(), -+ NumericTypeID element_epilogue = NumericTypeID::kInvalid, -+ SplitKMode split_k_mode = SplitKMode::kNone, -+ ComplexTransform transform_A = ComplexTransform::kNone, -+ ComplexTransform transform_B = ComplexTransform::kNone -+ ): -+ GemmDescription(gemm_kind, A, B, C, element_epilogue, split_k_mode, transform_A, transform_B) -+ {this->E = E;} -+}; -+ -+/// Description of all Reduction operations -+struct ReductionDescription : public OperationDescription { -+ -+ /// Describes the data type of workspace -+ NumericTypeID element_workspace; -+ -+ /// Describes the data type of final output -+ NumericTypeID element_output; -+ -+ /// Describes the data type of the scalars passed to the epilogue -+ NumericTypeID element_epilogue; -+}; -+ -+/// Description of all Rank K update computations (SYRK, HERK, SYR2K, HER2K) -+struct RankKDescription : public OperationDescription { -+ -+ /// Indicates which device template is used (universal or regular) -+ RankKKind rank_k_kind; -+ -+ /// Number of rank update (rank k or rank 2k) -+ int num_ranks; -+ -+ /// Describes the A operand -+ TensorDescription A; -+ -+ /// Describes the B operand (used only for SYR2K and HER2K) -+ TensorDescription B; -+ -+ /// Describes the source and destination matrices -+ TensorDescription C; -+ -+ /// Describes the fill mode for matrix C -+ FillMode fill_mode; -+ -+ /// Describes the blas mode (symmetric/hermitian) -+ BlasMode blas_mode; -+ -+ /// Describes the data type of the scalars passed to the epilogue -+ NumericTypeID element_epilogue; -+ -+ /// Describes the structure of parallel reductions -+ SplitKMode split_k_mode; -+ -+ /// Transformation on A operand -+ ComplexTransform transform_A; -+ -+ /// Transformation on B operand -+ ComplexTransform transform_B; -+ -+ // -+ // Methods -+ // -+ -+ RankKDescription( -+ RankKKind rank_k_kind = RankKKind::kUniversal, -+ int num_ranks = 1, -+ TensorDescription const &A = TensorDescription(), -+ TensorDescription const &B = TensorDescription(), -+ TensorDescription const &C = TensorDescription(), -+ FillMode fill_mode = FillMode::kInvalid, -+ BlasMode blas_mode = BlasMode::kInvalid, -+ NumericTypeID element_epilogue = NumericTypeID::kInvalid, -+ SplitKMode split_k_mode = SplitKMode::kNone, -+ ComplexTransform transform_A = ComplexTransform::kNone, -+ ComplexTransform transform_B = ComplexTransform::kNone -+ ): -+ rank_k_kind(rank_k_kind), -+ num_ranks(num_ranks), -+ A(A), -+ B(B), -+ C(C), -+ fill_mode(fill_mode), -+ blas_mode(blas_mode), -+ element_epilogue(element_epilogue), -+ split_k_mode(split_k_mode), -+ transform_A(transform_A), -+ transform_B(transform_B) {} -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Description of all TRMM computations -+struct TrmmDescription : public OperationDescription { -+ -+ /// Indicates the kind of TRMM performed -+ TrmmKind trmm_kind; -+ -+ /// Describes the A operand -+ TensorDescription A; -+ -+ /// Describes the side mode for matrix A -+ SideMode side_mode; -+ -+ /// Describes the fill mode for matrix A -+ FillMode fill_mode; -+ -+ /// Describes the diag type for matrix A -+ DiagType diag_type; -+ -+ /// Describes the B operand -+ TensorDescription B; -+ -+ /// Describes the source and destination matrices -+ TensorDescription D; -+ -+ /// Describes the data type of the scalars passed to the epilogue -+ NumericTypeID element_epilogue; -+ -+ /// Describes the structure of parallel reductions -+ SplitKMode split_k_mode; -+ -+ /// Transformation on A operand -+ ComplexTransform transform_A; -+ -+ // -+ // Methods -+ // -+ -+ TrmmDescription( -+ TrmmKind trmm_kind = TrmmKind::kUniversal, -+ TensorDescription const &A = TensorDescription(), -+ SideMode side_mode = SideMode::kInvalid, -+ FillMode fill_mode = FillMode::kInvalid, -+ DiagType diag_type = DiagType::kInvalid, -+ TensorDescription const &B = TensorDescription(), -+ TensorDescription const &D = TensorDescription(), -+ NumericTypeID element_epilogue = NumericTypeID::kInvalid, -+ SplitKMode split_k_mode = SplitKMode::kNone, -+ ComplexTransform transform_A = ComplexTransform::kNone -+ ): -+ trmm_kind(trmm_kind), -+ A(A), -+ side_mode(side_mode), -+ fill_mode(fill_mode), -+ diag_type(diag_type), -+ B(B), -+ D(D), -+ element_epilogue(element_epilogue), -+ split_k_mode(split_k_mode), -+ transform_A(transform_A) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Description of all SYMM/HEMM update computations -+struct SymmDescription : public OperationDescription { -+ -+ /// Indicates which device template is used (universal or regular) -+ SymmKind symm_kind; -+ -+ /// Describes the A operand -+ TensorDescription A; -+ -+ /// Describes the B operand -+ TensorDescription B; -+ -+ /// Describes the source and destination matrices -+ TensorDescription C; -+ -+ /// Describes the side mode for matrix A -+ SideMode side_mode; -+ -+ /// Describes the fill mode for matrix A -+ FillMode fill_mode; -+ -+ /// Describes the blas mode (symmetric/hermitian) -+ BlasMode blas_mode; -+ -+ /// Describes the data type of the scalars passed to the epilogue -+ NumericTypeID element_epilogue; -+ -+ /// Describes the structure of parallel reductions -+ SplitKMode split_k_mode; -+ -+ /// Transformation on A operand -+ ComplexTransform transform_A; -+ -+ /// Transformation on B operand -+ ComplexTransform transform_B; -+ -+ // -+ // Methods -+ // -+ -+ SymmDescription( -+ SymmKind symm_kind = SymmKind::kUniversal, -+ TensorDescription const &A = TensorDescription(), -+ TensorDescription const &B = TensorDescription(), -+ TensorDescription const &C = TensorDescription(), -+ SideMode side_mode = SideMode::kInvalid, -+ FillMode fill_mode = FillMode::kInvalid, -+ BlasMode blas_mode = BlasMode::kInvalid, -+ NumericTypeID element_epilogue = NumericTypeID::kInvalid, -+ SplitKMode split_k_mode = SplitKMode::kNone, -+ ComplexTransform transform_A = ComplexTransform::kNone, -+ ComplexTransform transform_B = ComplexTransform::kNone -+ ): -+ symm_kind(symm_kind), -+ A(A), -+ B(B), -+ C(C), -+ side_mode(side_mode), -+ fill_mode(fill_mode), -+ blas_mode(blas_mode), -+ element_epilogue(element_epilogue), -+ split_k_mode(split_k_mode), -+ transform_A(transform_A), -+ transform_B(transform_B) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Description of all Conv2d operations -+struct ConvDescription : public OperationDescription { -+ /// Describes the convolution dimension support (2D or 3D) -+ int conv_dim; -+ -+ /// Describes the kind of convolution -+ ConvKind conv_kind; -+ -+ /// Describes the type of iterator algorithm (analytic or precomputed) -+ IteratorAlgorithmID iterator_algorithm; -+ -+ /// Describes the A operand -+ TensorDescription A; -+ -+ /// Describes the B operand -+ TensorDescription B; -+ -+ /// Describes the C operand -+ TensorDescription C; -+ -+ /// Describes the data type of the scalars passed to the epilogue -+ NumericTypeID element_epilogue; -+ -+ // -+ // Methods -+ // -+ // Returns Activation TensorDescription -+ TensorDescription activation() const { -+ switch(conv_kind) { -+ case library::ConvKind::kFprop : return A; -+ case library::ConvKind::kDgrad : return C; -+ case library::ConvKind::kWgrad : return B; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns Filter TensorDescription -+ TensorDescription filter() const { -+ switch(conv_kind) { -+ case library::ConvKind::kFprop : return B; -+ case library::ConvKind::kDgrad : return B; -+ case library::ConvKind::kWgrad : return C; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns Output TensorDescription -+ TensorDescription output() const { -+ switch(conv_kind) { -+ case library::ConvKind::kFprop : return C; -+ case library::ConvKind::kDgrad : return A; -+ case library::ConvKind::kWgrad : return A; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Base class for all operations -+class Operation { -+public: -+ -+ virtual ~Operation() { } -+ -+ virtual OperationDescription const & description() const = 0; -+ -+ virtual Status can_implement( -+ void const *configuration, -+ void const *arguments) const = 0; -+ -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const = 0; -+ -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration, -+ void const *arguments = nullptr) const = 0; -+ -+ virtual Status initialize( -+ void const *configuration, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const = 0; -+ -+ virtual Status run( -+ void const *arguments, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const = 0; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for basic GEMM operations -+// -+// OperationKind: Gemm -+// GemmKind: Gemm -+// -+struct GemmConfiguration { -+ -+ /// GEMM problem size -+ gemm::GemmCoord problem_size; -+ -+ /// Leading dimension of A matrix -+ int64_t lda; -+ -+ /// Leading dimension of B matrix -+ int64_t ldb; -+ -+ /// Leading dimension of C matrix -+ int64_t ldc; -+ -+ /// Leading dimension of D matrix -+ int64_t ldd; -+ -+ /// Number of partitions of K dimension -+ int split_k_slices; -+}; -+ -+/// Arguments for GEMM -+struct GemmArguments { -+ -+ /// Pointer to A matrix -+ void const *A; -+ -+ /// Pointer to B matrix -+ void const *B; -+ -+ /// Pointer to C matrix -+ void const *C; -+ -+ /// Pointer to D matrix -+ void *D; -+ -+ /// Host or device pointer to alpha scalar -+ void const *alpha; -+ -+ /// Host or device pointer to beta scalar -+ void const *beta; -+ -+ /// Enumerant indicating whether alpha/beta point to host or device memory -+ ScalarPointerMode pointer_mode; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for batched GEMM in which multiple matrix products are computed -+// -+// OperationKind: Gemm -+// GemmKind: Batched -+ -+struct GemmBatchedConfiguration { -+ -+ /// GEMM problem size -+ gemm::GemmCoord problem_size; -+ -+ /// Leading dimension of A matrix -+ int64_t lda; -+ -+ /// Leading dimension of B matrix -+ int64_t ldb; -+ -+ /// Leading dimension of C matrix -+ int64_t ldc; -+ -+ /// Leading dimension of D matrix -+ int64_t ldd; -+ -+ /// Stride between instances of the A matrix in memory -+ int64_t batch_stride_A; -+ -+ /// Stride between instances of the B matrix in memory -+ int64_t batch_stride_B; -+ -+ /// Stride between instances of the C matrix in memory -+ int64_t batch_stride_C; -+ -+ /// Stride between instances of the D matrix in memory -+ int64_t batch_stride_D; -+ -+ /// Number of GEMMs in batch -+ int batch_count; -+}; -+ -+/// Arguments to batched GEMM -+using GemmBatchedArguments = GemmArguments; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for batched GEMM in which multiple matrix products are computed -+// -+// OperationKind: Gemm -+// GemmKind: Array -+ -+struct GemmArrayConfiguration { -+ -+ gemm::GemmCoord problem_size; -+ -+ /// Leading dimension of A matrix -+ int64_t lda; -+ -+ /// Leading dimension of B matrix -+ int64_t ldb; -+ -+ /// Leading dimension of C matrix -+ int64_t ldc; -+ -+ /// Leading dimension of D matrix -+ int64_t ldd; -+ -+ int batch_count; -+}; -+ -+/// Arguments for GEMM - used by all the GEMM operations -+struct GemmArrayArguments { -+ void const * const *A; -+ void const * const *B; -+ void const * const *C; -+ void * const *D; -+ void const *alpha; -+ void const *beta; -+ ScalarPointerMode pointer_mode; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Universal GEMM supporting multiple split-K modes, multiple batched modes, real and complex -+// -+// OperationKind: Gemm -+// GemmKind: Universal -+ -+struct GemmUniversalConfiguration { -+ -+ GemmUniversalMode mode; -+ gemm::GemmCoord problem_size; -+ int batch_count; -+ -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldc; -+ int64_t ldd; -+}; -+ -+struct GemmUniversalArguments { -+ // NOTE: these are replicated for 3.0 interfaces -+ gemm::GemmCoord problem_size; -+ int batch_count; -+ -+ void const *A; -+ void const *B; -+ void const *C; -+ void *D; -+ -+ void const *alpha; -+ void const *beta; -+ ScalarPointerMode pointer_mode; -+ -+ // NOTE: these are replicated for 3.0 interfaces -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldc; -+ int64_t ldd; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Complex valued GEMM in which real and imaginary parts are separated by a stride -+// -+// OperationKind: Gemm -+// GemmKind: Planar complex -+ -+struct GemmPlanarComplexConfiguration { -+ -+ GemmUniversalMode mode; -+ gemm::GemmCoord problem_size; -+ int batch_count; -+ -+ int64_t lda_real; -+ int64_t lda_imag; -+ -+ int64_t ldb_real; -+ int64_t ldb_imag; -+ -+ int64_t ldc_real; -+ int64_t ldc_imag; -+ -+ int64_t ldd_real; -+ int64_t ldd_imag; -+}; -+ -+/// Arguments for planar complex GEMMs -+struct GemmPlanarComplexArguments { -+ -+ void const *A_real; -+ void const *A_imag; -+ -+ void const *B_real; -+ void const *B_imag; -+ -+ void const *C_real; -+ void const *C_imag; -+ -+ void *D_real; -+ void *D_imag; -+ -+ void const *alpha; -+ void const *beta; -+ ScalarPointerMode pointer_mode; -+ -+ int64_t batch_stride_A_real; -+ int64_t batch_stride_A_imag; -+ -+ int64_t batch_stride_B_real; -+ int64_t batch_stride_B_imag; -+ -+ int64_t batch_stride_C_real; -+ int64_t batch_stride_C_imag; -+ -+ int64_t batch_stride_D_real; -+ int64_t batch_stride_D_imag; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This is a special form of planar complex which loads pointers and problem size -+/// from memory. -+struct GemmPlanarComplexArrayConfiguration { -+ -+ gemm::GemmCoord problem_size; -+ int batch_count; -+ -+ int64_t lda_real; -+ int64_t lda_imag; -+ -+ int64_t ldb_real; -+ int64_t ldb_imag; -+ -+ int64_t ldc_real; -+ int64_t ldc_imag; -+ -+ int64_t ldd_real; -+ int64_t ldd_imag; -+}; -+ -+/// Arguments for planar complex GEMMs -+struct GemmPlanarComplexArrayArguments { -+ -+ int const *M; -+ int const *N; -+ int const *K; -+ -+ void const * const * A_real; -+ void const * const * A_imag; -+ void const * const * B_real; -+ void const * const * B_imag; -+ void const * const * C_real; -+ void const * const * C_imag; -+ void * const * D_real; -+ void * const * D_imag; -+ -+ void const * alpha; -+ void const * beta; -+ ScalarPointerMode pointer_mode; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Grouped GEMM supporting -+// -+// OperationKind: Gemm -+// GemmKind: Grouped -+ -+struct GemmGroupedConfiguration { -+ -+ int problem_count; -+ int threadblock_count; -+ -+}; -+ -+struct GemmGroupedArguments { -+ -+ gemm::GemmCoord *problem_sizes; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t *lda; -+ int64_t *ldb; -+ int64_t *ldc; -+ int64_t *ldd; -+ -+ void const *alpha; -+ void const *beta; -+ ScalarPointerMode pointer_mode; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// OperationKind: kSparseGemm -+// -+ -+/// Computes GEMM assumine one of the inputs has 2:4 structured sparsity. -+struct SparseGemmConfiguration { -+ -+ GemmUniversalMode mode; -+ gemm::GemmCoord problem_size; -+ int batch_count; /// number of sparse matrix products in batch -+ -+ int64_t lda; /// leading dimension of A operand -+ int64_t ldb; /// leading dimension of B operand -+ int64_t ldc; /// leading dimension of C operand -+ int64_t ldd; /// leading dimension of D operand -+ int64_t lde; /// leading dimension of E operand (metadata matrix) -+ -+ int64_t batch_stride_A; // stride between matrices -+ int64_t batch_stride_B; // stride between matrices -+ int64_t batch_stride_C; // stride between matrices -+ int64_t batch_stride_D; // stride between matrices -+ int64_t batch_stride_E; // stride between matrices -+}; -+ -+/// Arguments for sparse GEMMs -+struct SparseGemmArguments { -+ -+ void const *A; /// pointer to A matrix -+ void const *B; /// pointer to B matrix -+ void const *C; /// pointer to C matrix -+ void *D; /// pointer to D matrix -+ void const *E; /// pointer to E matric (metadata) -+ -+ void const *alpha; /// pointer to alpha scalar -+ void const *beta; /// pointer to beta scalar -+ ScalarPointerMode pointer_mode; /// enumerant indicating whether alpha/beta pointers are host -+ /// or device pointers. -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for basic Rank K update operations -+// -+// OperationKind: (Syrk, Herk, Syr2k, Her2k) -+// RankKKind: Universal -+// -+struct RankKConfiguration { -+ -+ /// SYRK problem size -+ gemm::GemmCoord problem_size; -+ -+ /// Leading dimension of A matrix -+ int64_t lda; -+ -+ /// Leading dimension of B matrix -+ int64_t ldb; -+ -+ /// Leading dimension of C matrix -+ int64_t ldc; -+ -+ /// Leading dimension of D matrix -+ int64_t ldd; -+ -+ /// Batch Count -+ int batch_count; -+}; -+ -+/// Arguments for (Syrk, Herk, Syr2k, Her2k) -+struct RankKArguments { -+ -+ /// Pointer to A matrix -+ void const *A; -+ -+ /// Pointer to B matrix (used only for Syr2k and Her2k) -+ void const *B; -+ -+ /// Pointer to C matrix -+ void const *C; -+ -+ /// Pointer to D matrix -+ void *D; -+ -+ /// Host or device pointer to alpha scalar -+ void const *alpha; -+ -+ /// Host or device pointer to beta scalar -+ void const *beta; -+ -+ /// Enumerant indicating whether alpha/beta point to host or device memory -+ ScalarPointerMode pointer_mode; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for basic TRMM operations -+// -+// OperationKind: Trmm -+// TrmmKind: Universal -+// -+struct TrmmConfiguration { -+ -+ /// TRMM problem size -+ gemm::GemmCoord problem_size; -+ -+ /// Leading dimension of A matrix -+ int64_t lda; -+ -+ /// Leading dimension of B matrix -+ int64_t ldb; -+ -+ /// Leading dimension of D matrix -+ int64_t ldd; -+ -+ /// Batch Count -+ int batch_count; -+}; -+ -+/// Arguments for TRMM -+struct TrmmArguments { -+ -+ /// Pointer to A matrix -+ void const *A; -+ -+ /// Pointer to B matrix -+ void const *B; -+ -+ /// Pointer to D matrix -+ void *D; -+ -+ /// Host or device pointer to alpha scalar -+ void const *alpha; -+ -+ /// Host or device pointer to beta scalar -+ void const *beta; -+ -+ /// Enumerant indicating whether alpha/beta point to host or device memory -+ ScalarPointerMode pointer_mode; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_D; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for basic SYMM/HEMM update operations -+// -+// OperationKind: (Symm, Hemm) -+// SymmKind: Universal -+// -+struct SymmConfiguration { -+ -+ /// SYMM/HEMM problem size -+ gemm::GemmCoord problem_size; -+ -+ /// Leading dimension of A matrix -+ int64_t lda; -+ -+ /// Leading dimension of B matrix -+ int64_t ldb; -+ -+ /// Leading dimension of C matrix -+ int64_t ldc; -+ -+ /// Leading dimension of D matrix -+ int64_t ldd; -+ -+ /// Batch Count -+ int batch_count; -+}; -+ -+/// Arguments for (Symm, Hemm) -+struct SymmArguments { -+ -+ /// Pointer to A matrix -+ void const *A; -+ -+ /// Pointer to B matrix -+ void const *B; -+ -+ /// Pointer to C matrix -+ void const *C; -+ -+ /// Pointer to D matrix -+ void *D; -+ -+ /// Host or device pointer to alpha scalar -+ void const *alpha; -+ -+ /// Host or device pointer to beta scalar -+ void const *beta; -+ -+ /// Enumerant indicating whether alpha/beta point to host or device memory -+ ScalarPointerMode pointer_mode; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Two dimensional convolution -+// -+// OperationKind: Conv2d -+// -+struct Conv2dConfiguration { -+ -+ conv::SplitKMode split_k_mode; -+ -+ /// Conv2d problem size -+ // contains strictly conv2d size (N,H,W,C,K,R,S,P,Q,padding,stride,dilation,mode) -+ // also includes (split_k_slices, groups) -+ conv::Conv2dProblemSize problem_size; -+ -+ // stride of operand A -+ std::vector stride_a; -+ -+ // stride of operand B -+ std::vector stride_b; -+ -+ // stride of operand C -+ std::vector stride_c; -+}; -+ -+ -+/// Three dimensional convolution -+// -+// OperationKind: Conv3d -+// -+struct Conv3dConfiguration { -+ -+ conv::SplitKMode split_k_mode; -+ -+ /// Conv2d problem size -+ // contains strictly conv2d size (N,D,H,W,C,K,T,R,S,Z,P,Q,padding,stride,dilation,mode) -+ // also includes (split_k_slices, groups) -+ conv::Conv3dProblemSize problem_size; -+ -+ /// Layout object for activations tensor -+ layout::TensorNDHWC layout_activations; -+ -+ /// Layout object for filters tensor -+ layout::TensorNDHWC layout_filters; -+ -+ /// Layout object for source tensor -+ layout::TensorNDHWC layout_source; -+ -+ /// Layout object for output tensor -+ layout::TensorNDHWC layout_output; -+ -+ // -+ // Methods -+ // -+ -+ // Mapping functions (A,B,C -> activation,filter,output) -+ layout::TensorNDHWC layout_a(library::ConvKind const &conv_kind) const { -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return layout_activations; -+ case library::ConvKind::kDgrad: return layout_output; -+ case library::ConvKind::kWgrad: return layout_output; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ layout::TensorNDHWC layout_b(library::ConvKind const &conv_kind) const { -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return layout_filters; -+ case library::ConvKind::kDgrad: return layout_filters; -+ case library::ConvKind::kWgrad: return layout_activations; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ layout::TensorNDHWC layout_c(library::ConvKind const &conv_kind) const { -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return layout_output; -+ case library::ConvKind::kDgrad: return layout_activations; -+ case library::ConvKind::kWgrad: return layout_filters; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+}; -+ -+/// Arguments for CONV -+struct ConvArguments { -+ -+ ///////////////////////////////////////////////////////// -+ /// ImplicitGemm matrices A, B, C, D -+ ///////////////////////////////////////////////////////// -+ /// pointer to implicit gemm matrix A -+ void const *A; -+ -+ /// pointer to implicit gemm matrix B -+ void const *B; -+ -+ /// pointer to reordered matrix B -+ void const *reordered_B; -+ -+ /// pointer to implicit gemm matrix C -+ void const *C; -+ -+ /// pointer to implicit gemm desitination matrix D -+ void *D; -+ -+ /// Host or device pointer to alpha scalar -+ void const *alpha; -+ -+ /// Host or device pointer to beta scalar -+ void const *beta; -+ -+ /// Enumerant indicating whether alpha/beta point to host or device memory -+ ScalarPointerMode pointer_mode; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for Reduction operations -+// -+// OperationKind: Reduction -+// -+struct ReductionConfiguration { -+ -+ /// Redcution problem size -+ MatrixCoord problem_size; -+ -+ /// Number of partitions to reduce -+ int partitions; -+ -+ /// Number of lements between each partition -+ int64_t partition_stride; -+ -+ /// leading dimension of 'w'orksace operand -+ int64_t ldw; -+ -+ /// leading dimension of 's'ource operand -+ int64_t lds; -+ -+ /// leading dimension of 'd'estination operand -+ int64_t ldd; -+}; -+ -+/// Arguments for Reduction -+struct ReductionArguments { -+ -+ /// Pointer to workspace matrix -+ void const *workspace; -+ -+ /// Pointer to source matrix -+ void const *source; -+ -+ /// Pointer to destination matrix -+ void *destination; -+ -+ /// pointer to reference matrix -+ void *reference; -+ -+ /// Host or device pointer to alpha scalar -+ void const *alpha; -+ -+ /// Host or device pointer to beta scalar -+ void const *beta; -+ -+ /// Enumerant indicating whether alpha/beta point to host or device memory -+ ScalarPointerMode pointer_mode; -+}; -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/manifest.h b/3rdparty/cutlass/tools/library/include/cutlass/library/manifest.h -new file mode 100644 -index 0000000..abce958 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/manifest.h -@@ -0,0 +1,110 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Manifest of CUTLASS Library -+ -+ This is the root of the data structure containing CUTLASS objects -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "library.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// Forward declaration -+class Manifest; -+ -+// init and insert all cutlass gemm operations in manifest object (procedurally generated using generator.py) -+void initialize_all(Manifest &manifest); -+ -+// init and insert all reduction op in manifest object (manually instantiated in library/reduction) -+void initialize_all_reduction_op(Manifest &manifest); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// List of operations -+using OperationVector = std::vector>; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Manifest of CUTLASS Library -+class Manifest { -+private: -+ -+ /// Operation provider -+ Provider provider_; -+ -+ /// Global list of operations -+ OperationVector operations_; -+ -+public: -+ Manifest (Provider provider = library::Provider::kCUTLASS) : provider_(provider) { } -+ -+ /// Top-level initialization -+ Status initialize(); -+ -+ /// Used for initialization -+ void reserve(size_t operation_count); -+ -+ /// Graceful shutdown -+ Status release(); -+ -+ /// Appends an operation and takes ownership -+ void append(Operation *operation_ptr); -+ -+ /// Returns an iterator to the first operation -+ OperationVector const &operations() const; -+ -+ /// Returns a const iterator -+ OperationVector::const_iterator begin() const; -+ -+ /// Returns a const iterator -+ OperationVector::const_iterator end() const; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/operation_table.h b/3rdparty/cutlass/tools/library/include/cutlass/library/operation_table.h -new file mode 100644 -index 0000000..037703f ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/operation_table.h -@@ -0,0 +1,508 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* -+ \file -+ \brief Defines a data structure in which a set of functionally equivalent library::Operation -+ instances may be queried. -+*/ -+ -+#pragma once -+#include -+#include -+#include -+#include -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/util.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Data Structures for Gemm Functional Maps -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tuple uniquely identifying Gemm functional behavior -+struct GemmFunctionalKey { -+ -+ Provider provider; -+ GemmKind gemm_kind; -+ NumericTypeID element_compute; -+ NumericTypeID element_scalar; -+ NumericTypeID element_A; -+ LayoutTypeID layout_A; -+ ComplexTransform transform_A; -+ NumericTypeID element_B; -+ LayoutTypeID layout_B; -+ ComplexTransform transform_B; -+ NumericTypeID element_C; -+ -+ // -+ // Methods -+ // -+ -+ inline -+ GemmFunctionalKey( -+ Provider provider, -+ GemmKind gemm_kind = GemmKind::kGemm, -+ NumericTypeID element_compute = NumericTypeID::kF32, -+ NumericTypeID element_scalar = NumericTypeID::kF32, -+ NumericTypeID element_A = NumericTypeID::kF16, -+ LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, -+ ComplexTransform transform_A = ComplexTransform::kNone, -+ NumericTypeID element_B = NumericTypeID::kF16, -+ LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, -+ ComplexTransform transform_B = ComplexTransform::kNone, -+ NumericTypeID element_C = NumericTypeID::kF16 -+ ): -+ provider(provider), -+ gemm_kind(gemm_kind), -+ element_compute(element_compute), -+ element_scalar(element_scalar), -+ element_A(element_A), -+ layout_A(layout_A), -+ transform_A(transform_A), -+ element_B(element_B), -+ layout_B(layout_B), -+ transform_B(transform_B), -+ element_C(element_C) -+ { } -+ -+ inline -+ bool operator==(GemmFunctionalKey const &rhs) const { -+ return -+ (provider == rhs.provider) && -+ (gemm_kind == rhs.gemm_kind) && -+ (element_compute == rhs.element_compute) && -+ (element_scalar == rhs.element_scalar) && -+ (element_A == rhs.element_A) && -+ (layout_A == rhs.layout_A) && -+ (transform_A == rhs.transform_A) && -+ (element_B == rhs.element_B) && -+ (layout_B == rhs.layout_B) && -+ (transform_B == rhs.transform_B) && -+ (element_C == rhs.element_C); -+ } -+ -+ inline -+ bool operator!=(GemmFunctionalKey const &rhs) const { -+ return !(*this == rhs); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+inline -+std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k) { -+ -+ out << "{\n" -+ << " provider: " << to_string(k.provider) << "\n" -+ << " gemm_kind: " << to_string(k.gemm_kind) << "\n" -+ << " element_compute: " << to_string(k.element_compute) << "\n" -+ << " element_scalar: " << to_string(k.element_scalar) << "\n" -+ << " element_A: " << to_string(k.element_A) << "\n" -+ << " layout_A: " << to_string(k.layout_A) << "\n" -+ << " transform_A: " << to_string(k.transform_A) << "\n" -+ << " element_B: " << to_string(k.element_B) << "\n" -+ << " layout_B: " << to_string(k.layout_B) << "\n" -+ << " transform_B: " << to_string(k.transform_B) << "\n" -+ << " element_C: " << to_string(k.element_C) << "\n" -+ << "}"; -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Hash function for GemmFunctionalKey -+struct GemmFunctionalKeyHasher { -+ using IntHash = std::hash; -+ -+ inline -+ static size_t rotl(size_t key, int shl) { -+ return (key << shl) | (key >> (sizeof(key)*8 - shl)); -+ } -+ -+ inline -+ size_t operator()(GemmFunctionalKey const &key) const { -+ IntHash hash; -+ -+ return -+ rotl(hash(int(key.provider)), 1) ^ -+ rotl(hash(int(key.gemm_kind)), 2) ^ -+ rotl(hash(int(key.element_compute)), 3) ^ -+ rotl(hash(int(key.element_scalar)), 4) ^ -+ rotl(hash(int(key.element_A)), 5) ^ -+ rotl(hash(int(key.layout_A)), 6) ^ -+ rotl(hash(int(key.transform_A)), 7) ^ -+ rotl(hash(int(key.element_B)), 8) ^ -+ rotl(hash(int(key.layout_B)), 9) ^ -+ rotl(hash(int(key.transform_B)), 10) ^ -+ rotl(hash(int(key.element_C)), 11); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Establishes a partial ordering to search for GEMM operators -+struct GemmPreferenceKey { -+ -+ int compute_capability; -+ int alignment; -+ -+ // -+ // Methods -+ // -+ -+ GemmPreferenceKey(): compute_capability(), alignment() { } -+ -+ GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { } -+ -+ bool operator<(GemmPreferenceKey const &rhs) const { -+ return (compute_capability < rhs.compute_capability) || -+ ((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment)); -+ } -+ -+ bool operator==(GemmPreferenceKey const &rhs) const { -+ return compute_capability == rhs.compute_capability; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Maps minimum compute capability onto a vector of possible operations -+using GemmOperationVectorMap = std::map< -+ GemmPreferenceKey, -+ std::vector -+>; -+ -+/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm -+using GemmOperationFunctionalMap = std::unordered_map< -+ GemmFunctionalKey, -+ GemmOperationVectorMap, -+ GemmFunctionalKeyHasher -+>; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Data Structures for Conv Functional Maps -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tuple uniquely identifying conv2d functional behavior -+struct ConvFunctionalKey { -+ library::Provider provider; -+ library::ConvKind conv_kind; -+ library::NumericTypeID element_A; -+ library::LayoutTypeID layout_A; -+ library::NumericTypeID element_B; -+ library::LayoutTypeID layout_B; -+ library::NumericTypeID element_C; -+ library::LayoutTypeID layout_C; -+ library::NumericTypeID element_accumulator; -+ library::NumericTypeID element_compute; -+ -+ -+ // -+ // Methods -+ // -+ -+ inline -+ ConvFunctionalKey( -+ library::Provider provider = library::Provider::kInvalid, -+ library::ConvKind conv_kind = library::ConvKind::kFprop, -+ library::NumericTypeID element_A = library::NumericTypeID::kF16, -+ library::LayoutTypeID layout_A = library::LayoutTypeID::kTensorNHWC, -+ library::NumericTypeID element_B = library::NumericTypeID::kF16, -+ library::LayoutTypeID layout_B = library::LayoutTypeID::kTensorNHWC, -+ library::NumericTypeID element_C = library::NumericTypeID::kF16, -+ library::LayoutTypeID layout_C = library::LayoutTypeID::kTensorNHWC, -+ library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, -+ library::NumericTypeID element_compute = library::NumericTypeID::kF32 -+ ): -+ provider(provider), -+ conv_kind(conv_kind), -+ element_A(element_A), -+ layout_A(layout_A), -+ element_B(element_B), -+ layout_B(layout_B), -+ element_C(element_C), -+ layout_C(layout_C), -+ element_accumulator(element_accumulator), -+ element_compute(element_compute) -+ { } -+ -+ inline -+ bool operator==(ConvFunctionalKey const &rhs) const { -+ return -+ (provider == rhs.provider) && -+ (conv_kind == rhs.conv_kind) && -+ (element_A == rhs.element_A) && -+ (layout_A == rhs.layout_A) && -+ (element_B == rhs.element_B) && -+ (layout_B == rhs.layout_B) && -+ (element_C == rhs.element_C) && -+ (layout_C == rhs.layout_C) && -+ (element_accumulator == rhs.element_accumulator) && -+ (element_compute == rhs.element_compute); -+ } -+ -+ inline -+ bool operator!=(ConvFunctionalKey const &rhs) const { -+ return !(*this == rhs); -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+inline -+std::ostream& operator<< (std::ostream& out, const cutlass::library::ConvFunctionalKey& key) { -+ out << "{\n" -+ << "provider: " << to_string(key.provider) << std::endl -+ << "conv_kind: " << to_string(key.conv_kind) << std::endl -+ << "element_A: " << to_string(key.element_A) << std::endl -+ << "layout_A: " << to_string(key.layout_A) << std::endl -+ << "element_B: " << to_string(key.element_B) << std::endl -+ << "layout_B: " << to_string(key.layout_B) << std::endl -+ << "element_C: " << to_string(key.element_C) << std::endl -+ << "layout_C: " << to_string(key.layout_C) << std::endl -+ << "element_accumulator: " << to_string(key.element_accumulator) << std::endl -+ << "element_compute: " << to_string(key.element_compute) << std::endl -+ << "}"; -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+struct ConvFunctionalKeyHasher { -+ using IntHash = std::hash; -+ -+ inline -+ static size_t rotl(size_t key, int shl) { -+ return (key << shl) | (key >> (sizeof(key)*8 - shl)); -+ } -+ -+ inline -+ size_t operator()(ConvFunctionalKey const &key) const { -+ IntHash hash; -+ -+ return -+ rotl(hash(int(key.provider)), 1) ^ -+ rotl(hash(int(key.conv_kind)), 2) ^ -+ rotl(hash(int(key.element_A)), 3) ^ -+ rotl(hash(int(key.layout_A)), 4) ^ -+ rotl(hash(int(key.element_B)), 5) ^ -+ rotl(hash(int(key.layout_B)), 6) ^ -+ rotl(hash(int(key.element_C)), 7) ^ -+ rotl(hash(int(key.layout_C)), 8) ^ -+ rotl(hash(int(key.element_accumulator)), 9) ^ -+ rotl(hash(int(key.element_compute)), 10); -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Establishes a partial ordering to search for Conv2d operators -+struct ConvPreferenceKey { -+ -+ int compute_capability; -+ IteratorAlgorithmID iterator_algorithm; -+ -+ -+ // -+ // Methods -+ // -+ -+ ConvPreferenceKey(): compute_capability(), iterator_algorithm() { } -+ -+ ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm): -+ compute_capability(cc), iterator_algorithm(iterator_algorithm) { } -+ -+ bool operator<(ConvPreferenceKey const &rhs) const { -+ return (compute_capability < rhs.compute_capability) || -+ ((compute_capability == rhs.compute_capability) && (iterator_algorithm < rhs.iterator_algorithm)); -+ } -+ -+ bool operator==(ConvPreferenceKey const &rhs) const { -+ return (compute_capability == rhs.compute_capability) && -+ (iterator_algorithm == rhs.iterator_algorithm); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Maps minimum compute capability onto a vector of possible operations -+using ConvOperationVectorMap = std::map< -+ ConvPreferenceKey, -+ std::vector -+>; -+ -+/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm -+using ConvOperationFunctionalMap = std::unordered_map< -+ ConvFunctionalKey, -+ ConvOperationVectorMap, -+ ConvFunctionalKeyHasher -+>; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Tuple uniquely identifying conv2d functional behavior -+struct ReductionFunctionalKey { -+ library::Provider provider; -+ library::NumericTypeID element_workspace; -+ library::NumericTypeID element_accumulator; -+ library::NumericTypeID element_output; -+ library::NumericTypeID element_compute; -+ library::MathOperationID reduce_math_op; -+ library::EpilogueKind epilogue_math_op; -+ -+ -+ // -+ // Methods -+ // -+ -+ inline -+ ReductionFunctionalKey( -+ library::Provider provider = library::Provider::kInvalid, -+ library::NumericTypeID element_workspace = library::NumericTypeID::kF16, -+ library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, -+ library::NumericTypeID element_output = library::NumericTypeID::kF16, -+ library::NumericTypeID element_compute = library::NumericTypeID::kF32, -+ library::MathOperationID reduce_math_op = library::MathOperationID::kAdd, -+ library::EpilogueKind epilogue_math_op = library::EpilogueKind::kLinearCombination -+ ): -+ provider(provider), -+ element_workspace(element_workspace), -+ element_accumulator(element_accumulator), -+ element_output(element_output), -+ element_compute(element_compute), -+ reduce_math_op(reduce_math_op), -+ epilogue_math_op(epilogue_math_op) -+ { } -+ -+ inline -+ bool operator==(ReductionFunctionalKey const &rhs) const { -+ return -+ (provider == rhs.provider) && -+ (element_workspace == rhs.element_workspace) && -+ (element_accumulator == rhs.element_accumulator) && -+ (element_output == rhs.element_output) && -+ (element_compute == rhs.element_compute) && -+ (reduce_math_op == rhs.reduce_math_op) && -+ (epilogue_math_op == rhs.epilogue_math_op); -+ } -+ -+ inline -+ bool operator!=(ReductionFunctionalKey const &rhs) const { -+ return !(*this == rhs); -+ } -+}; -+ -+ -+struct ReductionFunctionalKeyHasher { -+ using IntHash = std::hash; -+ -+ inline -+ static size_t rotl(size_t key, int shl) { -+ return (key << shl) | (key >> (sizeof(key)*8 - shl)); -+ } -+ -+ inline -+ size_t operator()(ReductionFunctionalKey const &key) const { -+ IntHash hash; -+ -+ return -+ rotl(hash(int(key.provider)), 1) ^ -+ rotl(hash(int(key.element_workspace)), 2) ^ -+ rotl(hash(int(key.element_accumulator)), 3) ^ -+ rotl(hash(int(key.element_output)), 4) ^ -+ rotl(hash(int(key.element_compute)), 5) ^ -+ rotl(hash(int(key.reduce_math_op)), 6) ^ -+ rotl(hash(int(key.epilogue_math_op)), 7); -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+inline -+std::ostream& operator<< (std::ostream& out, const ReductionFunctionalKey& key) { -+ out << "{\n" -+ << "provider: " << library::to_string(key.provider) << std::endl -+ << "element_workspace : " << library::to_string(key.element_workspace) << std::endl -+ << "element_accumulator : " << library::to_string(key.element_accumulator) << std::endl -+ << "element_output : " << library::to_string(key.element_output) << std::endl -+ << "element_compute : " << library::to_string(key.element_compute) << std::endl -+ << "}"; -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// ReductionOperationFunctionalMap has NO preference key and a single instance per functional key -+// i.e. only one tile size configuration per functional key -+using ReductionOperationFunctionalMap = std::unordered_map< -+ ReductionFunctionalKey, -+ library::Operation const *, -+ ReductionFunctionalKeyHasher -+>; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Table of cutlass::library::Operation instances -+class OperationTable { -+public: -+ -+ /// Map of all operations of type kGemm -+ // provider (kCUTLASS) -+ GemmOperationFunctionalMap gemm_operations; -+ -+ /// Map of all operations of type kConv2d -+ // provider (kCUTLASS, kReferenceHost, kReferenceDevice) -+ ConvOperationFunctionalMap conv2d_operations; -+ -+ /// Map of all operations of type kConv3d -+ // provider (kCUTLASS, kReferenceHost, kReferenceDevice) -+ ConvOperationFunctionalMap conv3d_operations; -+ -+ /// Map of all operations of type kConv2d -+ // provider (kCUTLASS) -+ ReductionOperationFunctionalMap reduction_operations; -+ -+public: -+ -+ void append(Manifest const &manifest); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k); -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/singleton.h b/3rdparty/cutlass/tools/library/include/cutlass/library/singleton.h -new file mode 100644 -index 0000000..e0bd959 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/singleton.h -@@ -0,0 +1,68 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/operation_table.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Singleton instance stores a Manifest and Operation table -+class Singleton { -+public: -+ -+ /// Manifest object -+ Manifest manifest; -+ -+ /// Operation table referencing the Manifest -+ OperationTable operation_table; -+ -+public: -+ -+ Singleton(); -+ -+ static Singleton const &get(); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/util.h b/3rdparty/cutlass/tools/library/include/cutlass/library/util.h -new file mode 100644 -index 0000000..517c6e9 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/util.h -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ -+ \brief Utilities accompanying the CUTLASS library for interacting with Library types. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Lexical cast from string -+template T from_string(std::string const &); -+ -+/// Converts a Provider enumerant to a string -+char const *to_string(Provider provider, bool pretty = false); -+ -+/// Parses a Provider enumerant from a string -+template <> Provider from_string(std::string const &str); -+ -+/// Converts a GemmKind enumerant to a string -+char const *to_string(GemmKind type, bool pretty = false); -+ -+/// Converts a RankKKind enumerant to a string -+char const *to_string(RankKKind type, bool pretty = false); -+ -+/// Converts a TrmmKind enumerant to a string -+char const *to_string(TrmmKind type, bool pretty = false); -+ -+/// Converts a SymmKind enumerant to a string -+char const *to_string(SymmKind type, bool pretty = false); -+ -+/// Converts a SideMode enumerant to a string -+char const *to_string(SideMode type, bool pretty = false); -+ -+/// Converts a FillMode enumerant to a string -+char const *to_string(FillMode type, bool pretty = false); -+ -+/// Converts a BlasMode enumerant to a string -+char const *to_string(BlasMode type, bool pretty = false); -+ -+/// Converts a DiagType enumerant to a string -+char const *to_string(DiagType type, bool pretty = false); -+ -+/// Converts a NumericType enumerant to a string -+char const *to_string(OperationKind type, bool pretty = false); -+ -+/// Parses a NumericType enumerant from a string -+template <> OperationKind from_string(std::string const &str); -+ -+/// Converts a NumericType enumerant to a string -+char const *to_string(NumericTypeID type, bool pretty = false); -+ -+/// Parses a NumericType enumerant from a string -+template <> NumericTypeID from_string(std::string const &str); -+ -+/// Returns the size of a data type in bits -+int sizeof_bits(NumericTypeID type); -+ -+/// Returns true if the numeric type is a complex data type or false if real-valued. -+bool is_complex_type(NumericTypeID type); -+ -+/// Returns the real-valued type underlying a type (only different from 'type' if complex) -+NumericTypeID get_real_type(NumericTypeID type); -+ -+/// Returns true if numeric type is integer -+bool is_integer_type(NumericTypeID type); -+ -+/// Returns true if numeric type is signed -+bool is_signed_type(NumericTypeID type); -+ -+/// Returns true if numeric type is a signed integer -+bool is_signed_integer(NumericTypeID type); -+ -+/// returns true if numeric type is an unsigned integer -+bool is_unsigned_integer(NumericTypeID type); -+ -+/// Returns true if numeric type is floating-point type -+bool is_float_type(NumericTypeID type); -+ -+/// To string method for cutlass::Status -+char const *to_string(Status status, bool pretty = false); -+ -+/// Converts a LayoutTypeID enumerant to a string -+char const *to_string(LayoutTypeID layout, bool pretty = false); -+ -+/// Parses a LayoutType enumerant from a string -+template <> LayoutTypeID from_string(std::string const &str); -+ -+/// Returns the rank of a layout's stride base on the LayoutTypeID -+int get_layout_stride_rank(LayoutTypeID layout_id); -+ -+/// Converts a OpcodeClassID enumerant to a string -+char const *to_string(OpcodeClassID type, bool pretty = false); -+ -+/// Converts a OpcodeClassID enumerant from a string -+template <> -+OpcodeClassID from_string(std::string const &str); -+ -+/// Converts a ComplexTransform enumerant to a string -+char const *to_string(ComplexTransform type, bool pretty = false); -+ -+/// Converts a ComplexTransform enumerant from a string -+template <> -+ComplexTransform from_string(std::string const &str); -+ -+ -+/// Converts a SplitKMode enumerant to a string -+char const *to_string(SplitKMode split_k_mode, bool pretty = false); -+ -+/// Converts a SplitKMode enumerant from a string -+template <> -+SplitKMode from_string(std::string const &str); -+ -+/// Converts a ConvModeID enumerant to a string -+char const *to_string(ConvModeID type, bool pretty = false); -+ -+/// Converts a ConvModeID enumerant from a string -+template <> -+ConvModeID from_string(std::string const &str); -+ -+/// Converts a IteratorAlgorithmID enumerant to a string -+char const *to_string(IteratorAlgorithmID type, bool pretty = false); -+ -+/// Converts a IteratorAlgorithmID enumerant from a string -+template <> -+IteratorAlgorithmID from_string(std::string const &str); -+ -+/// Converts a ConvKind enumerant to a string -+char const *to_string(ConvKind type, bool pretty = false); -+ -+/// Converts a ConvKind enumerant from a string -+template <> -+ConvKind from_string(std::string const &str); -+ -+/// Lexical cast from int64_t to string -+std::string lexical_cast(int64_t int_value); -+ -+/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. -+bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str); -+ -+/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. -+std::string lexical_cast(std::vector &bytes, NumericTypeID type); -+ -+/// Casts from a signed int64 to the destination type. Returns true if successful. -+bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src); -+ -+/// Casts from an unsigned int64 to the destination type. Returns true if successful. -+bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src); -+ -+/// Casts from a real value represented as a double to the destination type. Returns true if successful. -+bool cast_from_double(std::vector &bytes, NumericTypeID type, double src); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/compiler.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/compiler.h -new file mode 100644 -index 0000000..b8e60bc ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/compiler.h -@@ -0,0 +1,75 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief In-memory compiled artifact cache -+*/ -+ -+#include -+#include -+#include -+ -+ -+namespace py = pybind11; -+ -+namespace cutlass { -+ -+struct CompileCache { -+public: -+ CompileCache() = default; -+ ~CompileCache() = default; -+ -+ using Cache = std::unordered_map; -+ -+ /// Check if the kernel has already been compiled -+ py::object at(const std::string &kernel) { -+ auto item = cache_.find(kernel); -+ -+ if (item != cache_.end()) { -+ return item->second; -+ } -+ return py::none(); -+ } -+ -+ /// Insert a new compiled kernel for new configuration -+ void insert(const std::string &kernel, const py::object &compiled_kernel){ -+ cache_.emplace(kernel, compiled_kernel); -+ } -+ -+ const int64_t size() const { return cache_.size(); } -+ -+ /// Clear the cache -+ void clear() { cache_.clear(); } -+ -+private: -+ Cache cache_; -+}; -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/arch.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/arch.h -new file mode 100644 -index 0000000..21f9771 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/arch.h -@@ -0,0 +1,59 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind opcode classes to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/arch/mma.h" -+ -+namespace py = pybind11; -+ -+namespace cutlass { -+enum class OpcodeClass { -+ kSimt, kTensorOp, kWmmaTensorOp, kSparseTensorOp -+}; -+} -+ -+void bind_opcode(py::module &m) { -+ py::enum_(m, "OpClass", -+ R"pbdoc(classification of math operators)pbdoc") -+ .value("Simt", cutlass::OpcodeClass::kSimt, -+ R"pbdoc(Tag classifying math operators as thread-level operations)pbdoc") -+ .value("TensorOp", cutlass::OpcodeClass::kTensorOp, -+ R"pbdoc(Tag classifing operators as Tensor Core operations)pbdoc") -+ .value("WmmaTensorOp", cutlass::OpcodeClass::kWmmaTensorOp, -+ R"pbdoc(Tag classifing operators as WMMA Tensor Core operations)pbdoc") -+ .value("SparseTensorOp", cutlass::OpcodeClass::kSparseTensorOp, -+ R"pbdoc(Tag classifing operators as sparseTensor Core operations)pbdoc"); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/conv_problem_size.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/conv_problem_size.h -new file mode 100644 -index 0000000..ab4a067 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/conv_problem_size.h -@@ -0,0 +1,102 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind Convolution problem sizes to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+namespace py = pybind11; -+ -+void bind_conv_problem_size(py::module &m) { -+ // -+ // Conv2d Problem Size: -+ // include/cutlass/conv/conv2d_problem_sizd.h -+ // -+ py::class_(m, "Conv2dProblemSize") -+ // constructors -+ .def(py::init()) -+ .def(py::init()) -+ // attribute accessors -+ .def_readwrite("N", &cutlass::conv::Conv2dProblemSize::N) -+ .def_readwrite("H", &cutlass::conv::Conv2dProblemSize::H) -+ .def_readwrite("W", &cutlass::conv::Conv2dProblemSize::W) -+ .def_readwrite("C", &cutlass::conv::Conv2dProblemSize::C) -+ .def_readwrite("P", &cutlass::conv::Conv2dProblemSize::P) -+ .def_readwrite("Q", &cutlass::conv::Conv2dProblemSize::Q) -+ .def_readwrite("K", &cutlass::conv::Conv2dProblemSize::K) -+ .def_readwrite("R", &cutlass::conv::Conv2dProblemSize::R) -+ .def_readwrite("S", &cutlass::conv::Conv2dProblemSize::S) -+ .def_readwrite("pad_h", &cutlass::conv::Conv2dProblemSize::pad_h) -+ .def_readwrite("pad_w", &cutlass::conv::Conv2dProblemSize::pad_w) -+ .def_readwrite("stride_h", &cutlass::conv::Conv2dProblemSize::stride_h) -+ .def_readwrite("stride_w", &cutlass::conv::Conv2dProblemSize::stride_w) -+ .def_readwrite("dilation_h", &cutlass::conv::Conv2dProblemSize::dilation_h) -+ .def_readwrite("dilation_w", &cutlass::conv::Conv2dProblemSize::dilation_w) -+ .def_readwrite("mode", &cutlass::conv::Conv2dProblemSize::mode) -+ .def_readwrite("split_k_slices", &cutlass::conv::Conv2dProblemSize::split_k_slices) -+ .def_readwrite("groups", &cutlass::conv::Conv2dProblemSize::groups) -+ // functions -+ .def("reset_split_k_slices", &cutlass::conv::Conv2dProblemSize::reset_split_k_slices) -+ .def("activation_extent", &cutlass::conv::Conv2dProblemSize::activation_extent) -+ .def("filter_extent", &cutlass::conv::Conv2dProblemSize::filter_extent) -+ .def("output_extent", &cutlass::conv::Conv2dProblemSize::output_extent) -+ .def("activation_size", &cutlass::conv::Conv2dProblemSize::activation_size) -+ .def("filter_size", &cutlass::conv::Conv2dProblemSize::filter_size) -+ .def("output_size", &cutlass::conv::Conv2dProblemSize::output_size); -+ -+ // Get tensor size -+ m.def("implicit_gemm_tensor_a_size", py::overload_cast(&cutlass::conv::implicit_gemm_tensor_a_size)); -+ m.def("implicit_gemm_tensor_b_size", py::overload_cast(&cutlass::conv::implicit_gemm_tensor_b_size)); -+ m.def("implicit_gemm_tensor_c_size", py::overload_cast(&cutlass::conv::implicit_gemm_tensor_c_size)); -+ -+ // Get tensor extent -+ m.def("implicit_gemm_tensor_a_extent", -+ py::overload_cast< -+ cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize& -+ >(&cutlass::conv::implicit_gemm_tensor_a_extent)); -+ -+ m.def("implicit_gemm_tensor_b_extent", -+ py::overload_cast< -+ cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize& -+ >(&cutlass::conv::implicit_gemm_tensor_b_extent)); -+ -+ m.def("implicit_gemm_tensor_c_extent", -+ py::overload_cast< -+ cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize& -+ >(&cutlass::conv::implicit_gemm_tensor_c_extent)); -+ -+ m.def("implicit_gemm_problem_size", py::overload_cast(&cutlass::conv::implicit_gemm_problem_size)); -+ -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/convolution.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/convolution.h -new file mode 100644 -index 0000000..36126ec ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/convolution.h -@@ -0,0 +1,91 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind convolution related enum types to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "conv_problem_size.h" -+#include "host.h" -+#include "cutlass/conv/convolution.h" -+ -+namespace py = pybind11; -+ -+void bind_convolution(py::module &m) { -+ // -+ // Enumerate types -+ // cutlass/include/cutlass/conv/convolution.h -+ // -+ -+ /// Convolutional operator -+ py::enum_(m, "Operator", R"pbdoc(Convolutional operator)pbdoc") -+ .value("fprop", cutlass::conv::Operator::kFprop, "Forward propagation") -+ .value("dgrad", cutlass::conv::Operator::kDgrad, "Activation grad") -+ .value("wgrad", cutlass::conv::Operator::kWgrad, "Weight grad"); -+ -+ /// Distinguishes convolution from cross correlation -+ py::enum_(m, "Mode") -+ .value("cross_correlation", cutlass::conv::Mode::kCrossCorrelation) -+ .value("convolution", cutlass::conv::Mode::kConvolution); -+ -+ /// Selects among several implementation variants trading off performance with simplicity -+ py::enum_(m, "IteratorAlgorithm", -+ R"pbdoc(Selects among several implementation variants trading off performance with simplicity)pbdoc") -+ .value("analytic", cutlass::conv::IteratorAlgorithm::kAnalytic, R"pbdoc(functionally correct in all cases but lower performance)pbdoc") -+ .value("optimized", cutlass::conv::IteratorAlgorithm::kOptimized, R"pbdoc(optimized for R <= 32, S <= 32 and unity-stride dgrad)pbdoc") -+ .value("fixed_channels", cutlass::conv::IteratorAlgorithm::kFixedChannels, R"pbdoc(Analytic algorithm optimized for fixed channel count (C == AccessSize))pbdoc") -+ .value("few_channels", cutlass::conv::IteratorAlgorithm::kFewChannels, R"pbdoc(Analytic algorithm optimized for few channels (C divisible by AccessSize))pbdoc"); -+ -+ /// Distinguishes among partial specializations that accelerate certain problems where convolution -+ /// stride is unit. -+ py::enum_(m, "StrideSupport", -+ R"pbdoc(Distinguishes among partial specializations that accelerate certain problems where convolution -+ stride is unit.)pbdoc") -+ .value("strided", cutlass::conv::StrideSupport::kStrided, R"pbdoc(arbitrary convolution stride)pbdoc") -+ .value("unity", cutlass::conv::StrideSupport::kUnity, R"pbdoc(unit convolution stride)pbdoc"); -+ -+ /// Identifies split-K mode -+ py::enum_(m, "SplitKMode") -+ .value("None", cutlass::conv::SplitKMode::kNone) -+ .value("Serial", cutlass::conv::SplitKMode::kSerial) -+ .value("Parallel", cutlass::conv::SplitKMode::kParallel); -+ -+ // Conv problem sizes -+ bind_conv_problem_size(m); -+ -+ // -+ // host helper functions -+ // -+ py::module_ host_submodule = m.def_submodule("host"); -+ bind_conv_host_helper(host_submodule); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/host.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/host.h -new file mode 100644 -index 0000000..7a33251 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/host.h -@@ -0,0 +1,54 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind conv host helpers to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/layout/tensor.h" -+ -+namespace py = pybind11; -+ -+ -+void bind_conv_host_helper(py::module &m) { -+ -+ /// reorder operand B for interleaved layout -+ m.def("reorder_convK", []( -+ cutlass::TensorRef> dest, -+ cutlass::TensorRef> src, -+ cutlass::conv::Operator conv_op, const cutlass::conv::Conv2dProblemSize & problem_size) { -+ cutlass::gemm::GemmCoord implicit_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_op, problem_size); -+ cutlass::reorder_convK<32>(dest, src, implicit_problem_size); -+ }); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h -new file mode 100644 -index 0000000..6b33f9a ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h -@@ -0,0 +1,222 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A generic wrapper around an epilogue visitor operation -+*/ -+ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" -+ -+#include "epilogue_visitor_op/visitor_op_linear_combination.h" -+#include "epilogue_visitor_op/visitor_op_tensor_input.h" -+#include "epilogue_visitor_op/visitor_op_accumulator.h" -+#include "epilogue_visitor_op/visitor_op_row_broadcast.h" -+#include "epilogue_visitor_op/visitor_op_tensor_output.h" -+#include "epilogue_visitor_op/visitor_op_column_reduction.h" -+#include "epilogue_visitor_op/visitor_op_row_reduction.h" -+#include "epilogue_visitor_op/visitor_op_column_broadcast.h" -+#include "epilogue_visitor_op/visitor_op_unary.h" -+#include "epilogue_visitor_op/visitor_op_binary.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Generic Epilogue Visitor. -+template < -+ typename OutputOp_ -+> -+class EpilogueVisitorGeneric { -+public: -+ -+ using OutputOp = OutputOp_; -+ using AccumulatorAccessType = typename OutputOp::AccumulatorAccessType; -+ static int const kElementsPerAccess = OutputOp::kElementsPerAccess; -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using OutputTileIterator = typename OutputOp::OutputTileIterator; -+ -+ static int const kIterations = OutputTileIterator::kIterations; -+ -+ /// -+ /// End Epilogue Tree -+ /// -+ -+ /// Additional SMEM bufer is not required in the broadcast epilogue visitor -+ struct SharedStorage { -+ -+ typename OutputOp::SharedStorage output_smem; -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+public: -+ -+ /// Argument structure -+ struct Arguments { -+ typename OutputOp::Arguments output_op_args; -+ // -+ // Methods -+ // -+ Arguments() { } -+ -+ Arguments( -+ typename OutputOp::Arguments output_op_args -+ ): -+ output_op_args(output_op_args) -+ { -+ -+ } -+ }; -+ -+ struct Params { -+ typename OutputOp::Params output_op_params; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ output_op_params(args.output_op_args) -+ { -+ -+ } -+ }; -+ -+ -+ -+private: -+ -+ OutputOp output_op; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueVisitorGeneric( -+ Params const ¶ms, ///< Parameters routed to the epilogue -+ SharedStorage &shared_storage, ///< Shared storage needed by the functors here -+ MatrixCoord threadblock_offset, -+ gemm::GemmCoord threadblock_tile_offset, -+ int thread_idx, -+ MatrixCoord problem_size -+ ): -+ output_op(params.output_op_params, shared_storage.output_smem, thread_idx, threadblock_offset, problem_size) -+ { } -+ -+ /// Helper to indicate split-K behavior -+ CUTLASS_DEVICE -+ void set_k_partition( -+ int split_k_index, ///< Index of this threadblock within split-K partitioned scheme -+ int split_k_slices) { ///< Total number of split-K slices -+ -+ } -+ -+ /// Called to set the batch index -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ output_op.set_batch_index(batch_idx); -+ } -+ -+ /// Called at the start of the epilogue just before iterating over accumulator slices -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ output_op.begin_epilogue(); -+ } -+ -+ /// Called at the start of one step before starting accumulator exchange -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ output_op.begin_step(step_idx); -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ output_op.begin_row(row_idx); -+ } -+ -+ /// Called after accumulators have been exchanged for each accumulator vector -+ CUTLASS_DEVICE -+ void visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum) { -+ output_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ output_op.end_row(row_idx); -+ -+ } -+ -+ /// Called after all accumulator elements have been visited -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ output_op.end_step(step_idx); -+ } -+ -+ /// Called after all steps have been completed -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ output_op.end_epilogue(); -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h -new file mode 100644 -index 0000000..f64066a ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h -@@ -0,0 +1,84 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the binary ops -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Scalar multiplication -+template -+struct VectorAdd { -+ -+ struct Arguments { -+ int tmp; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments():tmp(0){ } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(int tmp): tmp(tmp) { } -+ }; -+ -+ struct Params { -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args) { } -+ }; -+ -+ CUTLASS_HOST_DEVICE -+ VectorAdd( -+ Params const ¶ms -+ ) { } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ cutlass::plus> add_op; -+ return add_op(lhs, rhs); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h -new file mode 100644 -index 0000000..9952a52 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h -@@ -0,0 +1,233 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the unary ops -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Scalar multiplication -+template -+struct Mult { -+ -+ struct Arguments { -+ T alpha; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments():alpha(T(1.0)){ } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(T alpha): alpha(alpha) { } -+ }; -+ -+ struct Params { -+ T alpha; ///< scales accumulators -+ -+ CUTLASS_HOST_DEVICE -+ Params():alpha(T(1.0)){ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): alpha(args.alpha) { } -+ }; -+ -+ T alpha_; -+ -+ CUTLASS_HOST_DEVICE -+ Mult( -+ Params const ¶ms -+ ): -+ alpha_(params.alpha) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &source) const { -+ cutlass::multiplies> multiply_op; -+ return multiply_op(source, alpha_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool guard() { -+ return alpha_ != T(0); -+ } -+ -+}; -+ -+ -+/// ReLU -+template -+struct ReLUVisitor { -+ struct Arguments { -+ T threshold; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments():threshold(T(0.0)) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(T threshold): threshold(threshold) { } -+ }; -+ -+ struct Params { -+ T threshold; -+ -+ CUTLASS_HOST_DEVICE -+ Params():threshold(T(0.0)) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): threshold(args.threshold) { } -+ }; -+ -+ T threshold_; -+ -+ CUTLASS_HOST_DEVICE -+ ReLUVisitor(Params const ¶ms): -+ threshold_(params.threshold) { } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &frag) const { -+ maximum> mx; -+ return mx(frag, threshold_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool guard() { -+ return true; -+ } -+}; -+ -+/// leakyReLU -+template -+struct LeakyReLUVisitor { -+ struct Arguments { -+ T leaky_alpha; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments():leaky_alpha(T(0.0)) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(T leaky_alpha): leaky_alpha(leaky_alpha) { } -+ }; -+ -+ struct Params { -+ T leaky_alpha; -+ -+ CUTLASS_HOST_DEVICE -+ Params():leaky_alpha(T(0.0)) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): leaky_alpha(args.leaky_alpha) { } -+ }; -+ -+ T leaky_alpha_; -+ -+ CUTLASS_HOST_DEVICE -+ LeakyReLUVisitor(Params const ¶ms): -+ leaky_alpha_(params.leaky_alpha) { } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &frag) const { -+ cutlass::epilogue::thread::LeakyReLU> leaky_op; -+ return leaky_op(frag, leaky_alpha_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool guard() { -+ return true; -+ } -+ -+}; -+ -+/// Tanh -+template -+struct TanhVisitor { -+ /// Argument -+ struct Arguments { -+ // a placeholder argument to ensure correctness of ctypes -+ int tmp; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(): tmp(0) { }; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(int tmp): tmp(tmp) { }; -+ }; -+ -+ /// Param -+ struct Params { -+ CUTLASS_HOST_DEVICE -+ Params(){ }; -+ Params(Arguments const &args) { } -+ }; -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TanhVisitor(Params const ¶ms) { } -+ -+ // scalar operator -+ CUTLASS_HOST_DEVICE -+ T tanh_op(T const &scalar) const { -+ return fast_tanh(scalar); -+ } -+ -+ /// vector operator -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &frag) const { -+ Array y; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i=0; i < N; ++i) { -+ y[i] = tanh_op(frag[i]); -+ } -+ -+ return y; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool guard() { -+ return true; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h -new file mode 100644 -index 0000000..2072cfa ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h -@@ -0,0 +1,148 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with accumulator -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue Visitor operator for the following Computation -+/// -+/// ElementAccumulator accum; -+/// return accum; -+/// -+/// It can only be the leaf node of the epilogue tree -+ -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ int kElementsPerAccess_ ///< Number of elements computed per operation -+> -+class VisitorOpAccumulator{ -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ static int const kElementsPerAccess = kElementsPerAccess_; -+ -+ /// Fragment type for Accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Fragment type returned by this visitor -+ using VisitAccessType = AccumulatorAccessType; -+ -+ /// SMEM buffer class required in the epilogue visitor -+ struct SharedStorage { -+ CUTLASS_HOST_DEVICE -+ SharedStorage() {} -+ }; -+ -+ /// Host-constructable Arguments structure -+ struct Arguments { -+ // Note: it is strange that ctypes will return issue with empty arguments -+ int tmp; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(int tmp): tmp(tmp) { } -+ }; -+ -+ /// Parameter structure -+ struct Params { -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args) { } -+ }; -+ -+public: -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpAccumulator( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ) { } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ return accum; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h -new file mode 100644 -index 0000000..d9fa445 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h -@@ -0,0 +1,245 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with Binary op -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "binary_ops.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementCompute alpha; -+/// ElementCompute beta; -+/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B) -+/// Return C; -+/// -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementCompute_, ///< Data type used to compute linear combination -+ int kElementsPerAccess_, ///< Number of elements computed per operation -+ typename VisitorA_, ///< Child node A -+ typename VisitorB_, ///< Child node B -+ template typename BinaryOp_ -+> -+class VisitorOpBinary{ -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ static int const kElementsPerAccess = kElementsPerAccess_; -+ -+ using VisitorA = VisitorA_; -+ using VisitorB = VisitorB_; -+ -+ /// Fragment type returned from VisitorA.visit -+ using VisitAccessTypeA = typename VisitorA::VisitAccessType; -+ using ElementA = typename VisitAccessTypeA::Element; -+ -+ /// Fragment type returned from VisitorB.visit -+ using VisitAccessTypeB = typename VisitorB::VisitAccessType; -+ using ElementB = typename VisitAccessTypeB::Element; -+ -+ /// Fragment type returned by this visitor -+ using VisitAccessType = Array; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ using BinaryOp = BinaryOp_; -+ -+ static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A"); -+ static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B"); -+ -+ /// SMEM buffer class required in the epilogue visitor -+ struct SharedStorage { -+ typename VisitorA::SharedStorage storage_a; -+ typename VisitorB::SharedStorage storage_b; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() {} -+ }; -+ -+ -+ /// Host-constructable Arguments structure -+ struct Arguments { -+ typename BinaryOp::Arguments binary_arg; -+ typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a -+ typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Arguments():binary_arg() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ typename BinaryOp::Arguments binary_arg, -+ typename VisitorA::Arguments visitor_a_arg, -+ typename VisitorB::Arguments visitor_b_arg -+ ): -+ binary_arg(binary_arg), -+ visitor_a_arg(visitor_a_arg), -+ visitor_b_arg(visitor_b_arg) -+ { } -+ }; -+ -+ /// Parameter structure -+ struct Params { -+ typename BinaryOp::Params binary_param; -+ typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a -+ typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ binary_param(args.binary_arg), -+ visitor_a_param(args.visitor_a_arg), -+ visitor_b_param(args.visitor_b_arg) -+ { } -+ }; -+ -+private: -+ // -+ // Data members -+ // -+ -+ BinaryOp binary_op; -+ -+ VisitorA visitor_a_op; -+ VisitorB visitor_b_op; -+ -+public: -+ -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpBinary( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ binary_op(params.binary_param), -+ visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size), -+ visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size) -+ { } -+ -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ visitor_a_op.begin_epilogue(); -+ visitor_b_op.begin_epilogue(); -+ } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ visitor_a_op.set_batch_index(batch_idx); -+ visitor_b_op.set_batch_index(batch_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ visitor_a_op.begin_step(step_idx); -+ visitor_b_op.begin_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ visitor_a_op.begin_row(row_idx); -+ visitor_b_op.begin_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ /// Get result from visitor A and visitor B -+ VisitAccessTypeA result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ VisitAccessTypeB result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ -+ /// Type conversion -+ NumericArrayConverter source_converter_A; -+ NumericArrayConverter source_converter_B; -+ -+ return binary_op( -+ source_converter_A(result_A), -+ source_converter_B(result_B) -+ ); -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ visitor_a_op.end_row(row_idx); -+ visitor_b_op.end_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ visitor_a_op.end_step(step_idx); -+ visitor_b_op.end_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ visitor_a_op.end_epilogue(); -+ visitor_b_op.end_epilogue(); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h -new file mode 100644 -index 0000000..6dcb32b ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h -@@ -0,0 +1,250 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with broadcasting vector to all columns -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementVector T[i][j] <- device-memory Td[i] -+/// -+/// It can only be a leaf node in the epilogue tree -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementFragment_, ///< Data type used to cache vector in register -+ typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor -+> -+class VisitorOpColumnBroadcast { -+public: -+ using InputTileIterator = InputTileIterator_; -+ -+ static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementVector = typename InputTileIterator::Element; -+ using ElementFragment = ElementFragment_; -+ -+ using VisitAccessType = Array; -+ -+ /// Thread map used by input tile iterators -+ using ThreadMap = typename InputTileIterator::ThreadMap; -+ -+ /// Fragment object used to store the broadcast values -+ using BroadcastFragment = Array< -+ ElementFragment, kElementsPerAccess>; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Used for the broadcast -+ struct BroadcastDetail { -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = ThreadMap::kThreads; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread); -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ -+ // /// Number of iterations (accesses) the threadblock takes to reduce a row -+ // static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); -+ }; -+ -+ // using ComputeFragmentType = Array; -+ -+ struct SharedStorage { -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+ /// Host-constructable Argument structure -+ struct Arguments { -+ ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand -+ int64_t batch_stride; -+ -+ /// Methods -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ broadcast_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementVector *broadcast_ptr, -+ int64_t batch_stride -+ ): -+ broadcast_ptr(broadcast_ptr), -+ batch_stride(batch_stride) { } -+ }; -+ -+ /// Param structure -+ struct Params { -+ ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand -+ int64_t batch_stride; -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Params(): -+ broadcast_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ broadcast_ptr(args.broadcast_ptr), -+ batch_stride(args.batch_stride) { } -+ }; -+ -+private: -+ ElementVector *broadcast_ptr; -+ BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment -+ MatrixCoord threadblock_offset_; -+ int thread_idx_; -+ MatrixCoord problem_size; -+ -+ int thread_start_row_; -+ int state_[3]; -+ int thread_offset_row_; -+ -+ int64_t batch_stride_; -+ -+public: -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpColumnBroadcast( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ broadcast_ptr(params.broadcast_ptr), -+ threadblock_offset_(threadblock_offset), -+ thread_idx_(thread_idx), -+ problem_size(problem_size), -+ thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()), -+ batch_stride_(params.batch_stride) -+ { -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ broadcast_ptr += batch_idx * batch_stride_; -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) {} -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) {} -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ // get pointer -+ thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row(); -+ -+ ElementFragment broadcast_data = ElementFragment(*(broadcast_ptr + thread_offset_row_)); -+ -+ broadcast_fragment.fill(broadcast_data); -+ -+ return broadcast_fragment; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ // run operator ++ -+ ++state_[0]; -+ -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ if (state_[0] == ThreadMap::Count::kRow) { -+ state_[0] = 0; -+ ++state_[1]; -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ state_[1] = 0; -+ ++state_[2]; -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ } -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h -new file mode 100644 -index 0000000..624d7e6 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h -@@ -0,0 +1,341 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with reduction over columns in CTA -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementReductionAccumulator R[j] = \sum_i ElementReductionAccumulator(T[i][j]) -+/// device memory <- ElementReduction(R[j]) -+/// -+template < -+ typename ThreadblockShape_, /// Threadblock shape -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementReduction_, ///< Data type of the output reduction in device memory -+ typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register -+ typename OutputTileIterator_, ///< Tile Iterator type -+ typename Visitor_ ///< preceeding visitor op -+> -+class VisitorOpColumnReduction { -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementReductionAccumulator = ElementReductionAccumulator_; -+ using ElementReduction = ElementReduction_; -+ using OutputTileIterator = OutputTileIterator_; -+ using ThreadblockShape = ThreadblockShape_; -+ using Visitor = Visitor_; -+ -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ using ReductionOp = cutlass::plus>; -+ using ReductionOpScalar = cutlass::plus; -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ -+ -+ /// Fragment type returned from Visitor -+ using VisitAccessTypeVisitor = typename Visitor::VisitAccessType; -+ using ElementVisitor = typename VisitAccessTypeVisitor::Element; -+ -+ using VisitAccessType = VisitAccessTypeVisitor; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Fragment type of redcution -+ using ReductionAccumulatorAccessType = Array; -+ -+ /// Thread map used by output tile iterators -+ using ThreadMap = typename OutputTileIterator::ThreadMap; -+ /// Used for the reduction -+ struct ReductionDetail { -+ -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = ThreadMap::kThreads; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread; -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ -+ /// Number of iterations (accesses) the threadblock takes to reduce a row -+ static int const kThreadAccessesPerRow = const_max(1, (ThreadblockShape::kN + kThreadCount - 1) / kThreadCount); -+ -+ using StorageShape = MatrixShape< -+ kThreadRows, -+ ThreadblockShape::kN -+ >; -+ }; -+ -+ using ReductionFragment = Array; -+ -+ /// Shared storage -+ struct SharedStorage { -+ typename Visitor::SharedStorage storage_visitor; -+ AlignedArray reduction; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() {} -+ }; -+ -+ /// Host-constructable Argument structure -+ struct Arguments { -+ ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory -+ int64_t batch_stride; -+ typename Visitor::Arguments visitor_arg; ///< Argument type of visitor -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Arguments(): reduction_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementReduction *reduction_ptr, -+ int64_t batch_stride, -+ typename Visitor::Arguments visitor_arg -+ ): -+ reduction_ptr(reduction_ptr), -+ batch_stride(batch_stride), -+ visitor_arg(visitor_arg) -+ { } -+ }; -+ -+ /// Param structure -+ struct Params { -+ ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory -+ int64_t batch_stride; -+ typename Visitor::Params visitor_param; ///< Argument type of visitor -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Params(): reduction_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ reduction_ptr(args.reduction_ptr), -+ batch_stride(args.batch_stride), -+ visitor_param(args.visitor_arg) -+ { } -+ }; -+ -+private: -+ ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory -+ ElementReductionAccumulator *reduction_smem_ptr_; ///< Pointer to the partial reductions in shared memory -+ ReductionFragment reduction_fragment; ///< register fragments that hold the partial reduction -+ Visitor visitor_; ///< visitor -+ int thread_idx_; -+ MatrixCoord threadblock_offset; -+ MatrixCoord problem_size_; -+ int64_t batch_stride_; -+ -+public: -+ -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpColumnReduction( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ visitor_(params.visitor_param, shared_storage.storage_visitor, -+ thread_idx, threadblock_offset, problem_size), -+ reduction_smem_ptr_(shared_storage.reduction.data()), -+ reduction_output_ptr_(params.reduction_ptr), -+ thread_idx_(thread_idx), -+ threadblock_offset(threadblock_offset), -+ problem_size_(problem_size), -+ batch_stride_(params.batch_stride) -+ { } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ reduction_output_ptr_ += batch_idx * batch_stride_; -+ visitor_.set_batch_index(batch_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ visitor_.begin_epilogue(); -+ -+ // clear the reduction fragment -+ reduction_fragment.clear(); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ visitor_.begin_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ visitor_.begin_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ /// Get result from visitor -+ VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ -+ NumericArrayConverter reduction_converter; -+ ReductionOp reduction_op; -+ ReductionAccumulatorAccessType* reduction_fragment_ = reinterpret_cast(&reduction_fragment); -+ reduction_fragment_[column_idx] = reduction_op(reduction_fragment_[column_idx], reduction_converter(result)); -+ -+ return result; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ visitor_.end_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ visitor_.end_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ visitor_.end_epilogue(); -+ // -+ // Store the partially reduced value to SMEM -+ // -+ -+ // Guard against uses of the existing SMEM tile -+ __syncthreads(); -+ -+ using AccessType = AlignedArray; -+ -+ // -+ // Determine a compact thread arrangement to store to SMEM -+ // -+ -+ MatrixCoord thread_offset( -+ thread_idx_ / ReductionDetail::kThreadsPerRow, -+ (thread_idx_ % ReductionDetail::kThreadsPerRow) * ThreadMap::kElementsPerAccess -+ ); -+ -+ // -+ // Each thread store its fragment to a SMEM -+ // -+ AccessType *aligned_reduction_ptr = reinterpret_cast( -+ &reduction_smem_ptr_[thread_offset.row() * ThreadblockShape::kN + thread_offset.column()] -+ ); -+ -+ AccessType const *frag_ptr = reinterpret_cast( -+ &reduction_fragment -+ ); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess; -+ -+ aligned_reduction_ptr[col_idx] = frag_ptr[column]; -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Now, threads are assigned several columns of the output. The fetch over all rows from -+ // the compacted SMEM tile and perform a reduction. -+ // -+ -+ NumericConverter output_converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) { -+ int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount; -+ -+ ReductionOpScalar reduction_op; -+ ElementReductionAccumulator reduction_element = ElementReductionAccumulator(); -+ -+ int output_column_idx = threadblock_offset.column() + column_idx; -+ -+ if (column_idx < ThreadblockShape::kN && output_column_idx < problem_size_.column()) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ReductionDetail::kThreadRows; ++row) { -+ if (row) { -+ auto frag = reduction_smem_ptr_[row * ThreadblockShape::kN + column_idx]; -+ reduction_element = reduction_op(reduction_element, frag); -+ } -+ else { -+ -+ reduction_element = reduction_smem_ptr_[column_idx]; -+ } -+ } -+ -+ // Store -+ reduction_output_ptr_[column_idx + threadblock_offset.column() + threadblock_offset.row() / ThreadblockShape::kM * problem_size_.column()] = output_converter(reduction_element); -+ } -+ } -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h -new file mode 100644 -index 0000000..1e2b8e6 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h -@@ -0,0 +1,266 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with Linear Combination -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementCompute alpha; -+/// ElementCompute beta; -+/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B) -+/// Return C; -+/// -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementCompute_, ///< Data type used to compute linear combination -+ int kElementsPerAccess_, ///< Number of elements computed per operation -+ typename VisitorA_, ///< Child node A -+ typename VisitorB_ ///< Child node B -+> -+class VisitorOpLinearCombination{ -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ static int const kElementsPerAccess = kElementsPerAccess_; -+ -+ using VisitorA = VisitorA_; -+ using VisitorB = VisitorB_; -+ -+ /// Fragment type returned from VisitorA.visit -+ using VisitAccessTypeA = typename VisitorA::VisitAccessType; -+ using ElementA = typename VisitAccessTypeA::Element; -+ -+ /// Fragment type returned from VisitorB.visit -+ using VisitAccessTypeB = typename VisitorB::VisitAccessType; -+ using ElementB = typename VisitAccessTypeB::Element; -+ -+ /// Fragment type returned by this visitor -+ using VisitAccessType = Array; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Combination Op -+ using CombinationOp = cutlass::plus; -+ -+ static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A"); -+ static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B"); -+ -+ /// SMEM buffer class required in the epilogue visitor -+ struct SharedStorage { -+ typename VisitorA::SharedStorage storage_a; -+ typename VisitorB::SharedStorage storage_b; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() {} -+ }; -+ -+ -+ /// Host-constructable Arguments structure -+ struct Arguments { -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a -+ typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementCompute alpha, -+ ElementCompute beta, -+ typename VisitorA::Arguments visitor_a_arg, -+ typename VisitorB::Arguments visitor_b_arg -+ ): -+ alpha(alpha), -+ beta(beta), -+ visitor_a_arg(visitor_a_arg), -+ visitor_b_arg(visitor_b_arg) -+ { } -+ }; -+ -+ /// Parameter structure -+ struct Params { -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a -+ typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ alpha(args.alpha), -+ beta(args.beta), -+ visitor_a_param(args.visitor_a_arg), -+ visitor_b_param(args.visitor_b_arg) -+ { } -+ }; -+ -+private: -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+ VisitorA visitor_a_op; -+ VisitorB visitor_b_op; -+ -+public: -+ -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpLinearCombination( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ alpha_(params.alpha), -+ beta_(params.beta), -+ visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size), -+ visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size) -+ { } -+ -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ if (alpha_ != ElementCompute(0)) visitor_a_op.begin_epilogue(); -+ if (beta_ != ElementCompute(0)) visitor_b_op.begin_epilogue(); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ if (alpha_ != ElementCompute(0)) visitor_a_op.begin_step(step_idx); -+ if (beta_ != ElementCompute(0)) visitor_b_op.begin_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ if (alpha_ != ElementCompute(0)) visitor_a_op.begin_row(row_idx); -+ if (beta_ != ElementCompute(0)) visitor_b_op.begin_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ /// Get result from visitor A and visitor B -+ VisitAccessTypeA result_A; -+ VisitAccessTypeB result_B; -+ -+ if (alpha_ != ElementCompute(0)) { -+ result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ } else { -+ // Fill the result A with zeros -+ result_A.clear(); -+ } -+ -+ if (beta_ != ElementCompute(0)) { -+ result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ } else { -+ // Fill the result B with zeros -+ result_B.clear(); -+ } -+ -+ /// Type conversion -+ NumericArrayConverter source_converter_A; -+ NumericArrayConverter source_converter_B; -+ -+ CombinationOp combination_op; -+ -+ cutlass::multiplies multiply_op; -+ -+ return combination_op( -+ multiply_op(alpha_, source_converter_A(result_A)), -+ multiply_op(beta_, source_converter_B(result_B)) -+ ); -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ if (alpha_ != ElementCompute(0)) visitor_a_op.end_row(row_idx); -+ if (beta_ != ElementCompute(0)) visitor_b_op.end_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ if (alpha_ != ElementCompute(0)) visitor_a_op.end_step(step_idx); -+ if (beta_ != ElementCompute(0)) visitor_b_op.end_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ if (alpha_ != ElementCompute(0)) visitor_a_op.end_epilogue(); -+ if (beta_ != ElementCompute(0)) visitor_b_op.end_epilogue(); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h -new file mode 100644 -index 0000000..dc7bfa2 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with broadcasting vector to all rows -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementVector T[i][j] <- device-memory Td[j] -+/// -+/// It can only be a leaf node in the epilogue tree -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementFragment_, ///< Data type used to cache vector in register -+ typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor -+> -+class VisitorOpRowBroadcast { -+public: -+ using InputTileIterator = InputTileIterator_; -+ -+ static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementVector = typename InputTileIterator::Element; -+ using ElementFragment = ElementFragment_; -+ -+ using VisitAccessType = Array; -+ -+ /// Thread map used by input tile iterators -+ using ThreadMap = typename InputTileIterator::ThreadMap; -+ -+ /// Fragment object used to store the broadcast values -+ using BroadcastFragment = Array< -+ ElementFragment, -+ ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Used for the broadcast -+ struct BroadcastDetail { -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = ThreadMap::kThreads; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread); -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ -+ // /// Number of iterations (accesses) the threadblock takes to reduce a row -+ // static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); -+ }; -+ -+ // using ComputeFragmentType = Array; -+ -+ struct SharedStorage { -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+ /// Host-constructable Argument structure -+ struct Arguments { -+ ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand -+ int64_t batch_stride; -+ -+ /// Methods -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ broadcast_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementVector *broadcast_ptr, -+ int64_t batch_stride -+ ): -+ broadcast_ptr(broadcast_ptr), -+ batch_stride(batch_stride) { } -+ }; -+ -+ /// Param structure -+ struct Params { -+ ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand -+ int64_t batch_stride; -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Params(): -+ broadcast_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ broadcast_ptr(args.broadcast_ptr), -+ batch_stride(args.batch_stride) { } -+ }; -+ -+private: -+ ElementVector *broadcast_ptr; -+ BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment -+ MatrixCoord threadblock_offset_; -+ int thread_idx_; -+ MatrixCoord problem_size; -+ int64_t batch_stride_; -+ -+public: -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpRowBroadcast( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ broadcast_ptr(params.broadcast_ptr + threadblock_offset.column()), -+ threadblock_offset_(threadblock_offset), -+ thread_idx_(thread_idx), -+ problem_size(problem_size), -+ batch_stride_(params.batch_stride) { } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ broadcast_ptr += batch_idx * batch_stride_; -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ // load broadcast fragment -+ load_broadcast_fragment_(); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) {} -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) {} -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ VisitAccessType* broadcast_fragment_ = reinterpret_cast(&broadcast_fragment); -+ return broadcast_fragment_[column_idx]; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { } -+ -+private: -+ -+ CUTLASS_DEVICE -+ void load_broadcast_fragment_() { -+ -+ broadcast_fragment.clear(); -+ -+ // If no pointer is supplied, set with all zeros and avoid memory accesses -+ if (!broadcast_ptr) { -+ return; -+ } -+ -+ int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); -+ -+ int thread_column_idx = threadblock_offset_.column() + thread_initial_column; -+ broadcast_ptr += thread_initial_column; -+ -+ NumericArrayConverter converter; -+ using AccessType = AlignedArray; -+ using AccessFragmentType = Array; -+ -+ AccessFragmentType *frag_ptr = reinterpret_cast(&broadcast_fragment); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { -+ -+ AccessType loaded; -+ -+ loaded.clear(); -+ -+ if (thread_column_idx < problem_size.column()) { -+ loaded = *reinterpret_cast(broadcast_ptr); -+ } -+ -+ AccessFragmentType cvt = converter(loaded); -+ frag_ptr[j] = cvt; -+ -+ thread_column_idx += ThreadMap::Delta::kColumn; -+ broadcast_ptr += ThreadMap::Delta::kColumn; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h -new file mode 100644 -index 0000000..27b03f8 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h -@@ -0,0 +1,319 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with reduction over rows in CTA -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "stdio.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementReductionAccumulator R[i] = \sum_i ElementReductionAccumulator(T[i][j]) -+/// device memory <- ElementReduction(R[i]) -+/// -+template < -+ typename ThreadblockShape_, /// Threadblock shape -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementReduction_, ///< Data type of the output reduction in device memory -+ typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register -+ typename OutputTileIterator_, ///< Tile Iterator type -+ typename Visitor_ ///< preceeding visitor op -+> -+class VisitorOpRowReduction { -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementReductionAccumulator = ElementReductionAccumulator_; -+ using ElementReduction = ElementReduction_; -+ using OutputTileIterator = OutputTileIterator_; -+ using ThreadblockShape = ThreadblockShape_; -+ using Visitor = Visitor_; -+ -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ using ReductionOp = cutlass::plus>; -+ using ReductionOpScalar = cutlass::plus; -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Fragment type returned from Visitor -+ using VisitAccessTypeVisitor = typename Visitor::VisitAccessType; -+ using ElementVisitor = typename VisitAccessTypeVisitor::Element; -+ -+ using VisitAccessType = VisitAccessTypeVisitor; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Fragment type of redcution -+ using ReductionAccumulatorAccessType = Array; -+ -+ /// Thread map used by output tile iterators -+ using ThreadMap = typename OutputTileIterator::ThreadMap; -+ /// Used for the reduction -+ struct ReductionDetail { -+ -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = ThreadMap::kThreads; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread; -+ -+ /// Half number of threads per row used for cross-thread reduction -+ static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1); -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ }; -+ -+ /// Shared storage -+ struct SharedStorage { -+ typename Visitor::SharedStorage storage_visitor; -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+ /// Host-constructable Argument structure -+ struct Arguments { -+ ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory -+ int64_t batch_stride; -+ typename Visitor::Arguments visitor_arg; ///< Argument type of visitor -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Arguments(): reduction_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementReduction *reduction_ptr, -+ int64_t batch_stride, -+ typename Visitor::Arguments visitor_arg -+ ): -+ reduction_ptr(reduction_ptr), -+ batch_stride(batch_stride), -+ visitor_arg(visitor_arg) -+ { } -+ }; -+ -+ /// Param structure -+ struct Params { -+ ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory -+ int64_t batch_stride; -+ typename Visitor::Params visitor_param; ///< Argument type of visitor -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Params(): reduction_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ reduction_ptr(args.reduction_ptr), -+ batch_stride(args.batch_stride), -+ visitor_param(args.visitor_arg) -+ { } -+ }; -+ -+private: -+ ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory -+ ElementReductionAccumulator reduction_accum; -+ Visitor visitor_; ///< visitor -+ int thread_idx_; -+ MatrixCoord threadblock_offset; -+ MatrixCoord problem_size_; -+ -+ int thread_start_row_; /// used to identify -+ int state_[3]; /// used to track row iterator -+ int thread_offset_row_; -+ int64_t batch_stride_; -+public: -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpRowReduction( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ visitor_(params.visitor_param, shared_storage.storage_visitor, -+ thread_idx, threadblock_offset, problem_size), -+ reduction_output_ptr_(params.reduction_ptr), -+ thread_idx_(thread_idx), -+ threadblock_offset(threadblock_offset), -+ problem_size_(problem_size), -+ thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()), -+ batch_stride_(params.batch_stride) -+ { -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ reduction_output_ptr_ += batch_idx * batch_stride_; -+ visitor_.set_batch_index(batch_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ visitor_.begin_epilogue(); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ visitor_.begin_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ visitor_.begin_row(row_idx); -+ -+ reduction_accum = ElementReductionAccumulator(0); -+ } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ /// Get result from visitor -+ VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ -+ thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row(); -+ -+ ReductionOpScalar reduction_op; -+ -+ ElementReductionAccumulator reduction_accum_ = reduction(result); -+ -+ // After performing the in-thread reduction, we then perform cross-thread / in-warp reduction -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = ReductionDetail::kHalfThreadsPerRow; i > 0; i >>= 1) { -+ reduction_accum_ = reduction_op(reduction_accum_, __shfl_xor_sync(0xFFFFFFFF, reduction_accum_, i)); -+ } -+ reduction_accum = reduction_op(reduction_accum, reduction_accum_); -+ -+ return result; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ visitor_.end_row(row_idx); -+ NumericConverter output_converter; -+ -+ bool is_write_thread = (thread_offset_row_ < problem_size_.row() && (thread_idx_ % ReductionDetail::kThreadsPerRow) == 0); -+ int row_offset = thread_offset_row_ + threadblock_offset.column() / ThreadblockShape::kN * problem_size_.row(); -+ -+ ElementReduction *curr_ptr_reduction = reduction_output_ptr_ + row_offset; -+ -+ arch::global_store( -+ output_converter(reduction_accum), -+ (void *)curr_ptr_reduction, -+ is_write_thread); -+ } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ visitor_.end_step(step_idx); -+ -+ // run operator ++ -+ ++state_[0]; -+ -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ if (state_[0] == ThreadMap::Count::kRow) { -+ state_[0] = 0; -+ ++state_[1]; -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ state_[1] = 0; -+ ++state_[2]; -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ } -+ } -+ } -+ -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ visitor_.end_epilogue(); -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ ElementReductionAccumulator reduction(VisitAccessTypeVisitor const& result) { -+ ElementReductionAccumulator sum_ = ElementReductionAccumulator(0); -+ -+ ReductionOpScalar reduction_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < VisitAccessTypeVisitor::kElements; ++i) { -+ sum_ = reduction_op(sum_, result[i]); -+ } -+ -+ return sum_; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h -new file mode 100644 -index 0000000..d2eac4f ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h -@@ -0,0 +1,188 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with Tensor Output -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementInput C <- device memory -+/// -+/// It can only be a leaf node in the epilogue tree -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename InputTileIterator_ ///< Tile iterator type to read the tensor -+> -+class VisitorOpTensorInput { -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using InputTileIterator = InputTileIterator_; -+ -+ static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess; -+ using ElementInput = typename InputTileIterator::Element; -+ -+ using VisitAccessType = Array; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ struct SharedStorage { -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+ /// Host-constructable Argument structure -+ struct Arguments { -+ ElementInput *input_ptr; ///< Pointer to the input tensor in device memory -+ int ldt; ///< Leading dimension of the input tensor operand -+ int64_t batch_stride; ///< batch stride for batched GEMM -+ -+ /// Methods -+ CUTLASS_HOST_DEVICE -+ Arguments(): input_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementInput *input_ptr, -+ int ldt, int64_t batch_stride -+ ): -+ input_ptr(input_ptr), -+ ldt(ldt), -+ batch_stride(batch_stride) -+ { } -+ }; -+ -+ /// Param structure -+ struct Params { -+ typename InputTileIterator::Params params_input; -+ ElementInput *input_ptr; -+ int64_t batch_stride; -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Params(): -+ input_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ params_input(args.ldt), -+ input_ptr(args.input_ptr), -+ batch_stride(args.batch_stride) -+ { } -+ }; -+ -+private: -+ InputTileIterator iterator_T_; -+ typename InputTileIterator::Fragment fragment_T_; -+ MatrixCoord problem_size; -+ int64_t batch_stride_; -+ -+public: -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpTensorInput( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ iterator_T_( -+ InputTileIterator( -+ params.params_input, -+ params.input_ptr, -+ problem_size, -+ thread_idx, -+ threadblock_offset -+ ) -+ ), -+ problem_size(problem_size), -+ batch_stride_(params.batch_stride) { } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ iterator_T_.add_pointer_offset(batch_idx * batch_stride_); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ fragment_T_.clear(); -+ iterator_T_.load(fragment_T_); -+ ++iterator_T_; -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ VisitAccessType source = reinterpret_cast(&fragment_T_)[frag_idx]; -+ return source; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h -new file mode 100644 -index 0000000..407611a ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h -@@ -0,0 +1,240 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with Tensor Output -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "stdio.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementOutput T = ElementOutput(Visitor) -+/// T-> device memory -+/// -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename OutputTileIterator_, ///< Tile iterator type to write the tensor -+ typename Visitor_ ///< Child visitor that produces the output tensor -+> -+class VisitorOpTensorOutput { -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using OutputTileIterator = OutputTileIterator_; -+ -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ using Visitor = Visitor_; -+ -+ /// Fragment type returned from Visitor -+ using VisitAccessTypeVisitor = typename Visitor::VisitAccessType; -+ using ElementVisitor = typename VisitAccessTypeVisitor::Element; -+ -+ using VisitAccessType = VisitAccessTypeVisitor; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Fragment type of output -+ using OutputAccessType = Array; -+ -+ static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor"); -+ -+ struct SharedStorage { -+ typename Visitor::SharedStorage storage_visitor; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+ /// Host-constructable Argument structure -+ struct Arguments { -+ ElementOutput *output_ptr; ///< Pointer to the output tensor in device memory -+ int ldt; ///< Leading dimension of the output tensor operand -+ int64_t batch_stride; ///< batch stride -+ typename Visitor::Arguments visitor_arg; ///< Argument type of visitor -+ -+ /// Methods -+ CUTLASS_HOST_DEVICE -+ Arguments(): output_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementOutput *output_ptr, -+ int ldt, -+ int64_t batch_stride, -+ typename Visitor::Arguments visitor_arg -+ ): -+ output_ptr(output_ptr), -+ ldt(ldt), -+ batch_stride(batch_stride), -+ visitor_arg(visitor_arg) -+ { } -+ }; -+ -+ /// Param structure -+ struct Params { -+ typename OutputTileIterator::Params params_output; -+ ElementOutput *output_ptr; -+ int64_t batch_stride; -+ typename Visitor::Params visitor_param; -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Params(): -+ output_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ params_output(args.ldt), -+ output_ptr(args.output_ptr), -+ batch_stride(args.batch_stride), -+ visitor_param(args.visitor_arg) -+ { } -+ }; -+ -+private: -+ OutputTileIterator iterator_T_; -+ typename OutputTileIterator::Fragment fragment_T_; -+ MatrixCoord problem_size; -+ Visitor visitor_; -+ int64_t batch_stride_; -+ -+public: -+ -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpTensorOutput( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ visitor_(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size), -+ iterator_T_( -+ OutputTileIterator( -+ params.params_output, -+ params.output_ptr, -+ problem_size, -+ thread_idx, -+ threadblock_offset -+ ) -+ ), -+ problem_size(problem_size), -+ batch_stride_(params.batch_stride) { } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ iterator_T_.add_pointer_offset(batch_idx * batch_stride_); -+ visitor_.set_batch_index(batch_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ visitor_.begin_epilogue(); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ fragment_T_.clear(); -+ visitor_.begin_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ visitor_.begin_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ /// Get result from visitor -+ VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ -+ // Column guard -+ MatrixCoord thread_offset_ = iterator_T_.thread_start() + OutputTileIterator::ThreadMap::iteration_offset(frag_idx); -+ bool column_guard = (thread_offset_.column() < problem_size.column()); -+ -+ if (column_guard) { -+ NumericArrayConverter output_converter; -+ OutputAccessType &output = reinterpret_cast(&fragment_T_)[frag_idx]; -+ output = output_converter(result); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ visitor_.end_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ visitor_.end_step(step_idx); -+ iterator_T_.store(fragment_T_); -+ ++iterator_T_; -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ visitor_.end_epilogue(); -+ } -+ -+}; -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h -new file mode 100644 -index 0000000..c80543e ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h -@@ -0,0 +1,226 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with Unary operation -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "unary_ops.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementCompute alpha; -+/// ElementCompute beta; -+/// ElementCompute C = UnaryOp(ElementCompute(Visitor)) -+/// Return C; -+/// -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementCompute_, ///< Data type used to compute linear combination -+ int kElementsPerAccess_, ///< Number of elements computed per operation -+ typename Visitor_, ///< Child node -+ template typename UnaryOp_ -+> -+class VisitorOpUnary{ -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ static int const kElementsPerAccess = kElementsPerAccess_; -+ -+ using Visitor = Visitor_; -+ -+ /// Fragment type returned from Visitor.visit -+ using VisitAccessTypeVisitor = typename Visitor::VisitAccessType; -+ using ElementVisit = typename VisitAccessTypeVisitor::Element; -+ -+ /// Fragment type returned by this visitor -+ using VisitAccessType = Array; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Combination Op -+ using UnaryOp = UnaryOp_; -+ -+ static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor"); -+ -+ /// SMEM buffer class required in the epilogue visitor -+ struct SharedStorage { -+ typename Visitor::SharedStorage storage_visitor; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() {} -+ }; -+ -+ -+ /// Host-constructable Arguments structure -+ struct Arguments { -+ typename UnaryOp::Arguments unary_arg; -+ typename Visitor::Arguments visitor_arg; ///< Argument type for visitor -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Arguments():unary_arg() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ typename UnaryOp::Arguments unary_arg, -+ typename Visitor::Arguments visitor_arg -+ ): -+ unary_arg(unary_arg), -+ visitor_arg(visitor_arg) -+ { } -+ }; -+ -+ /// Parameter structure -+ struct Params { -+ typename UnaryOp::Params unary_param; -+ typename Visitor::Params visitor_param; ///< Argument type for visitor -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params():unary_param() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ unary_param(args.unary_arg), -+ visitor_param(args.visitor_arg) -+ { } -+ }; -+ -+private: -+ // -+ // Data members -+ // -+ UnaryOp unary_op; -+ -+ Visitor visitor_op; -+ -+public: -+ -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpUnary( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ unary_op(params.unary_param), -+ visitor_op(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size) -+ { } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ visitor_op.set_batch_index(batch_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ if (unary_op.guard()) visitor_op.begin_epilogue(); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ if (unary_op.guard()) visitor_op.begin_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ if (unary_op.guard()) visitor_op.begin_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ /// Get result from visitor A and visitor B -+ VisitAccessTypeVisitor result; -+ -+ if (unary_op.guard()){ -+ result = visitor_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ } else { -+ result.clear(); -+ } -+ -+ /// Type conversion -+ NumericArrayConverter source_converter; -+ -+ cutlass::multiplies multiply_op; -+ -+ return unary_op(source_converter(result)); -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ if (unary_op.guard()) visitor_op.end_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ if (unary_op.guard()) visitor_op.end_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ if (unary_op.guard()) visitor_op.end_epilogue(); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_with_layernorm.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_with_layernorm.h -new file mode 100644 -index 0000000..54936ff ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_with_layernorm.h -@@ -0,0 +1,480 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Epilogue visitor type used for partial computation of a layernorm operation -+ -+ GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm) -+ + lightweight full reduction kernel (ApplyFinalReduction) -+ + GEMM1 with elementwise operations fused in mainloop (GemmLayernormMainloopFusion) -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ThreadblockShape_, -+ int ThreadCount, -+ typename OutputTileIterator_, -+ typename AccumulatorTile_, -+ typename ElementAccumulator_, -+ typename ElementVariance_, -+ typename ElementMean_, -+ typename ElementLayernormCompute_, -+ typename ElementwiseFunctor_, -+ bool IsShiftedVariance_ = false -+> -+class EpilogueVisitorLayerNorm { -+public: -+ -+ using ElementVariance = ElementVariance_; -+ using ElementMean = ElementMean_; -+ using ElementLayernormCompute = ElementLayernormCompute_; -+ -+ using AccumulatorTile = AccumulatorTile_; -+ -+ using ThreadblockShape = ThreadblockShape_; -+ static int const kThreadCount = ThreadCount; -+ -+ using OutputTileIterator = OutputTileIterator_; -+ using ElementwiseFunctor = ElementwiseFunctor_; -+ -+ static int const kIterations = OutputTileIterator::kIterations; -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow; -+ -+ static int const kThreads = OutputTileIterator::ThreadMap::kThreads; -+ -+ static bool const kIsShiftedVariance = IsShiftedVariance_; -+ -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow; -+ -+ /// Array type used in Shift-K Layernorm -+ static int const kRowAccessCount = kIterations * kRowIterations; -+ -+ using ConvertedShiftFragment = Array; -+ -+ // Conducts manual transpose externally (already supported) for column major -+ using LayoutOutput = cutlass::layout::RowMajor; -+ -+ using ElementAccumulator = ElementAccumulator_; -+ -+ using AccumulatorFragment = Array; -+ using LayernormFragment = Array; -+ using OutputVector = Array; -+ using TensorRefD = TensorRef; -+ -+ static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; -+ static int const kThreadsInColumn = kThreads / kThreadsPerRow; -+ static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1); -+ -+ /// Argument structure -+ struct Arguments { -+ -+ typename ElementwiseFunctor::Params elementwise; -+ ElementVariance *ptr_Variance; -+ ElementMean *ptr_Mean; -+ ElementOutput *ptr_Shifted_K; -+ MatrixCoord extent; -+ -+ // -+ // Methods -+ // -+ Arguments(): -+ ptr_Variance(nullptr), -+ ptr_Mean(nullptr), -+ ptr_Shifted_K(nullptr) -+ { -+ -+ } -+ -+ Arguments( -+ typename ElementwiseFunctor::Params elementwise_, -+ ElementVariance *ptr_Variance, -+ ElementMean *ptr_Mean_, -+ ElementOutput *ptr_Shifted_K_ = nullptr, -+ MatrixCoord extent = MatrixCoord(0, 0) -+ ): -+ elementwise(elementwise_), -+ ptr_Variance(ptr_Variance), -+ ptr_Mean(ptr_Mean_), -+ ptr_Shifted_K(ptr_Shifted_K_), -+ extent(extent) -+ { -+ -+ } -+ }; -+ -+ struct Params { -+ -+ typename ElementwiseFunctor::Params elementwise; -+ ElementVariance *ptr_Variance; -+ ElementMean *ptr_Mean; -+ ElementOutput *ptr_Shifted_K; -+ MatrixCoord extent; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params(): -+ ptr_Variance(nullptr), -+ ptr_Mean(nullptr) -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ elementwise(args.elementwise), -+ ptr_Variance(args.ptr_Variance), -+ ptr_Mean(args.ptr_Mean), -+ ptr_Shifted_K(args.ptr_Shifted_K), -+ extent(args.extent) -+ { -+ -+ } -+ }; -+ -+ /// Shared storage -+ struct SharedStorage { -+ -+ }; -+ -+private: -+ -+ Params const & params_; -+ SharedStorage & shared_storage_; -+ MatrixCoord extent_; -+ ElementwiseFunctor elementwise_; -+ -+ OutputTileIterator iterator_C_; -+ OutputTileIterator iterator_D_; -+ typename OutputTileIterator::Fragment fragment_C_; -+ typename OutputTileIterator::Fragment fragment_D_; -+ -+ ElementAccumulator alpha_; -+ ElementAccumulator beta_; -+ ConvertedShiftFragment shift_k_frag_; -+ -+ ElementLayernormCompute accum_sum_square_; -+ ElementLayernormCompute accum_sum_element_; -+ int thread_idx_; -+ -+ MatrixCoord thread_offset_; -+ -+ gemm::GemmCoord threadblock_tile_offset_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ EpilogueVisitorLayerNorm( -+ Params const ¶ms, ///< Parameters routed to the epilogue -+ SharedStorage &shared_storage, ///< Shared storage needed by the functors here -+ MatrixCoord threadblock_offset, -+ gemm::GemmCoord threadblock_tile_offset, -+ int thread_idx, -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMMM -+ ): -+ params_(params), -+ shared_storage_(shared_storage), -+ elementwise_(params.elementwise), -+ extent_(params.extent), -+ iterator_C_(source_iterator), -+ iterator_D_(destination_iterator), -+ threadblock_tile_offset_(threadblock_tile_offset), -+ thread_idx_(thread_idx) -+ { -+ alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); -+ beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); -+ -+ if (beta_ == ElementAccumulator()) { -+ iterator_C_.clear_mask(); -+ } -+ } -+ -+ /// Helper to indicate split-K behavior -+ CUTLASS_DEVICE -+ void set_k_partition( -+ int split_k_index, ///< Index of this threadblock within split-K partitioned scheme -+ int split_k_slices) { ///< Total number of split-K slices -+ -+ } -+ -+ /// Called to set the batch index -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ -+ } -+ -+ /// Called at the start of the epilogue just before iterating over accumulator slices -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ -+ // If shift-K feature is enabled, we load shift-k fragment -+ // at the very beginning of an epilogue -+ if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) { -+ shift_k_frag_.clear(); -+ int thread_offset_row_base = iterator_D_.thread_start_row(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) { -+ int step_offset = iter_idx * OutputTileIterator::Shape::kRow; -+ CUTLASS_PRAGMA_UNROLL -+ for (int rid = 0; rid < kRowIterations; ++rid) { -+ int row_step_offset = rid * kDeltaRow; -+ int row_offset = thread_offset_row_base + step_offset + row_step_offset; -+ bool is_load = (row_offset < extent_.row()); -+ shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load); -+ } -+ -+ } -+ -+ } -+ -+ } -+ -+ /// Called at the start of one step before starting accumulator exchange -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ fragment_D_.clear(); -+ -+ if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { -+ fragment_C_.clear(); -+ iterator_C_.load(fragment_C_); -+ ++iterator_C_; -+ } -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ /// set the accumulator to 0 -+ accum_sum_element_ = ElementLayernormCompute(0); -+ accum_sum_square_ = ElementLayernormCompute(0); -+ } -+ -+ /// Called after accumulators have been exchanged for each accumulator vector -+ CUTLASS_DEVICE -+ void visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorFragment const &accum) { -+ -+ using Mul = cutlass::multiplies; -+ using Minus = cutlass::minus; -+ using Exp = cutlass::fast_exp_op; -+ -+ Minus minus; -+ Mul mul; -+ Exp exponential; -+ -+ LayernormFragment result; -+ -+ thread_offset_ = -+ iterator_D_.thread_start() + -+ OutputTileIterator::ThreadMap::iteration_offset(frag_idx); -+ -+ NumericArrayConverter source_converter; -+ OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; -+ -+ bool column_guard = (thread_offset_.column() < extent_.column()); -+ -+ if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { -+ result = source_converter(elementwise_(accum)); -+ }else{ -+ result = source_converter(elementwise_(accum, source_vector)); -+ } -+ -+ -+ ElementLayernormCompute inv_scalar = cutlass::constants::one() / ElementLayernormCompute(extent_.column()); -+ -+ // Fragment is cleared for non-reachable columns so no need to check against column guard -+ ElementLayernormCompute accum_sum_element_tmp = element_sum_accumulator_(result); -+ -+ // Square sum is different. Non-reachable columns should've been computed for shift-k -+ // Otherwise we will incorrectly have some extra k^2 added into square sum. -+ ElementLayernormCompute accum_sum_square_tmp = ElementLayernormCompute(0); -+ -+ if (column_guard) { -+ accum_sum_square_tmp = (kIsShiftedVariance) ? \ -+ square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \ -+ square_sum_accumulator_(result); -+ } -+ -+ accum_sum_element_tmp *= inv_scalar; -+ accum_sum_square_tmp *= inv_scalar; -+ -+ // After performing the in-thread reduction, we then perform cross-thread / in-warp reduction -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) { -+ accum_sum_element_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_tmp, i); -+ accum_sum_square_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_tmp, i); -+ } -+ accum_sum_element_ += accum_sum_element_tmp; -+ accum_sum_square_ += accum_sum_square_tmp; -+ -+ // Convert to the output -+ NumericArrayConverter output_converter; -+ OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; -+ output = output_converter(result); -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ -+ using ConvertVarianceOutput = cutlass::NumericConverter; -+ using ConvertMeanOutput = cutlass::NumericConverter; -+ -+ ConvertVarianceOutput convert_variance_output; -+ ConvertMeanOutput convert_mean_output; -+ -+ bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0); -+ int row_offset = thread_offset_.row() + threadblock_tile_offset_.n() * extent_.row(); -+ -+ ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset; -+ ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset; -+ -+ arch::global_store( -+ convert_variance_output(accum_sum_square_), -+ (void *)curr_ptr_sum_square, -+ is_write_thread); -+ -+ arch::global_store( -+ convert_mean_output(accum_sum_element_), -+ (void *)curr_ptr_element_sum, -+ is_write_thread); -+ } -+ -+ /// Called after all accumulator elements have been visited -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ -+ iterator_D_.store(fragment_D_); -+ ++iterator_D_; -+ } -+ -+ /// Called after all steps have been completed -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) { -+ using ConvertShiftK = cutlass::NumericConverter; -+ ConvertShiftK convert_shift_k; -+ ElementOutput shift_k_val; -+ -+ // Computes the address to load shift_k element -+ ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset; -+ // Conditionally loads from global memory -+ arch::global_load(shift_k_val, (void *)curr_ptr_shift_k, is_load); -+ // Converts data type to return -+ ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val); -+ -+ return converted_shift_k_val; -+ } -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) { -+ ElementLayernormCompute sum_ = ElementLayernormCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < LayernormFragment::kElements; ++i) { -+ auto accum_ = accum[i]; -+ sum_ += accum_ * accum_; -+ } -+ -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) { -+ ElementLayernormCompute sum_ = ElementLayernormCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < LayernormFragment::kElements; ++i) { -+ auto accum_ = accum[i] - shift_k_val; -+ sum_ += accum_ * accum_; -+ } -+ -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) { -+ ElementLayernormCompute sum_ = ElementLayernormCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < LayernormFragment::kElements; ++i) { -+ sum_ += accum[i]; -+ } -+ -+ return sum_; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h -new file mode 100644 -index 0000000..36987b5 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h -@@ -0,0 +1,77 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind gemm related enum types to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/gemm/gemm.h" -+#include "host.h" -+ -+namespace py = pybind11; -+ -+void bind_gemm(py::module &m) { -+ // -+ // Enumerate types -+ // cutlass/gemm/gemm.h -+ -+ py::enum_(m, "Mode") -+ .value("Gemm", cutlass::gemm::GemmUniversalMode::kGemm, "Ordinary GEMM & GEMM Split-K serial") -+ .value("GemmSplitKParallel", cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, "GEMM Split-K parallel") -+ .value("Batched", cutlass::gemm::GemmUniversalMode::kBatched, "Batched GEMM") -+ .value("Array", cutlass::gemm::GemmUniversalMode::kArray) -+ .value("Invalid", cutlass::gemm::GemmUniversalMode::kInvalid); -+ -+ /// GemmCoord is a structure that specifies a location within the coordiate space of a GEMM problem -+ py::class_(m, "GemmCoord") -+ .def(py::init()) -+ .def("m", py::overload_cast<>(&cutlass::gemm::GemmCoord::m)) -+ .def("n", py::overload_cast<>(&cutlass::gemm::GemmCoord::n)) -+ .def("k", py::overload_cast<>(&cutlass::gemm::GemmCoord::k)) -+ // get tensor coords -+ .def("mk", -+ [](const cutlass::gemm::GemmCoord & problem_size) { -+ return cutlass::MatrixCoord(problem_size.mk()); -+ }) -+ .def("kn", -+ [](const cutlass::gemm::GemmCoord & problem_size) { -+ return cutlass::MatrixCoord(problem_size.kn()); -+ }) -+ .def("mn", -+ [](const cutlass::gemm::GemmCoord & problem_size) { -+ return cutlass::MatrixCoord(problem_size.mn()); -+ }); -+ -+ py::module_ host_submodule = m.def_submodule("host"); -+ bind_gemm_host_helper(host_submodule); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h -new file mode 100644 -index 0000000..64b65a0 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h -@@ -0,0 +1,628 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmUniversalwithEpilogueVisitor { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueVisitor = typename Epilogue::Visitor; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename EpilogueVisitor::ElementOutput; -+ using LayoutC = typename EpilogueVisitor::OutputTileIterator::Layout; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value -+ ); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase { -+ -+ // -+ // Data members -+ // -+ -+ typename EpilogueVisitor::Arguments epilogue_visitor; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ -+ typename LayoutA::Stride stride_a; -+ typename LayoutB::Stride stride_b; -+ typename LayoutC::Stride stride_c; -+ typename LayoutC::Stride stride_d; -+ -+ typename LayoutA::Stride::LongIndex lda; -+ typename LayoutB::Stride::LongIndex ldb; -+ typename LayoutC::Stride::LongIndex ldc; -+ typename LayoutC::Stride::LongIndex ldd; -+ -+ int const * ptr_gather_A_indices; -+ int const * ptr_gather_B_indices; -+ int const * ptr_scatter_D_indices; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), -+ ptr_gather_A_indices(nullptr), -+ ptr_gather_B_indices(nullptr), -+ ptr_scatter_D_indices(nullptr) {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueVisitor::Arguments epilogue_visitor, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride stride_a, -+ typename LayoutB::Stride stride_b, -+ typename LayoutC::Stride stride_c, -+ typename LayoutC::Stride stride_d, -+ int const *ptr_gather_A_indices = nullptr, -+ int const *ptr_gather_B_indices = nullptr, -+ int const *ptr_scatter_D_indices = nullptr -+ ): -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue_visitor(epilogue_visitor), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), -+ stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), -+ ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), -+ ptr_scatter_D_indices(ptr_scatter_D_indices) { -+ lda = 0; -+ ldb = 0; -+ ldc = 0; -+ ldd = 0; -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueVisitor::Arguments epilogue_visitor, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::LongIndex lda, -+ typename LayoutB::Stride::LongIndex ldb, -+ typename LayoutC::Stride::LongIndex ldc, -+ typename LayoutC::Stride::LongIndex ldd, -+ int const *ptr_gather_A_indices = nullptr, -+ int const *ptr_gather_B_indices = nullptr, -+ int const *ptr_scatter_D_indices = nullptr -+ ): -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue_visitor(epilogue_visitor), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), -+ ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), -+ ptr_scatter_D_indices(ptr_scatter_D_indices) { -+ stride_a = make_Coord(lda); -+ stride_b = make_Coord(ldb); -+ stride_c = make_Coord(ldc); -+ stride_d = make_Coord(ldd); -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.stride_a, args.stride_b); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices); -+ -+ return args; -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> { -+ -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename EpilogueVisitor::OutputTileIterator::Params params_C; -+ typename EpilogueVisitor::OutputTileIterator::Params params_D; -+ -+ typename EpilogueVisitor::Params epilogue_visitor; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ -+ int * ptr_gather_A_indices; -+ int * ptr_gather_B_indices; -+ int * ptr_scatter_D_indices; -+ -+ int *semaphore; -+ -+ // -+ // Methods -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ int device_sms, -+ int sm_occupancy -+ ): -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), -+ params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), -+ params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), -+ params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), -+ epilogue_visitor(args.epilogue_visitor), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(args.ptr_D), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), -+ ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), -+ ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr) { -+ -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); -+ ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); -+ ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); -+ -+ batch_stride_A = args.batch_stride_A; -+ batch_stride_B = args.batch_stride_B; -+ batch_stride_C = args.batch_stride_C; -+ -+ epilogue_visitor = args.epilogue_visitor; -+ -+ semaphore = static_cast(workspace); -+ CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ typename EpilogueVisitor::SharedStorage visitor; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ GemmUniversalwithEpilogueVisitor() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmUniversalwithEpilogueVisitor::can_implement()"); -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A, -+ params.ptr_gather_A_indices); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B, -+ params.ptr_gather_B_indices); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ // EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // Tile iterator loading from source tensor. -+ -+ EpilogueVisitor epilogue_visitor( -+ params.epilogue_visitor, -+ shared_storage.visitor, -+ threadblock_offset, -+ threadblock_tile_offset, -+ thread_idx, -+ params.problem_size.mn() -+ ); -+ -+ if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { -+ epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); -+ } -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ semaphore.wait(threadblock_tile_offset.k()); -+ } -+ -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(epilogue_visitor, accumulators); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h -new file mode 100644 -index 0000000..3a6a587 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h -@@ -0,0 +1,47 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind gemm host helpers to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/layout/tensor.h" -+ -+namespace py = pybind11; -+ -+ -+void bind_gemm_host_helper(py::module &m) { -+ m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::RowMajorInterleaved<32>>); -+ m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::ColumnMajorInterleaved<32>>); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/layout.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/layout.h -new file mode 100644 -index 0000000..5968bc0 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/layout.h -@@ -0,0 +1,47 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind CUTLASS layouts to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "tensor.h" -+#include "matrix.h" -+ -+ -+namespace py = pybind11; -+ -+void bind_layout(py::module &m) { -+ bind_tensor_layout(m); -+ bind_matrix_layout(m); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/matrix.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/matrix.h -new file mode 100644 -index 0000000..f19e04e ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/matrix.h -@@ -0,0 +1,87 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind Matrix layouts to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/layout/matrix.h" -+ -+namespace py = pybind11; -+ -+void bind_matrix_layout(py::module &m) { -+ // -+ // Matrix layouts -+ // cutlass/layout/matrix.h -+ // -+ -+ py::class_(m, "RowMajor", R"pbdoc( -+ Mapping function for row-major matrices. -+ )pbdoc") -+ .def_static("packed", &cutlass::layout::RowMajor::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") -+ .def("stride", [](const cutlass::layout::RowMajor & layout){ -+ return layout.stride().at(0); -+ }, R"pbdoc(Returns the stride of the layout)pbdoc"); -+ -+ py::class_(m, "ColumnMajor", R"pbdoc( -+ Mapping function for column-major matrices. -+ )pbdoc") -+ .def_static("packed", &cutlass::layout::ColumnMajor::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc" ) -+ .def("stride", [](const cutlass::layout::ColumnMajor & layout){ -+ return layout.stride().at(0); -+ }, R"pbdoc(Returns the stride of the layout)pbdoc"); -+ -+ py::class_>(m, "RowMajorInterleaved32", -+ R"pbdoc(Mapping function for interleaved matrices. Matrix is structured -+ as row-major arrangement of fixed-size columns 32)pbdoc") -+ .def_static("packed", &cutlass::layout::RowMajorInterleaved<32>::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") -+ .def("stride", [](const cutlass::layout::RowMajorInterleaved<32> & layout){ -+ return layout.stride().at(0); -+ }, R"pbdoc(Returns the stride of the layout)pbdoc"); -+ -+ py::class_>(m, "ColumnMajorInterleaved32", -+ R"pbdoc(Mapping function for interleaved matrices. Matrix is structured -+ as column-major arrangement of fixed-size rows 32)pbdoc") -+ .def_static("packed", &cutlass::layout::ColumnMajorInterleaved<32>::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") -+ .def("stride", [](const cutlass::layout::ColumnMajorInterleaved<32> & layout){ -+ return layout.stride().at(0); -+ }, R"pbdoc(Returns the stride of the layout)pbdoc"); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/tensor.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/tensor.h -new file mode 100644 -index 0000000..5edb100 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/tensor.h -@@ -0,0 +1,74 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind Tensor layouts to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/layout/tensor.h" -+ -+namespace py = pybind11; -+ -+void bind_tensor_layout(py::module &m) { -+ // -+ // Tensor layouts -+ // cutlass/include/cutlass/layout/tensor.h -+ // -+ -+ /// Mapping function for 4-D NHWC tensors. -+ py::class_(m, "TensorNHWC", -+ R"pbdoc(Mapping function for 4-D NHWC tensors)pbdoc") -+ .def_static("packed", &cutlass::layout::TensorNHWC::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed NHWC tensor)pbdoc") -+ .def("stride", py::overload_cast<>(&cutlass::layout::TensorNHWC::stride), -+ R"pbdoc(Returns the stride of the layout)pbdoc"); -+ -+ /// Mapping function for 4-D NC/xHWx tensors. -+ py::class_>(m, "TensorNC32HW32", -+ R"pbdoc(Mapping function for 4-D NC/32HW32 tensors)pbdoc") -+ .def_static("packed", &cutlass::layout::TensorNCxHWx<32>::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") -+ .def("stride", py::overload_cast<>(&cutlass::layout::TensorNCxHWx<32>::stride), -+ R"pbdoc(Returns the stride of the layout)pbdoc"); -+ -+ /// Mapping function for 4-D CxRSKx tensors. -+ py::class_>(m, "TensorC32RSK32", -+ R"pbdoc(Mapping function for 4-D C32RSK32 tensors)pbdoc") -+ .def_static("packed", &cutlass::layout::TensorCxRSKx<32>::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") -+ .def("stride", py::overload_cast<>(&cutlass::layout::TensorCxRSKx<32>::stride), -+ R"pbdoc(Returns the stride of the layout)pbdoc"); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h -new file mode 100644 -index 0000000..43991e4 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h -@@ -0,0 +1,159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind threadblock swizzling to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/conv/threadblock/threadblock_swizzle.h" -+ -+#include -+#include -+ -+namespace py = pybind11; -+ -+std::string demangle(const char* mangled_name) { -+ std::size_t len = 0; -+ int status = 0; -+ std::unique_ptr ptr( -+ __cxxabiv1::__cxa_demangle(mangled_name, nullptr, &len, &status)); -+ return ptr.get(); -+} -+ -+template -+void bind_identity_swizzle(py::module & m, std::string name) { -+ py::class_(m, name.c_str(), -+ R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc") -+ .def(py::init<>()) -+ .def("get_tiled_shape", -+ py::overload_cast( -+ &T::get_tiled_shape, py::const_ -+ ), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), -+ R"pbdoc(Returns the shape of the problem in units of logical tiles -+ -+ :param problem_size: gemm(M, N, K) -+ :type problem_size: :class:`cutlass.gemm.GemmCoord` -+ )pbdoc") -+ .def("get_tiled_shape", -+ py::overload_cast( -+ &T::get_tiled_shape, py::const_ -+ ), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), -+ R"pbdoc(Returns the shape of the problem in units of logical tiles -+ -+ :param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC) -+ :type problem_size: :class:`cutlass.gemm.GemmCoord`) -+ )pbdoc") -+ .def("get_tiled_shape", -+ py::overload_cast( -+ &T::get_tiled_shape, py::const_ -+ ), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), -+ R"pbdoc(Returns the shape of the problem in units of logical tiles -+ -+ :param problem_size: Implicit gemm problem size conv_operator(NZPQK, NDHWC, KTRSC) -+ :type problem_size: :class:`cutlass.gemm.GemmCoord`) -+ )pbdoc") -+ .def("get_grid_shape", &T::get_grid_shape, -+ py::arg("tiled_shape"), -+ R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") -+ .def("tag", [](const T & swizzle){ -+ return demangle(typeid(T).name()); -+ }, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc"); -+} -+ -+template -+void bind_swizzle(py::module & m, std::string name, std::string doc) { -+ py::class_(m, name.c_str(), doc.c_str()) -+ .def(py::init<>()) -+ .def("get_tiled_shape", -+ py::overload_cast( -+ &T::get_tiled_shape, py::const_ -+ ), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), -+ R"pbdoc(Returns the shape of the problem in units of logical tiles -+ -+ :param problem_size: gemm(M, N, K) -+ :type problem_size: :class:`cutlass.gemm.GemmCoord` -+ )pbdoc") -+ .def("get_grid_shape", &T::get_grid_shape, -+ py::arg("tiled_shape"), -+ R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") -+ .def("tag", [](const T & swizzle){ -+ return demangle(typeid(T).name()); -+ }, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc"); -+} -+ -+template -+void bind_dgrad_swizzle(py::module & m, std::string name) { -+ py::class_(m, name.c_str(), -+ R"pbdoc(Threadblock swizzling function for strided dgrad convolution)pbdoc") -+ .def(py::init<>()) -+ .def("get_tiled_shape", -+ py::overload_cast( -+ &T::get_tiled_shape, py::const_ -+ ), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), -+ R"pbdoc(Returns the shape of the problem in units of logical tiles -+ -+ :param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC) -+ :type problem_size: :class:`cutlass.gemm.GemmCoord`) -+ )pbdoc") -+ .def("get_grid_shape", [](const T & swizzle, cutlass::gemm::GemmCoord tiled_shape) { -+ return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); -+ }, py::arg("tiled_shape"), -+ R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") -+ .def("tag", [](const T & swizzle){ -+ return demangle(typeid(T).name()); -+ }, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc"); -+} -+ -+void bind_threadblock_swizzle(py::module &m) { -+ -+ py::class_(m, "dim3", -+ R"pbdoc(A int3 type xyz contains three integers)pbdoc") -+ .def(py::init(), -+ py::arg("x"), py::arg("y"), py::arg("z")) -+ .def_readwrite("x", &dim3::x, R"pbdoc(get value x)pbdoc") -+ .def_readwrite("y", &dim3::y, R"pbdoc(get value y)pbdoc") -+ .def_readwrite("z", &dim3::z, R"pbdoc(get value z)pbdoc"); -+ -+ bind_identity_swizzle>(m, "IdentitySwizzle1"); -+ bind_identity_swizzle>(m, "IdentitySwizzle2"); -+ bind_identity_swizzle>(m, "IdentitySwizzle4"); -+ bind_identity_swizzle>(m, "IdentitySwizzle8"); -+ -+ bind_swizzle(m, "HorizontalSwizzle", R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc"); -+ bind_swizzle(m, "BatchedIdentitySwizzle", R"pbdoc(Threadblock swizzling function for batched GEMMs)pbdoc"); -+ -+ bind_dgrad_swizzle>(m, "StridedDgradIdentitySwizzle1"); -+ bind_dgrad_swizzle>(m, "StridedDgradIdentitySwizzle4"); -+ bind_dgrad_swizzle(m, "StridedDgradHorizontalSwizzle"); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h -new file mode 100644 -index 0000000..547df07 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h -@@ -0,0 +1,78 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind Tensor Coord to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/tensor_coord.h" -+ -+namespace py = pybind11; -+ -+void bind_tensor_coord(py::module &m) { -+ // -+ // Tensor Coords -+ // cutlass/include/cutlass/tensor_coord.h -+ // -+ -+ /// Defines a canonical 4D coordinate used by tensor operations. -+ py::class_(m, "Tensor4DCoord", -+ R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc") -+ .def(py::init(), -+ py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"), -+ R"pbdoc(Helper to construct from N, H, W, and C)pbdoc") -+ .def("at", py::overload_cast(&cutlass::Tensor4DCoord::at), -+ py::arg("dim"), -+ R"pbdoc(Gets the index of a given Coord element)pbdoc") -+ .def("size", [](const cutlass::Tensor4DCoord & coord) { -+ return coord.at(0) * coord.at(1) * coord.at(2) * coord.at(3);}, -+ R"pbdoc(The size of the tensor coord)pbdoc"); -+ -+ py::class_>(m, "Tensor3DCoord", -+ R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc") -+ .def("at", py::overload_cast(&cutlass::Coord<3>::at), -+ py::arg("dim"), -+ R"pbdoc(Gets the index of a given Coord element)pbdoc"); -+ -+ // Matrix Size -+ py::class_(m, "MatrixCoord", -+ R"pbdoc(MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes -+ expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord.)pbdoc") -+ .def(py::init(), -+ py::arg("row"), py::arg("column"), R"pbdoc(Helper to construct from a row and column)pbdoc") -+ .def("row", py::overload_cast<>(&cutlass::MatrixCoord::row), -+ R"pbdoc(Returns the row of the coordinate)pbdoc") -+ .def("column", py::overload_cast<>(&cutlass::MatrixCoord::column), -+ R"pbdoc(Returns the column of the coordinate)pbdoc"); -+ -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/tensor_ref_view.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/tensor_ref_view.h -new file mode 100644 -index 0000000..09a4add ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/tensor_ref_view.h -@@ -0,0 +1,102 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSE -+#include -+ -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "types.h" -+ -+ -+template -+void bind_tensor_ref_view(py::module &m, std::string name) { -+ py::class_>(m, ("TensorRef" + name).c_str()) -+ .def("__init__", [](cutlass::TensorRef& tensor_ref, int64_t address, const L& layout_ ) { -+ T* ptr = reinterpret_cast< T*>(address); -+ new (&tensor_ref) cutlass::TensorRef(ptr, layout_); -+ }) -+ .def("data", [](cutlass::TensorRef& tensor_ref) { -+ T* ptr = tensor_ref.data(); -+ return int64_t(ptr); -+ }) -+ .def("layout", py::overload_cast<>(&cutlass::TensorRef::layout)); -+ -+ m.def("get_tensor_ref", [](int64_t address, TF data, const L& layout_) { -+ T* ptr = reinterpret_cast(address); -+ cutlass::TensorRef tensor_ref = cutlass::TensorRef(ptr, layout_); -+ return tensor_ref; -+ }); -+ -+ py::class_>(m, ("TensorView" + name).c_str()) -+ .def(py::init&, const typename L::TensorCoord &>()); -+} -+ -+ -+void bind_tensor_refs_and_views(py::module &m) { -+ -+ /// float -+ bind_tensor_ref_view(m, "F32RowMajor"); -+ bind_tensor_ref_view(m, "F32ColumnMajor"); -+ bind_tensor_ref_view(m, "F32NHWC"); -+ -+ /// double -+ bind_tensor_ref_view(m, "F64RowMajor"); -+ bind_tensor_ref_view(m, "F64ColumnMajor"); -+ bind_tensor_ref_view(m, "F64NHWC"); -+ -+ // half_t -+ bind_tensor_ref_view(m, "F16RowMajor"); -+ bind_tensor_ref_view(m, "F16ColumnMajor"); -+ bind_tensor_ref_view(m, "F16NHWC"); -+ -+ // bfloat16 -+ bind_tensor_ref_view(m, "BF16RowMajor"); -+ bind_tensor_ref_view(m, "BF16ColumnMajor"); -+ bind_tensor_ref_view(m, "BF16NHWC"); -+ -+ // int8_t -+ bind_tensor_ref_view, cutlass::int8>(m, "S8RowMajorInterleaved32"); -+ bind_tensor_ref_view, cutlass::int8>(m, "S8ColumnMajorInterleaved32"); -+ bind_tensor_ref_view(m, "S8RowMajor"); -+ bind_tensor_ref_view(m, "S8ColumnMajor"); -+ bind_tensor_ref_view(m, "S8NHWC"); -+ bind_tensor_ref_view, cutlass::int8>(m, "S8NC32HW32"); -+ bind_tensor_ref_view, cutlass::int8>(m, "S8C32RSK32"); -+ -+ // int32_t -+ bind_tensor_ref_view(m, "S32RowMajor"); -+ bind_tensor_ref_view(m, "S32ColumnMajor"); -+ bind_tensor_ref_view(m, "S32NHWC"); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/types.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/types.h -new file mode 100644 -index 0000000..da16696 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/types.h -@@ -0,0 +1,146 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind CUTLASS types to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/half.h" -+ -+ -+namespace py = pybind11; -+ -+namespace cutlass { -+ -+/// IEEE 32-bit signed integer -+struct alignas(1) int8 { -+ int8_t storage; -+ explicit int8(int x) { -+ storage = int8_t(x); -+ } -+ explicit int8(float x) { -+ storage = int8_t(x); -+ } -+ -+ int8_t c_value(){return storage;} -+}; -+ -+/// IEEE 32-bit signed integer -+struct alignas(4) int32 { -+ int storage; -+ explicit int32(int x) { -+ storage = x; -+ } -+ explicit int32(float x) { -+ storage = int(x); -+ } -+ -+ int c_value(){return storage;} -+}; -+/// IEEE single-precision floating-point type -+struct alignas(4) float32 { -+ float storage; -+ explicit float32(float x) { -+ storage = x; -+ } -+ explicit float32(int x) { -+ storage = float(x); -+ } -+ float c_value(){return storage;} -+}; -+/// IEEE double-precision floating-point type -+struct alignas(4) float64 { -+ double storage; -+ explicit float64(float x) { -+ storage = double(x); -+ } -+ explicit float64(int x) { -+ storage = double(x); -+ } -+ double c_value(){return storage;} -+}; -+} -+ -+void bind_cutlass_types(py::module &m) { -+ -+ // s8 -+ py::class_(m, "int8") -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::int8::storage) -+ .def("value", &cutlass::int8::c_value); -+ -+ // s32 -+ py::class_(m, "int32") -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::int32::storage) -+ .def("value", &cutlass::int32::c_value); -+ -+ // f16 -+ py::class_(m, "float16") -+ .def(py::init()) -+ .def(py::init()) -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::half_t::storage) -+ .def("value", [](const cutlass::half_t& value) {return value;}); -+ -+ // bf16 -+ py::class_(m, "bfloat16") -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::bfloat16_t::storage) -+ .def("value", [](const cutlass::bfloat16_t& value) {return value;}); -+ -+ // f32 -+ py::class_(m, "float32") -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::float32::storage) -+ .def("value", &cutlass::float32::c_value); -+ -+ // tf32 -+ py::class_(m, "tfloat32") -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::tfloat32_t::storage) -+ .def("value", [](const cutlass::tfloat32_t& value) {return value;}); -+ -+ // f64 -+ py::class_(m, "float64") -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::float64::storage) -+ .def("value", &cutlass::float64::c_value); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/library.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/library.h -new file mode 100644 -index 0000000..5d46f69 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/library.h -@@ -0,0 +1,32 @@ -+#include -+ -+namespace cutlass { -+ -+/// ENUM class for datatypes -+enum class DataType { -+ kB1, kU2, kU4, kU8, -+ kU16, kU32, kU64, kS2, -+ kS4, kS8, kS16, kS32, -+ kS64, kF16, kBF16, kF32, -+ kTF32, kF64, kCF16, kCBF16, -+ kCF32, kCTF32, kCF64, kCS2, -+ kCS4, kCS8, kCS16, kCS32, -+ kCS64, kCU2, kCU4, kCU8, -+ kCU16, kCU32, kCU64, kInvalid -+}; -+ -+/// ENUM class for LayoutTypes -+enum class LayoutType { -+ kColumnMajor, kRowMajor, -+ kColumnMajorInterleaved2, kRowMajorInterleaved2, -+ kColumnMajorInterleaved32, kRowMajorInterleaved32, -+ kColumnMajorInterleaved64, kRowMajorInterleaved64, -+ kTensorNHWC, kTensorNDHWC, kTensorNCHW, kTensorNGHWC, -+ kTensorNC32HW32, kTensorNC64HW64, kTensorC32RSK32, -+ kTensorC64RSK64 -+}; -+ -+/// ENUM class for opcode class -+ -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/conv_problems.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/conv_problems.h -new file mode 100644 -index 0000000..f2c8ec8 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/conv_problems.h -@@ -0,0 +1,54 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind convolution problems to python -+*/ -+#pragma once -+#include -+#include -+ -+ -+#include "unit/conv/device/conv2d_problems.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+namespace py = pybind11; -+ -+PYBIND11_MAKE_OPAQUE(std::vector); -+ -+void bind_conv_problem_size_test(py::module &m) { -+ -+ py::bind_vector>(m, "Conv2dProblemVector") -+ .def("size", &std::vector::size); -+ // Get Conv2d problem sizes -+ py::class_(m, "TestbedConv2dProblemSizes") -+ .def(py::init()) -+ .def_readonly("conv2d_default_sizes", &test::conv::device::TestbedConv2dProblemSizes::conv2d_default_sizes); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/convolution.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/convolution.h -new file mode 100644 -index 0000000..dd97d28 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/convolution.h -@@ -0,0 +1,49 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind convolution related types to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "conv_problems.h" -+#include "host.h" -+ -+namespace py = pybind11; -+ -+void bind_convolution_test(py::module &m) { -+ // Conv problem sizes -+ bind_conv_problem_size_test(m); -+ -+ py::module_ host_submodule = m.def_submodule("host"); -+ bind_conv_host_references(host_submodule); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/host.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/host.h -new file mode 100644 -index 0000000..ca15ce6 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/host.h -@@ -0,0 +1,180 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind Convolution host test helpers to python -+*/ -+#pragma once -+#include -+#include -+#include "unit/conv/device/cache_testbed_output.h" -+ -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+namespace py = pybind11; -+ -+ -+template -+void bind_conv2d_host(py::module &m) { -+ m.def("conv2d", \ -+ &cutlass::reference::host::Conv2d< \ -+ Ta, La, Tb, Lb, Tc, Lc, Te, Tacc>); -+ -+ m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey); -+} -+ -+template -+void bind_conv2d_host_sat(py::module &m) { -+ m.def("conv2d", \ -+ &cutlass::reference::host::Conv2d< \ -+ Ta, La, Tb, Lb, Tc, Lc, Te, Tacc, cutlass::NumericConverterClamp>); -+ -+ m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey); -+} -+ -+template -+void bind_conv2d_host_nhwc(py::module &m) { -+ bind_conv2d_host< -+ Ta, cutlass::layout::TensorNHWC, -+ Tb, cutlass::layout::TensorNHWC, -+ Tc, cutlass::layout::TensorNHWC, -+ Tacc, Te>(m); -+} -+ -+template -+void bind_conv2d_host_nc32hw32(py::module &m) { -+ bind_conv2d_host_sat< -+ Ta, cutlass::layout::TensorNCxHWx<32>, -+ Tb, cutlass::layout::TensorCxRSKx<32>, -+ Tc, cutlass::layout::TensorNCxHWx<32>, -+ Tacc, Te>(m); -+} -+ -+ -+template -+void bind_tensor_equals(py::module &m) { -+ m.def("equals", py::overload_cast< -+ const cutlass::TensorView&, const cutlass::TensorView&>( -+ &cutlass::reference::host::TensorEquals -+ )); -+} -+ -+#define BIND_TENSOR_HASH(Element, Layout) { \ -+ m.def("TensorHash", &test::conv::device::TensorHash, py::arg("view"), py::arg("hash") = test::conv::device::CRC32(), py::arg("crc")=uint32_t()); \ -+} -+ -+void bind_conv_host_references(py::module &m) { -+ // -+ // Conv2d reference on host -+ // tools/util/include/cutlass/util/reference/host/convolution.h -+ -+ /// double -+ bind_conv2d_host_nhwc(m); -+ /// float -+ bind_conv2d_host_nhwc(m); -+ /// half -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ /// bfloat16 -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ /// s8 -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ -+ // -+ // Compare whether two tensors are equal -+ // -+ /// double -+ bind_tensor_equals(m); -+ /// float -+ bind_tensor_equals(m); -+ /// half -+ bind_tensor_equals(m); -+ /// bfloat16 -+ bind_tensor_equals(m); -+ /// s32 -+ bind_tensor_equals(m); -+ bind_tensor_equals>(m); -+ /// s8 -+ bind_tensor_equals(m); -+ bind_tensor_equals>(m); -+ -+ /// Cache -+ py::class_(m, "CachedTestKey") -+ .def(py::init<>()) -+ .def(py::init()); -+ -+ py::class_(m, "CachedTestResult") -+ .def(py::init<>()) -+ .def(py::init()) -+ .def_readwrite("D", &test::conv::device::CachedTestResult::D); -+ -+ py::class_(m, "CachedTestResultListing") -+ .def(py::init()) -+ .def("find", &test::conv::device::CachedTestResultListing::find) -+ .def("append", &test::conv::device::CachedTestResultListing::append) -+ .def("write", &test::conv::device::CachedTestResultListing::write); -+ -+ py::class_(m, "CRC32") -+ .def(py::init<>()); -+ -+ BIND_TENSOR_HASH(double, cutlass::layout::TensorNHWC) -+ BIND_TENSOR_HASH(float, cutlass::layout::TensorNHWC); -+ BIND_TENSOR_HASH(cutlass::half_t, cutlass::layout::TensorNHWC); -+ BIND_TENSOR_HASH(cutlass::bfloat16_t, cutlass::layout::TensorNHWC); -+ BIND_TENSOR_HASH(int32_t, cutlass::layout::TensorNHWC); -+ BIND_TENSOR_HASH(int8_t, cutlass::layout::TensorNCxHWx<32>); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h -new file mode 100644 -index 0000000..749d8d9 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h -@@ -0,0 +1,45 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind gemm test to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "host.h" -+ -+namespace py = pybind11; -+ -+void bind_gemm_test(py::module &m) { -+ py::module_ host_submodule = m.def_submodule("host"); -+ bind_gemm_host_reference(host_submodule); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h -new file mode 100644 -index 0000000..c6aeee8 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h -@@ -0,0 +1,431 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind gemm test host functions to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/host_reorder.h" -+ -+#include "cutlass/functional.h" -+ -+namespace py = pybind11; -+ -+ -+template< -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename AccumulatorType, typename ComputeType, -+ typename InnerProductOp> -+void bind_host_gemm_saturate(py::module &m) { -+ m.def("gemm_saturate", py::overload_cast< -+ cutlass::gemm::GemmCoord, ComputeType, -+ cutlass::TensorRef, -+ cutlass::TensorRef, -+ ComputeType, -+ cutlass::TensorRef, -+ cutlass::TensorRef, -+ AccumulatorType>( -+ &cutlass::reference::host::compute_gemm< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ComputeType, -+ AccumulatorType, -+ InnerProductOp, -+ cutlass::NumericConverterClamp> -+ )); -+} -+ -+template< -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename AccumulatorType, typename ComputeType, -+ typename InnerProductOp> -+void bind_host_gemm(py::module &m) { -+ m.def("gemm", py::overload_cast< -+ cutlass::gemm::GemmCoord, ComputeType, -+ cutlass::TensorRef, -+ cutlass::TensorRef, -+ ComputeType, -+ cutlass::TensorRef, -+ cutlass::TensorRef, -+ AccumulatorType>( -+ &cutlass::reference::host::compute_gemm< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ComputeType, -+ AccumulatorType, -+ InnerProductOp, -+ cutlass::NumericConverter> -+ )); -+} -+ -+ -+template< -+ typename ElementA, typename ElementB, typename ElementC, -+ typename AccumulatorType, typename ComputeType> -+void bind_host_gemm_multiply_add(py::module &m) { -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::RowMajor, -+ ComputeType, AccumulatorType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::RowMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::RowMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::RowMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+} -+ -+template< -+ typename ElementA, typename ElementB, typename ElementC, -+ typename AccumulatorType, typename ComputeType> -+void bind_host_gemm_multiply_add_saturate(py::module &m) { -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::RowMajor, -+ ComputeType, AccumulatorType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::RowMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::RowMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::RowMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+} -+ -+ -+template< -+ typename ElementA, typename ElementB, typename ElementC, -+ typename AccumulatorType, typename ComputeType> -+void bind_host_gemm_multiply_add_interleaved(py::module &m) { -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ ComputeType, AccumulatorType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+} -+ -+template< -+ typename ElementA, typename ElementB, typename ElementC, -+ typename AccumulatorType, typename ComputeType> -+void bind_host_gemm_multiply_add_saturate_interleaved(py::module &m) { -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ ComputeType, AccumulatorType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+} -+ -+#define BIND_TENSOR_EQUAL(Element, Layout) { \ -+ m.def("equals", py::overload_cast< \ -+ const cutlass::TensorView&, const cutlass::TensorView&>( \ -+ &cutlass::reference::host::TensorEquals)); \ -+} -+ -+void bind_gemm_host_reference(py::module &m) { -+ -+ /// double -+ bind_host_gemm_multiply_add(m); -+ /// float -+ bind_host_gemm_multiply_add(m); -+ /// half_t -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ /// bfloat16 -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ -+ /// s8 -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ -+ // float -+ BIND_TENSOR_EQUAL(float, cutlass::layout::RowMajor); -+ BIND_TENSOR_EQUAL(float, cutlass::layout::ColumnMajor); -+ -+ // double -+ BIND_TENSOR_EQUAL(double, cutlass::layout::RowMajor); -+ BIND_TENSOR_EQUAL(double, cutlass::layout::ColumnMajor); -+ -+ // half_t -+ BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::RowMajor); -+ BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::ColumnMajor); -+ -+ // bfloat16 -+ BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::RowMajor); -+ BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::ColumnMajor); -+ -+ // int32_t -+ BIND_TENSOR_EQUAL(int32_t, cutlass::layout::RowMajor); -+ BIND_TENSOR_EQUAL(int32_t, cutlass::layout::ColumnMajor); -+ -+ // int8_t -+ BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajor); -+ BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajor); -+ BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajorInterleaved<32>); -+ BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajorInterleaved<32>); -+ -+ -+} -diff --git a/3rdparty/cutlass/tools/library/src/conv2d_operation.h b/3rdparty/cutlass/tools/library/src/conv2d_operation.h -new file mode 100644 -index 0000000..5d06e72 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/conv2d_operation.h -@@ -0,0 +1,642 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all CONV operation kinds in CUTLASS Library. -+*/ -+ -+#pragma once -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" -+#include "cutlass/conv/kernel/default_depthwise_fprop.h" -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/conv/device/direct_convolution.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/core_io.h" -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Conv2dOperationBase : public Operation { -+public: -+ -+ using Operator = Operator_; -+ -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; -+ static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ ConvDescription description_; -+ -+public: -+ -+ /// Constructor -+ Conv2dOperationBase(char const *name = "unknown_conv2d") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.kind = OperationKind::kConv2d; -+ description_.conv_dim = Operator::kConvDim; -+ -+ description_.iterator_algorithm = IteratorAlgorithmMap::kId; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::UnderlyingKernel::WarpCount::kM, -+ Operator::UnderlyingKernel::WarpCount::kN, -+ Operator::UnderlyingKernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(); -+ description_.B = make_TensorDescription(); -+ description_.C = make_TensorDescription(); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ // TODO: Add split k mode Serial and parallel to convolutions -+ // description_.split_k_mode = Operator::kSplitK ? SplitKMode::kSerial : SplitKMode::kNone; -+ -+ } -+ -+ /// Returns the description of the GEMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Conv2d library operation class for cutlass profiler -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+template -+class Conv2dOperation : public Conv2dOperationBase { -+public: -+ -+ using Operator = Operator_; -+ -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ /// Constructor -+ Conv2dOperation(char const *name = "unknown_conv2d_fprop") : Conv2dOperationBase(name) { -+ this->description_.conv_kind = ConvKindMap::kId; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ Conv2dConfiguration const *configuration) { -+ -+ -+ operator_args.problem_size = configuration->problem_size; -+ -+ operator_args.ref_A = -+ { -+ nullptr, -+ LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_B = -+ { -+ nullptr, -+ LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_C = -+ { -+ nullptr, -+ LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_D = -+ { -+ nullptr, -+ LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.split_k_mode = configuration->split_k_mode; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ ConvArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.output_op = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.output_op = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); -+ operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); -+ operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); -+ operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ Conv2dConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ ConvArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ return Operator::get_workspace_size(args); -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ //std::cout << "initialize library::Conv2dOperation" << std::endl; -+ //print_operator_args(args); -+ return op->initialize(args, device_workspace, stream); -+ -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args, device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ //std::cout << "run library::Conv2dOperation" << std::endl; -+ //print_operator_args(args); -+ return op->run(stream); -+ } -+ -+ /// Call print_operator_args from the Conv2dOperation::initialize() -+ // to dump arguments passed on to cutlass operator for debugging -+ void print_operator_args(OperatorArguments &operator_args) const { -+ std::cout << "Conv2dOperation::OperatorArguments" << std::endl -+ << " problem_size:" << std::endl -+ << operator_args.problem_size << std::endl -+ << " split_k_mode: " -+ << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl -+ << " epilouge (alpha, beta): " -+ << operator_args.output_op.alpha << ", " -+ << operator_args.output_op.beta << std::endl -+ << " ref_A (ptr, {stride}): " -+ << operator_args.ref_A.data() << ", {" -+ << operator_args.ref_A.stride(0) << ", " -+ << operator_args.ref_A.stride(1) << ", " -+ << operator_args.ref_A.stride(2) << "}" << std::endl -+ << " ref_B (ptr, {stride}): " -+ << operator_args.ref_B.data() << ", {" -+ << operator_args.ref_B.stride(0) << ", " -+ << operator_args.ref_B.stride(1) << ", " -+ << operator_args.ref_B.stride(2) << "}" << std::endl -+ << " ref_C (ptr, {stride}): " -+ << operator_args.ref_C.data() << ", {" -+ << operator_args.ref_C.stride(0) << ", " -+ << operator_args.ref_C.stride(1) << ", " -+ << operator_args.ref_C.stride(2) << "}" << std::endl -+ << " ref_D (ptr, {stride}): " -+ << operator_args.ref_D.data() << ", {" -+ << operator_args.ref_D.stride(0) << ", " -+ << operator_args.ref_D.stride(1) << ", " -+ << operator_args.ref_D.stride(2) << "}" << std::endl; -+ } -+}; -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// DirectConv2d library operation class for cutlass profiler -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class DirectConv2dOperation : public Conv2dOperation { -+public: -+ -+ using Operator = Operator_; -+ using Base = Conv2dOperation; -+ -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ /// Constructor -+ DirectConv2dOperation(char const *name = "unknown_direct)conv2d_fprop") : Conv2dOperation(name) { -+ this->description_.conv_kind = ConvKindMap::kId; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ Conv2dConfiguration const *configuration) { -+ -+ -+ operator_args.problem_size = configuration->problem_size; -+ -+ operator_args.ref_A = -+ { -+ nullptr, -+ LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_B = -+ { -+ nullptr, -+ LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_reordered_B = -+ { -+ nullptr, -+ LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_C = -+ { -+ nullptr, -+ LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_D = -+ { -+ nullptr, -+ LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.split_k_mode = configuration->split_k_mode; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ ConvArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.output_op = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.output_op = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); -+ operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); -+ operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); -+ operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); -+ operator_args.ref_reordered_B.reset(static_cast(const_cast(arguments->reordered_B))); -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ Conv2dConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ ConvArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ return Operator::get_workspace_size(args); -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ //std::cout << "initialize library::Conv2dOperation" << std::endl; -+ //print_operator_args(args); -+ return op->initialize(args, device_workspace, stream); -+ -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args, device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ //std::cout << "run library::Conv2dOperation" << std::endl; -+ //print_operator_args(args); -+ return op->run(stream); -+ } -+ -+ /// Call print_operator_args from the Conv2dOperation::initialize() -+ // to dump arguments passed on to cutlass operator for debugging -+ void print_operator_args(OperatorArguments &operator_args) const { -+ std::cout << "Conv2dOperation::OperatorArguments" << std::endl -+ << " problem_size:" << std::endl -+ << operator_args.problem_size << std::endl -+ << " split_k_mode: " -+ << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl -+ << " epilouge (alpha, beta): " -+ << operator_args.output_op.alpha << ", " -+ << operator_args.output_op.beta << std::endl -+ << " ref_A (ptr, {stride}): " -+ << operator_args.ref_A.data() << ", {" -+ << operator_args.ref_A.stride(0) << ", " -+ << operator_args.ref_A.stride(1) << ", " -+ << operator_args.ref_A.stride(2) << "}" << std::endl -+ << " ref_B (ptr, {stride}): " -+ << operator_args.ref_B.data() << ", {" -+ << operator_args.ref_B.stride(0) << ", " -+ << operator_args.ref_B.stride(1) << ", " -+ << operator_args.ref_B.stride(2) << "}" << std::endl -+ << " ref_C (ptr, {stride}): " -+ << operator_args.ref_C.data() << ", {" -+ << operator_args.ref_C.stride(0) << ", " -+ << operator_args.ref_C.stride(1) << ", " -+ << operator_args.ref_C.stride(2) << "}" << std::endl -+ << " ref_D (ptr, {stride}): " -+ << operator_args.ref_D.data() << ", {" -+ << operator_args.ref_D.stride(0) << ", " -+ << operator_args.ref_D.stride(1) << ", " -+ << operator_args.ref_D.stride(2) << "}" << std::endl; -+ } -+}; -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/conv3d_operation.h b/3rdparty/cutlass/tools/library/src/conv3d_operation.h -new file mode 100644 -index 0000000..0e2a1c6 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/conv3d_operation.h -@@ -0,0 +1,385 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all CONV operation kinds in CUTLASS Library. -+*/ -+ -+#pragma once -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv3d_fprop.h" -+#include "cutlass/conv/kernel/default_conv3d_dgrad.h" -+#include "cutlass/conv/kernel/default_conv3d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/core_io.h" -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Conv3dOperationBase : public Operation { -+public: -+ -+ using Operator = Operator_; -+ -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; -+ static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ ConvDescription description_; -+ -+public: -+ -+ /// Constructor -+ Conv3dOperationBase(char const *name = "unknown_conv3d") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.kind = OperationKind::kConv3d; -+ description_.conv_dim = Operator::kConvDim; -+ -+ description_.iterator_algorithm = IteratorAlgorithmMap::kId; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::UnderlyingKernel::WarpCount::kM, -+ Operator::UnderlyingKernel::WarpCount::kN, -+ Operator::UnderlyingKernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(); -+ description_.B = make_TensorDescription(); -+ description_.C = make_TensorDescription(); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ } -+ -+ /// Returns the description of the GEMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Conv2d library operation class for cutlass profiler -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+template -+class Conv3dOperation : public Conv3dOperationBase { -+public: -+ -+ using Operator = Operator_; -+ -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ /// Constructor -+ Conv3dOperation(char const *name = "unknown_conv3d_fprop") : Conv3dOperationBase(name) { -+ this->description_.conv_kind = ConvKindMap::kId; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ Conv3dConfiguration const *configuration) { -+ -+ -+ operator_args.problem_size = configuration->problem_size; -+ -+ operator_args.ref_A = -+ { -+ nullptr, -+ LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_B = -+ { -+ nullptr, -+ LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_C = -+ { -+ nullptr, -+ LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_D = -+ { -+ nullptr, -+ LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.split_k_mode = configuration->split_k_mode; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ ConvArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.output_op = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.output_op = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); -+ operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); -+ operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); -+ operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ Conv3dConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ ConvArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ return Operator::get_workspace_size(args); -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ //std::cout << "initialize library::Conv3dOperation" << std::endl; -+ //print_operator_args(args); -+ return op->initialize(args, device_workspace, stream); -+ -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args, device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ //std::cout << "run library::Conv3dOperation" << std::endl; -+ //print_operator_args(args); -+ return op->run(stream); -+ } -+ -+ /// Call print_operator_args from the Conv3dOperation::initialize() -+ // to dump arguments passed on to cutlass operator for debugging -+ void print_operator_args(OperatorArguments &operator_args) const { -+ std::cout << "Conv3dOperation::OperatorArguments" << std::endl -+ << " problem_size: " -+ << operator_args.problem_size << std::endl -+ << " split_k_mode: " -+ << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl -+ << " epilouge (alpha, beta): " -+ << operator_args.output_op.alpha << ", " -+ << operator_args.output_op.beta << std::endl -+ << " ref_A (ptr, {stride}): " -+ << operator_args.ref_A.data() << ", {" -+ << operator_args.ref_A.stride(0) << ", " -+ << operator_args.ref_A.stride(1) << ", " -+ << operator_args.ref_A.stride(2) << ", " -+ << operator_args.ref_A.stride(3) << "}" << std::endl -+ << " ref_B (ptr, {stride}): " -+ << operator_args.ref_B.data() << ", {" -+ << operator_args.ref_B.stride(0) << ", " -+ << operator_args.ref_B.stride(1) << ", " -+ << operator_args.ref_B.stride(2) << ", " -+ << operator_args.ref_B.stride(3) << "}" << std::endl -+ << " ref_C (ptr, {stride}): " -+ << operator_args.ref_C.data() << ", {" -+ << operator_args.ref_C.stride(0) << ", " -+ << operator_args.ref_C.stride(1) << ", " -+ << operator_args.ref_C.stride(2) << ", " -+ << operator_args.ref_C.stride(3) << "}" << std::endl -+ << " ref_D (ptr, {stride}): " -+ << operator_args.ref_D.data() << ", {" -+ << operator_args.ref_D.stride(0) << ", " -+ << operator_args.ref_D.stride(1) << ", " -+ << operator_args.ref_D.stride(2) << ", " -+ << operator_args.ref_D.stride(3) << "}" << std::endl; -+ } -+}; -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/gemm_operation.h b/3rdparty/cutlass/tools/library/src/gemm_operation.h -new file mode 100644 -index 0000000..ab5704b ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/gemm_operation.h -@@ -0,0 +1,1356 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all GEMM operation kinds in CUTLASS Library. -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+#include "cutlass/gemm/device/gemm_batched.h" -+#include "cutlass/gemm/device/gemm_array.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmOperationBase : public Operation { -+public: -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ // assuming all tensors use same type for StrideIndex -+ using StrideIndex = typename Operator::LayoutA::Index; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ GemmDescription description_; -+ -+public: -+ -+ /// Constructor -+ GemmOperationBase(char const *name = "unknown_gemm") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.kind = OperationKind::kGemm; -+ description_.gemm_kind = GemmKind::kGemm; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::GemmKernel::WarpCount::kM, -+ Operator::GemmKernel::WarpCount::kN, -+ Operator::GemmKernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(Operator::kAlignmentA); -+ description_.B = make_TensorDescription(Operator::kAlignmentB); -+ description_.C = make_TensorDescription(Operator::kAlignmentC); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.split_k_mode = SplitKMode::kNone; -+ description_.transform_A = ComplexTransformMap::kId; -+ description_.transform_B = ComplexTransformMap::kId; -+ } -+ -+ /// Returns the description of the GEMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmOperation : public GemmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ GemmOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { -+ -+ this->description_.gemm_kind = GemmKind::kGemm; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ GemmConfiguration const *configuration) { -+ -+ operator_args.problem_size = configuration->problem_size; -+ -+ operator_args.ref_A = {nullptr, configuration->lda}; -+ operator_args.ref_B = {nullptr, configuration->ldb}; -+ operator_args.ref_C = {nullptr, configuration->ldc}; -+ operator_args.ref_D = {nullptr, configuration->ldd}; -+ -+ operator_args.split_k_slices = configuration->split_k_slices; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ GemmArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ operator_args.ref_A.reset(static_cast(arguments->A)); -+ operator_args.ref_B.reset(static_cast(arguments->B)); -+ operator_args.ref_C.reset(static_cast(arguments->C)); -+ operator_args.ref_D.reset(static_cast(arguments->D)); -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ GemmConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ GemmArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ return Operator::get_workspace_size(args); -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ return op->initialize(args, device_workspace, stream); -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return op->run(stream); -+ } -+ -+ void print_operator_args(OperatorArguments &operator_args) const { -+#if 0 -+ std::cout << "GemmOperation::OperatorArguments" << std::endl; -+ std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; -+ std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; -+ std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; -+ std::cout << " beta: " << operator_args.epilogue.beta << std::endl; -+ std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; -+ std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; -+ std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; -+ std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; -+ std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; -+ std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; -+ std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; -+#endif -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmSparseOperation : public GemmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementE = typename Operator::ElementE; -+ using LayoutE = typename Operator::LayoutE; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ GemmSparseOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { -+ -+ this->description_.kind = OperationKind::kSparseGemm; -+ this->description_.gemm_kind = GemmKind::kSparse; -+ this->description_.E = make_TensorDescription(Operator::kAlignmentE); -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ SparseGemmConfiguration const *configuration) { -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.ref_A = {nullptr, configuration->lda}; -+ operator_args.ref_B = {nullptr, configuration->ldb}; -+ operator_args.ref_C = {nullptr, configuration->ldc}; -+ operator_args.ref_D = {nullptr, configuration->ldd}; -+ operator_args.ref_E = {nullptr, configuration->lde}; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ SparseGemmArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ operator_args.ref_A.reset(static_cast(arguments->A)); -+ operator_args.ref_B.reset(static_cast(arguments->B)); -+ operator_args.ref_C.reset(static_cast(arguments->C)); -+ operator_args.ref_D.reset(static_cast(arguments->D)); -+ operator_args.ref_E.reset(static_cast(arguments->E)); -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ SparseGemmConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ SparseGemmArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ return Operator::get_workspace_size(args); -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ return op->initialize(args, device_workspace, stream); -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return op->run(stream); -+ } -+ -+ void print_operator_args(OperatorArguments &operator_args) const { -+#if 0 -+ std::cout << "GemmOperation::OperatorArguments" << std::endl; -+ std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; -+ std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; -+ std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; -+ std::cout << " beta: " << operator_args.epilogue.beta << std::endl; -+ std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; -+ std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; -+ std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; -+ std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; -+ std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; -+ std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; -+ std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; -+#endif -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmUniversalOperation : public GemmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ GemmUniversalOperation(char const *name = "unknown_gemm"): -+ GemmOperationBase(name) { -+ -+ this->description_.gemm_kind = GemmKind::kUniversal; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ GemmUniversalConfiguration const *configuration) { -+ -+ operator_args.mode = configuration->mode; -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ operator_args.lda = (configuration->lda); -+ operator_args.ldb = (configuration->ldb); -+ operator_args.ldc = (configuration->ldc); -+ operator_args.ldd = (configuration->ldd); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ GemmUniversalArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A = arguments->A; -+ operator_args.ptr_B = arguments->B; -+ operator_args.ptr_C = arguments->C; -+ operator_args.ptr_D = arguments->D; -+ -+ operator_args.batch_stride_A = arguments->batch_stride_A; -+ operator_args.batch_stride_B = arguments->batch_stride_B; -+ operator_args.batch_stride_C = arguments->batch_stride_C; -+ operator_args.batch_stride_D = arguments->batch_stride_D; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ GemmUniversalConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ GemmUniversalArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = op->run(stream); -+ -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmPlanarComplexOperation : public GemmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ GemmPlanarComplexOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { -+ -+ this->description_.gemm_kind = GemmKind::kPlanarComplex; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ GemmPlanarComplexConfiguration const *configuration) { -+ -+ operator_args.mode = cutlass::gemm::GemmUniversalMode::kBatched; -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ -+ operator_args.lda_real = configuration->lda_real; -+ operator_args.lda_imag = configuration->lda_imag; -+ operator_args.ldb_real = configuration->ldb_real; -+ operator_args.ldb_imag = configuration->ldb_imag; -+ operator_args.ldc_real = configuration->ldc_real; -+ operator_args.ldc_imag = configuration->ldc_imag; -+ operator_args.ldd_real = configuration->ldd_real; -+ operator_args.ldd_imag = configuration->ldd_imag; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ GemmPlanarComplexArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast const *>(arguments->alpha), -+ *static_cast const *>(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast const *>(arguments->alpha), -+ static_cast const *>(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A_real = arguments->A_real; -+ operator_args.ptr_A_imag = arguments->A_imag; -+ operator_args.ptr_B_real = arguments->B_real; -+ operator_args.ptr_B_imag = arguments->B_imag; -+ operator_args.ptr_C_real = arguments->C_real; -+ operator_args.ptr_C_imag = arguments->C_imag; -+ operator_args.ptr_D_real = arguments->D_real; -+ operator_args.ptr_D_imag = arguments->D_imag; -+ -+ operator_args.batch_stride_A = arguments->batch_stride_A_real; -+ operator_args.batch_stride_A_imag = arguments->batch_stride_A_imag; -+ operator_args.batch_stride_B = arguments->batch_stride_B_real; -+ operator_args.batch_stride_B_imag = arguments->batch_stride_B_imag; -+ operator_args.batch_stride_C = arguments->batch_stride_C_real; -+ operator_args.batch_stride_C_imag = arguments->batch_stride_C_imag; -+ operator_args.batch_stride_D = arguments->batch_stride_D_real; -+ operator_args.batch_stride_D_imag = arguments->batch_stride_D_imag; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ GemmPlanarComplexConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ GemmPlanarComplexArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = op->run(stream); -+ -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmPlanarComplexArrayOperation : public GemmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ GemmPlanarComplexArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { -+ -+ this->description_.gemm_kind = GemmKind::kPlanarComplexArray; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ GemmPlanarComplexArrayConfiguration const *configuration) { -+ -+ operator_args.mode = cutlass::gemm::GemmUniversalMode::kArray; -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ operator_args.lda_real = configuration->lda_real; -+ operator_args.lda_imag = configuration->lda_imag; -+ operator_args.ldb_real = configuration->ldb_real; -+ operator_args.ldb_imag = configuration->ldb_imag; -+ operator_args.ldc_real = configuration->ldc_real; -+ operator_args.ldc_imag = configuration->ldc_imag; -+ operator_args.ldd_real = configuration->ldd_real; -+ operator_args.ldd_imag = configuration->ldd_imag; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ GemmPlanarComplexArrayArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast const *>(arguments->alpha), -+ *static_cast const *>(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast const *>(arguments->alpha), -+ static_cast const *>(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A_real = arguments->A_real; -+ operator_args.ptr_A_imag = arguments->A_imag; -+ operator_args.ptr_B_real = arguments->B_real; -+ operator_args.ptr_B_imag = arguments->B_imag; -+ operator_args.ptr_C_real = arguments->C_real; -+ operator_args.ptr_C_imag = arguments->C_imag; -+ operator_args.ptr_D_real = arguments->D_real; -+ operator_args.ptr_D_imag = arguments->D_imag; -+ -+ operator_args.ptr_M = arguments->M; -+ operator_args.ptr_N = arguments->N; -+ operator_args.ptr_K = arguments->K; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ GemmPlanarComplexArrayConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ GemmPlanarComplexArrayArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = op->run(stream); -+ -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmGroupedOperation : public GemmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ GemmGroupedOperation(char const *name = "unknown_gemm"): -+ GemmOperationBase(name) { -+ -+ this->description_.gemm_kind = GemmKind::kGrouped; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &op_args, -+ GemmGroupedConfiguration const *config) { -+ -+ op_args.problem_count = config->problem_count; -+ op_args.threadblock_count = config->threadblock_count; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &op_args, -+ GemmGroupedArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ -+ op_args.output_op = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice) { -+ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ -+ op_args.output_op = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ op_args.problem_sizes = arguments->problem_sizes; -+ -+ op_args.ptr_A = static_cast(arguments->ptr_A); -+ op_args.ptr_B = static_cast(arguments->ptr_B); -+ op_args.ptr_C = static_cast(arguments->ptr_C); -+ op_args.ptr_D = static_cast(arguments->ptr_D); -+ -+ op_args.lda = arguments->lda; -+ op_args.ldb = arguments->ldb; -+ op_args.ldc = arguments->ldc; -+ op_args.ldd = arguments->ldd; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ GemmGroupedConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ GemmGroupedArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = op->run(stream); -+ -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/gemm_operation_3x.hpp b/3rdparty/cutlass/tools/library/src/gemm_operation_3x.hpp -new file mode 100644 -index 0000000..895de5b ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/gemm_operation_3x.hpp -@@ -0,0 +1,292 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all GEMM operation kinds in CUTLASS Library. -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "cutlass/kernel_hardware_info.hpp" -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmOperation3xBase : public Operation { -+public: -+ using Operator = Operator_; -+ using OperatorArguments = typename Operator::Arguments; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ // assuming all tensors use same type for StrideIndex -+ using StrideIndex = typename Operator::LayoutA::Index; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::CollectiveEpilogue::ElementCompute; -+ -+private: -+ -+ GemmDescription description_; -+ -+public: -+ -+ /// Constructor -+ GemmOperation3xBase(char const *name = "unknown_gemm", GemmKind gemm_kind_ = GemmKind::kGemm) { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.kind = OperationKind::kGemm; -+ description_.gemm_kind = gemm_kind_; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { -+ description_.tile_description.cluster_shape = make_Coord( -+ Operator::ClusterShape::kM, -+ Operator::ClusterShape::kN, -+ Operator::ClusterShape::kK); -+ } -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::WarpCount::kM, -+ Operator::WarpCount::kN, -+ Operator::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(Operator::kAlignmentA); -+ description_.B = make_TensorDescription(Operator::kAlignmentB); -+ description_.C = make_TensorDescription(Operator::kAlignmentC); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.split_k_mode = SplitKMode::kNone; -+ description_.transform_A = ComplexTransformMap::kId; -+ description_.transform_B = ComplexTransformMap::kId; -+ } -+ -+ /// Returns the description of the GEMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmUniversal3xOperation : public GemmOperation3xBase { -+public: -+ -+ using Operator = Operator_; -+ using OperatorArguments = typename Operator::Arguments; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using CollectiveMainloop = typename Operator::CollectiveMainloop; -+ using CollectiveEpilogue = typename Operator::CollectiveEpilogue; -+ using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; -+ -+public: -+ -+ /// Constructor -+ GemmUniversal3xOperation(char const *name = "unknown_gemm"): -+ GemmOperation3xBase(name, GemmKind::kUniversal) { -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { -+ // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides -+ // Do nothing here and construct kernel arguments in update_arguments_ instead -+ // We also cannot construct TMA descriptors without all the arguments available -+ -+ if (operator_args.hw_info.sm_count <= 0) { -+ operator_args.hw_info.sm_count = KernelHardwareInfo::query_device_multiprocessor_count(); -+ } -+ operator_args.mode = configuration->mode; -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, GemmUniversalArguments const *arguments) { -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename ThreadEpilogueOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta)); -+ operator_args.epilogue_params.thread_params = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice) { -+ typename ThreadEpilogueOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta)); -+ operator_args.epilogue_params.thread_params = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // TODO: type erase Arguments structure in 3.0 GEMM -+ operator_args.problem_shape = cute::make_shape( -+ arguments->problem_size.m(), -+ arguments->problem_size.n(), -+ arguments->problem_size.k(), -+ arguments->batch_count); -+ -+ // update arguments -+ operator_args.ptr_A = static_cast(arguments->A); -+ operator_args.ptr_B = static_cast(arguments->B); -+ operator_args.epilogue_params.ptr_C = static_cast(arguments->C); -+ operator_args.epilogue_params.ptr_D = static_cast(arguments->D); -+ -+ operator_args.dA = cute::make_int_tuple_from( -+ arguments->lda, arguments->batch_stride_A); -+ operator_args.dB = cute::make_int_tuple_from( -+ arguments->ldb, arguments->batch_stride_B); -+ operator_args.epilogue_params.dC = cute::make_int_tuple_from( -+ arguments->ldc, arguments->batch_stride_C); -+ operator_args.epilogue_params.dD = operator_args.epilogue_params.dC; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ Status can_implement( -+ void const *configuration_ptr, void const *arguments_ptr) const override { -+ -+ GemmUniversalArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ auto status = update_arguments_(args, arguments); -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ uint64_t get_host_workspace_size(void const *configuration) const override { -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ uint64_t get_device_workspace_size( -+ void const *configuration_ptr,void const *arguments_ptr) const override { -+ -+ OperatorArguments args; -+ auto status = update_arguments_( -+ args, static_cast(arguments_ptr)); -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ return size; -+ } -+ -+ /// Initializes the workspace -+ Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const override { -+ Operator *op = new (host_workspace) Operator; -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel -+ Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const override { -+ -+ OperatorArguments args; -+ Status status = update_arguments_(args, static_cast(arguments_ptr)); -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ // We need to call initialize() since we have to rebuild TMA desc for every new set of args -+ status = op->run(args, device_workspace, stream); -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::library -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/handle.cu b/3rdparty/cutlass/tools/library/src/handle.cu -new file mode 100644 -index 0000000..fdfe251 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/handle.cu -@@ -0,0 +1,1172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS Library handle. -+*/ -+#include -+#include -+#include -+ -+#include "cutlass/library/handle.h" -+#include "cutlass/library/singleton.h" -+#include "cutlass/library/util.h" -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Constructor -+Handle::Handle( -+ cudaStream_t stream, -+ size_t workspace_size -+): -+ provider_(Provider::kCUTLASS), -+ stream_(stream), -+ workspace_(nullptr), -+ workspace_size_(0), -+ scalar_pointer_mode_(ScalarPointerMode::kHost), -+ last_operation_(nullptr) { -+ -+ int device_idx = -1; -+ -+ cudaError_t error = cudaGetDevice(&device_idx); -+ if (error != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() failed"); -+ } -+ -+ error = cudaGetDeviceProperties(&device_, device_idx); -+ if (error != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ set_workspace_size(workspace_size); -+ -+ Singleton::get(); -+} -+ -+/// Destructor -+Handle::~Handle() { -+ if (workspace_) { -+ -+ if (workspace_) { -+ cudaFree(workspace_); -+ } -+ -+ workspace_ = nullptr; -+ workspace_size_ = 0; -+ } -+} -+ -+/// Move constructor -+Handle::Handle(Handle && handle) { -+ device_ = handle.device_; -+ workspace_size_ = handle.workspace_size_; -+ workspace_ = handle.workspace_; -+ stream_ = handle.stream_; -+ scalar_pointer_mode_ = handle.scalar_pointer_mode_; -+ -+ handle.workspace_ = nullptr; -+ handle.workspace_size_ = 0; -+} -+ -+/// Move assignment operator -+Handle & Handle::operator=(Handle && handle) { -+ -+ provider_ = handle.provider_; -+ device_ = handle.device_; -+ workspace_size_ = handle.workspace_size_; -+ workspace_ = handle.workspace_; -+ stream_ = handle.stream_; -+ scalar_pointer_mode_ = handle.scalar_pointer_mode_; -+ -+ handle.workspace_ = nullptr; -+ handle.workspace_size_ = 0; -+ -+ return *this; -+} -+ -+int Handle::compute_capability() const { -+ return device_.major * 10 + device_.minor; -+} -+ -+/// Sets the current CUDA stream -+void Handle::set_stream(cudaStream_t stream) { -+ stream_ = stream; -+} -+ -+/// Gets the current CUDA stream -+cudaStream_t Handle::get_stream() const { -+ return stream_; -+} -+ -+/// Gets the current provider -+Provider Handle::get_provider() const { -+ return provider_; -+} -+ -+/// Sets the provider of operations -+void Handle::set_provider(Provider provider) { -+ provider_ = provider; -+} -+ -+/// Gets the device workspace size -+size_t Handle::get_workspace_size() const { -+ return workspace_size_; -+} -+ -+/// Gets a pointer to the device workspace allocation in Global Memory -+void *Handle::get_workspace() const { -+ return workspace_; -+} -+ -+/// Sets the size of device workspace, invalidating previous calls to get_device_workspace() -+void Handle::set_workspace_size(size_t bytes) { -+ if (bytes != workspace_size_) { -+ -+ if (workspace_) { -+ cudaFree(workspace_); -+ } -+ -+ workspace_ = nullptr; -+ workspace_size_ = bytes; -+ -+ if (workspace_size_) { -+ -+ cudaError_t error = cudaMalloc((void **)&workspace_, workspace_size_); -+ -+ if (error != cudaSuccess) { -+ throw std::runtime_error("Failed to allocate workspace"); -+ } -+ } -+ } -+ -+ if (workspace_) { -+ cudaError_t error = cudaMemset(workspace_, 0, workspace_size_); -+ -+ if (error != cudaSuccess) { -+ throw std::runtime_error("Failed to clear workspace"); -+ } -+ } -+} -+ -+/// Gets the scalar pointer mode -+ScalarPointerMode Handle::get_scalar_pointer_mode() const { -+ return scalar_pointer_mode_; -+} -+ -+/// Sets the scalar pointer mode -+void Handle::set_scalar_pointer_mode(ScalarPointerMode mode) { -+ scalar_pointer_mode_ = mode; -+} -+ -+/// Gets the last operation -+Operation const *Handle::get_last_operation() const { -+ return last_operation_; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns the maximum required alignment for each operator -+static int maximum_alignment_requirement(GemmDescription const &desc) { -+ return std::max( -+ std::max(desc.A.alignment, desc.B.alignment), desc.C.alignment); -+} -+ -+/// Returns the largest alignment (in units of elements) the problem satisfies, starting from a -+/// given upper limit. -+static int gemm_problem_alignment( -+ int M, -+ int N, -+ int K, -+ NumericTypeID element_A, -+ void const *ptr_A, -+ int64_t lda, -+ int64_t batch_stride_A, -+ NumericTypeID element_B, -+ void const *ptr_B, -+ int64_t ldb, -+ int64_t batch_stride_B, -+ NumericTypeID element_C, -+ void const * ptr_C, -+ int64_t ldc, -+ int64_t batch_stride_C, -+ void const * ptr_D, -+ int64_t ldd, -+ int64_t batch_stride_D, -+ int max_alignment_in_bytes = 16 -+) { -+ -+ void const *pointers[] = { -+ ptr_A, ptr_B, ptr_C, ptr_D -+ }; -+ -+ int64_t extents[] = { -+ M, N, K, lda, ldb, ldc, ldd, batch_stride_A, batch_stride_B, batch_stride_C, batch_stride_D -+ }; -+ -+ NumericTypeID elements[] = { -+ element_A, element_B, element_C -+ }; -+ -+ for (; max_alignment_in_bytes > 0; max_alignment_in_bytes /= 2) { -+ -+ bool satisfied = true; -+ -+ // Can pointers satisfy this? -+ for (void const *ptr : pointers) { -+ std::uintptr_t int_ptr = reinterpret_cast(ptr); -+ -+ if (int_ptr % max_alignment_in_bytes) { -+ satisfied = false; -+ break; -+ } -+ } -+ -+ if (!satisfied) { -+ continue; -+ } -+ -+ // Compute the maximum alignment based on element data types -+ int max_element_alignment = 0; -+ -+ for (NumericTypeID type_id : elements) { -+ int element_alignment = max_alignment_in_bytes * 8 / library::sizeof_bits(type_id); -+ max_element_alignment = std::max(max_element_alignment, element_alignment); -+ } -+ -+ // Can the problem size and leading dimensions satisfy this? -+ for (int64_t extent : extents) { -+ if (extent % max_element_alignment) { -+ satisfied = false; -+ break; -+ } -+ } -+ -+ if (!satisfied) { -+ continue; -+ } -+ -+ // Yes -+ return max_element_alignment; -+ } -+ -+ // No alignment satisfies this problem -+ return 0; -+} -+ -+/// Find the best kernel in descending order of preference. -+static Operation const * find_gemm_operation( -+ GemmOperationFunctionalMap::const_iterator operators_it, -+ GemmPreferenceKey const preference_key) { -+ -+ auto cc_it = operators_it->second.upper_bound(preference_key); -+ -+ if (cc_it == operators_it->second.begin()) { -+ return nullptr; -+ } -+ -+ Operation const *operation = nullptr; -+ -+ // Search in descending order of compute capability -+ do { -+ --cc_it; -+ -+ // Search tile sizes in order, for now. -+ for (auto const * op : cc_it->second) { -+ -+ GemmDescription const &desc = static_cast(op->description()); -+ -+ int min_cc = desc.tile_description.minimum_compute_capability; -+ int max_cc = desc.tile_description.maximum_compute_capability; -+ -+ int op_alignment = maximum_alignment_requirement(desc); -+ -+ if ((min_cc <= preference_key.compute_capability) && -+ (preference_key.compute_capability <= max_cc) && -+ (op_alignment <= preference_key.alignment)) { -+ -+ operation = op; -+ break; -+ } -+ } -+ } while (!operation && cc_it != operators_it->second.begin()); -+ -+ return operation; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Executes a GEMM computation: D <= alpha * A*B + beta * C -+Status Handle::gemm( -+ -+ int M, /// GEMM M dimension -+ int N, /// GEMM N dimension -+ int K, /// GEMM K dimension -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices -+ -+ void const * ptr_A, /// Pointer to A matrix in Global Memory -+ int64_t lda, /// Leading dimension of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices -+ -+ void const * ptr_B, /// Pointer to B matrix in Global Memory -+ int64_t ldb, /// Leading dimension of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrices -+ -+ void const * ptr_C, /// Pointer to C matrix -+ int64_t ldc, /// Leading dimension of C matrix -+ -+ void * ptr_D, /// Pointer to D matrix -+ int64_t ldd /// Leading dimension of D matrix -+) { -+ -+ // -+ // Find the operation -+ // -+ -+ GemmFunctionalKey key( -+ provider_, -+ GemmKind::kGemm, -+ element_compute, -+ element_scalar, -+ element_A, -+ layout_A, -+ transform_A, -+ element_B, -+ layout_B, -+ transform_B, -+ element_C -+ ); -+ -+ auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); -+ -+ if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ if (operators_it->second.empty()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // -+ // Compute the largest alignment restriction the kernel can satisfy. -+ // -+ -+ // Maximum alignment expectation among all kernels (in units of bytes) -+ int const kMaximumAlignmentSize = 16; -+ -+ int alignment = gemm_problem_alignment( -+ M, N, K, -+ element_A, ptr_A, lda, 0, -+ element_B, ptr_B, ldb, 0, -+ element_C, ptr_C, ldc, 0, -+ ptr_D, ldd, 0, kMaximumAlignmentSize -+ ); -+ -+ // -+ // Find the best kernel in descending order of preference. -+ // -+ -+ GemmPreferenceKey preference_key(compute_capability(), alignment); -+ -+ Operation const *operation = find_gemm_operation(operators_it, preference_key); -+ -+ if (!operation) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ last_operation_ = operation; -+ -+ // -+ // Configure operation -+ // -+ -+ GemmConfiguration configuration{ -+ {M, N, K}, -+ lda, -+ ldb, -+ ldc, -+ ldd, -+ 1 -+ }; -+ -+ // Query host work space size -+ uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); -+ -+ if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ char host_workspace[kHostWorkspaceSize]; -+ -+ // Query device workspace size -+ uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); -+ -+ if (uint64_t(workspace_size_) < device_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // Initialize host and device workspaces -+ Status status = operation->initialize( -+ &configuration, -+ host_workspace, -+ workspace_, -+ stream_); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return status; -+ } -+ -+ // Run the operator -+ GemmArguments arguments{ -+ ptr_A, -+ ptr_B, -+ ptr_C, -+ ptr_D, -+ alpha, -+ beta, -+ scalar_pointer_mode_ -+ }; -+ -+ return operation->run(&arguments, host_workspace, workspace_, stream_); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Executes a GEMM computation: D <= alpha * A*B + beta * C. -+// -+// Supports batched-strided, batched array or split-K serial or split-K parallel. -+// -+Status Handle::gemm_universal( -+ -+ GemmUniversalMode mode, /// indicates the mode in which the kUniversal GEMM is launched -+ -+ int M, /// GEMM M dimension -+ int N, /// GEMM N dimension -+ int K, /// GEMM K dimension -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices -+ -+ void const * ptr_A, /// Pointer to A matrix in Global Memory -+ int64_t lda, /// Leading dimension of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices -+ -+ void const * ptr_B, /// Pointer to B matrix in Global Memory -+ int64_t ldb, /// Leading dimension of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrices -+ -+ void const * ptr_C, /// Pointer to C matrix -+ int64_t ldc, /// Leading dimension of C matrix -+ -+ void * ptr_D, /// Pointer to D matrix -+ int64_t ldd, /// Leading dimension of D matrix -+ -+ int batch_count, /// Batch count or number of split-K slices -+ -+ int64_t batch_stride_A, /// Batch stride of A operand -+ int64_t batch_stride_B, /// Batch stride of B operand -+ int64_t batch_stride_C, /// Batch stride of C operand -+ int64_t batch_stride_D /// Batch stride of D operand -+) { -+ -+ // -+ // Find the operation -+ // -+ -+ GemmFunctionalKey key( -+ provider_, -+ GemmKind::kUniversal, -+ element_compute, -+ element_scalar, -+ element_A, -+ layout_A, -+ transform_A, -+ element_B, -+ layout_B, -+ transform_B, -+ element_C -+ ); -+ -+ auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); -+ -+ if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ if (operators_it->second.empty()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // -+ // Compute the largest alignment restriction the kernel can satisfy. -+ // -+ -+ // Maximum alignment expectation among all kernels (in units of bytes) -+ int const kMaximumAlignmentSize = 16; -+ -+ void const *ptr_A_check = ptr_A; -+ void const *ptr_B_check = ptr_B; -+ void const *ptr_C_check = ptr_C; -+ void * ptr_D_check = ptr_D; -+ -+ // Ignore alignment of pointers to pointers. We can't check this from the host, -+ // as each batch index has its own pointer in device memory. -+ if (mode == GemmUniversalMode::kArray) { -+ ptr_A_check = nullptr; -+ ptr_B_check = nullptr; -+ ptr_C_check = nullptr; -+ ptr_D_check = nullptr; -+ } -+ -+ int alignment = gemm_problem_alignment( -+ M, N, K, -+ element_A, ptr_A_check, lda, 0, -+ element_B, ptr_B_check, ldb, 0, -+ element_C, ptr_C_check, ldc, 0, -+ ptr_D_check, ldd, 0, kMaximumAlignmentSize -+ ); -+ -+ // -+ // Find the best kernel in descending order of preference. -+ // -+ -+ GemmPreferenceKey preference_key(compute_capability(), alignment); -+ -+ Operation const *operation = find_gemm_operation(operators_it, preference_key); -+ -+ if (!operation) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ last_operation_ = operation; -+ -+ // -+ // Configure operation -+ // -+ -+ GemmUniversalConfiguration configuration{ -+ mode, -+ {M, N, K}, -+ batch_count, -+ lda, -+ ldb, -+ ldc, -+ ldd -+ }; -+ -+ // Query host work space size -+ uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); -+ -+ if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ char host_workspace[kHostWorkspaceSize]; -+ -+ GemmUniversalArguments arguments{ -+ {M, N, K}, -+ batch_count, -+ ptr_A, -+ ptr_B, -+ ptr_C, -+ ptr_D, -+ alpha, -+ beta, -+ scalar_pointer_mode_, -+ lda, -+ ldb, -+ ldc, -+ ldd, -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C, -+ batch_stride_D -+ }; -+ -+ // Query device workspace size -+ uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration, &arguments); -+ -+ if (uint64_t(workspace_size_) < device_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // Initialize host and device workspaces -+ Status status = operation->initialize( -+ &configuration, -+ host_workspace, -+ workspace_, -+ stream_); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return status; -+ } -+ -+ // Run the operator -+ -+ return operation->run(&arguments, host_workspace, workspace_, stream_); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Planar complex GEMM -+Status Handle::gemm_planar_complex( -+ -+ int M, /// GEMM M dimension -+ int N, /// GEMM N dimension -+ int K, /// GEMM K dimension -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix -+ -+ void const * ptr_A_real, /// Pointer to real part of A matrix -+ void const * ptr_A_imag, /// Pointer to imaginary part of A matrix -+ int64_t lda_real, /// Leading dimension of real part of A matrix -+ int64_t lda_imag, /// Leading dimension of imaginary part of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix -+ -+ void const * ptr_B_real, /// Pointer to real part of B matrix -+ void const * ptr_B_imag, /// Pointer to imaginary part of B matrix -+ int64_t ldb_real, /// Leading dimension of real part of B matrix -+ int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrix -+ -+ void const * ptr_C_real, /// Pointer to real part of C matrix -+ void const * ptr_C_imag, /// Pointer to imaginary part of C matrix -+ int64_t ldc_real, /// Leading dimension of real part of C matrix -+ int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix -+ -+ void * ptr_D_real, /// Pointer to real part of D matrix -+ void * ptr_D_imag, /// Pointer to imaginary part of D matrix -+ int64_t ldd_real, /// Leading dimension of real part of D matrix -+ int64_t ldd_imag, /// Leading dimension of imaginary part of D matrix -+ -+ int batch_count, /// Number of batched GEMMs to execute -+ -+ int64_t batch_stride_A_real, -+ int64_t batch_stride_A_imag, -+ -+ int64_t batch_stride_B_real, -+ int64_t batch_stride_B_imag, -+ -+ int64_t batch_stride_C_real, -+ int64_t batch_stride_C_imag, -+ -+ int64_t batch_stride_D_real, -+ int64_t batch_stride_D_imag -+) { -+ -+ // -+ // Find the operation -+ // -+ -+ GemmFunctionalKey key( -+ provider_, -+ GemmKind::kPlanarComplex, -+ element_compute, -+ element_scalar, -+ element_A, -+ layout_A, -+ transform_A, -+ element_B, -+ layout_B, -+ transform_B, -+ element_C -+ ); -+ -+ auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); -+ -+ if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ if (operators_it->second.empty()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // -+ // Compute the largest alignment restriction the kernel can satisfy. -+ // -+ -+ // Maximum alignment expectation among all kernels (in units of bytes) -+ int const kMaximumAlignmentSize = 16; -+ -+ int alignment = std::max( -+ gemm_problem_alignment( -+ M, N, K, -+ element_A, ptr_A_real, lda_real, batch_stride_A_real, -+ element_B, ptr_B_real, ldb_real, batch_stride_B_real, -+ element_C, ptr_C_real, ldc_real, batch_stride_C_real, -+ ptr_D_real, ldd_real, batch_stride_D_real, kMaximumAlignmentSize -+ ), -+ gemm_problem_alignment( -+ M, N, K, -+ element_A, ptr_A_imag, lda_imag, batch_stride_A_imag, -+ element_B, ptr_B_imag, ldb_imag, batch_stride_B_imag, -+ element_C, ptr_C_imag, ldc_imag, batch_stride_C_imag, -+ ptr_D_imag, ldd_imag, batch_stride_D_imag, kMaximumAlignmentSize -+ ) -+ ); -+ -+ // -+ // Find the best kernel in descending order of preference. -+ // -+ -+ GemmPreferenceKey preference_key(compute_capability(), alignment); -+ -+ Operation const *operation = find_gemm_operation(operators_it, preference_key); -+ -+ if (!operation) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ last_operation_ = operation; -+ -+ // -+ // Configure operation -+ // -+ -+ GemmPlanarComplexConfiguration configuration{ -+ GemmUniversalMode::kBatched, -+ {M, N, K}, -+ batch_count, -+ lda_real, -+ lda_imag, -+ ldb_real, -+ ldb_imag, -+ ldc_real, -+ ldc_imag, -+ ldd_real, -+ ldd_imag -+ }; -+ -+ // Query host work space size -+ uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); -+ -+ if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ char host_workspace[kHostWorkspaceSize]; -+ -+ // Query device workspace size -+ uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); -+ -+ if (uint64_t(workspace_size_) < device_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // Initialize host and device workspaces -+ Status status = operation->initialize( -+ &configuration, -+ host_workspace, -+ workspace_, -+ stream_); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return status; -+ } -+ -+ // Run the operator -+ GemmPlanarComplexArguments arguments{ -+ ptr_A_real, -+ ptr_A_imag, -+ ptr_B_real, -+ ptr_B_imag, -+ ptr_C_real, -+ ptr_C_imag, -+ ptr_D_real, -+ ptr_D_imag, -+ alpha, -+ beta, -+ scalar_pointer_mode_, -+ batch_stride_A_real, -+ batch_stride_A_imag, -+ batch_stride_B_real, -+ batch_stride_B_imag, -+ batch_stride_C_real, -+ batch_stride_C_imag, -+ batch_stride_D_real, -+ batch_stride_D_imag -+ }; -+ -+ return operation->run(&arguments, host_workspace, workspace_, stream_); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Planar complex batched GEMM loading pointers from arrays in global memory -+Status Handle::gemm_planar_complex_array( -+ -+ int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) -+ int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) -+ int expected_K, /// Expected GEMM K dimension -+ int batch_count, /// Number of independent GEMM computations to execute -+ -+ int const *M, /// Array containing the GEMM M dimension for each batch index -+ int const *N, /// Array containing the GEMM N dimension for each batch index -+ int const *K, /// Array containing the GEMM K dimension for each batch index -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix -+ -+ void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices -+ void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices -+ -+ int64_t lda_real, /// Leading dimension of real part of A matrix -+ int64_t lda_imag, /// Leading dimension of imaginary part of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix -+ -+ void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices -+ void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices -+ -+ int64_t ldb_real, /// Leading dimension of real part of B matrix -+ int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrix -+ -+ void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices -+ void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices -+ -+ int64_t ldc_real, /// Leading dimension of real part of C matrix -+ int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix -+ -+ void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices -+ void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices -+ -+ int64_t ldd_real, /// Leading dimension of real part of D matrix -+ int64_t ldd_imag /// Leading dimension of imaginary part of D matrix -+) { -+ -+ // -+ // Find the operation -+ // -+ -+ GemmFunctionalKey key( -+ provider_, -+ GemmKind::kPlanarComplexArray, -+ element_compute, -+ element_scalar, -+ element_A, -+ layout_A, -+ transform_A, -+ element_B, -+ layout_B, -+ transform_B, -+ element_C -+ ); -+ -+ auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); -+ -+ if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ if (operators_it->second.empty()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // -+ // Compute the largest alignment restriction the kernel can satisfy. -+ // -+ -+ // Maximum alignment expectation among all kernels (in units of bytes) -+ int const kMaximumAlignmentSize = 16; -+ -+ int alignment = std::max( -+ gemm_problem_alignment( -+ expected_M, expected_N, expected_K, -+ element_A, nullptr, lda_real, 0, -+ element_B, nullptr, ldb_real, 0, -+ element_C, nullptr, ldc_real, 0, -+ nullptr, ldd_real, 0, kMaximumAlignmentSize -+ ), -+ gemm_problem_alignment( -+ expected_M, expected_N, expected_K, -+ element_A, nullptr, lda_imag, 0, -+ element_B, nullptr, ldb_imag, 0, -+ element_C, nullptr, ldc_imag, 0, -+ nullptr, ldd_imag, 0, kMaximumAlignmentSize -+ ) -+ ); -+ -+ // -+ // Find the best kernel in descending order of preference. -+ // -+ -+ GemmPreferenceKey preference_key(compute_capability(), alignment); -+ -+ Operation const *operation = find_gemm_operation(operators_it, preference_key); -+ -+ if (!operation) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ last_operation_ = operation; -+ -+ // -+ // Configure operation -+ // -+ -+ GemmPlanarComplexArrayConfiguration configuration{ -+ {expected_M, expected_N, expected_K}, -+ batch_count, -+ lda_real, -+ lda_imag, -+ ldb_real, -+ ldb_imag, -+ ldc_real, -+ ldc_imag, -+ ldd_real, -+ ldd_imag -+ }; -+ -+ // Query host work space size -+ uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); -+ -+ if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ char host_workspace[kHostWorkspaceSize]; -+ -+ // Query device workspace size -+ uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); -+ -+ if (uint64_t(workspace_size_) < device_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // Initialize host and device workspaces -+ Status status = operation->initialize( -+ &configuration, -+ host_workspace, -+ workspace_, -+ stream_); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return status; -+ } -+ -+ // Run the operator -+ GemmPlanarComplexArrayArguments arguments{ -+ M, N, K, -+ ptr_A_real, -+ ptr_A_imag, -+ ptr_B_real, -+ ptr_B_imag, -+ ptr_C_real, -+ ptr_C_imag, -+ ptr_D_real, -+ ptr_D_imag, -+ alpha, -+ beta, -+ scalar_pointer_mode_ -+ }; -+ -+ return operation->run(&arguments, host_workspace, workspace_, stream_); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Finds conv operation instances with Conv::ElementC = Reduction::ElementWorkspace -+Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation) { -+ -+ ConvDescription const &conv_desc = -+ static_cast(operation->description()); -+ -+ // if the curren conv operation accumulator and output data type match return operation -+ if(conv_desc.tile_description.math_instruction.element_accumulator == conv_desc.C.element) { -+ return operation; -+ } -+ -+ // find conv operation to match conv output and reduction workspace data type -+ ConvFunctionalKey key( -+ library::Provider::kCUTLASS, -+ conv_desc.conv_kind, -+ conv_desc.A.element, -+ conv_desc.A.layout, -+ conv_desc.B.element, -+ conv_desc.B.layout, -+ conv_desc.tile_description.math_instruction.element_accumulator, -+ conv_desc.C.layout, -+ conv_desc.tile_description.math_instruction.element_accumulator, -+ conv_desc.element_epilogue); -+ -+ // conv operation table for conv2d or conv3d -+ auto conv_operations = (conv_desc.kind == OperationKind::kConv2d) ? -+ Singleton::get().operation_table.conv2d_operations : -+ Singleton::get().operation_table.conv3d_operations; -+ -+ // find ConvFunctionalKey in convolution operation table -+ auto operators_it = conv_operations.find(key); -+ -+ if (operators_it == conv_operations.end()) { -+ return nullptr; -+ } -+ -+ if (operators_it->second.empty()) { -+ return nullptr; -+ } -+ -+ // conv operation for same compute capability and iterator algorithm -+ ConvPreferenceKey preference_key( -+ conv_desc.tile_description.minimum_compute_capability, -+ conv_desc.iterator_algorithm); -+ -+ auto it = operators_it->second.find(preference_key); -+ -+ if(it == operators_it->second.end()) { -+ return nullptr; -+ } -+ -+ // return matching conv opertion (same tile sizes and instruction) -+ for (auto op : it->second) { -+ if (op->description().tile_description == operation->description().tile_description) { -+ return op; -+ } -+ } -+ -+ return nullptr; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Finds gemm operation instances with Gemm::ElementC = Reduction::ElementWorkspace -+Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation) { -+ -+ GemmDescription const &gemm_desc = -+ static_cast(operation->description()); -+ -+ // if the curren gemm operation accumulator and output data type match return operation -+ if(gemm_desc.tile_description.math_instruction.element_accumulator == gemm_desc.C.element) { -+ return operation; -+ } -+ -+ // find gemm operation to match gemm output and reduction workspace data type -+ GemmFunctionalKey key( -+ library::Provider::kCUTLASS, -+ gemm_desc.gemm_kind, -+ gemm_desc.tile_description.math_instruction.element_accumulator, -+ gemm_desc.element_epilogue, -+ gemm_desc.A.element, -+ gemm_desc.A.layout, -+ gemm_desc.transform_A, -+ gemm_desc.B.element, -+ gemm_desc.B.layout, -+ gemm_desc.transform_B, -+ gemm_desc.tile_description.math_instruction.element_accumulator); -+ -+ // gemm operation table -+ auto gemm_operations = Singleton::get().operation_table.gemm_operations; -+ -+ // find ConvFunctionalKey in gemm operation table -+ auto operators_it = gemm_operations.find(key); -+ -+ if (operators_it == gemm_operations.end()) { -+ return nullptr; -+ } -+ -+ if (operators_it->second.empty()) { -+ return nullptr; -+ } -+ -+ // A and B uses the same alignment in the generator.py -+ int alignment = gemm_desc.A.alignment; -+ -+ // gemm operation for same compute capability and iterator algorithm -+ GemmPreferenceKey preference_key( -+ gemm_desc.tile_description.minimum_compute_capability, -+ alignment); -+ -+ return find_gemm_operation(operators_it, preference_key); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/library_internal.h b/3rdparty/cutlass/tools/library/src/library_internal.h -new file mode 100644 -index 0000000..e9739e3 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/library_internal.h -@@ -0,0 +1,356 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ -+ \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. -+ -+ Generally, -+ -+ description - compile-time constant parameters used to instantiate an operation -+ -+ configuration - runtime parameters with computationally expensive initialization -+ -+ arguments - runtime parameters that may be passed to an initialized operation with low -+ computational overhead -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/arch_mappings.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct NumericTypeMap; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kB1; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kS4; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kS8; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kS16; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kS32; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kS64; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kU4; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kU8; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kU16; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kU32; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kU64; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kF16; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kF32; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kF64; -+}; -+ -+template <> struct NumericTypeMap > { -+ static NumericTypeID const kId = NumericTypeID::kCF16; -+}; -+ -+template <> struct NumericTypeMap > { -+ static NumericTypeID const kId = NumericTypeID::kCF32; -+}; -+ -+template <> struct NumericTypeMap > { -+ static NumericTypeID const kId = NumericTypeID::kCF64; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kBF16; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kTF32; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kInvalid; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAdd; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddGaussianComplex; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kXorPopc; -+}; -+ -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddFastF32; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddComplexFastF32; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct LayoutMap; -+ -+template <> struct LayoutMap { -+ static LayoutTypeID const kId = LayoutTypeID::kColumnMajor; -+}; -+ -+template <> struct LayoutMap { -+ static LayoutTypeID const kId = LayoutTypeID::kRowMajor; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; -+}; -+ -+template <> struct LayoutMap { -+ static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; -+}; -+ -+template <> struct LayoutMap { -+ static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kTensorNC32HW32; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kTensorNC64HW64; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kTensorC32RSK32; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kTensorC64RSK64; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct OpcodeClassMap; -+ -+template <> struct OpcodeClassMap { -+ static OpcodeClassID const kId = OpcodeClassID::kSimt; -+}; -+ -+template <> struct OpcodeClassMap { -+ static OpcodeClassID const kId = OpcodeClassID::kTensorOp; -+}; -+ -+template <> struct OpcodeClassMap { -+ static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct ComplexTransformMap; -+ -+template <> struct ComplexTransformMap { -+ static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kNone; -+}; -+ -+template <> struct ComplexTransformMap { -+ static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kConjugate; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct ConvModeMap; -+ -+template <> struct ConvModeMap { -+ static ConvModeID const kId = ConvModeID::kCrossCorrelation; -+}; -+ -+template <> struct ConvModeMap { -+ static ConvModeID const kId = ConvModeID::kConvolution; -+}; -+ -+ -+template struct ConvKindMap; -+ -+template <> struct ConvKindMap { -+ static ConvKind const kId = ConvKind::kFprop; -+}; -+ -+template <> struct ConvKindMap { -+ static ConvKind const kId = ConvKind::kDgrad; -+}; -+ -+template <> struct ConvKindMap { -+ static ConvKind const kId = ConvKind::kWgrad; -+}; -+ -+ -+template struct IteratorAlgorithmMap; -+ -+template <> struct IteratorAlgorithmMap { -+ static IteratorAlgorithmID const kId = IteratorAlgorithmID::kAnalytic; -+}; -+ -+template <> struct IteratorAlgorithmMap { -+ static IteratorAlgorithmID const kId = IteratorAlgorithmID::kOptimized; -+}; -+ -+template <> struct IteratorAlgorithmMap { -+ static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFixedChannels; -+}; -+ -+template <> struct IteratorAlgorithmMap { -+ static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFewChannels; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+TensorDescription make_TensorDescription(int alignment = 1) { -+ TensorDescription desc; -+ -+ desc.element = NumericTypeMap::kId; -+ desc.layout = LayoutMap::kId; -+ desc.alignment = alignment; -+ desc.log_extent_range = int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; -+ desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; -+ -+ return desc; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/operation_table.cu b/3rdparty/cutlass/tools/library/src/operation_table.cu -new file mode 100644 -index 0000000..d3799c3 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/operation_table.cu -@@ -0,0 +1,143 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* -+ \file -+ \brief Defines a data structure in which a set of functionally equivalent library::Operation -+ instances may be queried. -+*/ -+ -+#include "cutlass/library/operation_table.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void OperationTable::append(Manifest const &manifest) { -+ -+ // Insert operations into appropriate data structure -+ for (auto const & operation : manifest) { -+ -+ OperationDescription const &desc = operation->description(); -+ -+ // insert all gemm operation into operation table -+ if (desc.kind == OperationKind::kGemm) { -+ GemmDescription const &gemm_desc = static_cast(desc); -+ -+ -+ GemmFunctionalKey functional_key( -+ gemm_desc.provider, -+ gemm_desc.gemm_kind, -+ gemm_desc.tile_description.math_instruction.element_accumulator, -+ gemm_desc.element_epilogue, -+ gemm_desc.A.element, -+ gemm_desc.A.layout, -+ gemm_desc.transform_A, -+ gemm_desc.B.element, -+ gemm_desc.B.layout, -+ gemm_desc.transform_B, -+ gemm_desc.C.element -+ ); -+ -+ Operation const *op = operation.get(); -+ -+ int cc = gemm_desc.tile_description.minimum_compute_capability; -+ -+ int alignment = std::max(std::max( -+ gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment); -+ -+ GemmPreferenceKey preference_key(cc, alignment); -+ -+ gemm_operations[functional_key][preference_key].push_back(op); -+ } -+ -+ // insert all conv2d or conv3d operation into operation table -+ if (desc.kind == OperationKind::kConv2d || desc.kind == OperationKind::kConv3d) { -+ auto &conv_desc = static_cast(desc); -+ -+ ConvFunctionalKey functional_key( -+ conv_desc.provider, -+ conv_desc.conv_kind, -+ conv_desc.A.element, -+ conv_desc.A.layout, -+ conv_desc.B.element, -+ conv_desc.B.layout, -+ conv_desc.C.element, -+ conv_desc.C.layout, -+ conv_desc.tile_description.math_instruction.element_accumulator, -+ conv_desc.element_epilogue -+ ); -+ -+ Operation const *op = operation.get(); -+ -+ int cc = conv_desc.tile_description.minimum_compute_capability; -+ -+ ConvPreferenceKey preference_key(cc, conv_desc.iterator_algorithm); -+ -+ // insert conv operation to conv2d_operations or conv3d_operations map -+ (desc.kind == OperationKind::kConv2d) ? -+ conv2d_operations[functional_key][preference_key].push_back(op) : -+ conv3d_operations[functional_key][preference_key].push_back(op); -+ } -+ -+ // insert all reduction operation into operation table -+ if (desc.kind == OperationKind::kReduction) { -+ auto &reduce_desc = static_cast(desc); -+ -+ ReductionFunctionalKey functional_key( -+ reduce_desc.provider, -+ reduce_desc.element_workspace, -+ reduce_desc.tile_description.math_instruction.element_accumulator, -+ reduce_desc.element_output, -+ reduce_desc.element_epilogue, -+ library::MathOperationID::kAdd, -+ library::EpilogueKind::kLinearCombination -+ ); -+ -+ Operation const *op = operation.get(); -+ -+ reduction_operations[functional_key] = op; -+ -+ } -+ -+ } -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/rank_2k_operation.h b/3rdparty/cutlass/tools/library/src/rank_2k_operation.h -new file mode 100644 -index 0000000..d6e0dca ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/rank_2k_operation.h -@@ -0,0 +1,373 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all Rank 2K operation kinds (Syr2k, Her2k) -+ in CUTLASS Library. -+ -+ -+*/ -+ -+#pragma once -+#include -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/gemm/kernel/default_rank_2k_universal.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+#include "cutlass/core_io.h" -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Rank2KOperationBase : public Operation { -+public: -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static BlasMode const kBlasMode = Operator::kBlasMode; -+ static int const kUpdateRank = Operator::kUpdateRank; -+ static FillMode const kFillModeC = Operator::kFillModeC; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ RankKDescription description_; -+ -+public: -+ -+ /// Constructor -+ Rank2KOperationBase(char const *name = "unknown_rank_k") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.rank_k_kind = RankKKind::kUniversal; -+ description_.fill_mode = kFillModeC; -+ description_.blas_mode = kBlasMode; -+ description_.num_ranks = kUpdateRank; -+ -+ description_.kind = OperationKind::kRank2K; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::Rank2Kkernel::WarpCount::kM, -+ Operator::Rank2Kkernel::WarpCount::kN, -+ Operator::Rank2Kkernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(Operator::kAlignmentA); -+ description_.B = make_TensorDescription(Operator::kAlignmentB); -+ description_.C = make_TensorDescription(Operator::kAlignmentC); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.split_k_mode = SplitKMode::kNone; -+ description_.transform_A = ComplexTransformMap::kId; -+ description_.transform_B = ComplexTransformMap::kId; -+ } -+ -+ /// Returns the description of the SYRK operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Rank2KOperation : public Rank2KOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ static BlasMode const kBlasMode = Operator::kBlasMode; -+ static int const kUpdateRank = Operator::kUpdateRank; -+ static FillMode const kFillModeC = Operator::kFillModeC; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ Rank2KOperation(char const *name = "unknown_rank_2k"): -+ Rank2KOperationBase(name) { -+ -+ this->description_.rank_k_kind = RankKKind::kUniversal; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ RankKConfiguration const *configuration) { -+ -+ //operator_args.mode = configuration->mode; -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ operator_args.lda = int(configuration->lda); -+ operator_args.ldb = int(configuration->ldb); -+ operator_args.ldc = int(configuration->ldc); -+ operator_args.ldd = int(configuration->ldd); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ RankKArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A = arguments->A; -+ operator_args.ptr_B = arguments->B; -+ operator_args.ptr_C = arguments->C; -+ operator_args.ptr_D = arguments->D; -+ -+ operator_args.batch_stride_A = arguments->batch_stride_A; -+ operator_args.batch_stride_B = arguments->batch_stride_B; -+ operator_args.batch_stride_C = arguments->batch_stride_C; -+ operator_args.batch_stride_D = arguments->batch_stride_D; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ RankKConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ RankKArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ //std::cout << "initialize() library::Rank2KOperation" << std::endl; -+ //print_operator_args(args); -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args, device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ //std::cout << "run() library::Rank2KOperation" << std::endl; -+ //print_operator_args(args); -+ status = op->run(stream); -+ -+ return status; -+ } -+ -+ /// Call print_operator_args from the Conv2dOperation::initialize() -+ // to dump arguments passed on to cutlass operator for debugging -+ void print_operator_args(OperatorArguments &operator_args) const { -+ std::cout << "Rank2KOperation::OperatorArguments" << std::endl -+ << " problem_size:" << std::endl -+ << operator_args.problem_size << std::endl -+ << " epilouge (alpha, beta): " -+ << operator_args.epilogue.alpha << ", " -+ << operator_args.epilogue.beta << std::endl -+ << " ref_A (ptr, {stride}): " -+ << operator_args.ptr_A << ", {" -+ << operator_args.lda << "}" << std::endl -+ << " ref_B (ptr, {stride}): " -+ << operator_args.ptr_B << ", {" -+ << operator_args.ldb << "}" << std::endl -+ << " ref_C (ptr, {stride}): " -+ << operator_args.ptr_C << ", {" -+ << operator_args.ldc << "}" << std::endl -+ << " ref_D (ptr, {stride}): " -+ << operator_args.ptr_D << ", {" -+ << operator_args.ldd << "}" << std::endl; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/rank_k_operation.h b/3rdparty/cutlass/tools/library/src/rank_k_operation.h -new file mode 100644 -index 0000000..2eb7a2d ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/rank_k_operation.h -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all Rank K operation kinds (Syrk, Herk) -+ in CUTLASS Library. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/gemm/kernel/default_rank_k_universal.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class RankKOperationBase : public Operation { -+public: -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementA; -+ using LayoutB = typename Operator::LayoutA; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static BlasMode const kBlasMode = Operator::kBlasMode; -+ static int const kUpdateRank = Operator::kUpdateRank; -+ static FillMode const kFillModeC = Operator::kFillModeC; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ RankKDescription description_; -+ -+public: -+ -+ /// Constructor -+ RankKOperationBase(char const *name = "unknown_rank_k") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.rank_k_kind = RankKKind::kUniversal; -+ description_.fill_mode = kFillModeC; -+ description_.blas_mode = kBlasMode; -+ description_.num_ranks = kUpdateRank; -+ -+ description_.kind = OperationKind::kRankK; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::RankKkernel::WarpCount::kM, -+ Operator::RankKkernel::WarpCount::kN, -+ Operator::RankKkernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(Operator::kAlignmentA); -+ description_.B = make_TensorDescription(Operator::kAlignmentA); -+ description_.C = make_TensorDescription(Operator::kAlignmentC); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.split_k_mode = SplitKMode::kNone; -+ description_.transform_A = ComplexTransformMap::kId; -+ description_.transform_B = ComplexTransformMap::kId; -+ } -+ -+ /// Returns the description of the SYRK operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class RankKOperation : public RankKOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementA; -+ using LayoutB = typename Operator::LayoutA; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ static BlasMode const kBlasMode = Operator::kBlasMode; -+ static int const kUpdateRank = Operator::kUpdateRank; -+ static FillMode const kFillModeC = Operator::kFillModeC; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ RankKOperation(char const *name = "unknown_rank_k"): -+ RankKOperationBase(name) { -+ -+ this->description_.rank_k_kind = RankKKind::kUniversal; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ RankKConfiguration const *configuration) { -+ -+ //operator_args.mode = configuration->mode; -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ operator_args.lda = int(configuration->lda); -+ operator_args.ldb = int(configuration->lda); -+ operator_args.ldc = int(configuration->ldc); -+ operator_args.ldd = int(configuration->ldd); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ RankKArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A = arguments->A; -+ operator_args.ptr_C = arguments->C; -+ operator_args.ptr_D = arguments->D; -+ -+ operator_args.batch_stride_A = arguments->batch_stride_A; -+ operator_args.batch_stride_C = arguments->batch_stride_C; -+ operator_args.batch_stride_D = arguments->batch_stride_D; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ RankKConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ RankKArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args, device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = op->run(stream); -+ -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/reduction/init_reduction_operations.cu b/3rdparty/cutlass/tools/library/src/reduction/init_reduction_operations.cu -new file mode 100644 -index 0000000..b0f1695 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reduction/init_reduction_operations.cu -@@ -0,0 +1,65 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Initialize operations for reduction operation in CUTLASS Library. -+ -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+ -+namespace cutlass { -+namespace library { -+/////////////////////////////////////////////////////////////////////////////////////////////// -+// CUTLASS Reduction Instances // -+/////////////////////////////////////////////////////////////////////////////////////////////// -+void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest); -+void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest); -+void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest); -+void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest); -+ -+// -+// Entry point to construct operations -+// -+void initialize_all_reduction_op(Manifest &manifest) { -+ -+ initialize_reduce_add_linear_combination_f32_f32_f16(manifest); -+ initialize_reduce_add_linear_combination_f32_f32_f32(manifest); -+ initialize_reduce_add_linear_combination_f64_f64_f64(manifest); -+ initialize_reduce_add_linear_combination_cf32_cf32_cf32(manifest); -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/library/src/reduction/reduction_device.cu b/3rdparty/cutlass/tools/library/src/reduction/reduction_device.cu -new file mode 100644 -index 0000000..2eb6ab7 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reduction/reduction_device.cu -@@ -0,0 +1,184 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for reduction operation in CUTLASS Library. -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+ -+#include "reduction_operation.h" -+ -+namespace cutlass { -+namespace library { -+ -+// naming convention initialize_reduce_[ReductionOp]_[EpilogueOp]_[ElementWorkspace]_[ElementAccumulator]_[ElementOutput] -+ -+void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) { -+ -+ using ElementWorkspace = float; -+ using ElementAccumulator = float; -+ using ElementOutput = cutlass::half_t; -+ using ElementCompute = float; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using Operation_reduce_add_linear_combination_f32_f32_f16 = cutlass::reduction::device::ReduceSplitK< -+ cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ > -+ >; -+ -+ manifest.append(new ReductionOperation< -+ Operation_reduce_add_linear_combination_f32_f32_f16>( -+ "reduce_add_linear_combination_f32_f32_f16" -+ )); -+} -+ -+ -+void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest) { -+ -+ using ElementWorkspace = float; -+ using ElementAccumulator = float; -+ using ElementOutput = float; -+ using ElementCompute = float; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using Operation_reduce_add_linear_combination_f32_f32_f32 = cutlass::reduction::device::ReduceSplitK< -+ cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ > -+ >; -+ -+ manifest.append(new ReductionOperation< -+ Operation_reduce_add_linear_combination_f32_f32_f32>( -+ "reduce_add_linear_combination_f32_f32_f32" -+ )); -+} -+ -+void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest) { -+ -+ using ElementWorkspace = double; -+ using ElementAccumulator = double; -+ using ElementOutput = double; -+ using ElementCompute = double; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using Operation_reduce_add_linear_combination_f64_f64_f64 = cutlass::reduction::device::ReduceSplitK< -+ cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ > -+ >; -+ -+ manifest.append(new ReductionOperation< -+ Operation_reduce_add_linear_combination_f64_f64_f64>( -+ "reduce_add_linear_combination_f64_f64_f64" -+ )); -+} -+ -+void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest) { -+ -+ using ElementWorkspace = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementOutput = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using Operation_reduce_add_linear_combination_cf32_cf32_cf32 = cutlass::reduction::device::ReduceSplitK< -+ cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ > -+ >; -+ -+ manifest.append(new ReductionOperation< -+ Operation_reduce_add_linear_combination_cf32_cf32_cf32>( -+ "reduce_add_linear_combination_cf32_cf32_cf32" -+ )); -+} -+ -+} -+} -diff --git a/3rdparty/cutlass/tools/library/src/reduction/reduction_operation.h b/3rdparty/cutlass/tools/library/src/reduction/reduction_operation.h -new file mode 100644 -index 0000000..846ca02 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reduction/reduction_operation.h -@@ -0,0 +1,290 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for reduction operation in CUTLASS Library. -+*/ -+ -+#pragma once -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+#include "cutlass/core_io.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class ReductionOperation : public Operation { -+public: -+ using Operator = Operator_; -+ -+ using ElementWorkspace = typename Operator::ElementWorkspace; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementOutput = typename Operator::ElementOutput; -+ -+ using ElementCompute = typename Operator::OutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ ReductionDescription description_; -+ -+public: -+ -+ /// Constructor -+ ReductionOperation(char const *name = "unknown_reduction") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.kind = OperationKind::kReduction; -+ -+ description_.tile_description.threadblock_shape = make_Coord(Operator::Shape::kRow, Operator::Shape::kColumn, 1); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord(1, 1, 1); -+ description_.tile_description.math_instruction.element_accumulator = NumericTypeMap::kId; -+ description_.tile_description.math_instruction.opcode_class = OpcodeClassID::kSimt; -+ description_.tile_description.math_instruction.math_operation = MathOperationID::kAdd; -+ -+ description_.tile_description.minimum_compute_capability = 50; -+ description_.tile_description.maximum_compute_capability = 1024; -+ -+ description_.element_workspace = NumericTypeMap::kId; -+ description_.element_output = NumericTypeMap::kId; -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ } -+ -+ /// Returns the description of the Reduction operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+ -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ ReductionConfiguration const *configuration) { -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.partitions = configuration->partitions; -+ operator_args.partition_stride = configuration->partition_stride; -+ -+ operator_args.workspace = {nullptr, int(configuration->ldw)}; -+ operator_args.source = {nullptr, int(configuration->lds)}; -+ operator_args.destination = {nullptr, int(configuration->ldd)}; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ ReductionArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::OutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.output = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::OutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.output = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ operator_args.workspace.reset(static_cast(const_cast(arguments->workspace))); -+ operator_args.source.reset(static_cast(const_cast(arguments->source))); -+ operator_args.destination.reset(static_cast(const_cast(arguments->destination))); -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ ReductionConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ ReductionArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ return Operator::get_workspace_size(args); -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ //std::cout << "initialize library::Reduction" << std::endl; -+ //print_operator_args(args); -+ return op->initialize(args, device_workspace, stream); -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args, device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ //std::cout << "run library::Reduction" << std::endl; -+ //print_operator_args(args); -+ return op->run(stream); -+ } -+ -+ /// Call print_operator_args from the Reduction::initialize() -+ // to dump arguments passed on to cutlass operator for debugging -+ void print_operator_args(OperatorArguments &operator_args) const { -+ std::cout << "Reduction::OperatorArguments" << std::endl -+ << " problem_size: " -+ << operator_args.problem_size << std::endl -+ << " partitions: " -+ << operator_args.partitions << std::endl -+ << " partition_stride: " -+ << operator_args.partition_stride << std::endl -+ << " epilouge (alpha, beta): " -+ << operator_args.output.alpha << ", " -+ << operator_args.output.beta << std::endl -+ << " workspace (ptr, stride): " -+ << operator_args.workspace.data() << ", " -+ << operator_args.workspace.stride(0) << std::endl -+ << " source (ptr, stride): " -+ << operator_args.source.data() << ", " -+ << operator_args.source.stride(0) << std::endl -+ << " destination (ptr, stride): " -+ << operator_args.destination.data() << ", " -+ << operator_args.destination.stride(0) << std::endl; -+ } -+}; -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/reference/conv2d.cu b/3rdparty/cutlass/tools/library/src/reference/conv2d.cu -new file mode 100644 -index 0000000..715e3b0 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reference/conv2d.cu -@@ -0,0 +1,229 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+ -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+ -+#include "conv_reference_operation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void initialize_conv2d_reference_operations(Manifest &manifest) { -+ -+ make_conv_all< -+ 2, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ cutlass::half_t, -+ cutlass::half_t -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ float, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::bfloat16_t, cutlass::layout::TensorNHWC, -+ cutlass::bfloat16_t, cutlass::layout::TensorNHWC, -+ cutlass::bfloat16_t, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::bfloat16_t, cutlass::layout::TensorNHWC, -+ cutlass::bfloat16_t, cutlass::layout::TensorNHWC, -+ float, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::tfloat32_t, cutlass::layout::TensorNHWC, -+ cutlass::tfloat32_t, cutlass::layout::TensorNHWC, -+ cutlass::tfloat32_t, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::tfloat32_t, cutlass::layout::TensorNHWC, -+ cutlass::tfloat32_t, cutlass::layout::TensorNHWC, -+ float, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ float, cutlass::layout::TensorNHWC, -+ float, cutlass::layout::TensorNHWC, -+ float, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::complex, cutlass::layout::TensorNHWC, -+ cutlass::complex, cutlass::layout::TensorNHWC, -+ cutlass::complex, cutlass::layout::TensorNHWC, -+ cutlass::complex, -+ cutlass::complex -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ int8_t, cutlass::layout::TensorNHWC, -+ int8_t, cutlass::layout::TensorNHWC, -+ int32_t, cutlass::layout::TensorNHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ int8_t, cutlass::layout::TensorNHWC, -+ int8_t, cutlass::layout::TensorNHWC, -+ int8_t, cutlass::layout::TensorNHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ uint8_t, cutlass::layout::TensorNHWC, -+ uint8_t, cutlass::layout::TensorNHWC, -+ uint8_t, cutlass::layout::TensorNHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ uint8_t, cutlass::layout::TensorNHWC, -+ uint8_t, cutlass::layout::TensorNHWC, -+ int32_t, cutlass::layout::TensorNHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ uint8_t, cutlass::layout::TensorNHWC, -+ uint8_t, cutlass::layout::TensorNHWC, -+ int8_t, cutlass::layout::TensorNHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ cutlass::int4b_t, cutlass::layout::TensorNHWC, -+ cutlass::int4b_t, cutlass::layout::TensorNHWC, -+ int32_t, cutlass::layout::TensorNHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ cutlass::int4b_t, cutlass::layout::TensorNHWC, -+ cutlass::int4b_t, cutlass::layout::TensorNHWC, -+ cutlass::int4b_t, cutlass::layout::TensorNHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ cutlass::uint4b_t, cutlass::layout::TensorNHWC, -+ cutlass::uint4b_t, cutlass::layout::TensorNHWC, -+ int32_t, cutlass::layout::TensorNHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ cutlass::uint4b_t, cutlass::layout::TensorNHWC, -+ cutlass::uint4b_t, cutlass::layout::TensorNHWC, -+ cutlass::uint4b_t, cutlass::layout::TensorNHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/reference/conv3d.cu b/3rdparty/cutlass/tools/library/src/reference/conv3d.cu -new file mode 100644 -index 0000000..a0f9069 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reference/conv3d.cu -@@ -0,0 +1,209 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+ -+#include "conv_reference_operation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void initialize_conv3d_reference_operations(Manifest &manifest) { -+ -+ make_conv_all< -+ 3, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ cutlass::half_t, -+ cutlass::half_t -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ float, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ cutlass::bfloat16_t, cutlass::layout::TensorNDHWC, -+ cutlass::bfloat16_t, cutlass::layout::TensorNDHWC, -+ cutlass::bfloat16_t, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ cutlass::bfloat16_t, cutlass::layout::TensorNDHWC, -+ cutlass::bfloat16_t, cutlass::layout::TensorNDHWC, -+ float, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ cutlass::tfloat32_t, cutlass::layout::TensorNDHWC, -+ cutlass::tfloat32_t, cutlass::layout::TensorNDHWC, -+ cutlass::tfloat32_t, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ cutlass::tfloat32_t, cutlass::layout::TensorNDHWC, -+ cutlass::tfloat32_t, cutlass::layout::TensorNDHWC, -+ float, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ float, cutlass::layout::TensorNDHWC, -+ float, cutlass::layout::TensorNDHWC, -+ float, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ int8_t, cutlass::layout::TensorNDHWC, -+ int8_t, cutlass::layout::TensorNDHWC, -+ int32_t, cutlass::layout::TensorNDHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ int8_t, cutlass::layout::TensorNDHWC, -+ int8_t, cutlass::layout::TensorNDHWC, -+ int8_t, cutlass::layout::TensorNDHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ uint8_t, cutlass::layout::TensorNDHWC, -+ uint8_t, cutlass::layout::TensorNDHWC, -+ int32_t, cutlass::layout::TensorNDHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ uint8_t, cutlass::layout::TensorNDHWC, -+ uint8_t, cutlass::layout::TensorNDHWC, -+ int8_t, cutlass::layout::TensorNDHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ cutlass::int4b_t, cutlass::layout::TensorNDHWC, -+ cutlass::int4b_t, cutlass::layout::TensorNDHWC, -+ int32_t, cutlass::layout::TensorNDHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ cutlass::int4b_t, cutlass::layout::TensorNDHWC, -+ cutlass::int4b_t, cutlass::layout::TensorNDHWC, -+ cutlass::int4b_t, cutlass::layout::TensorNDHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ cutlass::uint4b_t, cutlass::layout::TensorNDHWC, -+ cutlass::uint4b_t, cutlass::layout::TensorNDHWC, -+ int32_t, cutlass::layout::TensorNDHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ cutlass::uint4b_t, cutlass::layout::TensorNDHWC, -+ cutlass::uint4b_t, cutlass::layout::TensorNDHWC, -+ cutlass::uint4b_t, cutlass::layout::TensorNDHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/reference/conv_reference_operation.h b/3rdparty/cutlass/tools/library/src/reference/conv_reference_operation.h -new file mode 100644 -index 0000000..3a294a2 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reference/conv_reference_operation.h -@@ -0,0 +1,632 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all CONV operation kinds in CUTLASS Library -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/util.h" -+#include "library_internal.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ Provider kProvider, -+ conv::Operator ConvolutionalOperator, -+ int ConvDim, -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename LayoutB_, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+struct ConvReferenceDispatcher; -+ -+/// Dispatcher for Conv2d (partially specialied for kConvDim == 2) -+template < -+ Provider kProvider, -+ conv::Operator kConvolutionalOperator, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator, -+ typename ConvertOp, -+ typename InnerProductOp -+> -+struct ConvReferenceDispatcher< -+ kProvider, -+ kConvolutionalOperator, -+ 2, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp> { -+ -+ static Status dispatch( -+ void const *configuration, -+ ElementA *ptr_A, -+ ElementB *ptr_B, -+ ElementC *ptr_C, -+ ElementC *ptr_D, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr -+ ) { -+ -+ Conv2dConfiguration const &config = -+ *static_cast(configuration); -+ -+ // TODO: make below code more general. It is fixed for NHWC now. -+ layout::TensorNHWC layout_a; -+ layout::TensorNHWC layout_b; -+ layout::TensorNHWC layout_c; -+ -+ layout_a.stride() = -+ make_Coord(int32_t(config.stride_a[0]), -+ int32_t(config.stride_a[1]), -+ int32_t(config.stride_a[2])); -+ -+ layout_b.stride() = -+ make_Coord(int32_t(config.stride_b[0]), -+ int32_t(config.stride_b[1]), -+ int32_t(config.stride_b[2])); -+ -+ layout_c.stride() = -+ make_Coord(int32_t(config.stride_c[0]), -+ int32_t(config.stride_c[1]), -+ int32_t(config.stride_c[2])); -+ -+ if (kProvider == Provider::kReferenceHost) { -+ -+ cutlass::reference::host::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC , -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >( -+ kConvolutionalOperator, -+ config.problem_size, -+ {ptr_A, layout_a}, -+ {ptr_B, layout_b}, -+ {ptr_C, layout_c}, -+ {ptr_D, layout_c}, -+ alpha, -+ beta -+ ); -+ -+ return Status::kSuccess; -+ } -+ else if (kProvider == Provider::kReferenceDevice) { -+ return cutlass::reference::device::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >( -+ kConvolutionalOperator, -+ config.problem_size, -+ {ptr_A, layout_a}, -+ {ptr_B, layout_b}, -+ {ptr_C, layout_c}, -+ {ptr_D, layout_c}, -+ alpha, -+ beta, -+ stream -+ ); -+ } -+ return Status::kErrorNotSupported; -+ } -+}; -+ -+/// Dispatcher for Conv3d (partially specialized for kConvDim == 3) -+template < -+ Provider kProvider, -+ conv::Operator kConvolutionalOperator, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator, -+ typename ConvertOp, -+ typename InnerProductOp -+> -+struct ConvReferenceDispatcher< -+ kProvider, -+ kConvolutionalOperator, -+ 3, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp> { -+ -+ static Status dispatch( -+ void const *configuration, -+ ElementA *ptr_A, -+ ElementB *ptr_B, -+ ElementC *ptr_C, -+ ElementC *ptr_D, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr -+ ) { -+ -+ Conv3dConfiguration const &config = -+ *static_cast(configuration); -+ -+ ConvKind const conv_kind = ConvKindMap::kId; -+ -+ if (kProvider == Provider::kReferenceHost) { -+ cutlass::reference::host::Conv3d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC , -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >( -+ kConvolutionalOperator, -+ config.problem_size, -+ {ptr_A, config.layout_a(conv_kind)}, -+ {ptr_B, config.layout_b(conv_kind)}, -+ {ptr_C, config.layout_c(conv_kind)}, -+ {ptr_D, config.layout_c(conv_kind)}, -+ alpha, -+ beta -+ ); -+ -+ return Status::kSuccess; -+ } -+ else if (kProvider == Provider::kReferenceDevice) { -+ return cutlass::reference::device::Conv3d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >( -+ kConvolutionalOperator, -+ config.problem_size, -+ {ptr_A, config.layout_a(conv_kind)}, -+ {ptr_B, config.layout_b(conv_kind)}, -+ {ptr_C, config.layout_c(conv_kind)}, -+ {ptr_D, config.layout_c(conv_kind)}, -+ alpha, -+ beta, -+ stream -+ ); -+ } -+ return Status::kErrorNotSupported; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ Provider Provider_, -+ conv::Operator ConvolutionalOperator, -+ int ConvDim, -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename LayoutB_, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+class ConvReferenceOperation : public Operation { -+public: -+ static Provider const kProvider = Provider_; -+ static conv::Operator const kConvolutionalOperator = ConvolutionalOperator; -+ static int const kConvDim = ConvDim; -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using ElementCompute = ElementCompute_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ConvertOp = ConvertOp_; -+ using InnerProductOp = InnerProductOp_; -+ -+protected: -+ -+ /// Storage for the name string -+ std::string name_; -+ -+ /// -+ ConvDescription description_; -+ -+public: -+ -+ /// Constructor -+ ConvReferenceOperation() { -+ -+ // Basic information -+ description_.provider = kProvider; -+ description_.kind = (kConvDim == 2 ? OperationKind::kConv2d : OperationKind::kConv3d); -+ description_.conv_kind = ConvKindMap::kId; -+ description_.conv_dim = kConvDim; -+ -+ // Tensor description -+ description_.A = make_TensorDescription(); -+ description_.B = make_TensorDescription(); -+ description_.C = make_TensorDescription(); -+ -+ // Epilogue compute and accumulator type description -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ // Iterator algorithm for convolution reference -+ description_.iterator_algorithm = IteratorAlgorithmID::kNone; -+ -+ // Compute capability for convolution reference -+ description_.tile_description.minimum_compute_capability = -+ (kProvider == Provider::kReferenceDevice ? 50 : 0); -+ -+ description_.tile_description.maximum_compute_capability = 1024; -+ -+ // Procedural name -+ std::stringstream ss; -+ -+ ss << "conv" << kConvDim << "d_" << to_string(description_.conv_kind) -+ << "_reference_" << to_string(description_.provider) -+ << "_" << to_string(description_.A.element) << to_string(description_.A.layout) -+ << "_" << to_string(description_.B.element) << to_string(description_.B.layout) -+ << "_" << to_string(description_.C.element) << to_string(description_.C.layout) -+ << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); -+ -+ name_ = ss.str(); -+ -+ description_.name = name_.c_str(); -+ -+ // Epilogue compute and accumulator type description -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ } -+ -+ /// Returns the description of the GEMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+ -+ virtual Status can_implement( -+ void const *configuration, -+ void const *arguments) const { -+ -+ return Status::kSuccess; -+ } -+ -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ switch (kConvDim) { -+ case 2: -+ return sizeof(Conv2dConfiguration); -+ case 3: -+ return sizeof(Conv3dConfiguration); -+ default: -+ break; -+ } -+ -+ return 0; -+ } -+ -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration, -+ void const *arguments = nullptr) const { -+ -+ return 0; -+ } -+ -+ virtual Status initialize( -+ void const *configuration, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); -+ -+ return Status::kSuccess; -+ } -+ -+ virtual Status run( -+ void const *arguments, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ ConvArguments const &args = *static_cast(arguments); -+ -+ ElementCompute alpha; -+ ElementCompute beta; -+ -+ alpha = *static_cast(args.alpha); -+ beta = *static_cast(args.beta); -+ -+ // TODO - respect pointer mode -+ -+ // Invoke 2D or 3D convolution -+ return detail::ConvReferenceDispatcher< -+ kProvider, -+ kConvolutionalOperator, -+ kConvDim, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >::dispatch( -+ host_workspace, -+ static_cast(const_cast(args.A)), -+ static_cast(const_cast(args.B)), -+ static_cast(const_cast(args.C)), -+ static_cast(args.D), -+ alpha, -+ beta, -+ stream -+ ); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Constructs Fprop reference operators. -+template < -+ int kConvDim, -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename LayoutB_, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_conv_fprop(Manifest &manifest) { -+ -+ manifest.append(new ConvReferenceOperation< -+ Provider::kReferenceHost, -+ conv::Operator::kFprop, -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+ -+ manifest.append(new ConvReferenceOperation< -+ Provider::kReferenceDevice, -+ conv::Operator::kFprop, -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+} -+ -+/// Constructs Dgrad and Wgrad reference operators. -+template < -+ int kConvDim, -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename LayoutB_, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_conv_backwards(Manifest &manifest) { -+ -+ manifest.append(new ConvReferenceOperation< -+ Provider::kReferenceHost, -+ conv::Operator::kDgrad, -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+ -+ manifest.append(new ConvReferenceOperation< -+ Provider::kReferenceDevice, -+ conv::Operator::kDgrad, -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+ -+ manifest.append(new ConvReferenceOperation< -+ Provider::kReferenceHost, -+ conv::Operator::kWgrad, -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+ -+ manifest.append(new ConvReferenceOperation< -+ Provider::kReferenceDevice, -+ conv::Operator::kWgrad, -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+} -+ -+/// Six operators for the price of one. -+template < -+ int kConvDim, -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename LayoutB_, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_conv_all(Manifest &manifest) { -+ -+ make_conv_fprop< -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_conv_backwards< -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/reference/gemm.cu b/3rdparty/cutlass/tools/library/src/reference/gemm.cu -new file mode 100644 -index 0000000..890772e ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reference/gemm.cu -@@ -0,0 +1,341 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Instantiates GEMM reference implementations. -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+ -+#include "gemm_reference_operation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void initialize_gemm_reference_operations(Manifest &manifest) { -+ -+ make_gemm_real_canonical_layouts< -+ float, // ElementA -+ float, // ElementB -+ float, // ElementC -+ float, // ElementScalar -+ float // ElementAccumulator -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ tfloat32_t, -+ tfloat32_t, -+ float, -+ float, -+ float -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ tfloat32_t, -+ tfloat32_t, -+ tfloat32_t, -+ float, -+ float -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ half_t, -+ half_t, -+ half_t, -+ float, -+ float -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ half_t, -+ half_t, -+ half_t, -+ half_t, -+ half_t -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ half_t, -+ half_t, -+ float, -+ float, -+ float -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ bfloat16_t, -+ bfloat16_t, -+ bfloat16_t, -+ float, -+ float -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ bfloat16_t, -+ bfloat16_t, -+ float, -+ float, -+ float -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ double, -+ double, -+ double, -+ double, -+ double -+ >(manifest); -+ -+ // -+ // Integer-valued GEMMs -+ // -+ -+ make_gemm_real_canonical_layouts< -+ int8_t, -+ int8_t, -+ int32_t, -+ int32_t, -+ int32_t -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ int8_t, -+ int8_t, -+ int8_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ int8_t, -+ int8_t, -+ int32_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ uint8_t, -+ uint8_t, -+ int32_t, -+ int32_t, -+ int32_t -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ uint8_t, -+ uint8_t, -+ int8_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ uint8_t, -+ uint8_t, -+ int32_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ int8_t, -+ int8_t, -+ int32_t, -+ int32_t, -+ int32_t -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ int8_t, -+ int8_t, -+ int32_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ int8_t, -+ int8_t, -+ int8_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ uint8_t, -+ uint8_t, -+ int32_t, -+ int32_t, -+ int32_t -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ uint8_t, -+ uint8_t, -+ int32_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ uint8_t, -+ uint8_t, -+ uint8_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ uint8_t, -+ uint8_t, -+ int8_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ int4b_t, -+ int4b_t, -+ int32_t, -+ int32_t, -+ int32_t -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ int4b_t, -+ int4b_t, -+ int32_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ int4b_t, -+ int4b_t, -+ int4b_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ uint4b_t, -+ uint4b_t, -+ int32_t, -+ int32_t, -+ int32_t -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ uint4b_t, -+ uint4b_t, -+ int32_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ uint4b_t, -+ uint4b_t, -+ uint4b_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ uint4b_t, -+ uint4b_t, -+ int4b_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ // -+ // Complex-valued GEMMs -+ // -+ -+ make_gemm_complex_canonical_layouts< -+ complex, -+ complex, -+ complex, -+ complex, -+ complex -+ >(manifest); -+ -+ make_gemm_complex_canonical_layouts< -+ complex, -+ complex, -+ complex, -+ complex, -+ complex -+ >(manifest); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/reference/gemm_reference_operation.h b/3rdparty/cutlass/tools/library/src/reference/gemm_reference_operation.h -new file mode 100644 -index 0000000..5d4d150 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reference/gemm_reference_operation.h -@@ -0,0 +1,473 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines reference operations for GEMM operation kinds in CUTLASS Library -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/util.h" -+#include "library_internal.h" -+ -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ Provider Provider_, -+ typename ElementA_, -+ typename LayoutA_, -+ cutlass::ComplexTransform TransformA, -+ typename ElementB_, -+ typename LayoutB_, -+ cutlass::ComplexTransform TransformB, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+class GemmReferenceOperation : public Operation { -+public: -+ static Provider const kProvider = Provider_; -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ static cutlass::ComplexTransform const kTransformA = TransformA; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ static cutlass::ComplexTransform const kTransformB = TransformB; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using ElementCompute = ElementCompute_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ConvertOp = ConvertOp_; -+ using InnerProductOp = InnerProductOp_; -+ -+protected: -+ -+ /// Storage for the name string -+ std::string name_; -+ -+ /// -+ GemmDescription description_; -+ -+public: -+ -+ /// Constructor -+ GemmReferenceOperation() { -+ -+ // Basic information -+ description_.provider = kProvider; -+ description_.kind = OperationKind::kGemm; -+ description_.gemm_kind = GemmKind::kUniversal; -+ -+ // Tensor description -+ description_.A = make_TensorDescription(); -+ description_.transform_A = ComplexTransformMap::kId; -+ description_.B = make_TensorDescription(); -+ description_.transform_B = ComplexTransformMap::kId; -+ description_.C = make_TensorDescription(); -+ -+ // Epilogue compute and accumulator type description -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ // Compute capability for gemm reference -+ description_.tile_description.minimum_compute_capability = -+ (kProvider == Provider::kReferenceDevice ? 50 : 0); -+ -+ description_.tile_description.maximum_compute_capability = 1024; -+ -+ // Procedural name -+ std::stringstream ss; -+ -+ ss << "gemm" -+ << "_reference_" << to_string(description_.provider) -+ << "_" << to_string(description_.A.element) << to_string(description_.A.layout) -+ << "_" << to_string(description_.B.element) << to_string(description_.B.layout) -+ << "_" << to_string(description_.C.element) << to_string(description_.C.layout) -+ << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); -+ -+ name_ = ss.str(); -+ -+ description_.name = name_.c_str(); -+ -+ // Epilogue compute and accumulator type description -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ } -+ -+ /// Returns the description of the GEMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+ -+ virtual Status can_implement( -+ void const *configuration, -+ void const *arguments) const { -+ -+ return Status::kSuccess; -+ } -+ -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(GemmUniversalConfiguration); -+ } -+ -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration, -+ void const *arguments = nullptr) const { -+ -+ return 0; -+ } -+ -+ virtual Status initialize( -+ void const *configuration, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); -+ -+ return Status::kSuccess; -+ } -+ -+ virtual Status run( -+ void const *arguments, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ GemmUniversalConfiguration const &config = *static_cast(host_workspace); -+ GemmUniversalArguments const &args = *static_cast(arguments); -+ -+ TensorRefA ref_A{static_cast(const_cast(args.A)), LayoutA(int(config.lda))}; -+ TensorRefB ref_B{static_cast(const_cast(args.B)), LayoutB(int(config.ldb))}; -+ TensorRefC ref_C{static_cast(const_cast(args.C)), LayoutC(int(config.ldc))}; -+ TensorRefC ref_D{static_cast(args.D), LayoutC(int(config.ldd))}; -+ -+ if (kProvider == Provider::kReferenceHost) { -+ -+ cutlass::reference::host::GemmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >( -+ config.problem_size, -+ *static_cast(args.alpha), -+ ref_A, -+ kTransformA, -+ ref_B, -+ kTransformB, -+ *static_cast(args.beta), -+ ref_C, -+ ref_D, -+ ElementAccumulator(), -+ ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), -+ args.batch_stride_A, -+ args.batch_stride_B, -+ args.batch_stride_C, -+ args.batch_stride_D -+ ); -+ -+ return Status::kSuccess; -+ } -+ else if (kProvider == Provider::kReferenceDevice) { -+ -+ cutlass::reference::device::GemmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >( -+ config.problem_size, -+ *static_cast(args.alpha), -+ ref_A, -+ kTransformA, -+ ref_B, -+ kTransformB, -+ *static_cast(args.beta), -+ ref_C, -+ ref_D, -+ ElementAccumulator(), -+ ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), -+ args.batch_stride_A, -+ args.batch_stride_B, -+ args.batch_stride_C, -+ args.batch_stride_D -+ ); -+ -+ return Status::kSuccess; -+ } -+ -+ return Status::kErrorNotSupported; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ cutlass::ComplexTransform TransformA, -+ typename ElementB_, -+ typename LayoutB_, -+ cutlass::ComplexTransform TransformB, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_gemm(Manifest &manifest) { -+ -+ manifest.append(new GemmReferenceOperation< -+ Provider::kReferenceHost, -+ ElementA_, LayoutA_, TransformA, -+ ElementB_, LayoutB_, TransformB, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+ -+ manifest.append(new GemmReferenceOperation< -+ Provider::kReferenceDevice, -+ ElementA_, LayoutA_, TransformA, -+ ElementB_, LayoutB_, TransformB, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+} -+ -+/// Helper to create NN, NT, TN, and TT GEMM layouts. -+template < -+ typename ElementA_, cutlass::ComplexTransform TransformA, -+ typename ElementB_, cutlass::ComplexTransform TransformB, -+ typename ElementC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_gemm_canonical_layouts(Manifest &manifest) { -+ -+ make_gemm< -+ ElementA_, cutlass::layout::ColumnMajor, TransformA, -+ ElementB_, cutlass::layout::ColumnMajor, TransformB, -+ ElementC_, cutlass::layout::ColumnMajor, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_gemm< -+ ElementA_, cutlass::layout::ColumnMajor, TransformA, -+ ElementB_, cutlass::layout::RowMajor, TransformB, -+ ElementC_, cutlass::layout::ColumnMajor, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_gemm< -+ ElementA_, cutlass::layout::RowMajor, TransformA, -+ ElementB_, cutlass::layout::ColumnMajor, TransformB, -+ ElementC_, cutlass::layout::ColumnMajor, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_gemm< -+ ElementA_, cutlass::layout::RowMajor, TransformA, -+ ElementB_, cutlass::layout::RowMajor, TransformB, -+ ElementC_, cutlass::layout::ColumnMajor, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+} -+ -+ -+/// Helper to create TN and interleaved layouts GEMM layouts. -+template < -+ int InterleaveK, -+ typename ElementA_, -+ typename ElementB_, -+ typename ElementC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_gemm_interleaved_layouts(Manifest &manifest) { -+ -+ make_gemm< -+ ElementA_, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, -+ ElementB_, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, -+ ElementC_, cutlass::layout::ColumnMajor, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+} -+ -+/// Helper to real-valued GEMM with canonical layouts -+template < -+ typename ElementA_, -+ typename ElementB_, -+ typename ElementC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_gemm_real_canonical_layouts(Manifest &manifest) { -+ make_gemm_canonical_layouts< -+ ElementA_, cutlass::ComplexTransform::kNone, -+ ElementB_, cutlass::ComplexTransform::kNone, -+ ElementC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+} -+ -+// Helper to create all complex transformation permutations -+template < -+ typename ElementA_, -+ typename ElementB_, -+ typename ElementC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_gemm_complex_canonical_layouts(Manifest &manifest) { -+ -+ make_gemm_canonical_layouts< -+ ElementA_, cutlass::ComplexTransform::kNone, -+ ElementB_, cutlass::ComplexTransform::kNone, -+ ElementC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_gemm_canonical_layouts< -+ ElementA_, cutlass::ComplexTransform::kConjugate, -+ ElementB_, cutlass::ComplexTransform::kConjugate, -+ ElementC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_gemm_canonical_layouts< -+ ElementA_, cutlass::ComplexTransform::kNone, -+ ElementB_, cutlass::ComplexTransform::kConjugate, -+ ElementC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_gemm_canonical_layouts< -+ ElementA_, cutlass::ComplexTransform::kConjugate, -+ ElementB_, cutlass::ComplexTransform::kNone, -+ ElementC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/reference/initialize_reference_operations.cu b/3rdparty/cutlass/tools/library/src/reference/initialize_reference_operations.cu -new file mode 100644 -index 0000000..b63367e ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reference/initialize_reference_operations.cu -@@ -0,0 +1,63 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+ -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+void initialize_gemm_reference_operations(Manifest &manifest); -+void initialize_conv2d_reference_operations(Manifest &manifest); -+void initialize_conv3d_reference_operations(Manifest &manifest); -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void initialize_reference_operations(Manifest &manifest) { -+ initialize_conv2d_reference_operations(manifest); -+ initialize_conv3d_reference_operations(manifest); -+ initialize_gemm_reference_operations(manifest); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/singleton.cu b/3rdparty/cutlass/tools/library/src/singleton.cu -new file mode 100644 -index 0000000..2315448 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/singleton.cu -@@ -0,0 +1,62 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/operation_table.h" -+#include "cutlass/library/singleton.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Singleton::Singleton() { -+ -+ manifest.initialize(); -+ -+ operation_table.append(manifest); -+} -+ -+Singleton const & Singleton::get() { -+ static Singleton instance; -+ return instance; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/symm_operation.h b/3rdparty/cutlass/tools/library/src/symm_operation.h -new file mode 100644 -index 0000000..d7554ed ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/symm_operation.h -@@ -0,0 +1,379 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all Symm operation kinds (Symm, Hemm) -+ in CUTLASS Library. -+ -+ -+*/ -+ -+#pragma once -+#include -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/gemm/kernel/default_symm_universal.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+#include "cutlass/core_io.h" -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class SymmOperationBase : public Operation { -+public: -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static BlasMode const kBlasMode = Operator::kBlasMode; -+ static SideMode const kSideModeA = Operator::kSideModeA; -+ static FillMode const kFillModeA = Operator::kFillModeA; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ SymmDescription description_; -+ -+public: -+ -+ /// Constructor -+ SymmOperationBase(char const *name = "unknown_symm") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.symm_kind = SymmKind::kUniversal; -+ description_.side_mode = kSideModeA; -+ description_.fill_mode = kFillModeA; -+ description_.blas_mode = kBlasMode; -+ -+ description_.kind = OperationKind::kSymm; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::SymmKernel::WarpCount::kM, -+ Operator::SymmKernel::WarpCount::kN, -+ Operator::SymmKernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(Operator::kAlignmentA); -+ description_.B = make_TensorDescription(Operator::kAlignmentB); -+ description_.C = make_TensorDescription(Operator::kAlignmentC); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.split_k_mode = SplitKMode::kNone; -+ } -+ -+ /// Returns the description of the SYMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class SymmOperation : public SymmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ static BlasMode const kBlasMode = Operator::kBlasMode; -+ static SideMode const kSideModeA = Operator::kSideModeA; -+ static FillMode const kFillModeA = Operator::kFillModeA; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ SymmOperation(char const *name = "unknown_symm"): -+ SymmOperationBase(name) { -+ -+ this->description_.symm_kind = SymmKind::kUniversal; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ SymmConfiguration const *configuration) { -+ -+ //operator_args.mode = configuration->mode; -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ operator_args.lda = int(configuration->lda); -+ operator_args.ldb = int(configuration->ldb); -+ operator_args.ldc = int(configuration->ldc); -+ operator_args.ldd = int(configuration->ldd); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ SymmArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A = arguments->A; -+ operator_args.ptr_B = arguments->B; -+ operator_args.ptr_C = arguments->C; -+ operator_args.ptr_D = arguments->D; -+ -+ operator_args.batch_stride_A = arguments->batch_stride_A; -+ operator_args.batch_stride_B = arguments->batch_stride_B; -+ operator_args.batch_stride_C = arguments->batch_stride_C; -+ operator_args.batch_stride_D = arguments->batch_stride_D; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ SymmConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ SymmArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ //std::cout << "initialize() library::SymmOperation" << std::endl; -+ //print_operator_args(args); -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ bool need_swapped_matrices = (kSideModeA == SideMode::kLeft && -+ std::is_same::value) || -+ (kSideModeA == SideMode::kRight && -+ std::is_same::value); -+ if (need_swapped_matrices) { -+ status = op->update(args.swapped_matrices(), device_workspace); -+ } else { -+ status = op->update(args, device_workspace); -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ //std::cout << "run() library::SymmOperation" << std::endl; -+ //print_operator_args(args); -+ status = op->run(stream); -+ -+ return status; -+ } -+ -+ /// Call print_operator_args from the Conv2dOperation::initialize() -+ // to dump arguments passed on to cutlass operator for debugging -+ void print_operator_args(OperatorArguments &operator_args) const { -+ std::cout << "SymmOperation::OperatorArguments" << std::endl -+ << " problem_size:" << std::endl -+ << operator_args.problem_size << std::endl -+ << " epilouge (alpha, beta): " -+ << operator_args.epilogue.alpha << ", " -+ << operator_args.epilogue.beta << std::endl -+ << " ref_A (ptr, {stride}): " -+ << operator_args.ptr_A << ", {" -+ << operator_args.lda << "}" << std::endl -+ << " ref_B (ptr, {stride}): " -+ << operator_args.ptr_B << ", {" -+ << operator_args.ldb << "}" << std::endl -+ << " ref_C (ptr, {stride}): " -+ << operator_args.ptr_C << ", {" -+ << operator_args.ldc << "}" << std::endl -+ << " ref_D (ptr, {stride}): " -+ << operator_args.ptr_D << ", {" -+ << operator_args.ldd << "}" << std::endl; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/trmm_operation.h b/3rdparty/cutlass/tools/library/src/trmm_operation.h -new file mode 100644 -index 0000000..55f4fa6 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/trmm_operation.h -@@ -0,0 +1,346 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all TRMM operation kinds in CUTLASS Library. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/gemm/kernel/default_trmm_universal.h" -+#include "cutlass/gemm/kernel/trmm_universal.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TrmmOperationBase : public Operation { -+public: -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ static SideMode const kSideMode = Operator::kSideMode; -+ static FillMode const kFillMode = Operator::kFillMode; -+ static DiagType const kDiagType = Operator::kDiagType; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ TrmmDescription description_; -+ -+public: -+ -+ /// Constructor -+ TrmmOperationBase(char const *name = "unknown_trmm") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.kind = OperationKind::kTrmm; -+ description_.trmm_kind = TrmmKind::kUniversal; -+ description_.side_mode = kSideMode; -+ description_.fill_mode = kFillMode; -+ description_.diag_type = kDiagType; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::TrmmKernel::WarpCount::kM, -+ Operator::TrmmKernel::WarpCount::kN, -+ Operator::TrmmKernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(Operator::kAlignmentA); -+ description_.B = make_TensorDescription(Operator::kAlignmentB); -+ description_.D = make_TensorDescription(Operator::kAlignmentC); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.split_k_mode = SplitKMode::kNone; -+ description_.transform_A = ComplexTransformMap::kId; -+ } -+ -+ /// Returns the description of the TRMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TrmmOperation : public TrmmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ static SideMode const kSideMode = Operator::kSideMode; -+ static FillMode const kFillMode = Operator::kFillMode; -+ static DiagType const kDiagType = Operator::kDiagType; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ TrmmOperation(char const *name = "unknown_trmm"): -+ TrmmOperationBase(name) { -+ -+ this->description_.trmm_kind = TrmmKind::kUniversal; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ TrmmConfiguration const *configuration) { -+ -+ //operator_args.mode = configuration->mode; -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ operator_args.lda = int(configuration->lda); -+ operator_args.ldb = int(configuration->ldb); -+ operator_args.ldd = int(configuration->ldd); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ TrmmArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A = arguments->A; -+ operator_args.ptr_B = arguments->B; -+ operator_args.batch_stride_A = arguments->batch_stride_A; -+ operator_args.batch_stride_B = arguments->batch_stride_B; -+ operator_args.ptr_D = arguments->D; -+ operator_args.batch_stride_D = arguments->batch_stride_D; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ TrmmConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ TrmmArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ bool need_swapped_matrices = (kSideMode == SideMode::kLeft && -+ std::is_same::value) || -+ (kSideMode == SideMode::kRight && -+ std::is_same::value); -+ if (need_swapped_matrices) { -+ status = op->update(args.swapped_matrices(), device_workspace); -+ } else { -+ status = op->update(args, device_workspace); -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = op->run(stream); -+ -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/util.cu b/3rdparty/cutlass/tools/library/src/util.cu -new file mode 100644 -index 0000000..a4e234a ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/util.cu -@@ -0,0 +1,1599 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/complex.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ Provider enumerant; -+} -+Provider_enumerants[] = { -+ {"none", "None", Provider::kNone}, -+ {"cutlass", "CUTLASS", Provider::kCUTLASS}, -+ {"host", "reference_host", Provider::kReferenceHost}, -+ {"device", "reference_device", Provider::kReferenceDevice}, -+ {"cublas", "cuBLAS", Provider::kCUBLAS}, -+ {"cudnn", "cuDNN", Provider::kCUDNN}, -+}; -+ -+/// Converts a Provider enumerant to a string -+char const *to_string(Provider provider, bool pretty) { -+ -+ for (auto const & possible : Provider_enumerants) { -+ if (provider == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Parses a Provider enumerant from a string -+template <> -+Provider from_string(std::string const &str) { -+ -+ for (auto const & possible : Provider_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return Provider::kInvalid; -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ GemmKind enumerant; -+} -+GemmKind_enumerants[] = { -+ {"gemm", "", GemmKind::kGemm}, -+ {"spgemm", "", GemmKind::kSparse}, -+ {"universal", "", GemmKind::kUniversal}, -+ {"planar_complex", "", GemmKind::kPlanarComplex}, -+ {"planar_complex_array", "", GemmKind::kPlanarComplexArray}, -+ {"grouped", "", GemmKind::kGrouped}, -+}; -+ -+/// Converts a GemmKind enumerant to a string -+char const *to_string(GemmKind type, bool pretty) { -+ -+ for (auto const & possible : GemmKind_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ RankKKind enumerant; -+} -+RankKKind_enumerants[] = { -+ {"universal", "", RankKKind::kUniversal}, -+}; -+ -+/// Converts a SyrkKind enumerant to a string -+char const *to_string(RankKKind type, bool pretty) { -+ -+ for (auto const & possible :RankKKind_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ TrmmKind enumerant; -+} -+TrmmKind_enumerants[] = { -+ {"universal", "", TrmmKind::kUniversal}, -+}; -+ -+/// Converts a TrmmKind enumerant to a string -+char const *to_string(TrmmKind type, bool pretty) { -+ -+ for (auto const & possible :TrmmKind_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ SymmKind enumerant; -+} -+SymmKind_enumerants[] = { -+ {"universal", "", SymmKind::kUniversal}, -+}; -+ -+/// Converts a SymmKind enumerant to a string -+char const *to_string(SymmKind type, bool pretty) { -+ -+ for (auto const & possible :SymmKind_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ SideMode enumerant; -+} -+SideMode_enumerants[] = { -+ {"left", "Left", SideMode::kLeft}, -+ {"right", "Right", SideMode::kRight} -+}; -+ -+/// Converts a SideMode enumerant to a string -+char const *to_string(SideMode type, bool pretty) { -+ -+ for (auto const & possible :SideMode_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ FillMode enumerant; -+} -+FillMode_enumerants[] = { -+ {"lower", "Lower", FillMode::kLower}, -+ {"upper", "Upper", FillMode::kUpper} -+}; -+ -+/// Converts a FillMode enumerant to a string -+char const *to_string(FillMode type, bool pretty) { -+ -+ for (auto const & possible :FillMode_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ BlasMode enumerant; -+} -+BlasMode_enumerants[] = { -+ {"symmetric", "Symmetric", BlasMode::kSymmetric}, -+ {"hermitian", "Hermitian", BlasMode::kHermitian} -+}; -+ -+/// Converts a BlasMode enumerant to a string -+char const *to_string(BlasMode type, bool pretty) { -+ -+ for (auto const & possible :BlasMode_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ DiagType enumerant; -+} -+DiagType_enumerants[] = { -+ {"nonunit", "NonUnit", DiagType::kNonUnit}, -+ {"unit", "Unit", DiagType::kUnit} -+}; -+ -+/// Converts a DiagType enumerant to a string -+char const *to_string(DiagType type, bool pretty) { -+ -+ for (auto const & possible :DiagType_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ OperationKind enumerant; -+} -+OperationKind_enumerants[] = { -+ {"eq_gemm", "EqGemm", OperationKind::kEqGemm}, -+ {"gemm", "Gemm", OperationKind::kGemm}, -+ {"rank_k", "RankK", OperationKind::kRankK}, -+ {"rank_2k", "Rank2K", OperationKind::kRank2K}, -+ {"trmm", "Trmm", OperationKind::kTrmm}, -+ {"symm", "Symm", OperationKind::kSymm}, -+ {"conv2d", "Conv2d", OperationKind::kConv2d}, -+ {"conv3d", "Conv3d", OperationKind::kConv3d}, -+ {"spgemm", "SparseGemm", OperationKind::kSparseGemm}, -+}; -+ -+/// Converts a Status enumerant to a string -+char const *to_string(OperationKind enumerant, bool pretty) { -+ -+ for (auto const & possible : OperationKind_enumerants) { -+ if (enumerant == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a Status enumerant from a string -+template <> -+OperationKind from_string(std::string const &str) { -+ -+ for (auto const & possible : OperationKind_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return OperationKind::kInvalid; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ Status enumerant; -+} -+Status_enumerants[] = { -+ {"success", "Success", Status::kSuccess}, -+ {"misaligned_operand", "Error: misaligned operand", Status::kErrorMisalignedOperand}, -+ {"invalid_problem", "Error: invalid problem", Status::kErrorInvalidProblem}, -+ {"not_supported", "Error: not supported", Status::kErrorNotSupported}, -+ {"internal", "Error: internal", Status::kErrorInternal} -+}; -+ -+/// Converts a Status enumerant to a string -+char const *to_string(Status status, bool pretty) { -+ -+ for (auto const & possible : Status_enumerants) { -+ if (status == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a Status enumerant from a string -+template <> -+Status from_string(std::string const &str) { -+ -+ for (auto const & possible : Status_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return Status::kInvalid; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ NumericTypeID enumerant; -+} -+NumericTypeID_enumerants[] = { -+ {"unknown", "", NumericTypeID::kUnknown}, -+ {"void", "Void", NumericTypeID::kVoid}, -+ {"b1", "B1", NumericTypeID::kB1}, -+ {"u2", "U2", NumericTypeID::kU2}, -+ {"u4", "U4", NumericTypeID::kU4}, -+ {"u8", "U8", NumericTypeID::kU8}, -+ {"u16", "U16", NumericTypeID::kU16}, -+ {"u32", "U32", NumericTypeID::kU32}, -+ {"u64", "U64", NumericTypeID::kU64}, -+ {"s2", "S2", NumericTypeID::kS2}, -+ {"s4", "S4", NumericTypeID::kS4}, -+ {"s8", "S8", NumericTypeID::kS8}, -+ {"s16", "S16", NumericTypeID::kS16}, -+ {"s32", "S32", NumericTypeID::kS32}, -+ {"s64", "S64", NumericTypeID::kS64}, -+ {"f16", "F16", NumericTypeID::kF16}, -+ {"bf16", "BF16", NumericTypeID::kBF16}, -+ {"f32", "F32", NumericTypeID::kF32}, -+ {"tf32", "TF32", NumericTypeID::kTF32}, -+ {"f64", "F64", NumericTypeID::kF64}, -+ {"cf16", "CF16", NumericTypeID::kCF16}, -+ {"cbf16", "CBF16", NumericTypeID::kCBF16}, -+ {"cf32", "CF32", NumericTypeID::kCF32}, -+ {"ctf32", "CTF32", NumericTypeID::kCTF32}, -+ {"cf64", "CF64", NumericTypeID::kCF64}, -+ {"cu2", "CU2", NumericTypeID::kCU2}, -+ {"cu4", "CU4", NumericTypeID::kCU4}, -+ {"cu8", "CU8", NumericTypeID::kCU8}, -+ {"cu16", "CU16", NumericTypeID::kCU16}, -+ {"cu32", "CU32", NumericTypeID::kCU32}, -+ {"cu64", "CU64", NumericTypeID::kCU64}, -+ {"cs2", "CS2", NumericTypeID::kCS2}, -+ {"cs4", "CS4", NumericTypeID::kCS4}, -+ {"cs8", "CS8", NumericTypeID::kCS8}, -+ {"cs16", "CS16", NumericTypeID::kCS16}, -+ {"cs32", "CS32", NumericTypeID::kCS32}, -+ {"cs64", "CS64", NumericTypeID::kCS64}, -+ {"*", "", NumericTypeID::kUnknown} -+}; -+ -+/// Converts a NumericTypeID enumerant to a string -+char const *to_string(NumericTypeID type, bool pretty) { -+ -+ for (auto const & possible : NumericTypeID_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Parses a NumericTypeID enumerant from a string -+template <> -+NumericTypeID from_string(std::string const &str) { -+ -+ for (auto const & possible : NumericTypeID_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return NumericTypeID::kInvalid; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns the size of a data type in bits -+int sizeof_bits(NumericTypeID type) { -+ switch (type) { -+ case NumericTypeID::kF16: return 16; -+ case NumericTypeID::kBF16: return 16; -+ case NumericTypeID::kTF32: return 32; -+ case NumericTypeID::kF32: return 32; -+ case NumericTypeID::kF64: return 64; -+ case NumericTypeID::kCF16: return 32; -+ case NumericTypeID::kCBF16: return 32; -+ case NumericTypeID::kCF32: return 64; -+ case NumericTypeID::kCTF32: return 64; -+ case NumericTypeID::kCF64: return 128; -+ case NumericTypeID::kS2: return 2; -+ case NumericTypeID::kS4: return 4; -+ case NumericTypeID::kS8: return 8; -+ case NumericTypeID::kS16: return 16; -+ case NumericTypeID::kS32: return 32; -+ case NumericTypeID::kS64: return 64; -+ case NumericTypeID::kU2: return 2; -+ case NumericTypeID::kU4: return 4; -+ case NumericTypeID::kU8: return 8; -+ case NumericTypeID::kU16: return 16; -+ case NumericTypeID::kU32: return 32; -+ case NumericTypeID::kU64: return 64; -+ case NumericTypeID::kB1: return 1; -+ default: break; -+ } -+ return 0; -+} -+ -+/// Returns true if the numeric type is a complex data type or false if real-valued. -+bool is_complex_type(NumericTypeID type) { -+ switch (type) { -+ case NumericTypeID::kCF16: return true; -+ case NumericTypeID::kCF32: return true; -+ case NumericTypeID::kCF64: return true; -+ case NumericTypeID::kCBF16: return true; -+ case NumericTypeID::kCTF32: return true; -+ default: break; -+ } -+ return false; -+} -+ -+/// Returns the field underlying a complex valued type -+NumericTypeID get_real_type(NumericTypeID type) { -+ switch (type) { -+ case NumericTypeID::kCF16: return NumericTypeID::kF16; -+ case NumericTypeID::kCF32: return NumericTypeID::kF32; -+ case NumericTypeID::kCF64: return NumericTypeID::kF64; -+ case NumericTypeID::kCBF16: return NumericTypeID::kBF16; -+ case NumericTypeID::kCTF32: return NumericTypeID::kTF32; -+ default: break; -+ } -+ return type; -+} -+ -+/// Returns true if numeric type is integer -+bool is_integer_type(NumericTypeID type) { -+ switch (type) { -+ case NumericTypeID::kS2: return true; -+ case NumericTypeID::kS4: return true; -+ case NumericTypeID::kS8: return true; -+ case NumericTypeID::kS16: return true; -+ case NumericTypeID::kS32: return true; -+ case NumericTypeID::kS64: return true; -+ case NumericTypeID::kU2: return true; -+ case NumericTypeID::kU4: return true; -+ case NumericTypeID::kU8: return true; -+ case NumericTypeID::kU16: return true; -+ case NumericTypeID::kU32: return true; -+ case NumericTypeID::kU64: return true; -+ default: break; -+ } -+ return false; -+} -+ -+/// Returns true if numeric type is signed -+bool is_signed_type(NumericTypeID type) { -+ switch (type) { -+ case NumericTypeID::kF16: return true; -+ case NumericTypeID::kBF16: return true; -+ case NumericTypeID::kTF32: return true; -+ case NumericTypeID::kF32: return true; -+ case NumericTypeID::kF64: return true; -+ case NumericTypeID::kS2: return true; -+ case NumericTypeID::kS4: return true; -+ case NumericTypeID::kS8: return true; -+ case NumericTypeID::kS16: return true; -+ case NumericTypeID::kS32: return true; -+ case NumericTypeID::kS64: return true; -+ default: break; -+ } -+ return false; -+} -+ -+/// Returns true if numeric type is a signed integer -+bool is_signed_integer(NumericTypeID type) { -+ return is_integer_type(type) && is_signed_type(type); -+} -+ -+/// returns true if numeric type is an unsigned integer -+bool is_unsigned_integer(NumericTypeID type) { -+ return is_integer_type(type) && !is_signed_type(type); -+} -+ -+/// Returns true if numeric type is floating-point type -+bool is_float_type(NumericTypeID type) { -+ switch (type) { -+ case NumericTypeID::kF16: return true; -+ case NumericTypeID::kBF16: return true; -+ case NumericTypeID::kTF32: return true; -+ case NumericTypeID::kF32: return true; -+ case NumericTypeID::kF64: return true; -+ case NumericTypeID::kCF16: return true; -+ case NumericTypeID::kCBF16: return true; -+ case NumericTypeID::kCTF32: return true; -+ case NumericTypeID::kCF32: return true; -+ case NumericTypeID::kCF64: return true; -+ default: break; -+ } -+ return false; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ LayoutTypeID layout; -+ char const *alias; -+} -+layout_aliases[] = { -+ {LayoutTypeID::kUnknown, "unknown"}, -+ {LayoutTypeID::kRowMajor, "row"}, -+ {LayoutTypeID::kRowMajor, "t"}, -+ {LayoutTypeID::kColumnMajor, "column"}, -+ {LayoutTypeID::kColumnMajor, "col"}, -+ {LayoutTypeID::kColumnMajor, "n"}, -+ -+ {LayoutTypeID::kColumnMajorInterleavedK2, "nk2"}, -+ {LayoutTypeID::kRowMajorInterleavedK2, "tk2"}, -+ -+ {LayoutTypeID::kColumnMajorInterleavedK4, "nk4"}, -+ {LayoutTypeID::kRowMajorInterleavedK4, "tk4"}, -+ -+ {LayoutTypeID::kColumnMajorInterleavedK16, "nk16"}, -+ {LayoutTypeID::kRowMajorInterleavedK16, "tk16"}, -+ -+ {LayoutTypeID::kColumnMajorInterleavedK32, "nk32"}, -+ {LayoutTypeID::kRowMajorInterleavedK32, "tk32"}, -+ -+ {LayoutTypeID::kColumnMajorInterleavedK64, "nk64"}, -+ {LayoutTypeID::kRowMajorInterleavedK64, "tk64"}, -+ -+ {LayoutTypeID::kTensorNCHW, "nchw"}, -+ {LayoutTypeID::kTensorNCDHW, "ncdhw"}, -+ {LayoutTypeID::kTensorNHWC, "nhwc"}, -+ {LayoutTypeID::kTensorNDHWC, "ndhwc"}, -+ {LayoutTypeID::kTensorNC32HW32, "nc32hw32"}, -+ {LayoutTypeID::kTensorNC64HW64, "nc64hw64"}, -+ {LayoutTypeID::kTensorC32RSK32, "c32rsk32"}, -+ {LayoutTypeID::kTensorC64RSK64, "c64rsk64"}, -+ -+ {LayoutTypeID::kUnknown, "*"}, -+ {LayoutTypeID::kInvalid, nullptr} -+}; -+ -+/// Converts a LayoutTypeID enumerant to a string -+char const *to_string(LayoutTypeID layout, bool pretty) { -+ for (auto const & alias : layout_aliases) { -+ if (alias.layout == layout) { -+ return alias.alias; -+ } -+ } -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Parses a LayoutTypeID enumerant from a string -+template <> -+LayoutTypeID from_string(std::string const &str) { -+ for (auto const & alias : layout_aliases) { -+ if (str.compare(alias.alias) == 0) { -+ return alias.layout; -+ } -+ } -+ return LayoutTypeID::kInvalid; -+} -+ -+/// Gets stride rank for the layout_id (static function) -+int get_layout_stride_rank(LayoutTypeID layout_id) { -+ switch (layout_id) { -+ case LayoutTypeID::kColumnMajor: -+ return cutlass::layout::ColumnMajor::kStrideRank; -+ case LayoutTypeID::kRowMajor: -+ return cutlass::layout::RowMajor::kStrideRank; -+ case LayoutTypeID::kColumnMajorInterleavedK2: -+ return cutlass::layout::ColumnMajorInterleaved<2>::kStrideRank; -+ case LayoutTypeID::kRowMajorInterleavedK2: -+ return cutlass::layout::RowMajorInterleaved<2>::kStrideRank; -+ case LayoutTypeID::kColumnMajorInterleavedK4: -+ return cutlass::layout::ColumnMajorInterleaved<4>::kStrideRank; -+ case LayoutTypeID::kRowMajorInterleavedK4: -+ return cutlass::layout::RowMajorInterleaved<4>::kStrideRank; -+ case LayoutTypeID::kColumnMajorInterleavedK16: -+ return cutlass::layout::ColumnMajorInterleaved<16>::kStrideRank; -+ case LayoutTypeID::kRowMajorInterleavedK16: -+ return cutlass::layout::RowMajorInterleaved<16>::kStrideRank; -+ case LayoutTypeID::kColumnMajorInterleavedK32: -+ return cutlass::layout::ColumnMajorInterleaved<32>::kStrideRank; -+ case LayoutTypeID::kRowMajorInterleavedK32: -+ return cutlass::layout::RowMajorInterleaved<32>::kStrideRank; -+ case LayoutTypeID::kColumnMajorInterleavedK64: -+ return cutlass::layout::ColumnMajorInterleaved<64>::kStrideRank; -+ case LayoutTypeID::kRowMajorInterleavedK64: -+ return cutlass::layout::RowMajorInterleaved<64>::kStrideRank; -+ case LayoutTypeID::kTensorNCHW: -+ return cutlass::layout::TensorNCHW::kStrideRank; -+ case LayoutTypeID::kTensorNHWC: -+ return cutlass::layout::TensorNHWC::kStrideRank; -+ case LayoutTypeID::kTensorNDHWC: -+ return cutlass::layout::TensorNDHWC::kStrideRank; -+ case LayoutTypeID::kTensorNC32HW32: -+ return cutlass::layout::TensorNCxHWx<32>::kStrideRank; -+ case LayoutTypeID::kTensorNC64HW64: -+ return cutlass::layout::TensorNCxHWx<64>::kStrideRank; -+ case LayoutTypeID::kTensorC32RSK32: -+ return cutlass::layout::TensorCxRSKx<32>::kStrideRank; -+ case LayoutTypeID::kTensorC64RSK64: -+ return cutlass::layout::TensorCxRSKx<64>::kStrideRank; -+ default: -+ throw std::runtime_error("Unsupported LayoutTypeID in LayoutType::get_stride_rank"); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ OpcodeClassID enumerant; -+} -+OpcodeClassID_enumerants[] = { -+ {"simt", "", OpcodeClassID::kSimt}, -+ {"tensorop", "", OpcodeClassID::kTensorOp}, -+ {"wmmatensorop", "", OpcodeClassID::kWmmaTensorOp}, -+ {"wmma", "", OpcodeClassID::kWmmaTensorOp}, -+}; -+ -+/// Converts a OpcodeClassID enumerant to a string -+char const *to_string(OpcodeClassID type, bool pretty) { -+ -+ for (auto const & possible : OpcodeClassID_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a OpcodeClassID enumerant from a string -+template <> -+OpcodeClassID from_string(std::string const &str) { -+ -+ for (auto const & possible : OpcodeClassID_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return OpcodeClassID::kInvalid; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ ComplexTransform enumerant; -+} -+ComplexTransform_enumerants[] = { -+ {"n", "none", ComplexTransform::kNone}, -+ {"c", "conj", ComplexTransform::kConjugate} -+}; -+ -+/// Converts a ComplexTransform enumerant to a string -+char const *to_string(ComplexTransform type, bool pretty) { -+ -+ for (auto const & possible : ComplexTransform_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a ComplexTransform enumerant from a string -+template <> -+ComplexTransform from_string(std::string const &str) { -+ -+ for (auto const & possible : ComplexTransform_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return ComplexTransform::kInvalid; -+} -+ -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ SplitKMode enumerant; -+} -+SplitKMode_enumerants[] = { -+ {"serial", "", SplitKMode::kSerial}, -+ {"parallel", "", SplitKMode::kParallel}, -+}; -+ -+/// Converts a SplitKMode enumerant to a string -+char const *to_string(SplitKMode type, bool pretty) { -+ -+ for (auto const & possible : SplitKMode_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a SplitKMode enumerant from a string -+template <> -+SplitKMode from_string(std::string const &str) { -+ -+ for (auto const & possible : SplitKMode_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return SplitKMode::kInvalid; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+static struct { -+ char const *text; -+ char const *pretty; -+ ConvModeID enumerant; -+} -+ConvModeID_enumerants[] = { -+ {"cross", "", ConvModeID::kCrossCorrelation}, -+ {"conv", "", ConvModeID::kConvolution}, -+}; -+ -+/// Converts a ConvModeID enumerant to a string -+char const *to_string(ConvModeID type, bool pretty) { -+ -+ for (auto const & possible : ConvModeID_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a ConvModeID enumerant from a string -+template <> -+ConvModeID from_string(std::string const &str) { -+ -+ for (auto const & possible : ConvModeID_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return ConvModeID::kInvalid; -+} -+ -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ IteratorAlgorithmID enumerant; -+} -+IteratorAlgorithmID_enumerants[] = { -+ {"none", "", IteratorAlgorithmID::kNone}, -+ {"analytic", "", IteratorAlgorithmID::kAnalytic}, -+ {"optimized", "", IteratorAlgorithmID::kOptimized}, -+ {"fixed_channels", "", IteratorAlgorithmID::kFixedChannels}, -+ {"few_channels", "", IteratorAlgorithmID::kFewChannels}, -+}; -+ -+/// Converts a ConvModeID enumerant to a string -+char const *to_string(IteratorAlgorithmID type, bool pretty) { -+ -+ for (auto const & possible : IteratorAlgorithmID_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a ConvModeID enumerant from a string -+template <> -+IteratorAlgorithmID from_string(std::string const &str) { -+ -+ for (auto const & possible : IteratorAlgorithmID_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return IteratorAlgorithmID::kInvalid; -+} -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ ConvKind enumerant; -+} -+ConvKind_enumerants[] = { -+ {"unknown", "", ConvKind::kUnknown}, -+ {"fprop", "", ConvKind::kFprop}, -+ {"dgrad", "", ConvKind::kDgrad}, -+ {"wgrad", "", ConvKind::kWgrad}, -+}; -+ -+/// Converts a ConvKind enumerant to a string -+char const *to_string(ConvKind type, bool pretty) { -+ -+ for (auto const & possible : ConvKind_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+ -+/// Converts a ConvKind enumerant from a string -+template <> -+ConvKind from_string(std::string const &str) { -+ -+ for (auto const & possible : ConvKind_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return ConvKind::kInvalid; -+} -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. -+bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str) { -+ int size_bytes = sizeof_bits(type) / 8; -+ if (!size_bytes) { -+ return false; -+ } -+ -+ bytes.resize(size_bytes, 0); -+ -+ std::stringstream ss; -+ ss << str; -+ -+ switch (type) { -+ case NumericTypeID::kU8: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kU16: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kU32: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kU64: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS8: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS16: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS32: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS64: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kF16: -+ { -+ float tmp; -+ ss >> tmp; -+ *reinterpret_cast(bytes.data()) = static_cast(tmp); -+ } -+ break; -+ case NumericTypeID::kBF16: -+ { -+ float tmp; -+ ss >> tmp; -+ *reinterpret_cast(bytes.data()) = static_cast(tmp); -+ } -+ break; -+ case NumericTypeID::kTF32: -+ { -+ float tmp; -+ ss >> tmp; -+ *reinterpret_cast(bytes.data()) = static_cast(tmp); -+ } -+ break; -+ case NumericTypeID::kF32: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kF64: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kCF16: -+ { -+ std::complex tmp; -+ ss >> tmp; -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(std::real(tmp)); -+ x->imag() = static_cast(std::imag(tmp)); -+ } -+ break; -+ case NumericTypeID::kCBF16: -+ { -+ std::complex tmp; -+ ss >> tmp; -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(std::real(tmp)); -+ x->imag() = static_cast(std::imag(tmp)); -+ } -+ break; -+ case NumericTypeID::kCF32: -+ { -+ ss >> *reinterpret_cast*>(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kCTF32: -+ { -+ std::complex tmp; -+ ss >> tmp; -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(std::real(tmp)); -+ x->imag() = static_cast(std::imag(tmp)); -+ } -+ break; -+ case NumericTypeID::kCF64: -+ { -+ ss >> *reinterpret_cast*>(bytes.data()); -+ } -+ break; -+ default: -+ return false; -+ } -+ -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+std::string lexical_cast(int64_t int_value) { -+ std::stringstream ss; -+ ss << int_value; -+ return ss.str(); -+} -+ -+/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. -+std::string lexical_cast(std::vector &bytes, NumericTypeID type) { -+ -+ int size_bytes = sizeof_bits(type) / 8; -+ -+ if (!size_bytes || size_bytes != bytes.size()) { -+ return ""; -+ } -+ -+ bytes.resize(size_bytes, 0); -+ -+ std::stringstream ss; -+ -+ switch (type) { -+ case NumericTypeID::kU8: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kU16: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kU32: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kU64: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS8: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS16: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS32: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS64: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kF16: -+ { -+ float tmp = *reinterpret_cast(bytes.data()); -+ ss << tmp; -+ } -+ break; -+ case NumericTypeID::kBF16: -+ { -+ float tmp = *reinterpret_cast(bytes.data()); -+ ss << tmp; -+ } -+ break; -+ case NumericTypeID::kTF32: -+ { -+ float tmp = *reinterpret_cast(bytes.data()); -+ ss << tmp; -+ } -+ break; -+ case NumericTypeID::kF32: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kF64: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kCF16: -+ { -+ cutlass::complex const *x = -+ reinterpret_cast const *>(bytes.data()); -+ -+ ss << float(x->real()); -+ -+ if (x->imag() != cutlass::half_t()) { -+ ss << "+i" << float(x->imag()); -+ } -+ } -+ break; -+ case NumericTypeID::kCBF16: -+ { -+ cutlass::complex const *x = -+ reinterpret_cast const *>(bytes.data()); -+ -+ ss << float(x->real()); -+ -+ if (x->imag() != cutlass::bfloat16_t()) { -+ ss << "+i" << float(x->imag()); -+ } -+ } -+ break; -+ case NumericTypeID::kCF32: -+ { -+ cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); -+ -+ ss << x->real(); -+ -+ if (x->imag() != float()) { -+ ss << "+i" << x->imag(); -+ } -+ } -+ break; -+ case NumericTypeID::kCTF32: -+ { -+ cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); -+ -+ ss << float(x->real()); -+ -+ if (x->imag() != tfloat32_t()) { -+ ss << "+i" << float(x->imag()); -+ } -+ } -+ break; -+ case NumericTypeID::kCF64: -+ { -+ cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); -+ -+ ss << x->real(); -+ -+ if (x->imag() != double()) { -+ ss << "+i" << x->imag(); -+ } -+ } -+ break; -+ default: -+ return ""; -+ } -+ -+ return ss.str(); -+} -+ -+/// Casts from a signed int64 to the destination type. Returns true if successful. -+bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src) { -+ int size_bytes = sizeof_bits(type) / 8; -+ if (!size_bytes) { -+ return false; -+ } -+ -+ bytes.resize(size_bytes, 0); -+ -+ switch (type) { -+ case NumericTypeID::kU8: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU64: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS8: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS64: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kF16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kBF16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kTF32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kF32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kF64: -+ { -+ *reinterpret_cast(bytes.data()) = double(src); -+ } -+ break; -+ case NumericTypeID::kCF16: -+ { -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(float(src)); -+ x->imag() = static_cast(float(0)); -+ } -+ break; -+ case NumericTypeID::kCF32: -+ { -+ *reinterpret_cast*>(bytes.data()) = cutlass::complex(float(src), float(0)); -+ } -+ break; -+ case NumericTypeID::kCF64: -+ { -+ *reinterpret_cast*>(bytes.data()) = cutlass::complex(double(src), double(0)); -+ } -+ break; -+ default: -+ return false; -+ } -+ -+ return true; -+ -+} -+ -+/// Casts from an unsigned int64 to the destination type. Returns true if successful. -+bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src) { -+ int size_bytes = sizeof_bits(type) / 8; -+ if (!size_bytes) { -+ return false; -+ } -+ -+ bytes.resize(size_bytes, 0); -+ -+ switch (type) { -+ case NumericTypeID::kU8: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU64: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS8: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS64: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kF16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kBF16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kTF32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kF32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kF64: -+ { -+ *reinterpret_cast(bytes.data()) = double(src); -+ } -+ break; -+ case NumericTypeID::kCF16: -+ { -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(float(src)); -+ x->imag() = static_cast(float(0)); -+ } -+ break; -+ case NumericTypeID::kCF32: -+ { -+ *reinterpret_cast*>(bytes.data()) = std::complex(float(src), float(0)); -+ } -+ break; -+ case NumericTypeID::kCF64: -+ { -+ *reinterpret_cast*>(bytes.data()) = std::complex(double(src), double(0)); -+ } -+ break; -+ default: -+ return false; -+ } -+ -+ return true; -+ -+} -+ -+/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. -+bool cast_from_double(std::vector &bytes, NumericTypeID type, double src) { -+ -+ int size_bytes = sizeof_bits(type) / 8; -+ if (!size_bytes) { -+ return false; -+ } -+ -+ bytes.resize(size_bytes, 0); -+ -+ switch (type) { -+ case NumericTypeID::kU8: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU64: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS8: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS64: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kF16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kBF16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kTF32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kF32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kF64: -+ { -+ *reinterpret_cast(bytes.data()) = src; -+ } -+ break; -+ case NumericTypeID::kCF16: -+ { -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(float(src)); -+ x->imag() = static_cast(float(0)); -+ } -+ break; -+ case NumericTypeID::kCBF16: -+ { -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(bfloat16_t(src)); -+ x->imag() = static_cast(bfloat16_t(0)); -+ } -+ break; -+ case NumericTypeID::kCF32: -+ { -+ *reinterpret_cast*>(bytes.data()) = cutlass::complex(float(src), float()); -+ } -+ break; -+ case NumericTypeID::kCTF32: -+ { -+ *reinterpret_cast*>(bytes.data()) = cutlass::complex(tfloat32_t(src), tfloat32_t()); -+ } -+ break; -+ case NumericTypeID::kCF64: -+ { -+ *reinterpret_cast*>(bytes.data()) = cutlass::complex(src, double()); -+ } -+ break; -+ default: -+ return false; -+ } -+ -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/conv2d_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/conv2d_operation_profiler.cu -new file mode 100644 -index 0000000..0693058 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/conv2d_operation_profiler.cu -@@ -0,0 +1,1488 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Convolution 2D profiling -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "conv2d_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+using namespace cutlass::library; -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+Conv2dOperationProfiler::Conv2dOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kConv2d, -+ { -+ {ArgumentTypeID::kEnumerated, {"conv_kind"}, "Convolutional operator (fprop, dgrad, wgrad)"}, -+ {ArgumentTypeID::kInteger, {"n", "input_n"}, "Input N dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"h", "input_h"}, "Input H dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"w", "input_w"}, "Input W dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"c", "input_c"}, "Input C dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"k", "filter_k"}, "Filter K dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"r", "filter_r"}, "Filter R dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"s", "filter_s"}, "Filter S dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"p", "output_p"}, "Output P dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"q", "output_q"}, "Output Q dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"g", "groups"}, "Number of convolution groups"}, -+ {ArgumentTypeID::kInteger, {"pad_h"}, "Padding in H direction"}, -+ {ArgumentTypeID::kInteger, {"pad_w"}, "Padding in W direction"}, -+ {ArgumentTypeID::kInteger, {"stride_h"}, "Stride in H direction"}, -+ {ArgumentTypeID::kInteger, {"stride_w"}, "Stride in W direction"}, -+ {ArgumentTypeID::kInteger, {"dilation_h"}, "Dilation in H direction"}, -+ {ArgumentTypeID::kInteger, {"dilation_w"}, "Dilation in W direction"}, -+ {ArgumentTypeID::kTensor, {"Activation"}, "Tensor storing the Activation operand"}, -+ {ArgumentTypeID::kTensor, {"Filter"}, "Tensor storing the Filter operand"}, -+ {ArgumentTypeID::kTensor, {"Output"}, "Tensor storing the Output operand"}, -+ {ArgumentTypeID::kEnumerated, {"conv_mode"}, "Convolution filter mode (conv, cross)"}, -+ {ArgumentTypeID::kEnumerated, {"iterator_algorithm", "iterator_algo"}, "Convolution iterator algorithm (analytic, optimized)"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "SplitK mode for serial or parallel reduction (serial, parallel)"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kEnumerated, {"eq_gemm_provider", "eq-gemm-provider"}, "Enable profiling equivalent gemm by the following providers (cutlass)"}, -+ }, -+ { library::Provider::kReferenceDevice, library::Provider::kReferenceHost, library::Provider::kCUDNN } -+ ) { -+ -+ description_ = " Conv2d operation. Output(Tensor4D) = alpha * Input(Tensor4D) * Filter(Tensor4D) + beta * Input(Tensor4D)"; -+ -+} -+ -+/// Destructor -+Conv2dOperationProfiler::~Conv2dOperationProfiler() { -+ -+} -+ -+ -+/// Prints usage statement for the math function -+void Conv2dOperationProfiler::print_usage(std::ostream &out) const { -+ out << "Conv2d" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void Conv2dOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular convolution (specify all the convolution parameters):\n" -+ << " $ cutlass_profiler --operation=Conv2d" -+ " --Activation=f16:nhwc --Filter=f16:nhwc --Output=f16 --accumulator-type=f32" -+ " --n=32 --h=14 --w=14 --c=8 --k=64 --r=3 --s=3" -+ " --pad_h=1 --pad_w=1" -+ " --stride_h=1 --stride_w=1" -+ " --dilation_h=1 --dilation_w=1\n\n"; -+} -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Total number of bytes loaded -+int64_t Conv2dOperationProfiler::Conv2dProblem::bytes( -+ library::ConvDescription const &operation_desc) const { -+ -+ cutlass::gemm::GemmCoord mnk = eq_gemm_size(operation_desc.conv_kind); -+ -+ // Input bytes read and Output bytes written for the gemm problem -+ int64_t bytes_ = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * mnk.m() / 8) * mnk.k() + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * mnk.n() / 8) * mnk.k() + -+ int64_t(library::sizeof_bits(operation_desc.C.element) * mnk.m() / 8) * mnk.n(); -+ -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ bytes_ += int64_t(library::sizeof_bits(operation_desc.C.element) * mnk.m() / 8) * mnk.n(); -+ } -+ -+ return bytes_; -+} -+ -+/// Total number of flops computed -+int64_t Conv2dOperationProfiler::Conv2dProblem::flops( -+ library::ConvDescription const &operation_desc) const { -+ -+ cutlass::gemm::GemmCoord mnk = eq_gemm_size(operation_desc.conv_kind); -+ -+ int64_t flops_mainloop_ = int64_t(mnk.m()) * mnk.n() * mnk.k() * 2; -+ int64_t flops_epilogue_ = int64_t(mnk.m()) * int64_t(mnk.n()) * 2; -+ -+ // Adjust mainloop flop for dgrad strided -+ if (operation_desc.conv_kind == library::ConvKind::kDgrad) { -+ flops_mainloop_ = flops_mainloop_ / (stride_h * stride_w); -+ } -+ int64_t flops_total_ = flops_mainloop_ + flops_epilogue_; -+ -+ //complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ flops_total_ *=4; -+ break; -+ -+ default: break; -+ } -+ -+ return flops_total_; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status Conv2dOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::ConvDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (!arg_as_int(problem_.n, "n", problem_space, problem)) { -+ // default value -+ problem_.n = 1; -+ } -+ -+ if (!arg_as_int(problem_.h, "h", problem_space, problem)) { -+ // default value -+ problem_.h = 16; -+ } -+ -+ if (!arg_as_int(problem_.w, "w", problem_space, problem)) { -+ // default value -+ problem_.w = 16; -+ } -+ -+ if (!arg_as_int(problem_.c, "c", problem_space, problem)) { -+ // default value -+ problem_.c = 64; -+ } -+ -+ if (!arg_as_int(problem_.k, "k", problem_space, problem)) { -+ // default value -+ problem_.k = 64; -+ } -+ -+ if (!arg_as_int(problem_.r, "r", problem_space, problem)) { -+ // default value -+ problem_.r = 3; -+ } -+ -+ if (!arg_as_int(problem_.s, "s", problem_space, problem)) { -+ // default value -+ problem_.s = 3; -+ } -+ -+ if (!arg_as_int(problem_.groups, "g", problem_space, problem)) { -+ // default value -+ problem_.groups = 1; -+ } -+ -+ if (!arg_as_int(problem_.pad_h, "pad_h", problem_space, problem)) { -+ // default value -+ problem_.pad_h = 1; -+ } -+ -+ if (!arg_as_int(problem_.pad_w, "pad_w", problem_space, problem)) { -+ // default value -+ problem_.pad_w = 1; -+ } -+ -+ if (!arg_as_int(problem_.stride_h, "stride_h", problem_space, problem)) { -+ // default value -+ problem_.stride_h = 1; -+ } -+ -+ if (!arg_as_int(problem_.stride_w, "stride_w", problem_space, problem)) { -+ // default value -+ problem_.stride_w = 1; -+ } -+ -+ if (!arg_as_int(problem_.dilation_h, "dilation_h", problem_space, problem)) { -+ // default value -+ problem_.dilation_h = 1; -+ } -+ -+ if (!arg_as_int(problem_.dilation_w, "dilation_w", problem_space, problem)) { -+ // default value -+ problem_.dilation_w = 1; -+ } -+ -+ //////////////////////// Convolution output dimensions p and q //////////////////////// -+ // Cutlass convolutions support arbitrary output sizes and not constriant by // -+ // input, filter, padding, striding, dilation sizes. // -+ // cuDNN sets the output dimensions (p, q) using following equations: // -+ // // -+ // output = div_up(input + 2 * pad - ((filter - 1) * dilation + 1) + 1, stride) // -+ // where; div_up(a, b) : (a - 1)/b + 1 // -+ // // -+ // Thus, when output p and q dimensions are unspecified by the user // -+ // cutlass profiler sets p and q which are cuDNN compliant. // -+ // // -+ //////////////////////////////////////////////////////////////////////////////////////// -+ // set convolution output p -+ if (!arg_as_int(problem_.p, "p", problem_space, problem)) { -+ // default value (set using cudnn formula for output height, when p is not provided) -+ problem_.p = ( -+ problem_.h + -+ 2 * problem_.pad_h - -+ ((problem_.r - 1) * problem_.dilation_h + 1) -+ ) / (problem_.stride_h) -+ + 1; -+ } -+ -+ // set convolution output q -+ if (!arg_as_int(problem_.q, "q", problem_space, problem)) { -+ // default value (set using cudnn formula for output width, when q is not provided) -+ problem_.q = ( -+ problem_.w + -+ 2 * problem_.pad_w - -+ ((problem_.s - 1) * problem_.dilation_w + 1) -+ ) / (problem_.stride_w) -+ + 1; -+ } -+ ///////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+ if (!arg_as_SplitKModeID(problem_.split_k_mode, "split_k_mode", problem_space, problem)) { -+ // default value -+ problem_.split_k_mode = library::SplitKMode::kSerial; -+ } -+ -+ if (!arg_as_int(problem_.split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ problem_.split_k_slices = 1; -+ } -+ -+ if (!arg_as_ConvModeID(problem_.conv_mode, "conv_mode", problem_space, problem)) { -+ // default value -+ problem_.conv_mode = library::ConvModeID::kCrossCorrelation; -+ } -+ -+ if (!arg_as_ProviderID(problem_.eq_gemm_provider, "eq_gemm_provider", problem_space, problem)) { -+ // default value -+ problem_.eq_gemm_provider = library::Provider::kNone; -+ } -+ -+ if (!conv_kind_satisfies(operation_desc.conv_kind, "conv_kind", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!iterator_algorithm_satisfies(operation_desc.iterator_algorithm, "iterator_algorithm", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.activation(), "Activation", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.filter(), "Filter", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.output(), "Output", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ problem_.alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(problem_.alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ problem_.beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(problem_.beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // initialize library::Conv2dConfiguration -+ conv_workspace_.configuration.problem_size = conv::Conv2dProblemSize( -+ int(problem_.n), -+ int(problem_.h), -+ int(problem_.w), -+ int(problem_.c), -+ int(problem_.k), -+ int(problem_.r), -+ int(problem_.s), -+ int(problem_.p), -+ int(problem_.q), -+ int(problem_.pad_h), -+ int(problem_.pad_w), -+ int(problem_.stride_h), -+ int(problem_.stride_w), -+ int(problem_.dilation_h), -+ int(problem_.dilation_w), -+ static_cast(static_cast(problem_.conv_mode)), -+ int(problem_.split_k_slices), -+ int(problem_.groups) -+ ); -+ -+ conv_workspace_.configuration.split_k_mode = static_cast(static_cast(problem_.split_k_mode)); -+ -+ conv_workspace_.set_stride_vector( -+ problem_, operation_desc.conv_kind, operation_desc.A.layout, -+ operation_desc.B.layout, operation_desc.C.layout); -+ -+ // initialize library::ConvArguments -+ conv_workspace_.arguments.A = nullptr; -+ conv_workspace_.arguments.B = nullptr; -+ conv_workspace_.arguments.C = nullptr; -+ conv_workspace_.arguments.D = nullptr; -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // initialize reduction operation for parallel splitKMode -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if(!initialize_reduction_configuration_(options, report, device_context, operation, problem_space, problem)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&conv_workspace_.configuration, &conv_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void Conv2dOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::ConvDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "Activation", problem_space, -+ std::string(library::to_string(operation_desc.activation().element)) -+ + ":" + library::to_string(operation_desc.activation().layout)); -+ -+ set_argument(result, "Filter", problem_space, -+ std::string(library::to_string(operation_desc.filter().element)) -+ + ":" + library::to_string(operation_desc.filter().layout)); -+ -+ set_argument(result, "Output", problem_space, -+ std::string(library::to_string(operation_desc.output().element)) -+ + ":" + library::to_string(operation_desc.output().layout)); -+ -+ set_argument(result, "conv_kind", problem_space, library::to_string(operation_desc.conv_kind)); -+ -+ set_argument(result, "iterator_algorithm", problem_space, std::string(library::to_string(operation_desc.iterator_algorithm))); -+ -+ set_argument(result, "n", problem_space, problem_.n); -+ set_argument(result, "h", problem_space, problem_.h); -+ set_argument(result, "w", problem_space, problem_.w); -+ set_argument(result, "c", problem_space, problem_.c); -+ -+ set_argument(result, "k", problem_space, problem_.k); -+ set_argument(result, "r", problem_space, problem_.r); -+ set_argument(result, "s", problem_space, problem_.s); -+ -+ set_argument(result, "p", problem_space, problem_.p); -+ set_argument(result, "q", problem_space, problem_.q); -+ -+ set_argument(result, "g", problem_space, problem_.groups); -+ -+ set_argument(result, "pad_h", problem_space, problem_.pad_h); -+ set_argument(result, "pad_w", problem_space, problem_.pad_w); -+ -+ set_argument(result, "stride_h", problem_space, problem_.stride_h); -+ set_argument(result, "stride_w", problem_space, problem_.stride_w); -+ -+ set_argument(result, "dilation_h", problem_space, problem_.dilation_h); -+ set_argument(result, "dilation_w", problem_space, problem_.dilation_w); -+ -+ set_argument(result, "split_k_mode", problem_space, -+ std::string(library::to_string(problem_.split_k_mode))); -+ set_argument(result, "split_k_slices", problem_space, problem_.split_k_slices); -+ -+ set_argument(result, "conv_mode", problem_space, -+ std::string(library::to_string(problem_.conv_mode))); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(problem_.alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(problem_.beta, operation_desc.element_epilogue)); -+ -+ set_argument(result, "eq_gemm_provider", problem_space, -+ std::string(library::to_string(problem_.eq_gemm_provider))); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ // Bytes of activation, filter, and output tensors -+ int64_t activation_bytes = int64_t(library::sizeof_bits(operation_desc.activation().element) / 8) * -+ conv_workspace_.configuration.problem_size.activation_size(); -+ -+ int64_t filter_bytes = int64_t(library::sizeof_bits(operation_desc.filter().element) / 8) * -+ conv_workspace_.configuration.problem_size.filter_size(); -+ -+ int64_t output_bytes = int64_t(library::sizeof_bits(operation_desc.output().element) / 8) * -+ conv_workspace_.configuration.problem_size.output_size(); -+ -+ // Bytes of activation, filter, and output tensors -+ result.bytes = problem_.bytes(operation_desc); -+ -+ // Theoritical flops required for the computation -+ result.flops = problem_.flops(operation_desc); -+ -+ // Measured runtime -+ result.runtime = 0; -+ -+} -+ -+/// Initialize reduction problem dimenstions and library::Operation -+bool Conv2dOperationProfiler::initialize_reduction_configuration_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::ConvDescription const &conv_desc = -+ static_cast(operation->description()); -+ -+ library::ConvKind const &conv_kind = conv_desc.conv_kind; -+ -+ if (!cast_from_double(problem_.alpha_one, conv_desc.element_epilogue, 1)) { -+ return false; -+ } -+ -+ if (!cast_from_double(problem_.beta_zero, conv_desc.element_epilogue, 0)) { -+ return false; -+ } -+ -+ /// This chooses the appropriate stride element of the row-major C tensor. -+ int const & tensor_c_stride_idx = (conv_kind == library::ConvKind::kWgrad ? 2 : 0); -+ -+ /// intialize library::ReductionConfiguration -+ conv_workspace_.reduction_configuration.problem_size = problem_.eq_gemm_size(conv_kind).mn(); -+ conv_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices); -+ conv_workspace_.reduction_configuration.partition_stride = problem_.eq_gemm_size(conv_kind).mn().product(); -+ conv_workspace_.reduction_configuration.ldw = -+ conv_workspace_.configuration.stride_c[tensor_c_stride_idx]; -+ conv_workspace_.reduction_configuration.lds = -+ conv_workspace_.configuration.stride_c[tensor_c_stride_idx]; -+ conv_workspace_.reduction_configuration.ldd = -+ conv_workspace_.configuration.stride_c[tensor_c_stride_idx]; -+ -+ // find reduction operation -+ library::ReductionFunctionalKey reduction_key( -+ library::Provider::kCUTLASS, -+ conv_desc.tile_description.math_instruction.element_accumulator, // element workspace -+ conv_desc.tile_description.math_instruction.element_accumulator, // element accumulator -+ conv_desc.C.element, // element output -+ conv_desc.element_epilogue // element compute -+ ); -+ -+#if 0// debug print to check which reduction instance is selected -+ std::cout << reduction_key << "\n"; -+#endif -+ auto reduction_it = Singleton::get().operation_table.reduction_operations.find(reduction_key); -+ -+ if(reduction_it == Singleton::get().operation_table.reduction_operations.end()) { -+ -+ return false; -+ } -+ -+ // initialize reduction operation required for parallel split-k conv2d operator -+ reduction_op_ = reduction_it->second; -+ -+ // reduction operation found and initialized -+ return true; -+} -+ -+ -+/// Initializes workspace -+Status Conv2dOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ // initialize conv2d underlying operation to handle parallel reduction -+ library::Operation const* underlying_operation = operation; -+ -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_conv_operation_for_parallel_reduction(operation))) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ library::ConvDescription const &operation_desc = -+ static_cast(underlying_operation->description()); -+ -+ // Compute the number of copies of the problem to avoid L2 camping. -+ if (!options.profiling.workspace_count) { -+ int64_t bytes = problem_.bytes(operation_desc); -+ if (bytes < 3 * int64_t(options.device.properties.l2CacheSize)) { -+ conv_workspace_.problem_count = -+ 1 + int((3 * int64_t(options.device.properties.l2CacheSize)) / bytes); -+ } -+ else { -+ conv_workspace_.problem_count = 1; -+ } -+ } -+ else { -+ conv_workspace_.problem_count = options.profiling.workspace_count; -+ } -+ -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ conv_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ problem_.extent_a(operation_desc.conv_kind), -+ conv_workspace_.configuration.stride_a, -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ problem_.extent_b(operation_desc.conv_kind), -+ conv_workspace_.configuration.stride_b, -+ conv_workspace_.problem_count -+ ); -+ -+ if(problem_.groups == problem_.c && problem_.groups == problem_.k){ -+ // Depthwise direct conv kernel needs reorder the filter. -+ conv_workspace_.reordered_B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ problem_.extent_b(operation_desc.conv_kind), -+ conv_workspace_.configuration.stride_b, -+ conv_workspace_.problem_count -+ ); -+ } -+ -+ conv_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ problem_.extent_c(operation_desc.conv_kind), -+ conv_workspace_.configuration.stride_c, -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ problem_.extent_c(operation_desc.conv_kind), -+ conv_workspace_.configuration.stride_c, -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ problem_.extent_c(operation_desc.conv_kind), -+ conv_workspace_.configuration.stride_c, -+ conv_workspace_.problem_count -+ ); -+ } -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = underlying_operation->get_host_workspace_size(&conv_workspace_.configuration); -+ conv_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = underlying_operation->get_device_workspace_size(&conv_workspace_.configuration); -+ conv_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = underlying_operation->initialize( -+ &conv_workspace_.configuration, -+ conv_workspace_.host_workspace.data(), -+ conv_workspace_.device_workspace.data()); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ workspace_size = reduction_op_->get_host_workspace_size(&conv_workspace_.reduction_configuration); -+ conv_workspace_.reduction_host_workspace.resize(workspace_size, 0); -+ -+ status = reduction_op_->initialize( -+ &conv_workspace_.reduction_configuration, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kConv2d; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool Conv2dOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ cudaError_t result; -+ -+ // Initialize structure containing Conv2d arguments -+ conv_workspace_.arguments.A = conv_workspace_.A->data(); -+ conv_workspace_.arguments.B = conv_workspace_.B->data(); -+ conv_workspace_.arguments.C = conv_workspace_.C->data(); -+ conv_workspace_.arguments.D = conv_workspace_.Computed->data(); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ if (conv_workspace_.reordered_B != nullptr){ -+ conv_workspace_.arguments.reordered_B = conv_workspace_.reordered_B->data(); -+ }else{ -+ conv_workspace_.arguments.reordered_B = nullptr; -+ } -+ -+ conv_workspace_.Computed->copy_from_device(conv_workspace_.C->data()); -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ // update library::ConvArguments for parallel split-k reduction -+ conv_workspace_.arguments.D = conv_workspace_.device_workspace.data(); -+ conv_workspace_.arguments.alpha = problem_.alpha_one.data(); -+ conv_workspace_.arguments.beta = problem_.beta_zero.data(); -+ -+ /// intialize library::ReductionArguments -+ conv_workspace_.reduction_arguments.workspace = conv_workspace_.device_workspace.data(); -+ conv_workspace_.reduction_arguments.source = conv_workspace_.C->data(); -+ conv_workspace_.reduction_arguments.destination = conv_workspace_.Computed->data(); -+ conv_workspace_.reduction_arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.reduction_arguments.beta = problem_.beta.data(); -+ conv_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ } -+ -+ // -+ // Run the CUTLASS operation -+ // -+ // initialize conv2d underlying operation to handle parallel reduction -+ library::Operation const* underlying_operation = operation; -+ -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_conv_operation_for_parallel_reduction(operation))) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ } -+ -+#if 0 -+ std::cout << "profiling : " << std::endl -+ << "conv2d : " << operation->description().name << std::endl -+ << "underlying conv2d : " << underlying_operation->description().name << std::endl -+ << "reduction : " << reduction_op_->description().name << std::endl; -+#endif -+ -+ // run cutlass conv2d operation -+ results_.back().status = underlying_operation->run( -+ &conv_workspace_.arguments, -+ conv_workspace_.host_workspace.data(), -+ conv_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ -+ results_.back().status = reduction_op_->run( -+ &conv_workspace_.reduction_arguments, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ } -+ -+ // Synchronize before running device reference -+ result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUDNN -+ // Run verification cudnn reference -+ if (options.verification.provider_enabled(library::Provider::kCUDNN)) { -+ -+ // Guard against unsupported cases -+ auto const & conv_desc = static_cast(operation->description()); -+ -+ Status status = cudnn_satisfies(conv_desc, conv_workspace_.configuration); -+ -+ // Initialize reference data to the source data -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ -+ if (status == Status::kSuccess) { -+ // call cudnn verification if supported -+ verify_with_cudnn_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else if (status == Status::kErrorInvalidProblem) { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kInvalidProblem; -+ } -+ -+ else { -+ // set verification map for cudnn to not supported -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUDNN -+ -+ // Run verification device reference -+ if (options.verification.provider_enabled(library::Provider::kReferenceDevice)) { -+ -+ // Restore reference data back to initial source data -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ -+ verify_with_device_reference_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ // Run verification host reference -+ if (options.verification.provider_enabled(library::Provider::kReferenceHost)) { -+ -+ // Restore reference data back to initial source data -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ -+ verify_with_host_reference_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+ -+/// Verifies CUTLASS against host reference -+bool Conv2dOperationProfiler::verify_with_host_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ Status status; -+ -+ // -+ // Find host reference operation using conv2d functional description key -+ // -+ library::OperationDescription const &desc = operation->description(); -+ -+ auto &conv_desc = static_cast(desc); -+ -+ library::ConvFunctionalKey conv2d_key( -+ library::Provider::kReferenceHost, -+ conv_desc.conv_kind, -+ conv_desc.A.element, -+ conv_desc.A.layout, -+ conv_desc.B.element, -+ conv_desc.B.layout, -+ conv_desc.C.element, -+ conv_desc.C.layout, -+ conv_desc.tile_description.math_instruction.element_accumulator, -+ conv_desc.element_epilogue); -+ -+#if 0 // debug print to check which host refererence instance is selected -+ std::cout << conv2d_key << "\n"; -+#endif -+ -+ auto operators_it = Singleton::get().operation_table.conv2d_operations.find(conv2d_key); -+ -+ if(operators_it == Singleton::get().operation_table.conv2d_operations.end()) { -+ -+ results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotRun; -+ return true; -+ } -+ -+ // conv2d host reference minimum cc is 0 (CPU) and no iterator algorithm -+ library::ConvPreferenceKey preference_key(0, library::IteratorAlgorithmID::kNone); -+ auto cc_it = operators_it->second.find(preference_key); -+ -+ if(cc_it == operators_it->second.end()) { -+ results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotRun; -+ return true; -+ } -+ -+ // host refernce has only one instances in Conv2dOperationVectorMap -+ library::Operation const *reference_op = cc_it->second[0]; -+ -+ // -+ // Copy input tensors A, B, and C from device to host buffers -+ // -+ conv_workspace_.host_tensor_a.resize(conv_workspace_.A->bytes()); -+ conv_workspace_.host_tensor_b.resize(conv_workspace_.B->bytes()); -+ conv_workspace_.host_tensor_c.resize(conv_workspace_.C->bytes()); -+ -+ conv_workspace_.A->copy_to_host(conv_workspace_.host_tensor_a.data()); -+ conv_workspace_.B->copy_to_host(conv_workspace_.host_tensor_b.data()); -+ conv_workspace_.C->copy_to_host(conv_workspace_.host_tensor_c.data()); -+ -+ // -+ // Initialize structure containing Conv2d arguments -+ // -+ conv_workspace_.arguments.A = conv_workspace_.host_tensor_a.data(); -+ conv_workspace_.arguments.B = conv_workspace_.host_tensor_b.data(); -+ conv_workspace_.arguments.C = conv_workspace_.host_tensor_c.data(); -+ conv_workspace_.arguments.D = conv_workspace_.host_tensor_c.data(); -+ -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Intialize host reference operation -+ // -+ std::vector host_workspace_reference_op; -+ -+ uint64_t workspace_size = reference_op->get_host_workspace_size(&conv_workspace_.configuration); -+ host_workspace_reference_op.resize(workspace_size, 0); -+ -+ reference_op->initialize( -+ &conv_workspace_.configuration, -+ host_workspace_reference_op.data()); -+ -+ // -+ // Run host reference operation -+ // -+ status = reference_op->run( -+ &conv_workspace_.arguments, -+ host_workspace_reference_op.data()); -+ -+ // Handle errors -+ if (status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotVerified; -+ return true; -+ } -+ -+ // -+ // Copy host reference output to device memory for equality check on device -+ // -+ conv_workspace_.Reference->copy_from_host(conv_workspace_.arguments.D); -+ -+ // -+ // Verify results -+ // -+ results_.back().verification_map[library::Provider::kReferenceHost] = compare_tensors( -+ options, -+ *conv_workspace_.Computed, -+ *conv_workspace_.Reference, -+ conv_workspace_.Computed->batch_stride() -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kReferenceHost] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ static_cast(operation->description()), -+ library::Provider::kCUTLASS, -+ library::Provider::kReferenceHost); -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+ -+/// Verifies CUTLASS against host reference -+bool Conv2dOperationProfiler::verify_with_device_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ Status status; -+ -+ // -+ // Find device reference operation using conv2d functional description key -+ // -+ library::OperationDescription const &desc = operation->description(); -+ -+ auto &conv_desc = static_cast(desc); -+ -+ library::ConvFunctionalKey conv2d_key( -+ library::Provider::kReferenceDevice, -+ conv_desc.conv_kind, -+ conv_desc.A.element, -+ conv_desc.A.layout, -+ conv_desc.B.element, -+ conv_desc.B.layout, -+ conv_desc.C.element, -+ conv_desc.C.layout, -+ conv_desc.tile_description.math_instruction.element_accumulator, -+ conv_desc.element_epilogue); -+ -+ auto operators_it = Singleton::get().operation_table.conv2d_operations.find(conv2d_key); -+ -+ if(operators_it == Singleton::get().operation_table.conv2d_operations.end()) { -+ -+ results_.back().verification_map[library::Provider::kReferenceDevice] = Disposition::kNotRun; -+ -+ return true; -+ } -+ -+ // conv2d device reference minimum cc is 50 and no iterator algorithm -+ library::ConvPreferenceKey preference_key(50, library::IteratorAlgorithmID::kNone); -+ auto cc_it = operators_it->second.find(preference_key); -+ -+ if(cc_it == operators_it->second.end()) { -+ results_.back().verification_map[library::Provider::kReferenceDevice] = Disposition::kNotRun; -+ -+ return true; -+ } -+ -+ // device refernce has only one instances in Conv2dOperationVectorMap -+ library::Operation const *reference_op = cc_it->second[0]; -+ -+ // -+ // Intialize device reference operation -+ // -+ std::vector host_workspace_reference_op; -+ -+ uint64_t workspace_size = reference_op->get_host_workspace_size(&conv_workspace_.configuration); -+ host_workspace_reference_op.resize(workspace_size, 0); -+ -+ reference_op->initialize( -+ &conv_workspace_.configuration, -+ host_workspace_reference_op.data()); -+ -+ // Initialize structure containing Conv2d arguments -+ conv_workspace_.arguments.A = conv_workspace_.A->data(); -+ conv_workspace_.arguments.B = conv_workspace_.B->data(); -+ conv_workspace_.arguments.C = conv_workspace_.C->data(); -+ conv_workspace_.arguments.D = conv_workspace_.Reference->data(); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Run device reference operation -+ // -+ status = reference_op->run( -+ &conv_workspace_.arguments, -+ host_workspace_reference_op.data()); -+ -+ -+ // Handle errors -+ if (status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kReferenceDevice] = Disposition::kNotVerified; -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ results_.back().verification_map[library::Provider::kReferenceDevice] = compare_tensors( -+ options, -+ *conv_workspace_.Computed, -+ *conv_workspace_.Reference, -+ conv_workspace_.Computed->batch_stride() -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kReferenceDevice] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ static_cast(operation->description()), -+ library::Provider::kCUTLASS, -+ library::Provider::kReferenceDevice); -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/// Measures performance results -+bool Conv2dOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing Conv2d arguments -+ conv_workspace_.arguments.A = conv_workspace_.A->data(); -+ conv_workspace_.arguments.B = conv_workspace_.B->data(); -+ conv_workspace_.arguments.C = conv_workspace_.C->data(); -+ conv_workspace_.arguments.D = conv_workspace_.Computed->data(); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ // update library::ConvArguments for parallel split-k reduction -+ conv_workspace_.arguments.D = conv_workspace_.device_workspace.data(); -+ conv_workspace_.arguments.alpha = problem_.alpha_one.data(); -+ conv_workspace_.arguments.beta = problem_.beta_zero.data(); -+ -+ /// intialize library::ReductionArguments -+ conv_workspace_.reduction_arguments.workspace = conv_workspace_.device_workspace.data(); -+ conv_workspace_.reduction_arguments.source = conv_workspace_.C->data(); -+ conv_workspace_.reduction_arguments.destination = conv_workspace_.Computed->data(); -+ conv_workspace_.reduction_arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.reduction_arguments.beta = problem_.beta.data(); -+ conv_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ } -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &conv_workspace_.arguments, -+ conv_workspace_.host_workspace.data(), -+ conv_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+ -+} -+ -+/// Method to profile a CUTLASS Operation -+Status Conv2dOperationProfiler::profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace) { -+ -+ GpuTimer timer; -+ -+ // initialize conv2d underlying operation to handle parallel reduction -+ library::Operation const* underlying_operation = operation; -+ -+ library::ConvArguments *conv_arguments = static_cast(arguments); -+ -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_conv_operation_for_parallel_reduction(operation))) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ // -+ // Optional sleep to limit power consumption and thermals -+ // -+ -+ sleep(options.profiling.sleep_duration); -+ -+ // -+ // Warmup loop -+ // -+ -+ Status status; -+ -+ for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { -+ -+ // Setup rotating workspace -+ int workspace_idx = options.profiling.warmup_iterations + iteration; -+ int problem_idx = (workspace_idx % conv_workspace_.problem_count); -+ -+ conv_arguments->A = conv_workspace_.A->batch_data(problem_idx); -+ conv_arguments->B = conv_workspace_.B->batch_data(problem_idx); -+ conv_arguments->C = conv_workspace_.C->batch_data(problem_idx); -+ conv_arguments->D = conv_workspace_.Computed->batch_data(problem_idx); -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ // update library::ConvArguments for parallel split-k reduction -+ conv_arguments->D = conv_workspace_.device_workspace.data(); -+ -+ /// intialize library::ReductionArguments -+ conv_workspace_.reduction_arguments.workspace = conv_workspace_.device_workspace.data(); -+ conv_workspace_.reduction_arguments.source = conv_workspace_.C->batch_data(problem_idx); -+ conv_workspace_.reduction_arguments.destination = conv_workspace_.Computed->batch_data(problem_idx); -+ } -+ -+ // Run underlying conv2d operation -+ status = underlying_operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ -+ status = reduction_op_->run( -+ &conv_workspace_.reduction_arguments, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ -+ // -+ // Initialize GPU timer -+ // -+ -+ timer.start(); -+ -+ // -+ // Profiling loop -+ // -+ -+ int Iterations = options.profiling.iterations; -+ -+ int iteration = 0; -+ for (; iteration < Iterations; ++iteration) { -+ -+ // Setup rotating workspace -+ int problem_idx = (iteration % conv_workspace_.problem_count); -+ -+ conv_arguments->A = conv_workspace_.A->batch_data(problem_idx); -+ conv_arguments->B = conv_workspace_.B->batch_data(problem_idx); -+ conv_arguments->C = conv_workspace_.C->batch_data(problem_idx); -+ conv_arguments->D = conv_workspace_.Computed->batch_data(problem_idx); -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ // update library::ConvArguments for parallel split-k reduction -+ conv_arguments->D = conv_workspace_.device_workspace.data(); -+ -+ /// intialize library::ReductionArguments -+ conv_workspace_.reduction_arguments.workspace = conv_workspace_.device_workspace.data(); -+ conv_workspace_.reduction_arguments.source = conv_workspace_.C->batch_data(problem_idx); -+ conv_workspace_.reduction_arguments.destination = conv_workspace_.Computed->batch_data(problem_idx); -+ } -+ -+ // Run underlying conv2d operation -+ status = underlying_operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ -+ status = reduction_op_->run( -+ &conv_workspace_.reduction_arguments, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ timer.stop_and_wait(); -+ -+ // -+ // Update performance result -+ // -+ -+ runtime = timer.duration(iteration); -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#if CUTLASS_ENABLE_CUDNN -+ -+/// Verifies CUTLASS against cudnn reference -+bool Conv2dOperationProfiler::verify_with_cudnn_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ auto &conv_desc = static_cast(operation->description()); -+ -+ // -+ // Construct cudnn operators -+ // -+ -+ CudnnCreate handle; -+ cudnnStatus_t status = handle.get_cudnn_create_status(); -+ -+ if (status != CUDNN_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUDNN] = get_cutlass_disposition(status); -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ // Initialize structure containing Conv2d arguments -+ conv_workspace_.arguments.A = conv_workspace_.A->data(); -+ conv_workspace_.arguments.B = conv_workspace_.B->data(); -+ conv_workspace_.arguments.D = conv_workspace_.Reference->data(); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // cuDNN does not support four tensor arguments, so we copy the tensor C data into -+ // tensor D. -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ conv_workspace_.arguments.C = conv_workspace_.arguments.D; -+ -+ try { -+ -+ // -+ // Construct dispatcher to cudnn operator -+ // -+ -+ detail::cudnnConvDispatcher conv_op( -+ conv_desc, -+ conv_workspace_.configuration, -+ conv_workspace_.arguments, -+ handle -+ ); -+ -+ if (conv_op.status != Status::kSuccess) { -+ if (conv_op.status == Status::kErrorNotSupported) { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kNotSupported; -+ -+ } else { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kFailed; -+ } -+ return true; -+ } -+ -+ -+ status = conv_op(handle); -+ -+ // Handle errors -+ if (status != CUDNN_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUDNN] = get_cutlass_disposition(status); -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[library::Provider::kCUDNN] = compare_tensors( -+ options, -+ *conv_workspace_.Computed, -+ *conv_workspace_.Reference, -+ conv_workspace_.Computed->batch_stride() -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUDNN] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ conv_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUDNN); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kFailed; -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+#endif // #if CUTLASS_ENABLE_CUDNN -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/conv2d_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/conv2d_operation_profiler.h -new file mode 100644 -index 0000000..f432c7e ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/conv2d_operation_profiler.h -@@ -0,0 +1,493 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines profiling functionality for convolution -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/handle.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/singleton.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+#include "reduction_operation_profiler.h" -+#if CUTLASS_ENABLE_CUDNN -+#include "cudnn_helpers.h" -+#endif //#if CUTLASS_ENABLE_CUDNN -+#include "debug.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class Conv2dOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct Conv2dProblem { -+ -+ int64_t n, h, w, c, p, q, k, r, s; -+ int64_t groups; -+ int64_t pad_h, pad_w; -+ int64_t stride_h, stride_w; -+ int64_t dilation_h, dilation_w; -+ -+ std::vector alpha; -+ std::vector beta; -+ -+ library::SplitKMode split_k_mode; -+ int64_t split_k_slices; -+ -+ library::ConvModeID conv_mode; -+ -+ library::Provider eq_gemm_provider; -+ -+ // convolution with parallel interleaved reduction -+ // convolution epilogue (alpha, beta) = (1.0, 0.0) -+ // reduction epilogue (alpha, beta) = (Conv2dProblem::alpha, Conv2dProblem::beta) -+ std::vector alpha_one; -+ std::vector beta_zero; -+ -+ // -+ // Methods -+ // -+ -+ /// Total number of bytes loaded -+ int64_t bytes(library::ConvDescription const &operation_desc) const; -+ -+ /// Total number of flops computed -+ int64_t flops(library::ConvDescription const &operation_desc) const; -+ -+ void set_default_output_size() { -+ p = ((h + pad_h - r * dilation_h) / stride_h) + 1; -+ q = ((w + pad_w - s * dilation_w) / stride_w) + 1; -+ } -+ -+ // Returns equivalent gemm problem size for convolution -+ cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * p * q), int(k), int(r * s * c / groups)); -+ case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * h * w), int(c), int(k * r * s)); -+ case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(r * s * c), int(n * p * q)); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns extent for tensor A -+ std::vector extent_a(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return {int(n), int(h), int(w), int(c)}; -+ case library::ConvKind::kDgrad: return {int(n), int(p), int(q), int(k)}; -+ case library::ConvKind::kWgrad: return {int(n), int(p), int(q), int(k)}; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns extent for tensor B -+ std::vector extent_b(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return {int(k), int(r), int(s), int(c / groups)}; -+ case library::ConvKind::kDgrad: return {int(k), int(r), int(s), int(c)}; -+ case library::ConvKind::kWgrad: return {int(n), int(h), int(w), int(c)}; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns extent for tensor C -+ std::vector extent_c(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return {int(n), int(p), int(q), int(k)}; -+ case library::ConvKind::kDgrad: return {int(n), int(h), int(w), int(c)}; -+ case library::ConvKind::kWgrad: return {int(k), int(r), int(s), int(c)}; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns layout for equivalent gemm matrix A -+ library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm -+ case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm -+ case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns layout for equivalent gemm matrix B -+ library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm -+ case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm -+ case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns layout for equivalent gemm matrix C -+ library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ // Gemm operator assumes column-major output -+ case library::ConvKind::kFprop: -+ case library::ConvKind::kDgrad: -+ case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns leading dimenstion for equivalent gemm matrix A -+ int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); -+ case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); -+ case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns leading dimenstion for equivalent gemm matrix B -+ int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); -+ case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); -+ case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns leading dimenstion for equivalent gemm matrix C -+ int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: -+ case library::ConvKind::kDgrad: -+ case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ }; -+ -+ /// Workspace used -+ struct Conv2dWorkspace { -+ -+ /// Conv device allocations -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *reordered_B; -+ DeviceAllocation *C; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ /// Library configuration and arguments for convolution operator -+ library::Conv2dConfiguration configuration; -+ library::ConvArguments arguments; -+ -+ /// Number of copies of the problem workspace which are visited sequentially during -+ /// profiling to avoid camping in the last level cache. -+ int problem_count; -+ -+ /// Buffer used for the cutlass conv2d operations' host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the cutlass operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ /// Library configuration and arguments for reduction operator -+ library::ReductionConfiguration reduction_configuration; -+ library::ReductionArguments reduction_arguments; -+ -+ /// Buffer used for the cutlass reduction operations' host workspace -+ std::vector reduction_host_workspace; -+ -+ /// Host data buffers for host reference operation -+ /// host buffer for tensor -+ std::vector host_tensor_a; -+ -+ /// host buffer for tensor b -+ std::vector host_tensor_b; -+ -+ /// host buffer for tensor c -+ std::vector host_tensor_c; -+ -+ // -+ // Methods -+ // -+ -+ Conv2dWorkspace() -+ : A(nullptr), -+ B(nullptr), -+ reordered_B(nullptr), -+ C(nullptr), -+ Computed(nullptr), -+ Reference(nullptr) {} -+ -+ // Set stride vector for tensor activations, filters, output -+ void set_stride_vector(Conv2dProblem const &problem, -+ library::ConvKind const &conv_kind, -+ library::LayoutTypeID const &layout_a, -+ library::LayoutTypeID const &layout_b, -+ library::LayoutTypeID const &layout_c) { -+ std::vector stride_activations; -+ std::vector stride_filters; -+ std::vector stride_output; -+ -+ // Strides for interleaved fprop -+ if (conv_kind == library::ConvKind::kFprop && -+ ((layout_a == library::LayoutTypeID::kTensorNC32HW32 && -+ layout_b == library::LayoutTypeID::kTensorC32RSK32 && -+ layout_c == library::LayoutTypeID::kTensorNC32HW32) || -+ (layout_a == library::LayoutTypeID::kTensorNC64HW64 && -+ layout_b == library::LayoutTypeID::kTensorC64RSK64 && -+ layout_c == library::LayoutTypeID::kTensorNC64HW64))) { -+ int interleave = -+ (layout_a == library::LayoutTypeID::kTensorNC32HW32) ? 32 : 64; -+ -+ stride_activations.push_back(int(problem.w) * interleave); -+ stride_activations.push_back(int(problem.w) * int(problem.h) * -+ interleave); -+ stride_activations.push_back(int(problem.h) * int(problem.w) * -+ int(problem.c)); -+ -+ stride_filters.push_back(int(problem.k) * interleave); -+ stride_filters.push_back(int(problem.k) * int(problem.s) * interleave); -+ stride_filters.push_back(int(problem.k) * int(problem.s) * -+ int(problem.r) * interleave); -+ -+ stride_output.push_back(int(problem.q) * interleave); -+ stride_output.push_back(int(problem.q) * int(problem.p) * interleave); -+ stride_output.push_back(int(problem.q) * int(problem.p) * -+ int(problem.k)); -+ } else { -+ // Strides for the rest cases -+ stride_activations.push_back(int(problem.c)); -+ stride_activations.push_back(int(problem.w) * int(problem.c)); -+ stride_activations.push_back(int(problem.h) * int(problem.w) * -+ int(problem.c)); -+ -+ stride_filters.push_back(int(problem.c / problem.groups)); -+ stride_filters.push_back(int(problem.s) * int(problem.c / problem.groups)); -+ stride_filters.push_back(int(problem.r) * int(problem.s) * -+ int(problem.c / problem.groups)); -+ -+ stride_output.push_back(int(problem.k)); -+ stride_output.push_back(int(problem.q) * int(problem.k)); -+ stride_output.push_back(int(problem.q) * int(problem.p) * -+ int(problem.k)); -+ } -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: -+ configuration.stride_a = stride_activations; -+ configuration.stride_b = stride_filters; -+ configuration.stride_c = stride_output; -+ -+ break; -+ case library::ConvKind::kDgrad: -+ configuration.stride_a = stride_output; -+ configuration.stride_b = stride_filters; -+ configuration.stride_c = stride_activations; -+ -+ break; -+ case library::ConvKind::kWgrad: -+ configuration.stride_a = stride_output; -+ configuration.stride_b = stride_activations; -+ configuration.stride_c = stride_filters; -+ -+ break; -+ default: -+ throw std::runtime_error( -+ "Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// CONV problem obtained from problem space -+ Conv2dProblem problem_; -+ -+ /// Device memory allocations -+ Conv2dWorkspace conv_workspace_; -+ -+ /// CUTLASS parallel reduction operation to follow this* conv2d operation -+ library::Operation const *reduction_op_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ Conv2dOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~Conv2dOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ /// Method to profile an initialized CUTLASS operation -+ virtual Status profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace); -+ -+ -+ /// Initialize reduction problem dimenstions and library::Operation -+ bool initialize_reduction_configuration_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::ConvDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against host reference -+ bool verify_with_host_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against device reference -+ bool verify_with_device_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+#if CUTLASS_ENABLE_CUDNN -+ -+ /// Verifies CUTLASS against cudnn reference -+ bool verify_with_cudnn_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+#endif //#if CUTLASS_ENABLE_CUDNN -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/conv3d_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/conv3d_operation_profiler.cu -new file mode 100644 -index 0000000..34fee85 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/conv3d_operation_profiler.cu -@@ -0,0 +1,1351 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Convolution 3D profiling -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "conv3d_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+using namespace cutlass::library; -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+Conv3dOperationProfiler::Conv3dOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kConv3d, -+ { -+ {ArgumentTypeID::kEnumerated, {"conv_kind"}, "Convolutional operator (fprop, dgrad, wgrad)"}, -+ {ArgumentTypeID::kInteger, {"n", "input_n"}, "Input N dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"d", "input_d"}, "Input D dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"h", "input_h"}, "Input H dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"w", "input_w"}, "Input W dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"c", "input_c"}, "Input C dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"k", "filter_k"}, "Filter K dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"t", "filter_t"}, "Filter T dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"r", "filter_r"}, "Filter R dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"s", "filter_s"}, "Filter S dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"z", "output_z"}, "Output Z dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"p", "output_p"}, "Output P dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"q", "output_q"}, "Output Q dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"pad_d"}, "Padding in D direction"}, -+ {ArgumentTypeID::kInteger, {"pad_h"}, "Padding in H direction"}, -+ {ArgumentTypeID::kInteger, {"pad_w"}, "Padding in W direction"}, -+ {ArgumentTypeID::kInteger, {"stride_d"}, "Stride in D direction"}, -+ {ArgumentTypeID::kInteger, {"stride_h"}, "Stride in H direction"}, -+ {ArgumentTypeID::kInteger, {"stride_w"}, "Stride in W direction"}, -+ {ArgumentTypeID::kInteger, {"dilation_d"}, "Dilation in D direction"}, -+ {ArgumentTypeID::kInteger, {"dilation_h"}, "Dilation in H direction"}, -+ {ArgumentTypeID::kInteger, {"dilation_w"}, "Dilation in W direction"}, -+ {ArgumentTypeID::kTensor, {"Activation"}, "Tensor storing the Activation operand"}, -+ {ArgumentTypeID::kTensor, {"Filter"}, "Tensor storing the Filter operand"}, -+ {ArgumentTypeID::kTensor, {"Output"}, "Tensor storing the Output operand"}, -+ {ArgumentTypeID::kEnumerated, {"conv_mode"}, "Convolution filter mode (conv, cross)"}, -+ {ArgumentTypeID::kEnumerated, {"iterator_algorithm", "iterator_algo"}, "Convolution iterator algorithm (analytic, optimized)"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "SplitK mode for serial or parallel reduction (serial, parallel)"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kEnumerated, {"eq_gemm_provider", "eq-gemm-provider"}, "Enable profiling equivalent gemm by the following providers (cutlass)"}, -+ }, -+ { library::Provider::kReferenceDevice, library::Provider::kReferenceHost, library::Provider::kCUDNN } -+ ) { -+ -+ description_ = " Conv3d operation. Output(Tensor5D) = alpha * Input(Tensor5D) * Filter(Tensor5D) + beta * Input(Tensor5D)"; -+ -+} -+ -+/// Destructor -+Conv3dOperationProfiler::~Conv3dOperationProfiler() { -+ -+} -+ -+ -+/// Prints usage statement for the math function -+void Conv3dOperationProfiler::print_usage(std::ostream &out) const { -+ out << "Conv3d" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void Conv3dOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular convolution (specify all the convolution parameters):\n" -+ << " $ cutlass_profiler --operation=Conv3d" -+ " --Activation=f16:ndhwc --Filter=f16:ndhwc --Output=f16 --accumulator-type=f32" -+ " --n=32 --d=16 --h=14 --w=14 --c=8 --k=64 --t=3 --r=3 --s=3" -+ " --pad_d=1 --pad_h=1 --pad_w=1" -+ " --stride_d=1 --stride::h=1 --stride::w=1" -+ " --dilation_d=1 --dilation::h=1 --dilation::w=1\n\n"; -+} -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Total number of bytes loaded -+int64_t Conv3dOperationProfiler::Conv3dProblem::bytes(library::ConvDescription const &operation_desc) const { -+ cutlass::gemm::GemmCoord mnk = eq_gemm_size(operation_desc.conv_kind); -+ -+ // Input bytes read and Output bytes written for the gemm problem -+ int64_t bytes_ = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * mnk.m() / 8) * mnk.k() + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * mnk.n() / 8) * mnk.k() + -+ int64_t(library::sizeof_bits(operation_desc.C.element) * mnk.m() / 8) * mnk.n(); -+ -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ bytes_ += int64_t(library::sizeof_bits(operation_desc.C.element) * mnk.m() / 8) * mnk.n(); -+ } -+ -+ return bytes_; -+} -+ -+/// Total number of flops computed -+int64_t Conv3dOperationProfiler::Conv3dProblem::flops( -+ library::ConvDescription const &operation_desc) const { -+ -+ cutlass::gemm::GemmCoord mnk = eq_gemm_size(operation_desc.conv_kind); -+ -+ int64_t flops_mainloop_ = int64_t(mnk.m()) * mnk.n() * mnk.k() * 2; -+ int64_t flops_epilogue_ = int64_t(mnk.m()) * int64_t(mnk.n()) * 2; -+ -+ // Adjust mainloop flop for dgrad strided -+ if (operation_desc.conv_kind == library::ConvKind::kDgrad) { -+ flops_mainloop_ = flops_mainloop_ / ( stride_d * stride_h * stride_w); -+ } -+ -+ return (flops_mainloop_ + flops_epilogue_); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status Conv3dOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::ConvDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (!arg_as_int(problem_.n, "n", problem_space, problem)) { -+ // default value -+ problem_.n = 1; -+ } -+ -+ if (!arg_as_int(problem_.d, "d", problem_space, problem)) { -+ // default value -+ problem_.d = 8; -+ } -+ -+ if (!arg_as_int(problem_.h, "h", problem_space, problem)) { -+ // default value -+ problem_.h = 14; -+ } -+ -+ if (!arg_as_int(problem_.w, "w", problem_space, problem)) { -+ // default value -+ problem_.w = 14; -+ } -+ -+ if (!arg_as_int(problem_.c, "c", problem_space, problem)) { -+ // default value -+ problem_.c = 32; -+ } -+ -+ if (!arg_as_int(problem_.k, "k", problem_space, problem)) { -+ // default value -+ problem_.k = 32; -+ } -+ -+ if (!arg_as_int(problem_.t, "t", problem_space, problem)) { -+ // default value -+ problem_.t = 3; -+ } -+ -+ if (!arg_as_int(problem_.r, "r", problem_space, problem)) { -+ // default value -+ problem_.r = 3; -+ } -+ -+ if (!arg_as_int(problem_.s, "s", problem_space, problem)) { -+ // default value -+ problem_.s = 3; -+ } -+ -+ if (!arg_as_int(problem_.pad_d, "pad_d", problem_space, problem)) { -+ // default value -+ problem_.pad_d = 1; -+ } -+ -+ if (!arg_as_int(problem_.pad_w, "pad_w", problem_space, problem)) { -+ // default value -+ problem_.pad_w = 1; -+ } -+ if (!arg_as_int(problem_.pad_h, "pad_h", problem_space, problem)) { -+ // default value -+ problem_.pad_h = 1; -+ } -+ -+ if (!arg_as_int(problem_.stride_d, "stride_d", problem_space, problem)) { -+ // default value -+ problem_.stride_d = 1; -+ } -+ -+ if (!arg_as_int(problem_.stride_h, "stride_h", problem_space, problem)) { -+ // default value -+ problem_.stride_h = 1; -+ } -+ -+ if (!arg_as_int(problem_.stride_w, "stride_w", problem_space, problem)) { -+ // default value -+ problem_.stride_w = 1; -+ } -+ -+ if (!arg_as_int(problem_.dilation_d, "dilation_d", problem_space, problem)) { -+ // default value -+ problem_.dilation_d = 1; -+ } -+ -+ if (!arg_as_int(problem_.dilation_h, "dilation_h", problem_space, problem)) { -+ // default value -+ problem_.dilation_h = 1; -+ } -+ -+ if (!arg_as_int(problem_.dilation_w, "dilation_w", problem_space, problem)) { -+ // default value -+ problem_.dilation_w = 1; -+ } -+ -+ //////////////////////// Convolution output dimensions p and q //////////////////////// -+ // Cutlass convolutions support arbitrary output sizes and not constriant by // -+ // input, filter, padding, striding, dilation sizes. // -+ // cuDNN sets the output dimensions (p, q) using following equations: // -+ // // -+ // output = div_up(input + 2 * pad - ((filter - 1) * dilation + 1) + 1, stride) // -+ // where; div_up(a, b) : (a - 1)/b + 1 // -+ // // -+ // Thus, when output p and q dimensions are unspecified by the user // -+ // cutlass profiler sets p and q which are cuDNN compliant. // -+ // // -+ //////////////////////////////////////////////////////////////////////////////////////// -+ // set convolution output z -+ if (!arg_as_int(problem_.z, "z", problem_space, problem)) { -+ // default value (set using cudnn formula for output height, when p is not provided) -+ problem_.z = ( -+ problem_.d + -+ 2 * problem_.pad_d - -+ ((problem_.t - 1) * problem_.dilation_d + 1) -+ ) / (problem_.stride_d) -+ + 1; -+ } -+ -+ // set convolution output p -+ if (!arg_as_int(problem_.p, "p", problem_space, problem)) { -+ // default value (set using cudnn formula for output height, when p is not provided) -+ problem_.p = ( -+ problem_.h + -+ 2 * problem_.pad_h - -+ ((problem_.r - 1) * problem_.dilation_h + 1) -+ ) / (problem_.stride_h) -+ + 1; -+ } -+ -+ // set convolution output q -+ if (!arg_as_int(problem_.q, "q", problem_space, problem)) { -+ // default value (set using cudnn formula for output width, when q is not provided) -+ problem_.q = ( -+ problem_.w + -+ 2 * problem_.pad_w - -+ ((problem_.s - 1) * problem_.dilation_w + 1) -+ ) / (problem_.stride_w) -+ + 1; -+ } -+ ///////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+ if (!arg_as_SplitKModeID(problem_.split_k_mode, "split_k_mode", problem_space, problem)) { -+ // default value -+ problem_.split_k_mode = library::SplitKMode::kSerial; -+ } -+ -+ if (!arg_as_int(problem_.split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ problem_.split_k_slices = 1; -+ } -+ -+ if (!arg_as_ConvModeID(problem_.conv_mode, "conv_mode", problem_space, problem)) { -+ // default value -+ problem_.conv_mode = library::ConvModeID::kCrossCorrelation; -+ } -+ -+ if (!arg_as_ProviderID(problem_.eq_gemm_provider, "eq_gemm_provider", problem_space, problem)) { -+ // default value -+ problem_.eq_gemm_provider = library::Provider::kNone; -+ } -+ -+ if (!conv_kind_satisfies(operation_desc.conv_kind, "conv_kind", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!iterator_algorithm_satisfies(operation_desc.iterator_algorithm, "iterator_algorithm", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.activation(), "Activation", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.filter(), "Filter", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.output(), "Output", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ problem_.alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(problem_.alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ problem_.beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(problem_.beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // initialize library::ConvConfiguration -+ conv_workspace_.configuration.problem_size = conv::Conv3dProblemSize( -+ int(problem_.n), -+ int(problem_.d), -+ int(problem_.h), -+ int(problem_.w), -+ int(problem_.c), -+ int(problem_.k), -+ int(problem_.t), -+ int(problem_.r), -+ int(problem_.s), -+ int(problem_.z), -+ int(problem_.p), -+ int(problem_.q), -+ int(problem_.pad_d), -+ int(problem_.pad_h), -+ int(problem_.pad_w), -+ int(problem_.stride_d), -+ int(problem_.stride_h), -+ int(problem_.stride_w), -+ int(problem_.dilation_d), -+ int(problem_.dilation_h), -+ int(problem_.dilation_w), -+ static_cast(static_cast(problem_.conv_mode)), -+ int(problem_.split_k_slices), -+ 1 // groups -+ ); -+ -+ conv_workspace_.configuration.split_k_mode = static_cast(static_cast(problem_.split_k_mode)); -+ -+ conv_workspace_.configuration.layout_activations.stride() = make_Coord( -+ int(problem_.c), -+ int(problem_.w) * int(problem_.c), -+ int(problem_.h) * int(problem_.w) * int(problem_.c), -+ int(problem_.d) * int(problem_.h) * int(problem_.w) * int(problem_.c) -+ ); -+ -+ conv_workspace_.configuration.layout_filters.stride() = make_Coord( -+ int(problem_.c), -+ int(problem_.s) * int(problem_.c), -+ int(problem_.r) * int(problem_.s) * int(problem_.c), -+ int(problem_.t) * int(problem_.r) * int(problem_.s) * int(problem_.c) -+ ); -+ -+ conv_workspace_.configuration.layout_output.stride() = make_Coord( -+ int(problem_.k), -+ int(problem_.q) * int(problem_.k), -+ int(problem_.q) * int(problem_.p) * int(problem_.k), -+ int(problem_.z) * int(problem_.q) * int(problem_.p) * int(problem_.k) -+ ); -+ -+ -+ // initialize library::ConvArguments -+ conv_workspace_.arguments.A = nullptr; -+ conv_workspace_.arguments.B = nullptr; -+ conv_workspace_.arguments.C = nullptr; -+ conv_workspace_.arguments.D = nullptr; -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // initialize reduction operation for parallel splitKMode not supported for conv3d -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if(!initialize_reduction_configuration_(options, report, device_context, operation, problem_space, problem)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&conv_workspace_.configuration, &conv_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void Conv3dOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::ConvDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "Activation", problem_space, -+ std::string(library::to_string(operation_desc.activation().element)) -+ + ":" + library::to_string(operation_desc.activation().layout)); -+ -+ set_argument(result, "Filter", problem_space, -+ std::string(library::to_string(operation_desc.filter().element)) -+ + ":" + library::to_string(operation_desc.filter().layout)); -+ -+ set_argument(result, "Output", problem_space, -+ std::string(library::to_string(operation_desc.output().element)) -+ + ":" + library::to_string(operation_desc.output().layout)); -+ -+ set_argument(result, "conv_kind", problem_space, library::to_string(operation_desc.conv_kind)); -+ -+ set_argument(result, "iterator_algorithm", problem_space, std::string(library::to_string(operation_desc.iterator_algorithm))); -+ -+ set_argument(result, "n", problem_space, problem_.n); -+ set_argument(result, "d", problem_space, problem_.d); -+ set_argument(result, "h", problem_space, problem_.h); -+ set_argument(result, "w", problem_space, problem_.w); -+ set_argument(result, "c", problem_space, problem_.c); -+ -+ set_argument(result, "k", problem_space, problem_.k); -+ set_argument(result, "t", problem_space, problem_.t); -+ set_argument(result, "r", problem_space, problem_.r); -+ set_argument(result, "s", problem_space, problem_.s); -+ -+ set_argument(result, "z", problem_space, problem_.z); -+ set_argument(result, "p", problem_space, problem_.p); -+ set_argument(result, "q", problem_space, problem_.q); -+ -+ set_argument(result, "pad_d", problem_space, problem_.pad_d); -+ set_argument(result, "pad_h", problem_space, problem_.pad_h); -+ set_argument(result, "pad_w", problem_space, problem_.pad_w); -+ -+ set_argument(result, "stride_d", problem_space, problem_.stride_d); -+ set_argument(result, "stride_h", problem_space, problem_.stride_h); -+ set_argument(result, "stride_w", problem_space, problem_.stride_w); -+ -+ set_argument(result, "dilation_d", problem_space, problem_.dilation_d); -+ set_argument(result, "dilation_h", problem_space, problem_.dilation_h); -+ set_argument(result, "dilation_w", problem_space, problem_.dilation_w); -+ -+ set_argument(result, "split_k_mode", problem_space, -+ std::string(library::to_string(problem_.split_k_mode))); -+ set_argument(result, "split_k_slices", problem_space, problem_.split_k_slices); -+ -+ set_argument(result, "conv_mode", problem_space, -+ std::string(library::to_string(problem_.conv_mode))); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(problem_.alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(problem_.beta, operation_desc.element_epilogue)); -+ -+ set_argument(result, "eq_gemm_provider", problem_space, -+ std::string(library::to_string(problem_.eq_gemm_provider))); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ // Bytes of activation, filter, and output tensors -+ result.bytes = problem_.bytes(operation_desc); -+ -+ // Theoritical flops required for the computation -+ result.flops = problem_.flops(operation_desc); -+ -+ // Measured runtime -+ result.runtime = 0; -+ -+} -+ -+/// Initialize reduction problem dimenstions and library::Operation -+bool Conv3dOperationProfiler::initialize_reduction_configuration_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::ConvDescription const &conv_desc = -+ static_cast(operation->description()); -+ -+ library::ConvKind const &conv_kind = conv_desc.conv_kind; -+ -+ if (!cast_from_double(problem_.alpha_one, conv_desc.element_epilogue, 1)) { -+ return false; -+ } -+ -+ if (!cast_from_double(problem_.beta_zero, conv_desc.element_epilogue, 0)) { -+ return false; -+ } -+ -+ /// This chooses the appropriate stride element of the row-major C tensor. -+ int const & tensor_c_stride_idx = (conv_kind == library::ConvKind::kWgrad ? 3 : 0); -+ -+ /// intialize library::ReductionConfiguration -+ conv_workspace_.reduction_configuration.problem_size = problem_.eq_gemm_size(conv_kind).mn(); -+ conv_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices); -+ conv_workspace_.reduction_configuration.partition_stride = problem_.eq_gemm_size(conv_kind).mn().product(); -+ conv_workspace_.reduction_configuration.ldw = conv_workspace_.configuration.layout_c(conv_kind).stride()[tensor_c_stride_idx]; -+ conv_workspace_.reduction_configuration.lds = conv_workspace_.configuration.layout_c(conv_kind).stride()[tensor_c_stride_idx]; -+ conv_workspace_.reduction_configuration.ldd = conv_workspace_.configuration.layout_c(conv_kind).stride()[tensor_c_stride_idx]; -+ -+ // find reduction operation -+ library::ReductionFunctionalKey reduction_key( -+ library::Provider::kCUTLASS, -+ conv_desc.tile_description.math_instruction.element_accumulator, // element workspace -+ conv_desc.tile_description.math_instruction.element_accumulator, // element accumulator -+ conv_desc.C.element, // element output -+ conv_desc.element_epilogue // element compute -+ ); -+ -+#if 0// debug print to check which reduction instance is selected -+ std::cout << reduction_key << "\n"; -+#endif -+ auto reduction_it = Singleton::get().operation_table.reduction_operations.find(reduction_key); -+ -+ if(reduction_it == Singleton::get().operation_table.reduction_operations.end()) { -+ -+ return false; -+ } -+ -+ // initialize reduction operation required for parallel split-k conv2d operator -+ reduction_op_ = reduction_it->second; -+ -+ // reduction operation found and initialized -+ return true; -+} -+ -+ -+/// Initializes workspace -+Status Conv3dOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ // initialize conv2d underlying operation to handle parallel reduction -+ library::Operation const* underlying_operation = operation; -+ -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_conv_operation_for_parallel_reduction(operation))) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ library::ConvDescription const &operation_desc = -+ static_cast(underlying_operation->description()); -+ -+ // Compute the number of copies of the problem to avoid L2 camping. -+ if (!options.profiling.workspace_count) { -+ int64_t bytes = problem_.bytes(operation_desc); -+ if (bytes < 3 * int64_t(options.device.properties.l2CacheSize)) { -+ conv_workspace_.problem_count = -+ 1 + int((3 * int64_t(options.device.properties.l2CacheSize)) / bytes); -+ } -+ else { -+ conv_workspace_.problem_count = 1; -+ } -+ } -+ else { -+ conv_workspace_.problem_count = options.profiling.workspace_count; -+ } -+ -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ conv_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ problem_.extent_a(operation_desc.conv_kind), -+ conv_workspace_.stride_a(operation_desc.conv_kind), -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ problem_.extent_b(operation_desc.conv_kind), -+ conv_workspace_.stride_b(operation_desc.conv_kind), -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ problem_.extent_c(operation_desc.conv_kind), -+ conv_workspace_.stride_c(operation_desc.conv_kind), -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ problem_.extent_c(operation_desc.conv_kind), -+ conv_workspace_.stride_c(operation_desc.conv_kind), -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ problem_.extent_c(operation_desc.conv_kind), -+ conv_workspace_.stride_c(operation_desc.conv_kind), -+ conv_workspace_.problem_count -+ ); -+ -+ } -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = underlying_operation->get_host_workspace_size(&conv_workspace_.configuration); -+ conv_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = underlying_operation->get_device_workspace_size(&conv_workspace_.configuration); -+ conv_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = underlying_operation->initialize( -+ &conv_workspace_.configuration, -+ conv_workspace_.host_workspace.data(), -+ conv_workspace_.device_workspace.data()); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ workspace_size = reduction_op_->get_host_workspace_size(&conv_workspace_.reduction_configuration); -+ conv_workspace_.reduction_host_workspace.resize(workspace_size, 0); -+ -+ status = reduction_op_->initialize( -+ &conv_workspace_.reduction_configuration, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kConv3d; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool Conv3dOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ cudaError_t result; -+ -+ // Initialize structure containing Conv arguments -+ set_cutlass_operator_arguments_(); -+ -+ conv_workspace_.Computed->copy_from_device(conv_workspace_.C->data()); -+ -+ // -+ // Run the CUTLASS operation -+ // -+ // initialize conv2d underlying operation to handle parallel reduction -+ library::Operation const* underlying_operation = operation; -+ -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_conv_operation_for_parallel_reduction(operation))) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ } -+ -+#if 0 -+ std::cout << "profiling : " << std::endl -+ << "conv2d : " << operation->description().name << std::endl -+ << "underlying conv2d : " << underlying_operation->description().name << std::endl -+ << "reduction : " << reduction_op_->description().name << std::endl; -+#endif -+ -+ // run cutlass conv2d operation -+ results_.back().status = underlying_operation->run( -+ &conv_workspace_.arguments, -+ conv_workspace_.host_workspace.data(), -+ conv_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ -+ results_.back().status = reduction_op_->run( -+ &conv_workspace_.reduction_arguments, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ } -+ -+ // Synchronize before running device reference -+ result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUDNN -+ // Run verification cudnn reference -+ if (options.verification.provider_enabled(library::Provider::kCUDNN)) { -+ -+ // Guard against unsupported cases -+ auto const & conv_desc = static_cast(operation->description()); -+ -+ Status status = cudnn_satisfies(conv_desc, conv_workspace_.configuration); -+ -+ // Initialize reference data to the source data -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ -+ if (status == Status::kSuccess) { -+ // call cudnn verification if supported -+ verify_with_cudnn_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else if (status == Status::kErrorInvalidProblem) { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kInvalidProblem; -+ } -+ -+ else { -+ // set verification map for cudnn to not supported -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUDNN -+ -+ // Run verification host reference -+ if (options.verification.provider_enabled(library::Provider::kReferenceHost)) { -+ -+ // Restore reference data back to initial source data -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ -+ verify_with_host_reference_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+ -+/// Verifies CUTLASS against host reference -+bool Conv3dOperationProfiler::verify_with_host_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ Status status; -+ -+ // -+ // Find host reference operation using conv functional description key -+ // -+ library::OperationDescription const &desc = operation->description(); -+ -+ auto &conv_desc = static_cast(desc); -+ -+ library::ConvFunctionalKey conv_key( -+ library::Provider::kReferenceHost, -+ conv_desc.conv_kind, -+ conv_desc.A.element, -+ conv_desc.A.layout, -+ conv_desc.B.element, -+ conv_desc.B.layout, -+ conv_desc.C.element, -+ conv_desc.C.layout, -+ conv_desc.tile_description.math_instruction.element_accumulator, -+ conv_desc.element_epilogue); -+ -+#if 0 // debug print to check which host refererence instance is selected -+ std::cout << conv_key << "\n"; -+#endif -+ -+ auto operators_it = Singleton::get().operation_table.conv3d_operations.find(conv_key); -+ -+ if(operators_it == Singleton::get().operation_table.conv3d_operations.end()) { -+ -+ results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotRun; -+ return true; -+ } -+ -+ // conv3d host reference minimum cc is 0 (CPU) and no iterator algorithm -+ library::ConvPreferenceKey preference_key(0, library::IteratorAlgorithmID::kNone); -+ auto cc_it = operators_it->second.find(preference_key); -+ -+ if(cc_it == operators_it->second.end()) { -+ results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotRun; -+ return true; -+ } -+ -+ // host refernce has only one instances in ConvOperationVectorMap -+ library::Operation const *reference_op = cc_it->second[0]; -+ -+ // -+ // Copy input tensors A, B, and C from device to host buffers -+ // -+ conv_workspace_.host_tensor_a.resize(conv_workspace_.A->bytes()); -+ conv_workspace_.host_tensor_b.resize(conv_workspace_.B->bytes()); -+ conv_workspace_.host_tensor_c.resize(conv_workspace_.C->bytes()); -+ conv_workspace_.A->copy_to_host(conv_workspace_.host_tensor_a.data()); -+ conv_workspace_.B->copy_to_host(conv_workspace_.host_tensor_b.data()); -+ conv_workspace_.C->copy_to_host(conv_workspace_.host_tensor_c.data()); -+ -+ // -+ // Initialize structure containing Conv3d arguments -+ // -+ conv_workspace_.arguments.A = conv_workspace_.host_tensor_a.data(); -+ conv_workspace_.arguments.B = conv_workspace_.host_tensor_b.data(); -+ conv_workspace_.arguments.C = conv_workspace_.host_tensor_c.data(); -+ conv_workspace_.arguments.D = conv_workspace_.host_tensor_c.data(); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Intialize host reference operation -+ // -+ std::vector host_workspace_reference_op; -+ -+ uint64_t workspace_size = reference_op->get_host_workspace_size(&conv_workspace_.configuration); -+ host_workspace_reference_op.resize(workspace_size, 0); -+ -+ reference_op->initialize( -+ &conv_workspace_.configuration, -+ host_workspace_reference_op.data()); -+ -+ // -+ // Run host reference operation -+ // -+ status = reference_op->run( -+ &conv_workspace_.arguments, -+ host_workspace_reference_op.data()); -+ -+ // Handle errors -+ if (status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotVerified; -+ return true; -+ } -+ -+ // -+ // Copy host reference output to device memory for equality check on device -+ // -+ conv_workspace_.Reference->copy_from_host(conv_workspace_.arguments.D); -+ -+ // -+ // Verify results -+ // -+ results_.back().verification_map[library::Provider::kReferenceHost] = compare_tensors( -+ options, -+ *conv_workspace_.Computed, -+ *conv_workspace_.Reference, -+ conv_workspace_.Computed->batch_stride() -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kReferenceHost] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ static_cast(operation->description()), -+ library::Provider::kCUTLASS, -+ library::Provider::kReferenceHost); -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+ -+/// Verifies CUTLASS against host reference -+bool Conv3dOperationProfiler::verify_with_device_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ // TODO: verify cutlass conv3d against device reference -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/// Measures performance results -+bool Conv3dOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ set_cutlass_operator_arguments_(); -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &conv_workspace_.arguments, -+ conv_workspace_.host_workspace.data(), -+ conv_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+ -+} -+ -+/// Updates the arguments structure for the CUTLASS operator based on -+/// the problem index. -+void Conv3dOperationProfiler::set_cutlass_operator_arguments_(int problem_idx) { -+ // Initialize structure containing Conv3d arguments -+ conv_workspace_.arguments.A = conv_workspace_.A->batch_data(problem_idx); -+ conv_workspace_.arguments.B = conv_workspace_.B->batch_data(problem_idx); -+ conv_workspace_.arguments.C = conv_workspace_.C->batch_data(problem_idx); -+ conv_workspace_.arguments.D = conv_workspace_.Computed->batch_data(problem_idx); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ // update library::ConvArguments for parallel split-k reduction -+ conv_workspace_.arguments.D = conv_workspace_.device_workspace.data(); -+ conv_workspace_.arguments.alpha = problem_.alpha_one.data(); -+ conv_workspace_.arguments.beta = problem_.beta_zero.data(); -+ -+ /// intialize library::ReductionArguments -+ conv_workspace_.reduction_arguments.workspace = conv_workspace_.device_workspace.data(); -+ conv_workspace_.reduction_arguments.source = conv_workspace_.C->batch_data(problem_idx); -+ conv_workspace_.reduction_arguments.destination = conv_workspace_.Computed->batch_data(problem_idx); -+ conv_workspace_.reduction_arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.reduction_arguments.beta = problem_.beta.data(); -+ conv_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ } -+} -+ -+/// Method to profile a CUTLASS Operation -+Status Conv3dOperationProfiler::profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace) { -+ -+ GpuTimer timer; -+ -+ // initialize conv2d underlying operation to handle parallel reduction -+ library::Operation const* underlying_operation = operation; -+ -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_conv_operation_for_parallel_reduction(operation))) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ // -+ // Optional sleep to limit power consumption and thermals -+ // -+ -+ sleep(options.profiling.sleep_duration); -+ -+ // -+ // Warmup loop -+ // -+ -+ Status status; -+ -+ for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { -+ -+ // Setup rotating workspace -+ int workspace_idx = options.profiling.warmup_iterations + iteration; -+ int problem_idx = (workspace_idx % conv_workspace_.problem_count); -+ -+ set_cutlass_operator_arguments_(problem_idx); -+ -+ // Run underlying conv2d operation -+ status = underlying_operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ -+ status = reduction_op_->run( -+ &conv_workspace_.reduction_arguments, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ -+ // -+ // Initialize GPU timer -+ // -+ -+ timer.start(); -+ -+ // -+ // Profiling loop -+ // -+ -+ int Iterations = options.profiling.iterations; -+ -+ int iteration = 0; -+ for (; iteration < Iterations; ++iteration) { -+ -+ // Setup rotating workspace -+ int problem_idx = (iteration % conv_workspace_.problem_count); -+ -+ set_cutlass_operator_arguments_(problem_idx); -+ -+ // Run underlying conv2d operation -+ status = underlying_operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ status = reduction_op_->run( -+ &conv_workspace_.reduction_arguments, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ timer.stop_and_wait(); -+ -+ // -+ // Update performance result -+ // -+ -+ runtime = timer.duration(iteration); -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#if CUTLASS_ENABLE_CUDNN -+ -+/// Verifies CUTLASS against cudnn reference -+bool Conv3dOperationProfiler::verify_with_cudnn_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ auto &conv_desc = static_cast(operation->description()); -+ -+ // -+ // Construct cudnn operators -+ // -+ -+ CudnnCreate handle; -+ cudnnStatus_t status = handle.get_cudnn_create_status(); -+ -+ if (status != CUDNN_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUDNN] = get_cutlass_disposition(status); -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ // Initialize structure containing Conv2d arguments -+ conv_workspace_.arguments.A = conv_workspace_.A->data(); -+ conv_workspace_.arguments.B = conv_workspace_.B->data(); -+ conv_workspace_.arguments.D = conv_workspace_.Reference->data(); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // cuDNN does not support four tensor arguments, so we copy the tensor C data into -+ // tensor D. -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ conv_workspace_.arguments.C = conv_workspace_.arguments.D; -+ -+ try { -+ -+ // -+ // Construct dispatcher to cudnn operator -+ // -+ -+ detail::cudnnConvDispatcher conv_op( -+ conv_desc, -+ conv_workspace_.configuration, -+ conv_workspace_.arguments, -+ handle -+ ); -+ -+ if (conv_op.status != Status::kSuccess) { -+ if (conv_op.status == Status::kErrorNotSupported) { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kNotSupported; -+ -+ } else { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kFailed; -+ } -+ return true; -+ } -+ -+ -+ status = conv_op(handle); -+ -+ // Handle errors -+ if (status != CUDNN_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUDNN] = get_cutlass_disposition(status); -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[library::Provider::kCUDNN] = compare_tensors( -+ options, -+ *conv_workspace_.Computed, -+ *conv_workspace_.Reference -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUDNN] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ conv_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUDNN); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kFailed; -+ } -+ -+ // Return true means continue profiling -+ return true; -+ -+} -+ -+#endif // #if CUTLASS_ENABLE_CUDNN -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/conv3d_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/conv3d_operation_profiler.h -new file mode 100644 -index 0000000..aba832e ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/conv3d_operation_profiler.h -@@ -0,0 +1,447 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines profiling functionality for convolution -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/handle.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/singleton.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+#include "reduction_operation_profiler.h" -+#if CUTLASS_ENABLE_CUDNN -+#include "cudnn_helpers.h" -+#endif //#if CUTLASS_ENABLE_CUDNN -+#include "debug.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class Conv3dOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct Conv3dProblem { -+ -+ int64_t n, d, h, w, c, z, p, q, k, t, r, s; -+ int64_t pad_d, pad_h, pad_w; -+ int64_t stride_d, stride_h, stride_w; -+ int64_t dilation_d, dilation_h, dilation_w; -+ -+ std::vector alpha; -+ std::vector beta; -+ -+ library::SplitKMode split_k_mode; -+ int64_t split_k_slices; -+ -+ library::ConvModeID conv_mode; -+ -+ library::Provider eq_gemm_provider; -+ -+ // convolution with parallel interleaved reduction -+ // convolution epilogue (alpha, beta) = (1.0, 0.0) -+ // reduction epilogue (alpha, beta) = (Conv3dProblem::alpha, Conv3dProblem::beta) -+ std::vector alpha_one; -+ std::vector beta_zero; -+ -+ // -+ // Methods -+ // -+ -+ /// Total number of bytes loaded -+ int64_t bytes(library::ConvDescription const &operation_desc) const; -+ -+ /// Total number of flops computed -+ int64_t flops(library::ConvDescription const &operation_desc) const; -+ -+ /// Infers output size from theinput size, padding, stride, and dilation -+ void set_default_output_size() { -+ z = ((d + pad_d - t * dilation_d) / stride_d) + 1; -+ p = ((h + pad_h - r * dilation_h) / stride_h) + 1; -+ q = ((w + pad_w - s * dilation_w) / stride_w) + 1; -+ } -+ -+ // Returns equivalent gemm problem size for convolution -+ cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * z * p * q), int(k), int(t * r * s * c)); -+ case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * d * h * w), int(c), int(t * r * s * k)); -+ case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(t * r * s * c), int(n * z * p * q)); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns extent for tensor A -+ std::vector extent_a(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return {int(n), int(d), int(h), int(w), int(c)}; -+ case library::ConvKind::kDgrad: return {int(n), int(z), int(p), int(q), int(k)}; -+ case library::ConvKind::kWgrad: return {int(n), int(z), int(p), int(q), int(k)}; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns extent for tensor B -+ std::vector extent_b(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return {int(k), int(t), int(r), int(s), int(c)}; -+ case library::ConvKind::kDgrad: return {int(k), int(t), int(r), int(s), int(c)}; -+ case library::ConvKind::kWgrad: return {int(n), int(d), int(h), int(w), int(c)}; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns extent for tensor C -+ std::vector extent_c(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return {int(n), int(z), int(p), int(q), int(k)}; -+ case library::ConvKind::kDgrad: return {int(n), int(d), int(h), int(w), int(c)}; -+ case library::ConvKind::kWgrad: return {int(k), int(t), int(r), int(s), int(c)}; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns layout for equivalent gemm matrix A -+ library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm -+ case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm -+ case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns layout for equivalent gemm matrix B -+ library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm -+ case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm -+ case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns layout for equivalent gemm matrix C -+ library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ // Gemm operator assumes column-major output -+ case library::ConvKind::kFprop: -+ case library::ConvKind::kDgrad: -+ case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns leading dimenstion for equivalent gemm matrix A -+ int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); -+ case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); -+ case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns leading dimenstion for equivalent gemm matrix B -+ int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); -+ case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); -+ case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns leading dimenstion for equivalent gemm matrix C -+ int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: -+ case library::ConvKind::kDgrad: -+ case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ }; -+ -+ /// Workspace used -+ struct Conv2dWorkspace { -+ -+ /// Conv device allocations -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *C; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ /// Library configuration and arguments for convolution operator -+ library::Conv3dConfiguration configuration; -+ library::ConvArguments arguments; -+ -+ /// Number of copies of the problem workspace which are visited sequentially during -+ /// profiling to avoid camping in the last level cache. -+ int problem_count; -+ -+ /// Buffer used for the cutlass conv2d operations' host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the cutlass operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ /// Library configuration and arguments for reduction operator -+ library::ReductionConfiguration reduction_configuration; -+ library::ReductionArguments reduction_arguments; -+ -+ /// Buffer used for the cutlass reduction operations' host workspace -+ std::vector reduction_host_workspace; -+ -+ /// Host data buffers for host reference operation -+ /// host buffer for tensor -+ std::vector host_tensor_a; -+ -+ /// host buffer for tensor b -+ std::vector host_tensor_b; -+ -+ /// host buffer for tensor c -+ std::vector host_tensor_c; -+ -+ -+ // -+ // Methods -+ // -+ -+ Conv2dWorkspace(): -+ A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } -+ -+ // Returns stride vector for tensor A -+ std::vector stride_a(library::ConvKind const &conv_kind) { -+ return { -+ configuration.layout_a(conv_kind).stride()[0], -+ configuration.layout_a(conv_kind).stride()[1], -+ configuration.layout_a(conv_kind).stride()[2], -+ configuration.layout_a(conv_kind).stride()[3] -+ }; -+ } -+ -+ // Returns stride vector for tensor B -+ std::vector stride_b(library::ConvKind const &conv_kind) { -+ -+ return { -+ configuration.layout_b(conv_kind).stride()[0], -+ configuration.layout_b(conv_kind).stride()[1], -+ configuration.layout_b(conv_kind).stride()[2], -+ configuration.layout_b(conv_kind).stride()[3] -+ }; -+ } -+ -+ // Returns stride vector for tensor C -+ std::vector stride_c(library::ConvKind const &conv_kind) { -+ -+ return { -+ configuration.layout_c(conv_kind).stride()[0], -+ configuration.layout_c(conv_kind).stride()[1], -+ configuration.layout_c(conv_kind).stride()[2], -+ configuration.layout_c(conv_kind).stride()[3] -+ }; -+ } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// CONV problem obtained from problem space -+ Conv3dProblem problem_; -+ -+ /// Device memory allocations -+ Conv2dWorkspace conv_workspace_; -+ -+ /// CUTLASS parallel reduction operation to follow this* conv2d operation -+ library::Operation const *reduction_op_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ Conv3dOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~Conv3dOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Updates the arguments structure for the CUTLASS operator based on -+ /// the problem index. -+ void set_cutlass_operator_arguments_(int problem_idx = 0); -+ -+ /// Method to profile an initialized CUTLASS operation -+ virtual Status profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace); -+ -+ /// Initialize reduction problem dimenstions and library::Operation -+ bool initialize_reduction_configuration_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::ConvDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against host reference -+ bool verify_with_host_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against device reference -+ bool verify_with_device_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+#if CUTLASS_ENABLE_CUDNN -+ -+ /// Verifies CUTLASS against cudnn reference -+ bool verify_with_cudnn_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+#endif //#if CUTLASS_ENABLE_CUDNN -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/cublas_helpers.cu b/3rdparty/cutlass/tools/profiler/src/cublas_helpers.cu -new file mode 100644 -index 0000000..5f7354c ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/cublas_helpers.cu -@@ -0,0 +1,1159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Helper functions for mapping CUTLASS concepts to cuBLAS. -+*/ -+ -+#include -+ -+#if CUTLASS_ENABLE_CUBLAS -+#include "cublas_helpers.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Converts a cuBLAS status to cutlass::Status -+Status get_cutlass_status(cublasStatus_t cublas) { -+ -+ switch (cublas) { -+ case CUBLAS_STATUS_SUCCESS: -+ return Status::kSuccess; -+ case CUBLAS_STATUS_INVALID_VALUE: -+ return Status::kErrorInvalidProblem; -+ case CUBLAS_STATUS_NOT_SUPPORTED: -+ return Status::kErrorNotSupported; -+ default: break; -+ } -+ return Status::kErrorInternal; -+} -+ -+/// Converts a cuBLASS status to cutlass::profiler::Disposition -+Disposition get_cutlass_disposition(cublasStatus_t cublas_status) { -+ -+ if (cublas_status == CUBLAS_STATUS_INVALID_VALUE) { -+ return Disposition::kInvalidProblem; -+ } -+ else if (cublas_status == CUBLAS_STATUS_NOT_SUPPORTED) { -+ return Disposition::kNotSupported; -+ } -+ return Disposition::kFailed; -+} -+ -+/// Maps a CUTLASS tensor layout to a cuBLAS transpose operation -+bool get_cublas_transpose_operation( -+ cublasOperation_t &operation, -+ library::LayoutTypeID layout, -+ library::ComplexTransform transform) { -+ -+ switch (layout) { -+ case library::LayoutTypeID::kColumnMajor: -+ if (transform == library::ComplexTransform::kNone) { -+ operation = CUBLAS_OP_N; -+ return true; -+ } -+ else { -+ return false; -+ } -+ break; -+ case library::LayoutTypeID::kRowMajor: -+ if (transform == library::ComplexTransform::kNone) { -+ operation = CUBLAS_OP_T; -+ return true; -+ } -+ else if (transform == library::ComplexTransform::kConjugate) { -+ operation = CUBLAS_OP_C; -+ return true; -+ } -+ break; -+ default: break; -+ } -+ -+ return false; -+} -+ -+/// Maps a CUTLASS numeric type to a cuBLAS data type enumeration -+bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type) { -+ switch (element_type) { -+ case library::NumericTypeID::kF16: -+ data_type = CUDA_R_16F; -+ return true; -+ -+ case library::NumericTypeID::kBF16: -+ break; -+ -+ case library::NumericTypeID::kTF32: -+ break; -+ -+ case library::NumericTypeID::kF32: -+ data_type = CUDA_R_32F; -+ return true; -+ -+ case library::NumericTypeID::kF64: -+ data_type = CUDA_R_64F; -+ return true; -+ -+ case library::NumericTypeID::kS4: -+ break; -+ -+ case library::NumericTypeID::kS8: -+ data_type = CUDA_R_8I; -+ return true; -+ -+ case library::NumericTypeID::kS16: -+ break; -+ -+ case library::NumericTypeID::kS32: -+ data_type = CUDA_R_32I; -+ return true; -+ -+ case library::NumericTypeID::kS64: -+ break; -+ -+ case library::NumericTypeID::kU4: -+ break; -+ -+ case library::NumericTypeID::kU8: -+ data_type = CUDA_R_8U; -+ return true; -+ -+ case library::NumericTypeID::kU16: -+ break; -+ -+ case library::NumericTypeID::kU32: -+ data_type = CUDA_R_32U; -+ return true; -+ -+ case library::NumericTypeID::kU64: -+ break; -+ -+ case library::NumericTypeID::kB1: -+ break; -+ -+ case library::NumericTypeID::kCF32: -+ data_type = CUDA_C_32F; -+ return true; -+ -+ case library::NumericTypeID::kCF64: -+ data_type = CUDA_C_64F; -+ return true; -+ -+ case library::NumericTypeID::kInvalid: -+ -+ default: -+ break; -+ } -+ -+ return false; -+} -+ -+/// Maps a cutlass::SideMode to cuBLAS side mode -+bool get_cublas_side_mode(cublasSideMode_t& side, SideMode side_mode) { -+ -+ switch (side_mode) { -+ case SideMode::kLeft: -+ side = CUBLAS_SIDE_LEFT; -+ return true; -+ case SideMode::kRight: -+ side = CUBLAS_SIDE_RIGHT; -+ return true; -+ default: break; -+ } -+ -+ return false; -+} -+ -+/// Maps a cutlass::FillMode to cuBLAS fill mode -+bool get_cublas_fill_mode(cublasFillMode_t& uplo, FillMode fill_mode) { -+ -+ switch (fill_mode) { -+ case FillMode::kLower: -+ uplo = CUBLAS_FILL_MODE_LOWER; -+ return true; -+ case FillMode::kUpper: -+ uplo = CUBLAS_FILL_MODE_UPPER; -+ return true; -+ default: break; -+ } -+ -+ return false; -+} -+ -+/// Maps a cutlass::DiagType to cuBLAS diag type -+bool get_cublas_diag_type(cublasDiagType_t& diag, DiagType diag_type) { -+ -+ switch (diag_type) { -+ case DiagType::kNonUnit: -+ diag = CUBLAS_DIAG_NON_UNIT; -+ return true; -+ case DiagType::kUnit: -+ diag = CUBLAS_DIAG_UNIT; -+ return true; -+ default: break; -+ } -+ -+ return false; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gets the cublas algorithm given threadblock tile dimensions and math opcode class -+cublasGemmAlgo_t get_cublas_gemm_algo(int cta_m, int cta_n, int cta_k, library::OpcodeClassID opcode_class) { -+ return (opcode_class == library::OpcodeClassID::kSimt ? -+ CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns a status if cuBLAS can satisfy a particular GEMM description -+Status cublas_satisfies(library::GemmDescription const &desc) { -+ auto const &math_instruction = desc.tile_description.math_instruction; -+ -+ if (math_instruction.element_accumulator == library::NumericTypeID::kS32 && -+ math_instruction.opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // output type S4 and S8 not supported in cuBLAS -+ if (desc.C.element == library::NumericTypeID::kS4 || -+ desc.C.element == library::NumericTypeID::kS8) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ return Status::kSuccess; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+cublasGemmExDispatcher::cublasGemmExDispatcher( -+ library::GemmDescription const &op_desc, -+ library::GemmUniversalConfiguration configuration_, -+ library::GemmUniversalArguments arguments_, -+ cublasGemmAlgo_t algorithm -+): -+ configuration(configuration_), arguments(arguments_), algo(algorithm), status(Status::kSuccess) { -+ -+ bool good = true; -+ -+ good = (good && get_cublas_transpose_operation(trans_A, op_desc.A.layout, op_desc.transform_A)); -+ good = (good && get_cublas_transpose_operation(trans_B, op_desc.B.layout, op_desc.transform_B)); -+ good = (good && get_cublas_datatype(data_type_A, op_desc.A.element)); -+ good = (good && get_cublas_datatype(data_type_B, op_desc.B.element)); -+ good = (good && get_cublas_datatype(data_type_C, op_desc.C.element)); -+ -+ good = (good && get_cublas_datatype( -+ compute_data_type, -+ op_desc.tile_description.math_instruction.element_accumulator)); -+ -+ // cuBLAS introduces a separate cublasComputeType enumerant to more precisely describe -+ // internal numerical data types used in the computation. -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ library::OpcodeClassID const & opcode_class = -+ op_desc.tile_description.math_instruction.opcode_class; -+ -+ if (good && -+ op_desc.A.element == library::NumericTypeID::kF32 && -+ op_desc.B.element == library::NumericTypeID::kF32 && -+ opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; -+ } -+ else if (good) { -+ bool const isPedantic = false; -+ switch (compute_data_type) { -+ case CUDA_R_32F: -+ case CUDA_C_32F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F; -+ break; -+ case CUDA_R_64F: -+ case CUDA_C_64F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_64F_PEDANTIC : CUBLAS_COMPUTE_64F; -+ break; -+ case CUDA_R_16F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_16F_PEDANTIC : CUBLAS_COMPUTE_16F; -+ break; -+ case CUDA_R_32I: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I; -+ break; -+ default: -+ good = false; -+ break; -+ } -+ } -+#endif // __CUDACC_VER_MAJOR__ >= 11 -+ -+ if (!good) { -+ status = Status::kErrorNotSupported; -+ } -+} -+ -+/// Executes GEMM using these arguments -+cublasStatus_t cublasGemmExDispatcher::operator()(cublasHandle_t handle) { -+ -+ if (configuration.mode == library::GemmUniversalMode::kBatched) { -+ return cublasGemmStridedBatchedEx( -+ handle, -+ trans_A, -+ trans_B, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ arguments.alpha, -+ arguments.A, -+ data_type_A, -+ int(configuration.lda), -+ arguments.batch_stride_A, -+ arguments.B, -+ data_type_B, -+ int(configuration.ldb), -+ arguments.batch_stride_B, -+ arguments.beta, -+ arguments.D, -+ data_type_C, -+ int(configuration.ldc), -+ arguments.batch_stride_C, -+ configuration.batch_count, -+ #if (__CUDACC_VER_MAJOR__ >= 11) -+ compute_type, -+ #else -+ compute_data_type, -+ #endif -+ algo -+ ); -+ } -+ else { -+ return cublasGemmEx( -+ handle, -+ trans_A, -+ trans_B, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ arguments.alpha, -+ arguments.A, -+ data_type_A, -+ int(configuration.lda), -+ arguments.B, -+ data_type_B, -+ int(configuration.ldb), -+ arguments.beta, -+ arguments.D, -+ data_type_C, -+ int(configuration.ldc), -+ #if (__CUDACC_VER_MAJOR__ >= 11) -+ compute_type, -+ #else -+ compute_data_type, -+ #endif -+ algo -+ ); -+ } -+} -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns a status if cuBLAS can satisfy a particular RankK description -+Status cublas_satisfies(library::RankKDescription const &desc) { -+ auto const &math_instruction = desc.tile_description.math_instruction; -+ -+ if (math_instruction.element_accumulator == library::NumericTypeID::kS32 && -+ math_instruction.opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // output type S4 and S8 not supported in cuBLAS -+ if (desc.C.element == library::NumericTypeID::kS4 || -+ desc.C.element == library::NumericTypeID::kS8) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // input type BF16 and TF32 not supported in cuBLAS -+ if (desc.A.element == library::NumericTypeID::kBF16 || -+ desc.A.element == library::NumericTypeID::kTF32) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ return Status::kSuccess; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+cublasRankKDispatcher::cublasRankKDispatcher( -+ library::RankKDescription const &op_desc, -+ library::RankKConfiguration configuration_, -+ library::RankKArguments arguments_ -+): -+ configuration(configuration_), arguments(arguments_), status(Status::kSuccess) { -+ -+ blas_mode = op_desc.blas_mode; -+ num_ranks = op_desc.num_ranks; -+ -+ bool good = true; -+ -+ good = (good && get_cublas_transpose_operation(trans_A, op_desc.A.layout, op_desc.transform_A)); -+ good = (good && get_cublas_fill_mode(uplo, op_desc.fill_mode)); -+ good = (good && get_cublas_datatype(data_type_A, op_desc.A.element)); -+ good = (good && get_cublas_datatype(data_type_C, op_desc.C.element)); -+ -+ good = (good && get_cublas_datatype( -+ compute_data_type, -+ op_desc.tile_description.math_instruction.element_accumulator)); -+ -+ // cuBLAS introduces a separate cublasComputeType enumerant to more precisely describe -+ // internal numerical data types used in the computation. -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ library::OpcodeClassID const & opcode_class = -+ op_desc.tile_description.math_instruction.opcode_class; -+ -+ if (good && -+ op_desc.A.element == library::NumericTypeID::kF32 && -+ opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; -+ } -+ else if (good) { -+ bool const isPedantic = false; -+ switch (compute_data_type) { -+ case CUDA_R_32F: -+ case CUDA_C_32F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F; -+ break; -+ case CUDA_R_64F: -+ case CUDA_C_64F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_64F_PEDANTIC : CUBLAS_COMPUTE_64F; -+ break; -+ case CUDA_R_16F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_16F_PEDANTIC : CUBLAS_COMPUTE_16F; -+ break; -+ case CUDA_R_32I: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I; -+ break; -+ default: -+ good = false; -+ break; -+ } -+ } -+#endif // __CUDACC_VER_MAJOR__ >= 11 -+ -+ if (!good) { -+ status = Status::kErrorNotSupported; -+ } -+} -+ -+/// Executes RankK using these arguments -+cublasStatus_t cublasRankKDispatcher::operator()(cublasHandle_t handle) { -+ -+ // SYRK and HERK -+ if (num_ranks == 1) { -+ if (data_type_A == data_type_C && data_type_A == CUDA_R_64F) { -+ return cublasDsyrk( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_R_32F) { -+ -+ #if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ #endif -+ -+ return cublasSsyrk( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_C_64F) { -+ -+ if (blas_mode == BlasMode::kHermitian) { -+ return cublasZherk( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ else { -+ return cublasZsyrk( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_C_32F) { -+ -+ #if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ #endif -+ -+ if (blas_mode == BlasMode::kHermitian) { -+ return cublasCherk( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ else { -+ return cublasCsyrk( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ } else { -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ } -+ } -+ -+ // SYR2K and HER2K -+ else if (num_ranks == 2) { -+ if (data_type_A == data_type_C && data_type_A == CUDA_R_64F) { -+ return cublasDsyr2k( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_R_32F) { -+ -+ #if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ #endif -+ -+ return cublasSsyr2k( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_C_64F) { -+ -+ if (blas_mode == BlasMode::kHermitian) { -+ return cublasZher2k( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ else { -+ return cublasZsyr2k( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_C_32F) { -+ -+ #if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ #endif -+ -+ if (blas_mode == BlasMode::kHermitian) { -+ return cublasCher2k( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ else { -+ return cublasCsyr2k( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ } else { -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ } -+ } -+ else { -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ } -+} -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns a status if cuBLAS can satisfy a particular TRMM description -+Status cublas_satisfies(library::TrmmDescription const &desc) { -+ auto const &math_instruction = desc.tile_description.math_instruction; -+ -+ if (math_instruction.element_accumulator == library::NumericTypeID::kS32 && -+ math_instruction.opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // output type S4 and S8 not supported in cuBLAS -+ if (desc.D.element == library::NumericTypeID::kS4 || -+ desc.D.element == library::NumericTypeID::kS8) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // input type BF16 and TF32 not supported in cuBLAS -+ if (desc.A.element == library::NumericTypeID::kBF16 || -+ desc.A.element == library::NumericTypeID::kTF32) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ return Status::kSuccess; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+cublasTrmmDispatcher::cublasTrmmDispatcher( -+ library::TrmmDescription const &op_desc, -+ library::TrmmConfiguration configuration_, -+ library::TrmmArguments arguments_ -+): -+ configuration(configuration_), arguments(arguments_), status(Status::kSuccess) { -+ -+ bool good = true; -+ -+ good = (good && get_cublas_transpose_operation(trans_A, op_desc.A.layout, op_desc.transform_A)); -+ good = (good && get_cublas_side_mode(side, op_desc.side_mode)); -+ good = (good && get_cublas_fill_mode(uplo, op_desc.fill_mode)); -+ good = (good && get_cublas_diag_type(diag, op_desc.diag_type)); -+ good = (good && get_cublas_datatype(data_type_A, op_desc.A.element)); -+ good = (good && get_cublas_datatype(data_type_B, op_desc.B.element)); -+ good = (good && get_cublas_datatype(data_type_D, op_desc.D.element)); -+ -+ // if A is Transposed, then for cuBLAS that is inverted Fill Mode. -+ if (trans_A == CUBLAS_OP_T || trans_A == CUBLAS_OP_C) { -+ if (uplo == CUBLAS_FILL_MODE_LOWER) -+ uplo = CUBLAS_FILL_MODE_UPPER; -+ else -+ uplo = CUBLAS_FILL_MODE_LOWER; -+ } -+ -+ good = (good && get_cublas_datatype( -+ compute_data_type, -+ op_desc.tile_description.math_instruction.element_accumulator)); -+ -+ // cuBLAS introduces a separate cublasComputeType enumerant to more precisely describe -+ // internal numerical data types used in the computation. -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ library::OpcodeClassID const & opcode_class = -+ op_desc.tile_description.math_instruction.opcode_class; -+ -+ if (good && -+ op_desc.A.element == library::NumericTypeID::kF32 && -+ opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; -+ } -+ else if (good) { -+ bool const isPedantic = false; -+ switch (compute_data_type) { -+ case CUDA_R_32F: -+ case CUDA_C_32F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F; -+ break; -+ case CUDA_R_64F: -+ case CUDA_C_64F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_64F_PEDANTIC : CUBLAS_COMPUTE_64F; -+ break; -+ case CUDA_R_16F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_16F_PEDANTIC : CUBLAS_COMPUTE_16F; -+ break; -+ case CUDA_R_32I: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I; -+ break; -+ default: -+ good = false; -+ break; -+ } -+ } -+#endif // __CUDACC_VER_MAJOR__ >= 11 -+ -+ if (!good) { -+ status = Status::kErrorNotSupported; -+ } -+} -+ -+/// Executes TRMM using these arguments -+cublasStatus_t cublasTrmmDispatcher::operator()(cublasHandle_t handle) { -+ -+ if (data_type_A == data_type_D && data_type_A == CUDA_R_64F) { -+ return cublasDtrmm( -+ handle, -+ side, -+ uplo, -+ trans_A, -+ diag, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.D), -+ int(configuration.ldd) -+ ); -+ } else if (data_type_A == data_type_D && data_type_A == CUDA_R_32F) { -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+#endif -+ -+ return cublasStrmm( -+ handle, -+ side, -+ uplo, -+ trans_A, -+ diag, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.D), -+ int(configuration.ldd) -+ ); -+ } else if (data_type_A == data_type_D && data_type_A == CUDA_C_64F) { -+ return cublasZtrmm( -+ handle, -+ side, -+ uplo, -+ trans_A, -+ diag, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.D), -+ int(configuration.ldd) -+ ); -+ } else if (data_type_A == data_type_D && data_type_A == CUDA_C_32F) { -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+#endif -+ -+ return cublasCtrmm( -+ handle, -+ side, -+ uplo, -+ trans_A, -+ diag, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.D), -+ int(configuration.ldd) -+ ); -+ } else { -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ } -+} -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns a status if cuBLAS can satisfy a particular Symm description -+Status cublas_satisfies(library::SymmDescription const &desc) { -+ auto const &math_instruction = desc.tile_description.math_instruction; -+ -+ if (math_instruction.element_accumulator == library::NumericTypeID::kS32 && -+ math_instruction.opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // output type S4 and S8 not supported in cuBLAS -+ if (desc.C.element == library::NumericTypeID::kS4 || -+ desc.C.element == library::NumericTypeID::kS8) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // input type BF16 and TF32 not supported in cuBLAS -+ if (desc.A.element == library::NumericTypeID::kBF16 || -+ desc.A.element == library::NumericTypeID::kTF32) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // input type BF16 and TF32 not supported in cuBLAS -+ if (desc.B.element == library::NumericTypeID::kBF16 || -+ desc.B.element == library::NumericTypeID::kTF32) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // only column major layout is supported in cuBLAS -+ if (desc.A.layout != library::LayoutTypeID::kColumnMajor || -+ desc.transform_A != library::ComplexTransform::kNone) { -+ -+ return Status::kErrorNotSupported; -+} -+ -+ return Status::kSuccess; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+cublasSymmDispatcher::cublasSymmDispatcher( -+ library::SymmDescription const &op_desc, -+ library::SymmConfiguration configuration_, -+ library::SymmArguments arguments_ -+): -+ configuration(configuration_), arguments(arguments_), status(Status::kSuccess) { -+ -+ blas_mode = op_desc.blas_mode; -+ -+ bool good = true; -+ -+ good = (good && get_cublas_side_mode(side, op_desc.side_mode)); -+ good = (good && get_cublas_fill_mode(uplo, op_desc.fill_mode)); -+ good = (good && get_cublas_datatype(data_type_A, op_desc.A.element)); -+ good = (good && get_cublas_datatype(data_type_C, op_desc.C.element)); -+ -+ good = (good && get_cublas_datatype( -+ compute_data_type, -+ op_desc.tile_description.math_instruction.element_accumulator)); -+ -+ // cuBLAS introduces a separate cublasComputeType enumerant to more precisely describe -+ // internal numerical data types used in the computation. -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ library::OpcodeClassID const & opcode_class = -+ op_desc.tile_description.math_instruction.opcode_class; -+ -+ if (good && -+ op_desc.A.element == library::NumericTypeID::kF32 && -+ opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; -+ } -+ else if (good) { -+ bool const isPedantic = false; -+ switch (compute_data_type) { -+ case CUDA_R_32F: -+ case CUDA_C_32F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F; -+ break; -+ case CUDA_R_64F: -+ case CUDA_C_64F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_64F_PEDANTIC : CUBLAS_COMPUTE_64F; -+ break; -+ case CUDA_R_16F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_16F_PEDANTIC : CUBLAS_COMPUTE_16F; -+ break; -+ case CUDA_R_32I: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I; -+ break; -+ default: -+ good = false; -+ break; -+ } -+ } -+#endif // __CUDACC_VER_MAJOR__ >= 11 -+ -+ if (!good) { -+ status = Status::kErrorNotSupported; -+ } -+} -+ -+/// Executes Symm using these arguments -+cublasStatus_t cublasSymmDispatcher::operator()(cublasHandle_t handle) { -+ -+ // SYMM and HEMM -+ if (data_type_A == data_type_C && data_type_A == CUDA_R_64F) { -+ return cublasDsymm( -+ handle, -+ side, -+ uplo, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_R_32F) { -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+#endif -+ -+ return cublasSsymm( -+ handle, -+ side, -+ uplo, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_C_64F) { -+ -+ if (blas_mode == BlasMode::kHermitian) { -+ return cublasZhemm( -+ handle, -+ side, -+ uplo, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ else { -+ return cublasZsymm( -+ handle, -+ side, -+ uplo, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_C_32F) { -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+#endif -+ -+ if (blas_mode == BlasMode::kHermitian) { -+ return cublasChemm( -+ handle, -+ side, -+ uplo, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ else { -+ return cublasCsymm( -+ handle, -+ side, -+ uplo, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ } else { -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ } -+} -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+#endif // #if CUTLASS_ENABLE_CUBLAS -diff --git a/3rdparty/cutlass/tools/profiler/src/cublas_helpers.h b/3rdparty/cutlass/tools/profiler/src/cublas_helpers.h -new file mode 100644 -index 0000000..8c36fb7 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/cublas_helpers.h -@@ -0,0 +1,358 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Helper functions for mapping CUTLASS concepts to cuBLAS. -+*/ -+ -+#pragma once -+ -+#if CUTLASS_ENABLE_CUBLAS -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/blas3.h" -+ -+#include "options.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Converts a cuBLAS status to cutlass::Status -+Status get_cutlass_status(cublasStatus_t cublas); -+ -+/// Converts a cuBLASS status to cutlass::profiler::Disposition -+Disposition get_cutlass_disposition(cublasStatus_t cublas_status); -+ -+/// Maps a CUTLASS tensor layout to a cuBLAS transpose operation -+bool get_cublas_transpose_operation( -+ cublasOperation_t &operation, -+ library::LayoutTypeID layout, -+ library::ComplexTransform transform = library::ComplexTransform::kNone); -+ -+/// Maps a CUTLASS numeric type to a cuBLAS data type enumeration -+bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type); -+ -+/// Gets the cublas algorithm given threadblock tile dimensions and math opcode class -+cublasGemmAlgo_t get_cublas_gemm_algo( -+ int cta_m, -+ int cta_n, -+ int cta_k, -+ library::OpcodeClassID opcode_class); -+ -+/// Returns a status if cuBLAS can satisfy a particular GEMM description -+Status cublas_satisfies(library::GemmDescription const &desc); -+ -+/// Returns a status if cuBLAS can satisfy a particular RankK description -+Status cublas_satisfies(library::RankKDescription const &desc); -+ -+/// Returns a status if cuBLAS can satisfy a particular TRMM description -+Status cublas_satisfies(library::TrmmDescription const &desc); -+ -+/// Returns a status if cuBLAS can satisfy a particular SYMM/HEMM description -+Status cublas_satisfies(library::SymmDescription const &desc); -+ -+/// This is a helper class to create cublasHandle_t automatically on CublasCreate object creation and -+/// to destroy cublasHandle_t on CublasCreate object destruction. -+/// Additionaly, it provides implicit cast from CublasCreate's object to cublasHandle_t's object -+class CublasCreate { -+private: -+ cublasHandle_t handle; -+ cublasStatus_t status; -+ -+public: -+ CublasCreate() { -+ status = cublasCreate(&handle); -+ } -+ -+ ~CublasCreate() { -+ cublasDestroy(handle); -+ } -+ -+ /// Implicit cast CublasCreate object to cublasHandle_t -+ operator cublasHandle_t() const { return handle; } -+ -+ /// returns cublasStatus_t for handle creation -+ cublasStatus_t get_cublas_create_status() { return status; } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Selects one or more cuBLAS algorithms. -+static void select_cublas_algorithms( -+ std::vector &algorithms, -+ Options const &options, -+ library::GemmDescription const &op_desc) { -+ -+ library::OpcodeClassID const & opcode_class = -+ op_desc.tile_description.math_instruction.opcode_class; -+ -+ switch (options.library.algorithm_mode) { -+ case AlgorithmMode::kMatching: -+ { -+ algorithms.push_back(get_cublas_gemm_algo( -+ op_desc.tile_description.threadblock_shape.m(), -+ op_desc.tile_description.threadblock_shape.n(), -+ op_desc.tile_description.threadblock_shape.k(), -+ opcode_class)); -+ break; -+ } -+ -+ case AlgorithmMode::kBest: -+ { -+ // Choose first enumerated mode. If none are enumerated, choose based on opcode class -+ // and evaluate all of them. -+ -+ if (options.library.algorithms.empty()) { -+ // Enumerate all algorithms -+ if (opcode_class == library::OpcodeClassID::kSimt) { -+ -+ for (int algo = CUBLAS_GEMM_DEFAULT; -+ algo <= CUBLAS_GEMM_ALGO23; -+ ++algo) { -+ -+ algorithms.push_back(cublasGemmAlgo_t(algo)); -+ } -+ } -+ else { -+ -+ for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP; -+ ++algo) { -+ -+ algorithms.push_back(cublasGemmAlgo_t(algo)); -+ } -+ } -+ } -+ else { -+ // Use the listed algorithms -+ algorithms.reserve(options.library.algorithms.size()); -+ -+ for (int algo : options.library.algorithms) { -+ algorithms.push_back(reinterpret_cast(algo)); -+ } -+ } -+ -+ break; -+ } -+ -+ case AlgorithmMode::kDefault: -+ { -+ -+ // Use the library's default algorithm -+ algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ? -+ CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -+ -+ break; -+ } -+ default: -+ { -+ break; -+ } -+ } -+} -+ -+/// Dispatcher to cublasGemmEx() -+struct cublasGemmExDispatcher { -+ -+ // -+ // Data members -+ // -+ library::GemmUniversalConfiguration configuration; -+ library::GemmUniversalArguments arguments; -+ -+ // cublass-specific data structures to fill cublas API call arguments -+ cublasOperation_t trans_A; -+ cublasOperation_t trans_B; -+ cudaDataType_t data_type_A; -+ cudaDataType_t data_type_B; -+ cudaDataType_t data_type_C; -+ cudaDataType_t compute_data_type; -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ cublasComputeType_t compute_type; -+#endif -+ -+ cublasGemmAlgo_t algo; -+ Status status; -+ -+ // -+ // Methods -+ // -+ -+ cublasGemmExDispatcher( -+ library::GemmDescription const &op_desc, -+ library::GemmUniversalConfiguration configuration_, -+ library::GemmUniversalArguments arguments_, -+ cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT -+ ); -+ -+ /// Executes GEMM using these arguments -+ cublasStatus_t operator()(cublasHandle_t handle); -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dispatcher to cublas rank k update kernels -+struct cublasRankKDispatcher { -+ -+ // -+ // Data members -+ // -+ library::RankKConfiguration configuration; -+ library::RankKArguments arguments; -+ -+ // cublass-specific data structures to fill cublas API call arguments -+ cublasOperation_t trans_A; -+ cublasFillMode_t uplo; -+ cudaDataType_t data_type_A; -+ cudaDataType_t data_type_C; -+ cudaDataType_t compute_data_type; -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ cublasComputeType_t compute_type; -+#endif -+ -+ int num_ranks; //(rank-k or rank-2k) -+ BlasMode blas_mode; //(symmetric or hermitian) -+ Status status; -+ -+ // -+ // Methods -+ // -+ -+ cublasRankKDispatcher( -+ library::RankKDescription const &op_desc, -+ library::RankKConfiguration configuration_, -+ library::RankKArguments arguments_ -+ ); -+ -+ /// Executes RankK using these arguments -+ cublasStatus_t operator()(cublasHandle_t handle); -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dispatcher to cublasTrmm() -+struct cublasTrmmDispatcher { -+ -+ // -+ // Data members -+ // -+ library::TrmmConfiguration configuration; -+ library::TrmmArguments arguments; -+ -+ // cublass-specific data structures to fill cublas API call arguments -+ cublasOperation_t trans_A; -+ cublasSideMode_t side; -+ cublasFillMode_t uplo; -+ cublasDiagType_t diag; -+ cudaDataType_t data_type_A; -+ cudaDataType_t data_type_B; -+ cudaDataType_t data_type_D; -+ cudaDataType_t compute_data_type; -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ cublasComputeType_t compute_type; -+#endif -+ -+ Status status; -+ -+ // -+ // Methods -+ // -+ -+ cublasTrmmDispatcher( -+ library::TrmmDescription const &op_desc, -+ library::TrmmConfiguration configuration_, -+ library::TrmmArguments arguments_ -+ ); -+ -+ /// Executes TRMM using these arguments -+ cublasStatus_t operator()(cublasHandle_t handle); -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dispatcher to cublas symm/hemm update kernels -+struct cublasSymmDispatcher { -+ -+ // -+ // Data members -+ // -+ library::SymmConfiguration configuration; -+ library::SymmArguments arguments; -+ -+ // cublass-specific data structures to fill cublas API call arguments -+ cublasSideMode_t side; -+ cublasFillMode_t uplo; -+ cudaDataType_t data_type_A; -+ cudaDataType_t data_type_B; -+ cudaDataType_t data_type_C; -+ cudaDataType_t compute_data_type; -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ cublasComputeType_t compute_type; -+#endif -+ -+ BlasMode blas_mode; //(symmetric or hermitian) -+ Status status; -+ -+ // -+ // Methods -+ // -+ -+ cublasSymmDispatcher( -+ library::SymmDescription const &op_desc, -+ library::SymmConfiguration configuration_, -+ library::SymmArguments arguments_ -+ ); -+ -+ /// Executes Symm using these arguments -+ cublasStatus_t operator()(cublasHandle_t handle); -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace detail -+ -+} // namespace profiler -+} // namespace cutlass -+ -+ -+#endif // #if CUTLASS_ENABLE_CUBLAS -diff --git a/3rdparty/cutlass/tools/profiler/src/cudnn_helpers.h b/3rdparty/cutlass/tools/profiler/src/cudnn_helpers.h -new file mode 100644 -index 0000000..2f02382 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/cudnn_helpers.h -@@ -0,0 +1,590 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Helper functions for mapping CUTLASS concepts to cuDNN. -+ -+*/ -+ -+#pragma once -+#if CUTLASS_ENABLE_CUDNN -+#include -+#include -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/library/library.h" -+#include "enumerated_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Converts a cuDNN status to cutlass::Status -+Status get_cutlass_status(cudnnStatus_t cudnn_status); -+ -+/// Converts a cuDNN status to cutlass::profiler::Disposition -+Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status); -+ -+/// Checks cudnnStatus_t converts to cutlas status and returns if Status::kSuccess o.w. throws exception -+Status checkCudnnErr(cudnnStatus_t cudnn_status); -+ -+/// Maps a CUTLASS conv mode to a cuDNN conv mode enumeration -+bool get_cudnn_conv_mode(cudnnConvolutionMode_t &cudnn_conv_mode, conv::Mode conv_mode); -+ -+/// Maps a CUTLASS layout type to a cuDNN data type enumeration -+bool get_cudnn_layout(cudnnTensorFormat_t &cudnn_layout, library::LayoutTypeID layout); -+ -+/// Maps a CUTLASS numeric type to a cuDNN data type enumeration -+bool get_cudnn_datatype(cudnnDataType_t &cudnn_element_type, library::NumericTypeID element_type); -+ -+/// Maps CUTLASS math OpcodeClassID and MathOperationID to cuDNN math_type -+bool get_cudnn_mathtype(cudnnMathType_t &cudnn_math_type, library::ConvDescription const &conv_desc); -+ -+/// Returns a status if cudnn can satisfy a particular Conv2d description -+Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv2dConfiguration const &configuration); -+ -+/// Returns a status if cudnn can satisfy a particular Conv3d description -+Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv3dConfiguration const &configuration); -+ -+/// Cudnn compute type seems to be hardcoded to float (To handle a possible cudnn issue) -+float cast_cudnn_compute_type_to_float(library::NumericTypeID type, void const * src); -+ -+ -+/// This is a helper class to create cudnnHandle_t automatically on CudnnCreate object creation and -+/// to destroy cudnnHandle_t on CudnnCreate object destruction. -+/// Additionaly, it provides implicit cast from CudnnCreate's object to cudnnHandle_t's object -+class CudnnCreate { -+private: -+ cudnnHandle_t handle; -+ cudnnStatus_t status; -+ -+public: -+ CudnnCreate() { -+ status = cudnnCreate(&handle); -+ } -+ -+ ~CudnnCreate() { -+ cudnnDestroy(handle); -+ } -+ -+ /// Implicit cast CudnnCreate object to cudnnHandle_t -+ operator cudnnHandle_t() const { return handle; } -+ -+ /// returns cudnnStatus_t for handle creation -+ cudnnStatus_t get_cudnn_create_status() { return status; } -+}; -+ -+ -+namespace detail { -+ -+/// Dispatcher to cudnn convolution operators -+struct cudnnConvDispatcher { -+ -+ // -+ // Data members -+ // -+ //library::Conv2dConfiguration configuration; -+ library::ConvArguments arguments; -+ library::ConvKind conv_kind; -+ -+ // cudnn-specific data structures to fill cudnn API call arguments -+ // cudnn activation, filter, and output descriptors -+ cudnnTensorDescriptor_t activation_desc; -+ cudnnFilterDescriptor_t filter_desc; -+ cudnnTensorDescriptor_t output_desc; -+ cudnnConvolutionDescriptor_t conv_desc; -+ -+ // cudnn datatypes -+ cudnnDataType_t data_type_activation; -+ cudnnDataType_t data_type_filter; -+ cudnnDataType_t data_type_output; -+ -+ // cudnn layouts -+ cudnnTensorFormat_t layout_activation; -+ cudnnTensorFormat_t layout_filter; -+ cudnnTensorFormat_t layout_output; -+ -+ // cudnn convolution mode -+ cudnnConvolutionMode_t conv_mode; -+ -+ // cudnn math type (tensorop, tensorop with conversion, simt) -+ cudnnMathType_t math_type; -+ -+ // cudnn compute data type -+ cudnnDataType_t compute_type; -+ -+ // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) -+ float alpha; -+ float beta; -+ -+ // cudnn workspace -+ size_t workspace_size_in_bytes = 0; -+ cutlass::device_memory::allocation workspace; -+ -+ // select cudnn's implicit gemm precomputed algorithm with tensor operations -+ static cudnnConvolutionFwdAlgo_t const fprop_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; -+ static cudnnConvolutionBwdDataAlgo_t const dgrad_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; -+ static cudnnConvolutionBwdFilterAlgo_t const wgrad_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; -+ -+ Status status; -+ -+ // -+ // Methods -+ // -+ -+ // TODO: unify ctor cudnnConvDispatcher for conv2d and conv3d by unifying Conv2dConfigration -+ -+ // ctor for conv2d -+ cudnnConvDispatcher( -+ library::ConvDescription const &op_desc, -+ library::Conv2dConfiguration configuration, -+ library::ConvArguments arguments_, -+ cudnnHandle_t handle -+ ): -+ //configuration(configuration_), -+ arguments(arguments_), -+ conv_kind(op_desc.conv_kind), -+ status(Status::kSuccess) { -+ -+ bool good = true; -+ -+ // Get cudnn datatype, layout, and convolution mode from library::ConvDescription -+ good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); -+ good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); -+ good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); -+ good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); -+ good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); -+ good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); -+ good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); -+ // Get cudnn mathtype (cudnnMathType_t) -+ good = (good && get_cudnn_mathtype(math_type, op_desc)); -+ good = (good && get_cudnn_datatype( -+ compute_type, -+ op_desc.tile_description.math_instruction.element_accumulator)); -+ // Check cutlass Conv2d description has equivalent operator in cudnn -+ if (!good) { -+ status = Status::kErrorNotSupported; -+ return; -+ } -+ // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) -+ alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); -+ beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); -+ -+ // Create convolution descriptor object -+ status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); -+ -+ // Configure convolution operator -+ std::vector padding {configuration.problem_size.pad_h, configuration.problem_size.pad_w}; -+ std::vector stride {configuration.problem_size.stride_h, configuration.problem_size.stride_w}; -+ std::vector dilation {configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; -+ -+ status = get_cutlass_status( -+ cudnnSetConvolutionNdDescriptor( -+ conv_desc, -+ op_desc.conv_dim, -+ padding.data(), -+ stride.data(), -+ dilation.data(), -+ conv_mode, -+ compute_type -+ )); -+ -+ // Set groups -+ status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); -+ -+ // Create activation, filter, and output descriptor objects -+ status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); -+ status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); -+ status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); -+ -+ // Set activation, filter, and output descriptor -+ status = get_cutlass_status( -+ cudnnSetTensor4dDescriptor( -+ activation_desc, -+ layout_activation, -+ data_type_activation, -+ configuration.problem_size.N, -+ configuration.problem_size.C, -+ configuration.problem_size.H, -+ configuration.problem_size.W -+ )); -+ -+ status = get_cutlass_status( -+ cudnnSetFilter4dDescriptor( -+ filter_desc, -+ data_type_filter, -+ layout_filter, -+ configuration.problem_size.K, -+ configuration.problem_size.C / configuration.problem_size.groups, -+ configuration.problem_size.R, -+ configuration.problem_size.S -+ )); -+ -+ status = get_cutlass_status( -+ cudnnSetTensor4dDescriptor( -+ output_desc, -+ layout_output, -+ data_type_output, -+ configuration.problem_size.N, -+ configuration.problem_size.K, -+ configuration.problem_size.P, -+ configuration.problem_size.Q -+ )); -+ -+ // Set math instruction to tensor op -+ status = get_cutlass_status( -+ cudnnSetConvolutionMathType(conv_desc, math_type)); -+ -+ // Initialize workspace -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: -+ status = get_cutlass_status( -+ cudnnGetConvolutionForwardWorkspaceSize( -+ handle, -+ activation_desc, -+ filter_desc, -+ conv_desc, -+ output_desc, -+ fprop_algo, -+ &workspace_size_in_bytes -+ )); break; -+ case library::ConvKind::kDgrad: -+ status = get_cutlass_status( -+ cudnnGetConvolutionBackwardDataWorkspaceSize( -+ handle, -+ filter_desc, -+ output_desc, -+ conv_desc, -+ activation_desc, -+ dgrad_algo, -+ &workspace_size_in_bytes -+ )); break; -+ case library::ConvKind::kWgrad: -+ status = get_cutlass_status( -+ cudnnGetConvolutionBackwardFilterWorkspaceSize( -+ handle, -+ activation_desc, -+ output_desc, -+ conv_desc, -+ filter_desc, -+ wgrad_algo, -+ &workspace_size_in_bytes -+ )); break; -+ -+ } -+ -+ workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); -+ } -+ -+ -+ // ctor for conv3d -+ cudnnConvDispatcher( -+ library::ConvDescription const &op_desc, -+ library::Conv3dConfiguration configuration, -+ library::ConvArguments arguments_, -+ cudnnHandle_t handle -+ ): -+ //configuration(configuration_), -+ arguments(arguments_), -+ conv_kind(op_desc.conv_kind), -+ status(Status::kSuccess) { -+ -+ bool good = true; -+ -+ // Get cudnn datatype, layout, and convolution mode from library::ConvDescription -+ good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); -+ good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); -+ good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); -+ -+ good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); -+ good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); -+ good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); -+ -+ good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); -+ -+ // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) -+ alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); -+ beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); -+ -+ good = (good && get_cudnn_datatype( -+ compute_type, -+ op_desc.tile_description.math_instruction.element_accumulator)); -+ -+ // Check cutlass Conv2d description has equivalent operator in cudnn -+ if (!good) { -+ status = Status::kErrorNotSupported; -+ } -+ -+ // Create convolution descriptor object -+ status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); -+ -+ // Configure convolution operator -+ std::vector padding {configuration.problem_size.pad_d, configuration.problem_size.pad_h, configuration.problem_size.pad_w}; -+ std::vector stride {configuration.problem_size.stride_d, configuration.problem_size.stride_h, configuration.problem_size.stride_w}; -+ std::vector dilation {configuration.problem_size.dilation_d, configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; -+ -+ status = get_cutlass_status( -+ cudnnSetConvolutionNdDescriptor( -+ conv_desc, -+ op_desc.conv_dim, -+ padding.data(), -+ stride.data(), -+ dilation.data(), -+ conv_mode, -+ compute_type -+ )); -+ -+ // Set groups -+ status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); -+ -+ // Create activation, filter, and output descriptor objects -+ status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); -+ status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); -+ status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); -+ -+ // Set activation descriptor -+ std::vector activation_extent { -+ configuration.problem_size.N, -+ configuration.problem_size.C, -+ configuration.problem_size.D, -+ configuration.problem_size.H, -+ configuration.problem_size.W -+ }; -+ -+ std::vector activation_stride { -+ configuration.layout_activations.stride()[3], -+ 1, -+ configuration.layout_activations.stride()[2], -+ configuration.layout_activations.stride()[1], -+ configuration.layout_activations.stride()[0] -+ }; -+ -+ status = get_cutlass_status( -+ cudnnSetTensorNdDescriptor( -+ activation_desc, -+ data_type_activation, -+ op_desc.conv_dim + 2, -+ activation_extent.data(), -+ activation_stride.data() -+ )); -+ -+ // Set filter descriptor -+ std::vector filter_extent { -+ configuration.problem_size.K, -+ configuration.problem_size.C, -+ configuration.problem_size.T, -+ configuration.problem_size.R, -+ configuration.problem_size.S -+ }; -+ -+ std::vector filter_stride { -+ configuration.layout_filters.stride()[3], -+ 1, -+ configuration.layout_filters.stride()[2], -+ configuration.layout_filters.stride()[1], -+ configuration.layout_filters.stride()[0] -+ }; -+ -+ status = get_cutlass_status( -+ cudnnSetFilterNdDescriptor( -+ filter_desc, -+ data_type_filter, -+ layout_filter, -+ op_desc.conv_dim + 2, -+ filter_extent.data() -+ )); -+ -+ -+ // Set output descriptor -+ std::vector output_extent { -+ configuration.problem_size.N, -+ configuration.problem_size.K, -+ configuration.problem_size.Z, -+ configuration.problem_size.P, -+ configuration.problem_size.Q -+ }; -+ -+ std::vector output_stride { -+ configuration.layout_output.stride()[3], -+ 1, -+ configuration.layout_output.stride()[2], -+ configuration.layout_output.stride()[1], -+ configuration.layout_output.stride()[0] -+ }; -+ -+ status = get_cutlass_status( -+ cudnnSetTensorNdDescriptor( -+ output_desc, -+ data_type_output, -+ op_desc.conv_dim + 2, -+ output_extent.data(), -+ output_stride.data() -+ )); -+ -+ // Set math instruction to tensor op -+ status = get_cutlass_status( -+ cudnnSetConvolutionMathType(conv_desc, math_type)); -+ -+ // Initialize workspace -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: -+ status = get_cutlass_status( -+ cudnnGetConvolutionForwardWorkspaceSize( -+ handle, -+ activation_desc, -+ filter_desc, -+ conv_desc, -+ output_desc, -+ fprop_algo, -+ &workspace_size_in_bytes -+ )); break; -+ case library::ConvKind::kDgrad: -+ status = get_cutlass_status( -+ cudnnGetConvolutionBackwardDataWorkspaceSize( -+ handle, -+ filter_desc, -+ output_desc, -+ conv_desc, -+ activation_desc, -+ dgrad_algo, -+ &workspace_size_in_bytes -+ )); break; -+ case library::ConvKind::kWgrad: -+ status = get_cutlass_status( -+ cudnnGetConvolutionBackwardFilterWorkspaceSize( -+ handle, -+ activation_desc, -+ output_desc, -+ conv_desc, -+ filter_desc, -+ wgrad_algo, -+ &workspace_size_in_bytes -+ )); break; -+ -+ } -+ -+ workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); -+ } -+ -+ /// Executes Conv2d operater from cudnn library -+ cudnnStatus_t operator()(cudnnHandle_t handle) { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: -+ return cudnnConvolutionForward( -+ handle, -+ &alpha, -+ activation_desc, -+ activation(), -+ filter_desc, -+ filter(), -+ conv_desc, -+ fprop_algo, -+ workspace.get(), -+ workspace_size_in_bytes, -+ &beta, -+ output_desc, -+ arguments.D -+ ); -+ case library::ConvKind::kDgrad: -+ return cudnnConvolutionBackwardData( -+ handle, -+ &alpha, -+ filter_desc, -+ filter(), -+ output_desc, -+ output(), -+ conv_desc, -+ dgrad_algo, -+ workspace.get(), -+ workspace_size_in_bytes, -+ &beta, -+ activation_desc, -+ arguments.D -+ ); -+ case library::ConvKind::kWgrad: -+ return cudnnConvolutionBackwardFilter( -+ handle, -+ &alpha, -+ activation_desc, -+ activation(), -+ output_desc, -+ output(), -+ conv_desc, -+ wgrad_algo, -+ workspace.get(), -+ workspace_size_in_bytes, -+ &beta, -+ filter_desc, -+ arguments.D -+ ); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns Actviation Tensor -+ void const * activation() const { -+ switch(conv_kind) { -+ case library::ConvKind::kFprop : return arguments.A; -+ case library::ConvKind::kDgrad : return arguments.C; -+ case library::ConvKind::kWgrad : return arguments.B; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns Filter Tensor -+ void const *filter() const { -+ switch(conv_kind) { -+ case library::ConvKind::kFprop : return arguments.B; -+ case library::ConvKind::kDgrad : return arguments.B; -+ case library::ConvKind::kWgrad : return arguments.C; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns Output Tensor -+ void const *output() const { -+ switch(conv_kind) { -+ case library::ConvKind::kFprop : return arguments.C; -+ case library::ConvKind::kDgrad : return arguments.A; -+ case library::ConvKind::kWgrad : return arguments.A; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+}; -+ -+} // namespace detail -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif //#if CUTLASS_ENABLE_CUDNN -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/cutlass_profiler.cu b/3rdparty/cutlass/tools/profiler/src/cutlass_profiler.cu -new file mode 100644 -index 0000000..026ffdf ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/cutlass_profiler.cu -@@ -0,0 +1,226 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+*/ -+ -+#include -+#include -+ -+// Profiler includes -+#include "cutlass_profiler.h" -+#include "gemm_operation_profiler.h" -+#include "rank_k_operation_profiler.h" -+#include "rank_2k_operation_profiler.h" -+#include "trmm_operation_profiler.h" -+#include "symm_operation_profiler.h" -+#include "conv2d_operation_profiler.h" -+#include "conv3d_operation_profiler.h" -+#include "sparse_gemm_operation_profiler.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CutlassProfiler::CutlassProfiler( -+ Options const &options -+): -+ options_(options) { -+ -+ operation_profilers_.emplace_back(new GemmOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new SparseGemmOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new Conv2dOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new Conv3dOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new RankKOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new Rank2KOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new TrmmOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new SymmOperationProfiler(options)); -+} -+ -+CutlassProfiler::~CutlassProfiler() { -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Execute the program -+int CutlassProfiler::operator()() { -+ -+ if (options_.cmdline.num_naked_args() > 0) { -+ std::cerr << "Unknown args: \n"; -+ options_.cmdline.print_naked_args(std::cerr); -+ std::cerr << "\n\n\n"; -+ -+ print_usage_(std::cout); -+ return 1; -+ } -+ -+ if (options_.about.help) { -+ if (options_.operation_kind == library::OperationKind::kInvalid) { -+ print_usage_(std::cout); -+ } -+ else { -+ for (auto & profiler : operation_profilers_) { -+ if (profiler->kind() == options_.operation_kind) { -+ profiler->print_usage(std::cout); -+ profiler->print_examples(std::cout); -+ return 0; -+ } -+ } -+ } -+ return 0; -+ } -+ else if (options_.about.version) { -+ options_.about.print_version(std::cout); -+ -+ std::cout << std::endl; -+ return 0; -+ } -+ else if (options_.about.device_info) { -+ options_.device.print_device_info(std::cout); -+ return 0; -+ } -+ -+ if (options_.execution_mode == ExecutionMode::kProfile || -+ options_.execution_mode == ExecutionMode::kDryRun || -+ options_.execution_mode == ExecutionMode::kTrace) { -+ -+ // Profiles all operations -+ profile_(); -+ } -+ else if (options_.execution_mode == ExecutionMode::kEnumerate) { -+ // Enumerates all operations -+ enumerate_(); -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Enumerates all operations -+void CutlassProfiler::enumerate_() { -+ -+} -+ -+/// Profiles all operations -+int CutlassProfiler::profile_() { -+ -+ int result = 0; -+ DeviceContext device_context; -+ -+ // For all profilers -+ for (auto & profiler : operation_profilers_) { -+ -+ if (options_.operation_kind == library::OperationKind::kInvalid || -+ options_.operation_kind == profiler->kind()) { -+ -+ result = profiler->profile_all(options_, library::Singleton::get().manifest, device_context); -+ -+ if (result) { -+ return result; -+ } -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Prints all options -+void CutlassProfiler::print_usage_(std::ostream &out) { -+ options_.print_usage(out); -+ -+ out << "\nOperations:\n\n"; -+ -+ // For all profilers -+ for (auto & profiler : operation_profilers_) { -+ -+ -+ std::string kind_str = library::to_string(profiler->kind()); -+ -+ size_t kAlignment = 40; -+ size_t columns = 0; -+ -+ if (kind_str.size() < kAlignment) { -+ columns = kAlignment - kind_str.size(); -+ } -+ -+ out << " " << kind_str << std::string(columns, ' ') << profiler->description() << "\n"; -+ -+ } -+ -+ out << "\n\nFor details about a particular function, specify the function name with --help.\n\nExample:\n\n" -+ << " $ cutlass_profiler --operation=Gemm --help\n\n" -+ << " $ cutlass_profiler --operation=RankK --help\n\n" -+ << " $ cutlass_profiler --operation=Trmm --help\n\n" -+ << " $ cutlass_profiler --operation=Symm --help\n\n" -+ << " $ cutlass_profiler --operation=Conv3d --help\n\n" -+ << " $ cutlass_profiler --operation=Conv2d --help\n\n" -+ << " $ cutlass_profiler --operation=SparseGemm --help\n\n" -+ ; -+} -+ -+/// Prints usage -+void CutlassProfiler::print_options_(std::ostream &out) { -+ options_.print_options(out); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Initializes the CUDA device -+void CutlassProfiler::initialize_device_() { -+ -+ cudaError_t result = cudaSetDevice(options_.device.device); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to set device."; -+ throw std::runtime_error("Failed to set device"); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/cutlass_profiler.h b/3rdparty/cutlass/tools/profiler/src/cutlass_profiler.h -new file mode 100644 -index 0000000..a3b0640 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/cutlass_profiler.h -@@ -0,0 +1,96 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+*/ -+ -+#pragma once -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/singleton.h" -+ -+#include "options.h" -+#include "operation_profiler.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// CUTLASS Profiler application -+class CutlassProfiler { -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Performance testbench options -+ Options options_; -+ -+ /// Entry points for each operation -+ OperationProfilerVector operation_profilers_; -+ -+private: -+ -+ /// Prints usage -+ void print_usage_(std::ostream &); -+ -+ /// Prints usage -+ void print_options_(std::ostream &); -+ -+ /// Initializes the device -+ void initialize_device_(); -+ -+ /// Enumerates all operations -+ void enumerate_(); -+ -+ /// Profiles all operations -+ int profile_(); -+ -+public: -+ -+ CutlassProfiler(Options const &options); -+ ~CutlassProfiler(); -+ -+ /// Invokes profiling operations -+ int operator()(); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/debug.h b/3rdparty/cutlass/tools/profiler/src/debug.h -new file mode 100644 -index 0000000..83e2c33 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/debug.h -@@ -0,0 +1,56 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include -+ -+//#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; } -+//#define report(x) {} -+ -+// Enable/Disble Profiler debug prints -+//#define DEBUG_PROFILER -+ -+//RED 31m // profiler prints debug messages in red -+//YELLOW 33m // ir prints debug messages in yellow -+ -+#ifndef DEBUG_PROFILER -+#define debugprof(...) -+#else -+#define debugprof(...) do { \ -+ printf("\033[33m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \ -+ printf(__VA_ARGS__); \ -+ printf("\033[0m\n"); \ -+ } while (0) -+#endif -diff --git a/3rdparty/cutlass/tools/profiler/src/device_allocation.cu b/3rdparty/cutlass/tools/profiler/src/device_allocation.cu -new file mode 100644 -index 0000000..e59c344 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/device_allocation.cu -@@ -0,0 +1,1681 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+*/ -+ -+#include -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+ -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/library/util.h" -+ -+#include "device_allocation.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+size_t DeviceAllocation::bytes(library::NumericTypeID type, size_t capacity) { -+ return size_t(cutlass::library::sizeof_bits(type)) * capacity / 8; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+static std::vector get_packed_layout_stride(std::vector const &extent) { -+ -+ typename Layout::TensorCoord extent_coord; -+ typename Layout::Stride stride_coord; -+ -+ if (extent.size() != size_t(Layout::kRank)) { -+ throw std::runtime_error("Layout does not have same rank as extent vector."); -+ } -+ -+ for (int i = 0; i < Layout::kRank; ++i) { -+ extent_coord[i] = extent.at(i); -+ } -+ -+ std::vector stride; -+ stride.resize(Layout::kStrideRank, 0); -+ -+ Layout layout = Layout::packed(extent_coord); -+ stride_coord = layout.stride(); -+ -+ for (int i = 0; i < Layout::kStrideRank; ++i) { -+ stride.at(i) = (int64_t)stride_coord[i]; -+ } -+ -+ return stride; -+} -+ -+/// Returns the stride of a packed layout -+std::vector DeviceAllocation::get_packed_layout( -+ library::LayoutTypeID layout_id, -+ std::vector const &extent) { -+ -+ std::vector stride; -+ -+ switch (layout_id) { -+ case library::LayoutTypeID::kColumnMajor: -+ stride = get_packed_layout_stride(extent); -+ break; -+ case library::LayoutTypeID::kRowMajor: -+ stride = get_packed_layout_stride(extent); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK2: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK2: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK4: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK4: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK16: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK16: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK32: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK32: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK64: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK64: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kTensorNCHW: -+ stride = get_packed_layout_stride(extent); -+ break; -+ case library::LayoutTypeID::kTensorNHWC: -+ stride = get_packed_layout_stride(extent); -+ break; -+ case library::LayoutTypeID::kTensorNDHWC: -+ stride = get_packed_layout_stride(extent); -+ break; -+ case library::LayoutTypeID::kTensorNC32HW32: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kTensorNC64HW64: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kTensorC32RSK32: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kTensorC64RSK64: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ default: break; -+ } -+ -+ return stride; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template to use CUTLASS Layout functions to -+template -+static size_t construct_layout_( -+ void *bytes, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector &stride) { -+ -+ if (extent.size() != Layout::kRank) { -+ throw std::runtime_error( -+ "Layout must have same rank as extent vector."); -+ } -+ -+ if (Layout::kStrideRank && stride.empty()) { -+ -+ stride = get_packed_layout_stride(extent); -+ -+ return construct_layout_( -+ bytes, -+ layout_id, -+ extent, -+ stride); -+ } -+ else if (Layout::kStrideRank && stride.size() != Layout::kStrideRank) { -+ throw std::runtime_error( -+ "Layout requires either empty stride or stride vector matching Layout::kStrideRank"); -+ } -+ -+ typename Layout::Stride stride_coord; -+ for (int i = 0; i < Layout::kStrideRank; ++i) { -+ stride_coord[i] = (int)stride.at(i); -+ } -+ -+ typename Layout::TensorCoord extent_coord; -+ for (int i = 0; i < Layout::kRank; ++i) { -+ extent_coord[i] = extent.at(i); -+ } -+ -+ // Construct the CUTLASS layout object from the stride object -+ Layout layout(stride_coord); -+ -+ // Pack it into bytes -+ if (bytes) { -+ *reinterpret_cast(bytes) = layout; -+ } -+ -+ // Return capacity -+ size_t capacity_ = layout.capacity(extent_coord); -+ -+ return capacity_; -+} -+ -+/// returns the capacity needed -+size_t DeviceAllocation::construct_layout( -+ void *bytes, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector &stride) { -+ -+ switch (layout_id) { -+ case library::LayoutTypeID::kColumnMajor: -+ return construct_layout_(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kRowMajor: -+ return construct_layout_(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kColumnMajorInterleavedK2: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kRowMajorInterleavedK2: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kColumnMajorInterleavedK4: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kRowMajorInterleavedK4: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kColumnMajorInterleavedK16: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kRowMajorInterleavedK16: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kColumnMajorInterleavedK32: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kRowMajorInterleavedK32: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kColumnMajorInterleavedK64: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kRowMajorInterleavedK64: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorNCHW: -+ return construct_layout_(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorNHWC: -+ return construct_layout_(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorNDHWC: -+ return construct_layout_(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorNC32HW32: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorNC64HW64: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorC32RSK32: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorC64RSK64: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ default: break; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+DeviceAllocation::DeviceAllocation(): -+ type_(library::NumericTypeID::kInvalid), -+ batch_stride_(0), -+ capacity_(0), -+ pointer_(nullptr), -+ layout_(library::LayoutTypeID::kUnknown), -+ batch_count_(1) { -+ -+} -+ -+DeviceAllocation::DeviceAllocation( -+ library::NumericTypeID type, -+ size_t capacity -+): -+ type_(type), batch_stride_(capacity), capacity_(capacity), pointer_(nullptr), -+ layout_(library::LayoutTypeID::kUnknown), batch_count_(1) { -+ -+ cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity)); -+ -+ if (result != cudaSuccess) { -+ type_ = library::NumericTypeID::kInvalid; -+ capacity_ = 0; -+ pointer_ = nullptr; -+ throw std::bad_alloc(); -+ } -+} -+ -+DeviceAllocation::DeviceAllocation( -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride, -+ int batch_count -+): -+ type_(type), batch_stride_(size_t(0)), capacity_(size_t(0)), pointer_(nullptr), batch_count_(1) { -+ -+ reset(type, layout_id, extent, stride, batch_count); -+} -+ -+DeviceAllocation::~DeviceAllocation() { -+ if (pointer_) { -+ cudaFree(pointer_); -+ } -+} -+ -+DeviceAllocation &DeviceAllocation::reset() { -+ if (pointer_) { -+ cudaFree(pointer_); -+ } -+ -+ type_ = library::NumericTypeID::kInvalid; -+ batch_stride_ = 0; -+ capacity_ = 0; -+ pointer_ = nullptr; -+ layout_ = library::LayoutTypeID::kUnknown; -+ stride_.clear(); -+ extent_.clear(); -+ tensor_ref_buffer_.clear(); -+ batch_count_ = 1; -+ -+ return *this; -+} -+ -+DeviceAllocation &DeviceAllocation::reset(library::NumericTypeID type, size_t capacity) { -+ -+ reset(); -+ -+ type_ = type; -+ batch_stride_ = capacity; -+ capacity_ = capacity; -+ -+ cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type_, capacity_)); -+ if (result != cudaSuccess) { -+ throw std::bad_alloc(); -+ } -+ -+ layout_ = library::LayoutTypeID::kUnknown; -+ stride_.clear(); -+ extent_.clear(); -+ batch_count_ = 1; -+ -+ tensor_ref_buffer_.resize(sizeof(pointer_), 0); -+ std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_)); -+ -+ return *this; -+} -+ -+/// Allocates memory for a given layout and tensor -+DeviceAllocation &DeviceAllocation::reset( -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride, -+ int batch_count) { -+ -+ reset(); -+ -+ tensor_ref_buffer_.resize(sizeof(pointer_) + (sizeof(int64_t) * library::get_layout_stride_rank(layout_id)), 0); -+ -+ type_ = type; -+ -+ layout_ = layout_id; -+ stride_ = stride; -+ extent_ = extent; -+ batch_count_ = batch_count; -+ -+ batch_stride_ = construct_layout( -+ tensor_ref_buffer_.data() + sizeof(pointer_), -+ layout_id, -+ extent, -+ stride_); -+ -+ capacity_ = batch_stride_ * batch_count_; -+ -+ cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity_)); -+ if (result != cudaSuccess) { -+ throw std::bad_alloc(); -+ } -+ -+ std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_)); -+ -+ return *this; -+} -+ -+bool DeviceAllocation::good() const { -+ return (capacity_ && pointer_); -+} -+ -+library::NumericTypeID DeviceAllocation::type() const { -+ return type_; -+} -+ -+void *DeviceAllocation::data() const { -+ return pointer_; -+} -+ -+void *DeviceAllocation::batch_data(int batch_idx) const { -+ return static_cast(data()) + batch_stride_bytes() * batch_idx; -+} -+ -+library::LayoutTypeID DeviceAllocation::layout() const { -+ return layout_; -+} -+ -+std::vector const & DeviceAllocation::stride() const { -+ return stride_; -+} -+ -+/// Gets the extent vector -+std::vector const & DeviceAllocation::extent() const { -+ return extent_; -+} -+ -+/// Gets the number of adjacent tensors in memory -+int DeviceAllocation::batch_count() const { -+ return batch_count_; -+} -+ -+/// Gets the stride (in units of elements) beteween items -+int64_t DeviceAllocation::batch_stride() const { -+ return batch_stride_; -+} -+ -+/// Gets the stride (in units of bytes) beteween items -+int64_t DeviceAllocation::batch_stride_bytes() const { -+ return bytes(type_, batch_stride_); -+} -+ -+size_t DeviceAllocation::capacity() const { -+ return capacity_; -+} -+ -+size_t DeviceAllocation::bytes() const { -+ return bytes(type_, capacity_); -+} -+ -+/// Copies from an equivalent-sized tensor in device memory -+void DeviceAllocation::copy_from_device(void const *ptr) { -+ cudaError_t result = cudaMemcpy(data(), ptr, bytes(), cudaMemcpyDeviceToDevice); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed device-to-device copy"); -+ } -+} -+ -+/// Copies from an equivalent-sized tensor in device memory -+void DeviceAllocation::copy_from_host(void const *ptr) { -+ cudaError_t result = cudaMemcpy(data(), ptr, bytes(), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed device-to-device copy"); -+ } -+} -+ -+/// Copies from an equivalent-sized tensor in device memory -+void DeviceAllocation::copy_to_host(void *ptr) { -+ cudaError_t result = cudaMemcpy(ptr, data(), bytes(), cudaMemcpyDeviceToHost); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed device-to-device copy"); -+ } -+} -+ -+void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { -+ if (!good()) { -+ throw std::runtime_error("Attempting to initialize invalid allocation."); -+ } -+ -+ // Instantiate calls to CURAND here. This file takes a long time to compile for -+ // this reason. -+ -+ switch (type_) { -+ case library::NumericTypeID::kF16: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kBF16: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kTF32: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kF32: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCBF16: -+ cutlass::reference::device::BlockFillRandom>( -+ reinterpret_cast *>(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCTF32: -+ cutlass::reference::device::BlockFillRandom>( -+ reinterpret_cast *>(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCF32: -+ cutlass::reference::device::BlockFillRandom>( -+ reinterpret_cast *>(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kF64: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCF64: -+ cutlass::reference::device::BlockFillRandom>( -+ reinterpret_cast *>(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS2: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS4: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS8: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS16: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS32: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS64: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kB1: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU2: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU4: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU8: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU16: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU32: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU64: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ default: break; -+ } -+} -+ -+void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { -+ if (!good()) { -+ throw std::runtime_error("Attempting to initialize invalid allocation."); -+ } -+ -+ std::vector host_data(bytes()); -+ -+ switch (type_) { -+ case library::NumericTypeID::kF16: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kBF16: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kTF32: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kF32: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCF16: -+ cutlass::reference::host::BlockFillRandom>( -+ reinterpret_cast *>(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCBF16: -+ cutlass::reference::host::BlockFillRandom>( -+ reinterpret_cast *>(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCTF32: -+ cutlass::reference::host::BlockFillRandom>( -+ reinterpret_cast *>(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCF32: -+ cutlass::reference::host::BlockFillRandom>( -+ reinterpret_cast *>(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kF64: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCF64: -+ cutlass::reference::host::BlockFillRandom>( -+ reinterpret_cast *>(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS2: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS4: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS8: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS16: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS32: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS64: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kB1: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU2: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU4: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU8: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU16: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU32: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU64: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ default: break; -+ } -+ -+ copy_from_host(host_data.data()); -+} -+ -+void DeviceAllocation::initialize_random_sparsemeta_device(int seed, int MetaSizeInBits) { -+ if (!good()) { -+ throw std::runtime_error("Attempting to initialize invalid allocation."); -+ } -+ -+ // Instantiate calls to CURAND here. This file takes a long time to compile for -+ // this reason. -+ -+ switch (type_) { -+ case library::NumericTypeID::kU16: -+ cutlass::reference::device::BlockFillRandomSparseMeta( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ MetaSizeInBits -+ ); -+ break; -+ case library::NumericTypeID::kU32: -+ cutlass::reference::device::BlockFillRandomSparseMeta( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ MetaSizeInBits -+ ); -+ break; -+ default: -+ break; -+ } -+} -+ -+void DeviceAllocation::initialize_random_sparsemeta_host(int seed, int MetaSizeInBits) { -+ if (!good()) { -+ throw std::runtime_error("Attempting to initialize invalid allocation."); -+ } -+ -+ std::vector host_data(bytes()); -+ -+ switch (type_) { -+ case library::NumericTypeID::kS16: -+ cutlass::reference::host::BlockFillRandomSparseMeta( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ MetaSizeInBits -+ ); -+ break; -+ case library::NumericTypeID::kS32: -+ cutlass::reference::host::BlockFillRandomSparseMeta( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ MetaSizeInBits -+ ); -+ break; -+ default: -+ break; -+ } -+ -+ copy_from_host(host_data.data()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if two blocks have exactly the same value -+bool DeviceAllocation::block_compare_equal( -+ library::NumericTypeID numeric_type, -+ void const *ptr_A, -+ void const *ptr_B, -+ size_t capacity) { -+ -+ switch (numeric_type) { -+ case library::NumericTypeID::kF16: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kBF16: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kTF32: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kF32: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCF32: -+ return reference::device::BlockCompareEqual >( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCF16: -+ return reference::device::BlockCompareEqual>( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCBF16: -+ return reference::device::BlockCompareEqual>( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCTF32: -+ return reference::device::BlockCompareEqual>( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kF64: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCF64: -+ return reference::device::BlockCompareEqual>( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kS2: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kS4: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kS8: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kS16: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kS32: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kS64: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kB1: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kU2: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kU4: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kU8: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kU16: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kU32: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kU64: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ default: -+ throw std::runtime_error("Unsupported numeric type"); -+ } -+} -+ -+/// Returns true if two blocks have approximately the same value -+bool DeviceAllocation::block_compare_relatively_equal( -+ library::NumericTypeID numeric_type, -+ void const *ptr_A, -+ void const *ptr_B, -+ size_t capacity, -+ double epsilon, -+ double nonzero_floor) { -+ -+ switch (numeric_type) { -+ case library::NumericTypeID::kF16: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kBF16: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kTF32: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kF32: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kF64: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kS2: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kS4: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kS8: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kS16: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kS32: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kS64: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kB1: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kU2: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kU4: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kU8: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kU16: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kU32: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kU64: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ // No relatively equal comparison for complex numbers. -+ // -+ // As a simplification, we can require bitwise equality. This avoids false positives. -+ // (i.e. "pass" really means passing. "Fail" may not actually mean failure given appropriate epsilon.) -+ // -+ case library::NumericTypeID::kCF16: -+ return reference::device::BlockCompareEqual >( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCF32: -+ return reference::device::BlockCompareEqual >( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCF64: -+ return reference::device::BlockCompareEqual >( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ default: -+ { -+ throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(numeric_type)); -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Permits copying dynamic vectors into static-length vectors -+template -+struct vector_to_coord { -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ coord[Rank - 1] = vec.at(Rank - 1); -+ -+ if (Rank > 1) { -+ vector_to_coord(coord, vec); -+ } -+ } -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ coord[Rank - 1] = (int)vec.at(Rank - 1); -+ -+ if (Rank > 1) { -+ vector_to_coord(coord, vec); -+ } -+ } -+}; -+ -+/// Permits copying dynamic vectors into static-length vectors -+template -+struct vector_to_coord { -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ coord[0] = vec.at(0); -+ } -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ coord[0] = (int)vec.at(0); -+ } -+}; -+ -+/// Permits copying dynamic vectors into static-length vectors -+template -+struct vector_to_coord { -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+static void write_tensor_csv_static_tensor_view( -+ std::ostream &out, -+ DeviceAllocation &allocation) { -+ -+ Coord extent; -+ Coord stride; -+ -+ if (allocation.extent().size() != Layout::kRank) { -+ throw std::runtime_error("Allocation extent has invalid rank"); -+ } -+ -+ if (allocation.stride().size() != Layout::kStrideRank) { -+ throw std::runtime_error("Allocation stride has invalid rank"); -+ } -+ -+ vector_to_coord, Layout::kRank>(extent, allocation.extent()); -+ vector_to_coord, -+ Layout::kStrideRank>(stride, allocation.stride()); -+ -+ Layout layout(stride); -+ HostTensor host_tensor(extent, layout, false); -+ -+ if (host_tensor.capacity() != allocation.batch_stride()) { -+ throw std::runtime_error("Unexpected capacity to equal."); -+ } -+ -+ host_tensor.copy_in_device_to_host( -+ static_cast(allocation.data()), -+ allocation.batch_stride()); -+ -+ TensorViewWrite(out, host_tensor.host_view()); -+ -+ out << "\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+static void write_tensor_csv_static_type( -+ std::ostream &out, -+ DeviceAllocation &allocation) { -+ -+ switch (allocation.layout()) { -+ case library::LayoutTypeID::kRowMajor: -+ write_tensor_csv_static_tensor_view(out, allocation); -+ break; -+ case library::LayoutTypeID::kColumnMajor: -+ write_tensor_csv_static_tensor_view(out, allocation); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK2: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK2: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK4: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK4: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK16: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK16: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK32: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK32: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK64: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK64: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kTensorNHWC: -+ write_tensor_csv_static_tensor_view(out, allocation); -+ break; -+ case library::LayoutTypeID::kTensorNDHWC: -+ write_tensor_csv_static_tensor_view(out, allocation); -+ break; -+ case library::LayoutTypeID::kTensorNC32HW32: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kTensorNC64HW64: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kTensorC32RSK32: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kTensorC64RSK64: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ default: -+ throw std::runtime_error("Unhandled layout"); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Writes a tensor to csv -+void DeviceAllocation::write_tensor_csv( -+ std::ostream &out) { -+ -+ switch (this->type()) { -+ case library::NumericTypeID::kF16: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kBF16: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kTF32: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kF32: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kF64: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kS2: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kS4: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kS8: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kS16: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kS32: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kS64: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kB1: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kU2: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kU4: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kU8: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kU16: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kU32: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kU64: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kCF16: -+ write_tensor_csv_static_type >(out, *this); -+ break; -+ -+ case library::NumericTypeID::kCF32: -+ write_tensor_csv_static_type >(out, *this); -+ break; -+ -+ case library::NumericTypeID::kCF64: -+ write_tensor_csv_static_type >(out, *this); -+ break; -+ -+ default: -+ throw std::runtime_error("Unsupported numeric type"); -+ } -+} -+ -+template -+static void tensor_fill_tensor_view(DeviceAllocation &allocation, Element val = Element()) { -+ Coord extent; -+ Coord stride; -+ -+ if (allocation.extent().size() != Layout::kRank) { -+ throw std::runtime_error("Allocation extent has invalid rank"); -+ } -+ -+ if (allocation.stride().size() != Layout::kStrideRank) { -+ throw std::runtime_error("Allocation stride has invalid rank"); -+ } -+ -+ vector_to_coord, Layout::kRank>(extent, allocation.extent()); -+ vector_to_coord, -+ Layout::kStrideRank>(stride, allocation.stride()); -+ -+ TensorView view( -+ static_cast(allocation.data()), -+ Layout(stride), -+ extent -+ ); -+ -+ -+ cutlass::reference::device::TensorFill( -+ view, -+ val -+ ); -+} -+ -+template -+static void tensor_fill(DeviceAllocation &allocation, Element val = Element()) { -+ switch (allocation.layout()) { -+ case library::LayoutTypeID::kRowMajor: -+ tensor_fill_tensor_view(allocation, val); -+ break; -+ case library::LayoutTypeID::kColumnMajor: -+ tensor_fill_tensor_view(allocation, val); -+ break; -+ case library::LayoutTypeID::kTensorNHWC: -+ tensor_fill_tensor_view(allocation, val); -+ break; -+ case library::LayoutTypeID::kTensorNDHWC: -+ tensor_fill_tensor_view(allocation, val); -+ break; -+ case library::LayoutTypeID::kTensorNC32HW32: -+ tensor_fill_tensor_view>(allocation, val); -+ break; -+ case library::LayoutTypeID::kTensorNC64HW64: -+ tensor_fill_tensor_view>(allocation, val); -+ break; -+ case library::LayoutTypeID::kTensorC32RSK32: -+ tensor_fill_tensor_view>(allocation, val); -+ break; -+ case library::LayoutTypeID::kTensorC64RSK64: -+ tensor_fill_tensor_view>(allocation, val); -+ break; -+ default: -+ throw std::runtime_error("Unsupported layout"); -+ break; -+ } -+} -+ -+/// Fills a tensor uniformly with a value (most frequently used to clear the tensor) -+void DeviceAllocation::fill(double val = 0.0) { -+ -+ switch (this->type()) { -+ case library::NumericTypeID::kF16: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kBF16: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kTF32: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kF32: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kF64: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kS2: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kS4: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kS8: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kS16: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kS32: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kS64: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kB1: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kU2: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kU4: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kU8: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kU16: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kU32: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kU64: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kCF16: -+ tensor_fill >(*this, from_real(val)); -+ break; -+ -+ case library::NumericTypeID::kCF32: -+ tensor_fill >(*this, from_real(val)); -+ break; -+ -+ case library::NumericTypeID::kCF64: -+ tensor_fill >(*this, from_real(val)); -+ break; -+ -+ default: -+ throw std::runtime_error("Unsupported numeric type"); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/device_allocation.h b/3rdparty/cutlass/tools/profiler/src/device_allocation.h -new file mode 100644 -index 0000000..d0bdfd4 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/device_allocation.h -@@ -0,0 +1,226 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/library/library.h" -+#include "cutlass/util/distribution.h" -+ -+#include "enumerated_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Device memory allocation -+class DeviceAllocation { -+private: -+ -+ /// Data type of contained elements -+ library::NumericTypeID type_; -+ -+ /// Gets the stride between elements -+ size_t batch_stride_; -+ -+ /// Capacity in elements of device allocation -+ size_t capacity_; -+ -+ /// Pointer to device memory -+ void *pointer_; -+ -+ /// Layout type ID -+ library::LayoutTypeID layout_; -+ -+ /// Stride vector -+ std::vector stride_; -+ -+ /// Extent vector -+ std::vector extent_; -+ -+ /// Support allocating a 'batch' of non-overlapping tensors in contiguous memory -+ int batch_count_; -+ -+ /// Buffer holding TensorRef instance to recently allocated memory -+ std::vector tensor_ref_buffer_; -+ -+public: -+ // -+ // Static member functions -+ // -+ -+ /// Determines the number of bytes needed to represent this numeric type -+ static size_t bytes(library::NumericTypeID type, size_t capacity); -+ -+ /// Returns the stride of a packed layout -+ static std::vector get_packed_layout( -+ library::LayoutTypeID layout_id, -+ std::vector const &extent); -+ -+ /// returns the capacity needed -+ static size_t construct_layout( -+ void *bytes, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector &stride); -+ -+ /// Returns true if two blocks have exactly the same value -+ static bool block_compare_equal( -+ library::NumericTypeID numeric_type, -+ void const *ptr_A, -+ void const *ptr_B, -+ size_t capacity); -+ -+ /// Returns true if two blocks have approximately the same value -+ static bool block_compare_relatively_equal( -+ library::NumericTypeID numeric_type, -+ void const *ptr_A, -+ void const *ptr_B, -+ size_t capacity, -+ double epsilon, -+ double nonzero_floor); -+ -+public: -+ // -+ // Methods -+ // -+ -+ DeviceAllocation(); -+ -+ DeviceAllocation(library::NumericTypeID type, size_t capacity); -+ -+ DeviceAllocation( -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride = std::vector(), -+ int batch_count = 1); -+ -+ ~DeviceAllocation(); -+ -+ DeviceAllocation &reset(); -+ -+ /// Allocates device memory of a given type and capacity -+ DeviceAllocation &reset(library::NumericTypeID type, size_t capacity); -+ -+ /// Allocates memory for a given layout and tensor -+ DeviceAllocation &reset( -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride = std::vector(), -+ int batch_count = 1); -+ -+ /// Returns a buffer owning the tensor reference -+ std::vector &tensor_ref() { -+ return tensor_ref_buffer_; -+ } -+ -+ bool good() const; -+ -+ /// Data type of contained elements -+ library::NumericTypeID type() const; -+ -+ /// Pointer to start of device memory allocation -+ void *data() const; -+ -+ /// Pointer to the first element of a batch -+ void *batch_data(int batch_idx) const; -+ -+ /// Gets the layout type -+ library::LayoutTypeID layout() const; -+ -+ /// Gets the stride vector -+ std::vector const & stride() const; -+ -+ /// Gets the extent vector -+ std::vector const & extent() const; -+ -+ /// Gets the number of adjacent tensors in memory -+ int batch_count() const; -+ -+ /// Gets the stride (in units of elements) beteween items -+ int64_t batch_stride() const; -+ -+ /// Gets the stride (in units of bytes) beteween items -+ int64_t batch_stride_bytes() const; -+ -+ /// Capacity of allocation in number of elements -+ size_t capacity() const; -+ -+ /// Capacity of allocation in bytes -+ size_t bytes() const; -+ -+ /// Initializes a device allocation to a random distribution using cuRAND -+ void initialize_random_device(int seed, Distribution dist); -+ -+ /// Initializes a host allocation to a random distribution using std::cout -+ void initialize_random_host(int seed, Distribution dist); -+ -+ /// Initializes a device allocation to a random distribution using cuRAND -+ void initialize_random_sparsemeta_device(int seed, int MetaSizeInBits); -+ -+ /// Initializes a host allocation to a random distribution using std::cout -+ void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits); -+ -+ /// Uniformly fills a tensor with a value when provided o.w. zero -+ void fill(double value); -+ -+ /// Copies from an equivalent-sized tensor in device memory -+ void copy_from_device(void const *ptr); -+ -+ /// Copies from an equivalent-sized tensor in device memory -+ void copy_from_host(void const *ptr); -+ -+ /// Copies from an equivalent-sized tensor in device memory -+ void copy_to_host(void *ptr); -+ -+ /// Writes a tensor to csv -+ void write_tensor_csv(std::ostream &out); -+}; -+ -+using DeviceAllocationList = std::list; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/device_context.cu b/3rdparty/cutlass/tools/profiler/src/device_context.cu -new file mode 100644 -index 0000000..117f78b ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/device_context.cu -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+*/ -+ -+#include "device_context.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocates memory of a given type, capacity (elements), and name -+DeviceAllocation *DeviceContext::allocate_block( -+ std::string const &name, -+ library::NumericTypeID type, -+ size_t capacity) { -+ -+ device_memory_.emplace_back(type, capacity); -+ DeviceAllocation *allocation = &device_memory_.back(); -+ -+ allocations_[name] = allocation; -+ return allocation; -+} -+ -+/// Allocates memory of a given type, capacity (elements), and name -+DeviceAllocation *DeviceContext::allocate_tensor( -+ std::string const &name, -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride, -+ int batch_count) { -+ -+ device_memory_.emplace_back(type, layout_id, extent, stride, batch_count); -+ DeviceAllocation *allocation = &device_memory_.back(); -+ -+ allocations_[name] = allocation; -+ return allocation; -+} -+ -+/// Allocates memory of a given type, capacity (elements), and name -+DeviceAllocation *DeviceContext::allocate_tensor( -+ Options const &options, -+ std::string const &name, -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride, -+ int batch_count) { -+ -+ DeviceAllocation *allocation = -+ allocate_tensor(name, type, layout_id, extent, stride, batch_count); -+ -+ if (options.initialization.enabled) { -+ Distribution data_distribution = options.initialization.data_distribution; -+ -+ // check if data distribution is allowed to change -+ if(!options.initialization.fix_data_distribution) { -+ // change data distribution based on bit width -+ switch(type) { -+ case library::NumericTypeID::kF16: -+ data_distribution.set_uniform(-3, 3, 0); -+ break; -+ case library::NumericTypeID::kB1: -+ data_distribution.set_uniform(0, 1, 0); -+ break; -+ case library::NumericTypeID::kS2: -+ data_distribution.set_uniform(-1, 1, 0); -+ break; -+ case library::NumericTypeID::kS4: -+ data_distribution.set_uniform(-2, 2, 0); -+ break; -+ case library::NumericTypeID::kU2: -+ data_distribution.set_uniform(0, 2, 0); -+ break; -+ case library::NumericTypeID::kU4: -+ data_distribution.set_uniform(0, 2, 0); -+ break; -+ case library::NumericTypeID::kS8: -+ data_distribution.set_uniform(-3, 3, 0); -+ break; -+ case library::NumericTypeID::kU8: -+ data_distribution.set_uniform(0, 4, 0); -+ break; -+ default: break; -+ } -+ } -+ -+ if (options.initialization.provider == library::Provider::kReferenceDevice) { -+ allocation->initialize_random_device( -+ options.initialization.seed, -+ data_distribution); -+ } -+ else if (options.initialization.provider == library::Provider::kReferenceHost) { -+ allocation->initialize_random_host( -+ options.initialization.seed, -+ data_distribution); -+ } -+ } -+ -+ return allocation; -+} -+ -+/// Allocates memory for sparse meta data -+DeviceAllocation *DeviceContext::allocate_sparsemeta_tensor( -+ Options const &options, -+ std::string const &name, -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ library::NumericTypeID type_a, -+ std::vector const &extent, -+ std::vector const &stride, -+ int batch_count) { -+ -+ DeviceAllocation *allocation = -+ allocate_tensor(name, type, layout_id, extent, stride, batch_count); -+ -+ if (options.initialization.enabled) { -+ // TF32 has 4bit meta data. The rest has 2bit. -+ int MetaSizeInBits = (cutlass::library::sizeof_bits(type_a) == 32) ? 4 : 2; -+ -+ if (options.initialization.provider == library::Provider::kReferenceDevice) { -+ allocation->initialize_random_sparsemeta_device( -+ options.initialization.seed, -+ MetaSizeInBits); -+ } -+ else if (options.initialization.provider == library::Provider::kReferenceHost) { -+ allocation->initialize_random_sparsemeta_host( -+ options.initialization.seed, -+ MetaSizeInBits); -+ } -+ } -+ -+ return allocation; -+} -+/// Clears named allocations (but does not necessarily free memory) -+void DeviceContext::clear() { -+ allocations_.clear(); -+} -+ -+/// Frees all device memory allocations -+void DeviceContext::free() { -+ allocations_.clear(); -+ device_memory_.clear(); -+} -+ -+/// Gets the allocation by name -+DeviceAllocation &DeviceContext::at(std::string const &name) { -+ return *allocations_.at(name); -+} -+ -+size_t DeviceContext::size() const { -+ return allocations_.size(); -+} -+ -+DeviceContext::AllocationMap::iterator DeviceContext::begin() { -+ return allocations_.begin(); -+} -+ -+DeviceContext::AllocationMap::iterator DeviceContext::end() { -+ return allocations_.end(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/device_context.h b/3rdparty/cutlass/tools/profiler/src/device_context.h -new file mode 100644 -index 0000000..16a72f9 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/device_context.h -@@ -0,0 +1,128 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+ -+#include "options.h" -+#include "device_allocation.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Collection of allocations on the device -+class DeviceContext { -+public: -+ -+ // -+ // Type definitions -+ // -+ using AllocationMap = std::map; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Memory allocations that exist (owning) -+ DeviceAllocationList device_memory_; -+ -+ /// Non-owning set of named allocations -+ AllocationMap allocations_; -+ -+public: -+ -+ /// Allocates memory of a given type, capacity (elements), and name -+ DeviceAllocation *allocate_block( -+ std::string const &name, -+ library::NumericTypeID type, -+ size_t capacity); -+ -+ /// Allocates memory of a given type, capacity (elements), and name -+ DeviceAllocation *allocate_tensor( -+ std::string const &name, -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride = std::vector(), -+ int batch_count = 1); -+ -+ /// Allocates memory of a given type, capacity (elements), and name -+ DeviceAllocation *allocate_tensor( -+ Options const &options, -+ std::string const &name, -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride = std::vector(), -+ int batch_count = 1); -+ -+ /// Allocates memory for sparse meta data -+ DeviceAllocation *allocate_sparsemeta_tensor( -+ Options const &options, -+ std::string const &name, -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ library::NumericTypeID type_a, -+ std::vector const &extent, -+ std::vector const &stride = std::vector(), -+ int batch_count = 1); -+ -+ /// Clears named allocations (but does not necessarily free memory) -+ void clear(); -+ -+ /// Frees all device memory allocations -+ void free(); -+ -+ /// Gets the allocation by name -+ DeviceAllocation &at(std::string const &name); -+ -+ size_t size() const; -+ -+ AllocationMap::iterator begin(); -+ AllocationMap::iterator end(); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/enumerated_types.h b/3rdparty/cutlass/tools/profiler/src/enumerated_types.h -new file mode 100644 -index 0000000..4d91324 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/enumerated_types.h -@@ -0,0 +1,169 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Provides several functions for filling tensors with data. -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include "cutlass/library/library.h" -+ -+#define TRACE(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+T from_string(std::string const &); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Enumerated type describing how the performance testbench evaluates kernels. -+enum class ExecutionMode { -+ kProfile, ///< regular verification and profiling -+ kDryRun, ///< no kernels are launched or workspaces allocated; used to assess what operators might be launched -+ kEnumerate, ///< no kernels launched or workspaces allocated; lists all operation kind and operations -+ kTrace, ///< executes a single device-side computation with no other kernel launches -+ kInvalid -+}; -+ -+/// Converts a ExecutionMode enumerant to a string -+char const *to_string(ExecutionMode mode, bool pretty = false); -+ -+/// Parses a ExecutionMode enumerant from a string -+template <> -+ExecutionMode from_string(std::string const &str); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Library algorithm mode -+enum class AlgorithmMode { -+ kMatching, ///< compare against best matching algorithm -+ kBest, ///< evaluate all library algorithms and report best -+ kDefault, ///< use the library's default algorithm option -+ kInvalid -+}; -+ -+/// Converts a ExecutionMode enumerant to a string -+char const *to_string(AlgorithmMode mode, bool pretty = false); -+ -+/// Parses a ExecutionMode enumerant from a string -+template <> -+AlgorithmMode from_string(std::string const &str); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Outcome of a performance test -+enum class Disposition { -+ kPassed, -+ kFailed, -+ kNotRun, -+ kIncorrect, -+ kNotVerified, -+ kInvalidProblem, -+ kNotSupported, -+ kInvalid -+}; -+ -+/// Converts a Disposition enumerant to a string -+char const *to_string(Disposition disposition, bool pretty = false); -+ -+/// Parses a Disposition enumerant from a string -+template <> -+Disposition from_string(std::string const &str); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Indicates when to save -+enum class SaveWorkspace { -+ kNever, -+ kIncorrect, -+ kAlways, -+ kInvalid -+}; -+ -+/// Converts a SaveWorkspace enumerant to a string -+char const *to_string(SaveWorkspace save_option, bool pretty = false); -+ -+/// Parses a SaveWorkspace enumerant from a string -+template <> -+SaveWorkspace from_string(std::string const &str); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Indicates the type of kernel argument -+// ArgumentType can be both ScalarType or NumericType. Thus, enums kScalar and kNumeric -+// 1) kScalar: e.g. of a Scalar ArgumentType is u32 is a Scalar type. -+// Its c++ equivalent as "type name = initializer" is "u32 m = 32" -+// 2) kNumeric: e.g. of a Numeric ArgumentType is NumericTypeID is a Numeric type. -+// Its c++ equivalent as "type name = initializer" is "NumericTypeID numeric_type = u32" -+enum class ArgumentTypeID { -+ kScalar, -+ kInteger, -+ kTensor, -+ kBatchedTensor, -+ kStructure, -+ kEnumerated, -+ kInvalid -+}; -+ -+/// Converts a ArgumentTypeID enumerant to a string -+char const *to_string(ArgumentTypeID type, bool pretty = false); -+ -+/// Parses a ArgumentTypeID enumerant from a string -+template <> -+ArgumentTypeID from_string(std::string const &str); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Profiler typedefs -+using ProviderVector = std::vector; -+using DispositionMap = std::map; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Print vector for the report -+template -+std::ostream& operator<< (std::ostream& out, const std::vector& v) { -+ for(int i = 0; i < v.size(); ++i) { -+ out << to_string(v[i], true) << (i+1 != v.size() ? "," : ""); -+ } -+ return out; -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/gemm_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/gemm_operation_profiler.cu -new file mode 100644 -index 0000000..4b15fda ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/gemm_operation_profiler.cu -@@ -0,0 +1,1219 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cublas_helpers.h" -+#include "gemm_operation_profiler.h" -+#include "gpu_timer.h" -+ -+#include "cutlass/library/singleton.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/handle.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+GemmOperationProfiler::GemmOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kGemm, -+ { -+ {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"}, -+ {ArgumentTypeID::kEnumerated, {"split_k_mode"}, "Variant of split K mode(serial, parallel)"}, -+ {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"}, -+ {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"}, -+ {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"}, -+ {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, -+ {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, -+ {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"}, -+ }, -+ { library::Provider::kCUBLAS} -+ ) { -+ -+ description_ = " General matrix-matrix product. D = alpha * A*B + beta * C"; -+} -+ -+/// Destructor -+GemmOperationProfiler::~GemmOperationProfiler() { -+ -+} -+ -+/// Prints usage statement for the math function -+void GemmOperationProfiler::print_usage(std::ostream &out) const { -+ out << "GEMM" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void GemmOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular problem size:\n" -+ << " $ cutlass_profiler --operation=Gemm --m=1024 --n=1024 --k=128\n\n" -+ -+ << "Schmoo over problem size and beta:\n" -+ << " $ cutlass_profiler --operation=Gemm --m=1024:4096:256 --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n" -+ -+ << "Schmoo over accumulator types:\n" -+ << " $ cutlass_profiler --operation=Gemm --accumulator-type=f16,f32\n\n" -+ -+ << "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" -+ << " $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row\n\n" -+ -+ << "Profile a particular problem size with split K and paralell reduction:\n" -+ << " $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128\n\n" -+ -+ << "Using various input value distribution:\n" -+ << " $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3\n" -+ << " $ cutlass_profiler --operation=Gemm --dist=gaussian,mean:0,stddev:3\n" -+ << " $ cutlass_profiler --operation=Gemm --dist=sequential,start:0,delta:1\n\n" -+ -+ << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" -+ << " $ cutlass_profiler --operation=Gemm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" -+ -+ << "Test your changes to gemm kernels with a quick functional test and save results in functional-test.csv:\n" -+ << " $ cutlass_profiler --operation=Gemm \\ \n" -+ << " --m=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" -+ << " --beta=0,1,2 --profiling-iterations=1 \\ \n" -+ << " --providers=cutlass --output=functional-test.csv\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+Status GemmOperationProfiler::GemmProblem::parse( -+ library::GemmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ this->mode = library::GemmUniversalMode::kGemm; -+ -+ if (!arg_as_int(this->m, "m", problem_space, problem)) { -+ // default value -+ this->m = 1024; -+ } -+ -+ if (!arg_as_int(this->n, "n", problem_space, problem)) { -+ // default value -+ this->n = 1024; -+ } -+ -+ if (!arg_as_int(this->k, "k", problem_space, problem)) { -+ // default value -+ this->k = 1024; -+ } -+ -+ if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) { -+ // defualt value -+ this->split_k_mode = library::SplitKMode::kSerial; -+ } -+ -+ this->mode = library::GemmUniversalMode::kGemm; -+ if(this->split_k_mode == library::SplitKMode::kParallel) { -+ this->mode = library::GemmUniversalMode::kGemmSplitKParallel; -+ } -+ -+ if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ this->split_k_slices = 1; -+ } -+ -+ if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { -+ // default value -+ this->batch_count = 1; -+ } else if (this->batch_count > 1) { -+ this->mode = library::GemmUniversalMode::kBatched; -+ } -+ -+ if (this->split_k_slices > 1 && this->batch_count > 1) { -+ // At least one of these must be one -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ this->alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ this->beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->m), int(this->k)}).front(); -+ -+ this->ldb = DeviceAllocation::get_packed_layout( -+ operation_desc.B.layout, {int(this->k), int(this->n)}).front(); -+ -+ this->ldc = DeviceAllocation::get_packed_layout( -+ operation_desc.C.layout, {int(this->m), int(this->n)}).front(); -+ -+ return Status::kSuccess; -+} -+ -+/// Total number of bytes loaded -+int64_t GemmOperationProfiler::GemmProblem::bytes(library::GemmDescription const &operation_desc) const { -+ // Input bytes read and Output bytes written for the gemm problem -+ int64_t bytes = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * m / 8) * k + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * n / 8) * k + -+ int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; -+ -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; -+ } -+ -+ bytes *= batch_count; -+ -+ return bytes; -+} -+ -+/// Total number of flops computed -+int64_t GemmOperationProfiler::GemmProblem::flops(library::GemmDescription const &operation_desc) const { -+ int64_t flops_ = (int64_t(m) * n * k + m * n) * 2 * batch_count; -+ -+ // complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddComplexFastF32: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddGaussianComplex: -+ flops_ *= 3; -+ break; -+ -+ default: break; -+ } -+ -+ return flops_; -+} -+ -+ -+/// Initializes a performance result -+void GemmOperationProfiler::GemmProblem::initialize_result( -+ PerformanceResult &result, -+ library::GemmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind)); -+ -+ set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode)); -+ -+ set_argument(result, "A", problem_space, -+ std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); -+ -+ set_argument(result, "B", problem_space, -+ std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); -+ -+ set_argument(result, "C", problem_space, -+ std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); -+ -+ set_argument(result, "m", problem_space, m); -+ set_argument(result, "n", problem_space, n); -+ set_argument(result, "k", problem_space, k); -+ -+ set_argument(result, "split_k_slices", problem_space, split_k_slices); -+ set_argument(result, "batch_count", problem_space, batch_count); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(beta, operation_desc.element_epilogue)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status GemmOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::GemmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (operation_desc.gemm_kind != library::GemmKind::kUniversal) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = problem_.parse(operation_desc, problem_space, problem); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ gemm_workspace_.configuration.mode = problem_.mode; -+ gemm_workspace_.configuration.problem_size.m() = int(problem_.m); -+ gemm_workspace_.configuration.problem_size.n() = int(problem_.n); -+ gemm_workspace_.configuration.problem_size.k() = int(problem_.k); -+ gemm_workspace_.configuration.lda = problem_.lda; -+ gemm_workspace_.configuration.ldb = problem_.ldb; -+ gemm_workspace_.configuration.ldc = problem_.ldc; -+ gemm_workspace_.configuration.ldd = problem_.ldc; -+ -+ if (problem_.mode == library::GemmUniversalMode::kBatched) { -+ gemm_workspace_.configuration.batch_count = problem_.batch_count; -+ } -+ else { -+ gemm_workspace_.configuration.batch_count = problem_.split_k_slices; -+ } -+ -+ gemm_workspace_.arguments.A = nullptr; -+ gemm_workspace_.arguments.B = nullptr; -+ gemm_workspace_.arguments.C = nullptr; -+ gemm_workspace_.arguments.D = nullptr; -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // initialize reduction operation for parallel splitKMode -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ if (!initialize_reduction_configuration_(operation, problem)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void GemmOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::GemmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ problem_.initialize_result(result, operation_desc, problem_space); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ result.bytes = problem_.bytes(operation_desc); -+ result.flops = problem_.flops(operation_desc); -+ result.runtime = 0; -+ -+} -+ -+/// Initialize redution problem dimentions and library::Operation -+bool GemmOperationProfiler::initialize_reduction_configuration_( -+ library::Operation const *operation, -+ ProblemSpace::Problem const &problem) { -+ library::GemmDescription const &gemm_desc = -+ static_cast(operation->description()); -+ -+ if (!cast_from_double(problem_.alpha_one, gemm_desc.element_epilogue, 1)) { -+ return false; -+ } -+ -+ if (!cast_from_double(problem_.beta_zero, gemm_desc.element_epilogue, 0)) { -+ return false; -+ } -+ -+ /// initialize library::ReductionConfiguration -+ gemm_workspace_.reduction_configuration.problem_size = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn(); -+ gemm_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices); -+ gemm_workspace_.reduction_configuration.partition_stride = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn().product(); -+ gemm_workspace_.reduction_configuration.ldw = problem_.ldc; -+ gemm_workspace_.reduction_configuration.lds = problem_.ldc; -+ gemm_workspace_.reduction_configuration.ldd = problem_.ldc; -+ -+ // find reduction operation -+ library::ReductionFunctionalKey reduction_key( -+ library::Provider::kCUTLASS, -+ gemm_desc.tile_description.math_instruction.element_accumulator, // element workspace -+ gemm_desc.tile_description.math_instruction.element_accumulator, // element accumulator -+ gemm_desc.C.element, // element output -+ gemm_desc.element_epilogue // element coumpute -+ ); -+ -+ auto reduction_it = library::Singleton::get().operation_table.reduction_operations.find(reduction_key); -+ -+ if (reduction_it == library::Singleton::get().operation_table.reduction_operations.end()) { -+ return false; -+ } -+ -+ // initialize reduction operation required for parallel split-k operator -+ reduction_op_ = reduction_it->second; -+ -+ // reduction operation found and initialized -+ return true; -+} -+ -+/// Initializes workspace -+Status GemmOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::Operation const* underlying_operation = operation; -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ library::GemmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ // Compute the number of copies of the problem to avoid L2 camping. -+ if (!options.profiling.workspace_count) { -+ int64_t bytes = problem_.bytes(operation_desc); -+ if (bytes < 3 * int64_t(options.device.properties.l2CacheSize)) { -+ gemm_workspace_.problem_count = -+ 1 + int((3 * int64_t(options.device.properties.l2CacheSize)) / bytes); -+ } -+ else { -+ gemm_workspace_.problem_count = 1; -+ } -+ } -+ else { -+ gemm_workspace_.problem_count = options.profiling.workspace_count; -+ } -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ gemm_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.m), int(problem_.k)}, -+ {int(problem_.lda)}, -+ problem_.batch_count * gemm_workspace_.problem_count -+ ); -+ -+ gemm_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ {int(problem_.k), int(problem_.n)}, -+ {int(problem_.ldb)}, -+ problem_.batch_count * gemm_workspace_.problem_count -+ ); -+ -+ gemm_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)}, -+ problem_.batch_count * gemm_workspace_.problem_count -+ ); -+ -+ gemm_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)}, -+ problem_.batch_count * gemm_workspace_.problem_count -+ ); -+ -+ gemm_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)}, -+ problem_.batch_count * gemm_workspace_.problem_count -+ ); -+ -+ gemm_workspace_.Reference->copy_from_device(gemm_workspace_.C->data()); -+ -+ // NOTE: the leading non-batch strides are duplicated here for 3.0 API kernels -+ gemm_workspace_.arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)}; -+ gemm_workspace_.arguments.batch_count = problem_.batch_count; -+ gemm_workspace_.arguments.lda = problem_.lda; -+ gemm_workspace_.arguments.ldb = problem_.ldb; -+ gemm_workspace_.arguments.ldc = problem_.ldc; -+ gemm_workspace_.arguments.ldd = problem_.ldc; -+ gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); -+ } -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_.configuration); -+ gemm_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration, -+ &gemm_workspace_.arguments); -+ gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = underlying_operation->initialize( -+ &gemm_workspace_.configuration, -+ gemm_workspace_.host_workspace.data(), -+ gemm_workspace_.device_workspace.data()); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ workspace_size = reduction_op_->get_host_workspace_size(&gemm_workspace_.reduction_configuration); -+ gemm_workspace_.reduction_host_workspace.resize(workspace_size, 0); -+ -+ status = reduction_op_->initialize( -+ &gemm_workspace_.reduction_configuration, -+ gemm_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kGemm; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool GemmOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ // Initialize structure containing GEMM arguments -+ gemm_workspace_.arguments.A = gemm_workspace_.A->data(); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->data(); -+ gemm_workspace_.arguments.C = gemm_workspace_.C->data(); -+ gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); -+ gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); -+ gemm_workspace_.arguments.beta = problem_.beta_zero.data(); -+ -+ gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); -+ gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); -+ gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); -+ gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); -+ gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ } -+ -+ // -+ // Run the CUTLASS operation -+ // -+ -+ // initialize gemm underlying operation to handle parallel reduction -+ library::Operation const * underlying_operation = operation; -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ } -+ -+ results_.back().status = underlying_operation->run( -+ &gemm_workspace_.arguments, -+ gemm_workspace_.host_workspace.data(), -+ gemm_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ results_.back().status = reduction_op_->run( -+ &gemm_workspace_.reduction_arguments, -+ gemm_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUBLAS -+ if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { -+ -+ // Guard against unsupported cases -+ auto const & gemm_desc = static_cast(operation->description()); -+ -+ if (cublas_satisfies(gemm_desc) == Status::kSuccess) { -+ -+ // call cublas verification if supported -+ verify_with_cublas_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else { -+ // set verification map for cublas to not supported -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUBLAS -+ -+ verify_with_reference_(options, report, device_context, operation, problem_space, problem); -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool GemmOperationProfiler::verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+#if CUTLASS_ENABLE_CUBLAS -+ -+ library::GemmDescription const &gemm_desc = -+ static_cast(operation->description()); -+ -+ // -+ // Construct cuBLAS operators -+ // -+ -+ CublasCreate handle; -+ cublasStatus_t status = handle.get_cublas_create_status(); -+ -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = get_cutlass_disposition(status); -+ return true; -+ } -+ -+ std::vector algorithms; -+ -+ detail::select_cublas_algorithms( -+ algorithms, -+ options, -+ gemm_desc); -+ -+ if (algorithms.empty()) { -+ // no algorithm selected -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ try { -+ -+ // -+ // Construct dispatcher to cublasGemmEx() -+ // -+ -+ // Initialize structure containing GEMM arguments -+ gemm_workspace_.arguments.A = gemm_workspace_.A->data(); -+ gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->data(); -+ gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); -+ gemm_workspace_.arguments.C = gemm_workspace_.Reference->data(); -+ gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.Reference->batch_stride(); -+ gemm_workspace_.arguments.D = gemm_workspace_.Reference->data(); -+ gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Reference->batch_stride(); -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ detail::cublasGemmExDispatcher gemm_op( -+ gemm_desc, -+ gemm_workspace_.configuration, -+ gemm_workspace_.arguments, -+ algorithms.front() -+ ); -+ -+ if (gemm_op.status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotRun; -+ return true; -+ } -+ -+ results_.back().status = Status::kSuccess; -+ -+ status = gemm_op(handle); -+ -+ // Handle errors -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = get_cutlass_disposition(status); -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( -+ options, -+ *gemm_workspace_.Computed, -+ *gemm_workspace_.Reference, -+ gemm_workspace_.Computed->batch_stride() -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ gemm_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUBLAS); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ } -+ -+#endif -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against host and device references -+bool GemmOperationProfiler::verify_with_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::GemmDescription const &gemm_desc = -+ static_cast(operation->description()); -+ -+ // -+ // Initialize state -+ // -+ -+ library::Provider references[] = { -+ library::Provider::kReferenceDevice, -+ library::Provider::kReferenceHost -+ }; -+ -+ for (auto provider : references) { -+ -+ // Skip providers that are not enabled -+ if (!options.verification.provider_enabled(provider)) { -+ continue; -+ } -+ -+ void *ptr_A = gemm_workspace_.A->data(); -+ void *ptr_B = gemm_workspace_.B->data(); -+ void *ptr_C = gemm_workspace_.C->data(); -+ void *ptr_D = gemm_workspace_.Reference->data(); -+ -+ // To support the host-side reference, conditionally allocate and -+ // copy tensors to host memory. -+ std::vector host_data_A; -+ std::vector host_data_B; -+ std::vector host_data_C; -+ std::vector host_data_D; -+ -+ if (provider == library::Provider::kReferenceHost) { -+ -+ host_data_A.resize(gemm_workspace_.A->bytes()); -+ ptr_A = host_data_A.data(); -+ gemm_workspace_.A->copy_to_host(ptr_A); -+ -+ host_data_B.resize(gemm_workspace_.B->bytes()); -+ ptr_B = host_data_B.data(); -+ gemm_workspace_.B->copy_to_host(ptr_B); -+ -+ host_data_C.resize(gemm_workspace_.C->bytes()); -+ ptr_C = host_data_C.data(); -+ gemm_workspace_.C->copy_to_host(ptr_C); -+ -+ host_data_D.resize(gemm_workspace_.Reference->bytes()); -+ ptr_D = host_data_D.data(); -+ } -+ -+ // -+ // Launch -+ // -+ -+ library::Handle handle; -+ -+ handle.set_provider(provider); -+ -+ Status status = handle.gemm_universal( -+ problem_.mode, -+ gemm_workspace_.configuration.problem_size.m(), -+ gemm_workspace_.configuration.problem_size.n(), -+ gemm_workspace_.configuration.problem_size.k(), -+ gemm_desc.tile_description.math_instruction.element_accumulator, -+ gemm_desc.element_epilogue, -+ -+ problem_.alpha.data(), -+ -+ gemm_desc.A.element, -+ gemm_desc.A.layout, -+ gemm_desc.transform_A, -+ ptr_A, -+ int(gemm_workspace_.configuration.lda), -+ -+ gemm_desc.B.element, -+ gemm_desc.B.layout, -+ gemm_desc.transform_B, -+ ptr_B, -+ int(gemm_workspace_.configuration.ldb), -+ -+ problem_.beta.data(), -+ -+ gemm_desc.C.element, -+ ptr_C, -+ int(gemm_workspace_.configuration.ldc), -+ -+ ptr_D, -+ int(gemm_workspace_.configuration.ldd), -+ -+ gemm_workspace_.configuration.batch_count, -+ gemm_workspace_.A->batch_stride(), -+ gemm_workspace_.B->batch_stride(), -+ gemm_workspace_.C->batch_stride(), -+ gemm_workspace_.Reference->batch_stride() -+ ); -+ -+ if (status != Status::kSuccess) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ return true; -+ } -+ -+ results_.back().status = status; -+ -+ if (provider == library::Provider::kReferenceHost) { -+ gemm_workspace_.Reference->copy_from_host(ptr_D); -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[provider] = compare_tensors( -+ options, -+ *gemm_workspace_.Computed, -+ *gemm_workspace_.Reference, -+ gemm_workspace_.Computed->batch_stride() -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[provider] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ gemm_desc, -+ library::Provider::kCUTLASS, -+ provider); -+ } -+ } -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Measures performance results -+bool GemmOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing GEMM arguments -+ gemm_workspace_.arguments.A = gemm_workspace_.A->data(); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->data(); -+ gemm_workspace_.arguments.C = gemm_workspace_.C->data(); -+ gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); -+ gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); -+ gemm_workspace_.arguments.beta = problem_.beta_zero.data(); -+ -+ gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); -+ gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); -+ gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); -+ gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); -+ gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ } -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &gemm_workspace_.arguments, -+ gemm_workspace_.host_workspace.data(), -+ gemm_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Method to profile a CUTLASS Operation -+Status GemmOperationProfiler::profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace) { -+ -+ GpuTimer timer; -+ -+ // initialize gemm underlying operation to handle parallel reduction -+ library::Operation const * underlying_operation = operation; -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ // -+ // Optional sleep to limit power consumption and thermals -+ // -+ -+ sleep(options.profiling.sleep_duration); -+ -+ // -+ // Warmup loop -+ // -+ -+ Status status; -+ -+ for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { -+ -+ int problem_idx = (iteration % gemm_workspace_.problem_count) * problem_.batch_count; -+ -+ gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->batch_data(problem_idx); -+ gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); -+ gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); -+ -+ gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); -+ gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx); -+ gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx); -+ } -+ -+ // Execute the CUTLASS operation -+ status = underlying_operation->run( -+ &gemm_workspace_.arguments, -+ host_workspace, -+ device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ status = reduction_op_->run( -+ &gemm_workspace_.reduction_arguments, -+ gemm_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ } -+ -+ // -+ // Initialize GPU timer -+ // -+ -+ timer.start(); -+ -+ // -+ // Profiling loop -+ // -+ -+ int Iterations = options.profiling.iterations; -+ -+ int iteration = 0; -+ for (; iteration < Iterations; ++iteration) { -+ -+ // Iterate over copies of the problem in memory -+ int workspace_idx = options.profiling.warmup_iterations + iteration; -+ int problem_idx = (workspace_idx % gemm_workspace_.problem_count) * problem_.batch_count; -+ -+ gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->batch_data(problem_idx); -+ gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); -+ gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); -+ -+ gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); -+ gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx); -+ gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx); -+ } -+ -+ status = underlying_operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ status = reduction_op_->run( -+ &gemm_workspace_.reduction_arguments, -+ gemm_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ timer.stop_and_wait(); -+ -+ // -+ // Update performance result -+ // -+ -+ runtime = timer.duration(iteration); -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/gemm_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/gemm_operation_profiler.h -new file mode 100644 -index 0000000..efee650 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/gemm_operation_profiler.h -@@ -0,0 +1,269 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+#include "reduction_operation_profiler.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class GemmOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct GemmProblem { -+ -+ cutlass::library::GemmUniversalMode mode; -+ cutlass::library::SplitKMode split_k_mode; -+ int64_t m; -+ int64_t n; -+ int64_t k; -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldc; -+ std::vector alpha; -+ std::vector beta; -+ int split_k_slices; -+ int batch_count; -+ -+ // gemm with parallel interleaved reduction -+ // gemm epilogue (alpha, beta) = (1.0, 0.0) -+ // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) -+ std::vector alpha_one; -+ std::vector beta_zero; -+ -+ // -+ // Methods -+ // -+ -+ GemmProblem(): -+ mode(library::GemmUniversalMode::kGemm), -+ m(16), n(16), k(16), lda(0), ldb(0), ldc(0), split_k_slices(1), batch_count(1) { } -+ -+ /// Parses the problem -+ Status parse( -+ library::GemmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Total number of bytes loaded -+ int64_t bytes(library::GemmDescription const &operation_desc) const; -+ -+ /// Total number of flops computed -+ int64_t flops(library::GemmDescription const &operation_desc) const; -+ -+ /// Initializes a performance result -+ void initialize_result( -+ PerformanceResult &result, -+ library::GemmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ }; -+ -+ /// Workspace used -+ struct GemmWorkspace { -+ -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *C; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ /// Number of copies of the problem workspace which are visited sequentially during -+ /// profiling to avoid camping in the last level cache. -+ int problem_count; -+ -+ library::GemmUniversalConfiguration configuration; -+ library::GemmUniversalArguments arguments; -+ -+ /// Buffer used for the operation's host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ /// Library configuration and arguments for reduction operator -+ library::ReductionConfiguration reduction_configuration; -+ library::ReductionArguments reduction_arguments; -+ -+ /// Buffer used for the cutlass reduction operations' host workspace -+ std::vector reduction_host_workspace; -+ -+ // -+ // Methods -+ // -+ -+ GemmWorkspace(): -+ A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr), problem_count(1) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// GEMM problem obtained from problem space -+ GemmProblem problem_; -+ -+ /// Device memory allocations -+ GemmWorkspace gemm_workspace_; -+ -+ /// CUTLASS parallel reduction operation to follow this* gemm operation -+ library::Operation const *reduction_op_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ GemmOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~GemmOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::GemmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against references -+ bool verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against host and device references -+ bool verify_with_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Method to profile a CUTLASS Operation -+ Status profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace); -+ -+ /// Initialize reduction problem dimensions and library::Operation -+ bool initialize_reduction_configuration_( -+ library::Operation const *operation, -+ ProblemSpace::Problem const &problem); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/gpu_timer.h b/3rdparty/cutlass/tools/profiler/src/gpu_timer.h -new file mode 100644 -index 0000000..d8bce95 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/gpu_timer.h -@@ -0,0 +1,72 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct GpuTimer { -+ -+ cudaEvent_t events[2]; -+ -+ // -+ // Methods -+ // -+ -+ GpuTimer(); -+ ~GpuTimer(); -+ -+ /// Records a start event in the stream -+ void start(cudaStream_t stream = nullptr); -+ -+ /// Records a stop event in the stream -+ void stop(cudaStream_t stream = nullptr); -+ -+ /// Records a stop event in the stream and synchronizes on the stream -+ void stop_and_wait(cudaStream_t stream = nullptr); -+ -+ /// Returns the duration in miliseconds -+ double duration(int iterations = 1) const; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/operation_profiler.cu -new file mode 100644 -index 0000000..b2e8f9b ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/operation_profiler.cu -@@ -0,0 +1,691 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+*/ -+ -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#ifdef __unix__ -+#include -+#elif defined(_WIN32) || defined(WIN32) -+#include -+#else -+// sleep not supported -+#endif -+ -+#include "options.h" -+#include "operation_profiler.h" -+#include "gpu_timer.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+OperationProfiler::OperationProfiler(): kind_(library::OperationKind::kInvalid) { } -+ -+/// Ctor -+OperationProfiler::OperationProfiler( -+ Options const &options, -+ library::OperationKind kind, -+ ArgumentDescriptionVector const &arguments, -+ ProviderVector const & verification_providers -+): -+ kind_(kind), arguments_(arguments) { -+ -+ ArgumentDescriptionVector tile_description_arguments{ -+ {ArgumentTypeID::kEnumerated, {"op_class", "opcode-class"}, "Class of math instruction (simt, tensorop, wmmatensorop, wmma)"}, -+ {ArgumentTypeID::kEnumerated, {"accum", "accumulator-type"}, "Math instruction accumulator data type"}, -+ {ArgumentTypeID::kInteger, {"cta_m", "threadblock-shape::m"}, "Threadblock shape in the M dimension"}, -+ {ArgumentTypeID::kInteger, {"cta_n", "threadblock-shape::n"}, "Threadblock shape in the N dimension"}, -+ {ArgumentTypeID::kInteger, {"cta_k", "threadblock-shape::k"}, "Threadblock shape in the K dimension"}, -+ {ArgumentTypeID::kInteger, {"cluster_m", "cluster-shape::m"}, "Cluster shape in the M dimension"}, -+ {ArgumentTypeID::kInteger, {"cluster_n", "cluster-shape::n"}, "Cluster shape in the N dimension"}, -+ {ArgumentTypeID::kInteger, {"cluster_k", "cluster-shape::k"}, "Cluster shape in the K dimension"}, -+ {ArgumentTypeID::kInteger, {"stages", "threadblock-stages"}, "Number of stages of threadblock-scoped matrix multiply"}, -+ {ArgumentTypeID::kInteger, {"warps_m", "warp-count::m"}, "Number of warps within threadblock along the M dimension"}, -+ {ArgumentTypeID::kInteger, {"warps_n", "warp-count::n"}, "Number of warps within threadblock along the N dimension"}, -+ {ArgumentTypeID::kInteger, {"warps_k", "warp-count::k"}, "Number of warps within threadblock along the K dimension"}, -+ {ArgumentTypeID::kInteger, {"inst_m", "instruction-shape::m"}, "Math instruction shape in the M dimension"}, -+ {ArgumentTypeID::kInteger, {"inst_n", "instruction-shape::n"}, "Math instruction shape in the N dimension"}, -+ {ArgumentTypeID::kInteger, {"inst_k", "instruction-shape::k"}, "Math instruction shape in the K dimension"}, -+ {ArgumentTypeID::kInteger, {"min_cc", "minimum-compute-capability"}, "Minimum device compute capability"}, -+ {ArgumentTypeID::kInteger, {"max_cc", "maximum-compute-capability"}, "Maximum device compute capability"} -+ }; -+ -+ arguments_.insert(arguments_.end(), tile_description_arguments.begin(), tile_description_arguments.end()); -+ -+ for (auto provider : verification_providers) { -+ if (std::find( -+ options.verification.providers.begin(), -+ options.verification.providers.end(), -+ provider) != options.verification.providers.end()) { -+ -+ verification_providers_.push_back(provider); -+ } -+ } -+} -+ -+/// Destructor -+OperationProfiler::~OperationProfiler() { -+ -+} -+ -+/// Gets the schema description -+std::string const & OperationProfiler::description() const { -+ return description_; -+} -+ -+/// Prints usage statement for the math function -+void OperationProfiler::print_usage(std::ostream &out) const { -+ for (auto const & desc : arguments_) { -+ -+ size_t const kAliasStart = 10; -+ -+ size_t columns = 0; -+ -+ std::string type_str = to_string(desc.type); -+ columns += type_str.size(); -+ -+ out << " [" << type_str << "]"; -+ -+ if (columns < kAliasStart) { -+ out << std::string(kAliasStart - columns, ' '); -+ } -+ -+ columns = 0; -+ -+ int j = 0; -+ for (auto const & alias : desc.aliases) { -+ columns += alias.size() + (j ? 1 : 0) + 2; -+ -+ out << (j++ ? "," : "") << "--" << alias; -+ } -+ -+ size_t const kTotalColumns = 50; -+ -+ if (columns < kTotalColumns) { -+ out << std::string(kTotalColumns - columns, ' '); -+ } -+ -+ out << desc.description << "\n"; -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if the current operation description satisfies the problem space -+bool OperationProfiler::satisfies( -+ library::OperationDescription const &op_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::OpcodeClassID opcode_class; -+ if (arg_as_OpcodeClassID(opcode_class, "op_class", problem_space, problem)) { -+ if (opcode_class != op_desc.tile_description.math_instruction.opcode_class) { -+ return false; -+ } -+ } -+ -+ int64_t int_value; -+ -+ if (arg_as_int(int_value, "inst_m", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.math_instruction.instruction_shape.m()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "inst_n", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.math_instruction.instruction_shape.n()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "inst_k", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.math_instruction.instruction_shape.k()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "cta_m", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.threadblock_shape.m()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "cta_n", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.threadblock_shape.n()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "cta_k", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.threadblock_shape.k()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "cluster_m", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.cluster_shape.m()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "cluster_n", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.cluster_shape.n()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "cluster_k", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.cluster_shape.k()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "stages", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.threadblock_stages) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "warps_m", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.warp_count.m()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "warps_n", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.warp_count.n()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "warps_k", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.warp_count.k()) != int_value) { -+ return false; -+ } -+ } -+ -+ library::NumericTypeID numeric_type; -+ if (arg_as_NumericTypeID(numeric_type, "accum", problem_space, problem)) { -+ if (numeric_type != op_desc.tile_description.math_instruction.element_accumulator) { -+ return false; -+ } -+ } -+ -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point to profile all operations in the manifest -+int OperationProfiler::profile_all( -+ Options const &options, -+ library::Manifest const &manifest, -+ DeviceContext &device_context) { -+ -+ ProblemSpace problem_space(arguments_, options.cmdline); -+ -+ // 1. Construct performance report -+ PerformanceReport report(options, problem_space.argument_names(), kind_); -+ -+ // 2. For each problem in problem space -+ ProblemSpace::Iterator problem_it = problem_space.begin(); -+ ProblemSpace::Iterator problem_end = problem_space.end(); -+ -+ bool continue_profiling = true, internal_error = false; -+ -+ // For each problem in problem space -+ for (; continue_profiling && problem_it != problem_end; ++problem_it) { -+ -+ ProblemSpace::Problem problem = problem_it.at(); -+ -+ report.next_problem(); -+ -+ // For each operation in manifest -+ for (auto const & operation_ptr : manifest) { -+ -+ library::Operation const *operation = operation_ptr.get(); -+ -+ auto min_cc = operation->description().tile_description.minimum_compute_capability; -+ auto max_cc = operation->description().tile_description.maximum_compute_capability; -+ -+ // Clear named allocations -+ device_context.free(); -+ -+ // Execute compatible cutlass operations if they satisfy the current device's compute capability -+ if (operation->description().kind == kind_ && -+ operation->description().provider == library::Provider::kCUTLASS && -+ options.device.compute_capability() >= min_cc && -+ options.device.compute_capability() <= max_cc) { -+ -+ std::string operation_name(operation->description().name); -+ -+ // Filter kernels by name -+ bool filtered_by_name = options.operation_names.empty(); -+ if (!filtered_by_name) { -+ -+ for (auto const & op_name : options.operation_names) { -+ if (find_string_matches_(op_name, operation_name)) { -+ filtered_by_name = true; -+ break; -+ } -+ } -+ } -+ -+ for (auto const & op_name : options.excluded_operation_names) { -+ if (find_string_matches_(op_name, operation_name)) { -+ filtered_by_name = false; -+ break; -+ } -+ } -+ -+ if (!filtered_by_name || !satisfies(operation->description(), problem_space, problem)) { -+ continue; -+ } -+ -+ // A. Initialize configuration -+ Status status = this->initialize_configuration( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ -+ if (status == Status::kErrorInternal) { -+ -+ // If there was an internal error, consume the CUDA error and move to the next operation. -+ (void)cudaGetLastError(); -+ -+ report.append_results(results_); -+ continue; -+ } -+ else if (status != Status::kSuccess) { -+ // If the workspace could not be initialized for any other reason, continue to -+ // the next operation. -+ continue; -+ } -+ -+ if (continue_profiling) { -+ -+ status = this->initialize_workspace( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ -+ if (status == Status::kErrorInternal) { -+ -+ // If there was an internal error, consume the CUDA error and move to the next operation. -+ (void)cudaGetLastError(); -+ -+ report.append_results(results_); -+ continue; -+ } -+ else if (status != Status::kSuccess) { -+ // If the workspace could not be initialized for any other reason, continue to -+ // the next operation. -+ continue; -+ } -+ } -+ -+ // -+ // Profile CUTLASS if it is enabled -+ // -+ -+ // B. Verify CUTLASS -+ -+ if (continue_profiling && options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ continue_profiling = this->verify_cutlass( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ report.append_results(results_); -+ results_.clear(); -+ continue; -+ } -+ -+ // -+ // C. Optionally save workspace -+ // -+ -+ if (options.verification.save_workspace == SaveWorkspace::kAlways) { -+ save_workspace( -+ device_context, -+ options, -+ operation->description(), -+ library::Provider::kCUTLASS); -+ } -+ -+ // -+ // D. Profile -+ // -+ -+ if (continue_profiling && options.profiling.enabled) { -+ -+ continue_profiling = this->profile( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ report.append_results(results_); -+ results_.clear(); -+ } -+ -+ if (!continue_profiling) { -+ break; -+ } -+ } -+ } -+ -+ return internal_error ? 1 : 0; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Sleep for a given duration in ms -+void OperationProfiler::sleep(int sleep_duration) { -+ if (sleep_duration) { -+ #ifdef __unix__ -+ usleep(sleep_duration * 1000); -+ #elif defined(_WIN32) || defined(WIN32) -+ SleepEx(sleep_duration, false); -+ #else -+ // sleep not supported -+ #endif -+ } -+} -+ -+ -+/// Compares tensors for equality -+Disposition OperationProfiler::compare_tensors( -+ Options const &options, -+ DeviceAllocation &experimental, -+ DeviceAllocation &reference, -+ int64_t count) { -+ -+ if (experimental.type() != reference.type()) { -+ return Disposition::kIncorrect; -+ } -+ -+ bool passed = false; -+ -+ if (count == 0) { -+ count = reference.capacity(); -+ } -+ -+ if (options.verification.epsilon == 0) { -+ -+ // bit-level equality -+ passed = DeviceAllocation::block_compare_equal( -+ experimental.type(), -+ experimental.data(), -+ reference.data(), -+ count); -+ } -+ else { -+ -+ // relative error function -+ passed = DeviceAllocation::block_compare_relatively_equal( -+ experimental.type(), -+ experimental.data(), -+ reference.data(), -+ count, -+ options.verification.epsilon, -+ options.verification.nonzero_floor); -+ } -+ -+ return passed ? Disposition::kPassed : Disposition::kIncorrect; -+} -+ -+/// Saves the workspace -+void OperationProfiler::save_workspace( -+ DeviceContext &device_context, -+ Options const &options, -+ library::OperationDescription const &desc, -+ library::Provider provider, -+ library::Provider verification_provider) { -+ -+ for (auto const & named_allocation : device_context) { -+ -+ DeviceAllocation *allocation = named_allocation.second; -+ -+ std::stringstream filename; -+ -+ filename << desc.name << "_" << library::to_string(provider) << "_"; -+ -+ if (verification_provider != library::Provider::kInvalid) { -+ filename << "verified_by_" << library::to_string(verification_provider) << "_"; -+ } -+ -+ filename << named_allocation.first + ".mat"; -+ -+ std::ofstream out(filename.str()); -+ -+ allocation->write_tensor_csv(out); -+ out << "\n"; -+ -+ if (options.report.verbose) { -+ std::cout << "wrote '" << filename.str() << "'" << std::endl; -+ } -+ } -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Method to profile a CUTLASS Operation -+Status OperationProfiler::profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace) { -+ -+ GpuTimer timer; -+ -+ // -+ // Optional sleep to limit power consumption and thermals -+ // -+ -+ sleep(options.profiling.sleep_duration); -+ -+ // -+ // Warmup loop -+ // -+ -+ Status status; -+ -+ for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { -+ -+ status = operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ -+ // -+ // Initialize GPU timer -+ // -+ -+ timer.start(); -+ -+ // -+ // Profiling loop -+ // -+ -+ int Iterations = options.profiling.iterations; -+ -+ int iteration = 0; -+ for (; iteration < Iterations; ++iteration) { -+ -+ status = operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ timer.stop_and_wait(); -+ -+ // -+ // Update performance result -+ // -+ -+ runtime = timer.duration(iteration); -+ -+ return status; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Sets operation description -+void OperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ library::OperationDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ set_argument(result, "op_class", problem_space, -+ library::to_string(operation_desc.tile_description.math_instruction.opcode_class)); -+ -+ set_argument(result, "accum", problem_space, -+ library::to_string(operation_desc.tile_description.math_instruction.element_accumulator)); -+ -+ set_argument(result, "cta_m", problem_space, operation_desc.tile_description.threadblock_shape.m()); -+ set_argument(result, "cta_n", problem_space, operation_desc.tile_description.threadblock_shape.n()); -+ set_argument(result, "cta_k", problem_space, operation_desc.tile_description.threadblock_shape.k()); -+ set_argument(result, "cluster_m", problem_space, operation_desc.tile_description.cluster_shape.m()); -+ set_argument(result, "cluster_n", problem_space, operation_desc.tile_description.cluster_shape.n()); -+ set_argument(result, "cluster_k", problem_space, operation_desc.tile_description.cluster_shape.k()); -+ set_argument(result, "stages", problem_space, operation_desc.tile_description.threadblock_stages); -+ set_argument(result, "warps_m", problem_space, operation_desc.tile_description.warp_count.m()); -+ set_argument(result, "warps_n", problem_space, operation_desc.tile_description.warp_count.n()); -+ set_argument(result, "warps_k", problem_space, operation_desc.tile_description.warp_count.k()); -+ set_argument(result, "inst_m", problem_space, operation_desc.tile_description.math_instruction.instruction_shape.m()); -+ set_argument(result, "inst_n", problem_space, operation_desc.tile_description.math_instruction.instruction_shape.n()); -+ set_argument(result, "inst_k", problem_space, operation_desc.tile_description.math_instruction.instruction_shape.k()); -+ set_argument(result, "min_cc", problem_space, operation_desc.tile_description.minimum_compute_capability); -+ set_argument(result, "max_cc", problem_space, operation_desc.tile_description.maximum_compute_capability); -+} -+ -+/// Helper -+void OperationProfiler::set_argument( -+ PerformanceResult &result, -+ char const *name, -+ ProblemSpace const &problem_space, -+ std::string const &value) { -+ -+ result.arguments.at(problem_space.argument_index(name)) = make_pair(std::string(name), value); -+} -+ -+void OperationProfiler::set_argument( -+ PerformanceResult &result, -+ char const *name, -+ ProblemSpace const &problem_space, -+ int64_t value) { -+ -+ result.arguments.at(problem_space.argument_index(name)) = make_pair(std::string(name), library::lexical_cast(value)); -+} -+ -+ -+/// finds string matches filter_string in operation_name -+bool OperationProfiler::find_string_matches_( -+ std::string const &filter_string, -+ std::string const &operation_name) { -+ // Returns true if all substrings appear in the operation_name in order -+ -+ // Split filter_string of the format "gemm*f32*nt" to tokens ["gemm", "f32", "nt"] -+ std::string item; -+ std::istringstream iss(filter_string); -+ std::vector filter_tokens; -+ while (std::getline(iss, item, '*')) { -+ filter_tokens.push_back(item); -+ } -+ -+ // Search filter_tokens in operation_name in order -+ size_t start = 0, idx = 0; -+ for(auto & token : filter_tokens) { -+ // Check if characters left to be parsed in operation_name -+ if (start < operation_name.length()) { -+ // Find token in operation_name[start:] -+ idx = operation_name.substr(start).find(token); -+ if (idx == std::string::npos) { -+ return false; -+ } -+ } -+ start += (idx + token.length()); -+ } -+ -+ // All tokens in filter_string found in operation_name -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/operation_profiler.h -new file mode 100644 -index 0000000..a2b0bdd ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/operation_profiler.h -@@ -0,0 +1,256 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "performance_result.h" -+#include "performance_report.h" -+#include "problem_space.h" -+#include "debug.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class OperationProfiler { -+public: -+ -+ -+protected: -+ // -+ // Data members -+ // -+ -+ /// Top-level operation kind -+ library::OperationKind kind_; -+ -+ /// Human readable description -+ std::string description_; -+ -+ /// Arguments parsed from command line -+ ArgumentDescriptionVector arguments_; -+ -+ /// List of providers used to verify and compare each result -+ ProviderVector verification_providers_; -+ -+ /// Model performance result initailized by the operation profiler with workload statistics -+ /// and reasonable default state. -+ PerformanceResult model_result_; -+ -+ /// Performance result vector constructed by profiling the operation -+ PerformanceResultVector results_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ OperationProfiler(); -+ -+ OperationProfiler( -+ Options const &options, -+ library::OperationKind kind, -+ ArgumentDescriptionVector const &arguments = ArgumentDescriptionVector(), -+ ProviderVector const & verification_providers = ProviderVector()); -+ -+ /// Destructor -+ virtual ~OperationProfiler(); -+ -+ /// Obtains the operation kind -+ library::OperationKind kind() const { return kind_; } -+ -+ /// Gets the schema description -+ std::string const &description() const; -+ -+ /// Returns a reference to the arguments -+ ArgumentDescriptionVector const &arguments() const { return arguments_; } -+ -+public: -+ -+ // -+ // Basic overrides -+ // -+ -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const =0; -+ -+ /// Entry point to profile all operations in the manifest -+ virtual int profile_all( -+ Options const &options, -+ library::Manifest const &manifest, -+ DeviceContext &device_context); -+ -+public: -+ -+ // -+ // Operation-specific phases of verification and profiling -+ // -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) = 0; -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) = 0; -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) = 0; -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) = 0; -+ -+public: -+ -+ // -+ // Static helpers -+ // -+ -+ /// Sleep for a given duration in ms -+ static void sleep(int sleep_duration); -+ -+ /// Returns true if the current operation description satisfies the problem space -+ static bool satisfies( -+ library::OperationDescription const &op_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Compares tensors for equality -+ static Disposition compare_tensors( -+ Options const &options, -+ DeviceAllocation &experimental, -+ DeviceAllocation &reference, -+ int64_t count = 0); -+ -+ static void save_workspace( -+ DeviceContext &device_context, -+ Options const &options, -+ library::OperationDescription const &desc, -+ library::Provider provider, -+ library::Provider verification_provider = library::Provider::kInvalid); -+ -+ /// Helper to set a performance result member -+ static void set_argument( -+ PerformanceResult &result, -+ char const *name, -+ ProblemSpace const &problem_space, -+ std::string const &value); -+ -+ /// Helper to set a performance result member -+ static void set_argument( -+ PerformanceResult &result, -+ char const *name, -+ ProblemSpace const &problem_space, -+ int64_t value); -+ -+protected: -+ -+ /// Sets operation description -+ static void initialize_result_( -+ PerformanceResult &result, -+ library::OperationDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Method to profile an initialized CUTLASS operation -+ virtual Status profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace); -+ -+private: -+ /// finds string matches filter_string in operation_name -+ bool find_string_matches_( -+ std::string const &filter_string, -+ std::string const &operation_name); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Vector of owning operation profilers -+using OperationProfilerVector = std::vector>; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/options.cu b/3rdparty/cutlass/tools/profiler/src/options.cu -new file mode 100644 -index 0000000..ea79a9d ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/options.cu -@@ -0,0 +1,815 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Command line options for performance test program -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/version.h" -+ -+#include "cutlass/library/util.h" -+ -+#include "options.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Newline and indent for help strings -+static char const *end_of_line = "\n "; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Device::Device(cutlass::CommandLine const &cmdline) { -+ -+ cmdline.get_cmd_line_argument("device", device, 0); -+ -+ cudaError_t result; -+ result = cudaGetDeviceProperties(&properties, device); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed for given device"); -+ } -+ -+ result = cudaSetDevice(device); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaSetDevice() failed for given device."); -+ } -+ -+ // Permit overriding the compute capability -+ if (cmdline.check_cmd_line_flag("compute-capability")) { -+ int cc = compute_capability(); -+ cmdline.get_cmd_line_argument("compute-capability", cc, cc); -+ properties.major = cc / 10; -+ properties.minor = cc % 10; -+ } -+ -+ // Permit overriding the L2 cache capacity -+ if (cmdline.check_cmd_line_flag("llc-capacity")) { -+ int llc_capacity = 0; -+ cmdline.get_cmd_line_argument("llc-capacity", llc_capacity, 0); -+ -+ if (llc_capacity >= 0) { -+ properties.l2CacheSize = (llc_capacity << 10); -+ } -+ } -+ -+} -+ -+void Options::Device::print_usage(std::ostream &out) const { -+ -+ out << "Device:\n" -+ << " --device= " -+ << " CUDA Device ID\n\n"; -+ -+ int device_count = 0; -+ cudaError_t result = cudaGetDeviceCount(&device_count); -+ -+ if (result != cudaSuccess) { -+ out << " \n"; -+ } -+ else { -+ -+ for (int idx = 0; idx < device_count; ++idx) { -+ cudaDeviceProp prop; -+ result = cudaGetDeviceProperties(&prop, idx); -+ if (result != cudaSuccess) { -+ out << " " << std::endl; -+ break; -+ } -+ else { -+ out << " [" << idx << "] - " -+ << prop.name << " - SM " << prop.major << "." << prop.minor << ", " -+ << prop.multiProcessorCount << " SMs @ " << (prop.clockRate / 1000.0) << " MHz, " -+ << "L2 cache: " << (prop.l2CacheSize >> 20) << " MB, Global Memory: " << (prop.totalGlobalMem >> 30) << " GB" -+ << std::endl; -+ } -+ } -+ out << "\n"; -+ } -+ -+ out -+ << " --compute-capability= " -+ << " Override the compute capability.\n\n" -+ -+ << " --llc-capacity= " -+ << " Capacity of last-level cache in kilobytes. If this is non-zero," << end_of_line -+ << " profiling phases cycle through different input tensors to induce" << end_of_line -+ << " capacity misses in the L2.\n\n"; -+ -+} -+ -+void Options::Device::print_device_info(std::ostream &out) const { -+ int num_devices; -+ cudaDeviceProp props; -+ -+ cudaError_t result; -+ result = cudaGetDeviceCount(&num_devices); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetNumDevices() failed"); -+ } -+ -+ out << "Device Name,SM,CUDA Device ID,Phy Device ID" << std::endl; -+ -+ for(int device = 0; device < num_devices; device++) { -+ result = cudaSetDevice(device); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaSetDevice() failed for device"); -+ } -+ -+ result = cudaGetDeviceProperties(&props, device); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties failed for device"); -+ } -+ -+ out << props.name << "," << props.major << props.minor << "," -+ << device << "," << props.multiGpuBoardGroupID << std::endl; -+ -+ } -+} -+ -+void Options::Device::print_options(std::ostream &out, int indent) const { -+ -+ out -+ << indent_str(indent) << "device: " << device << "\n" -+ << indent_str(indent) << "clock: " << int(double(properties.clockRate) / 1000.0) << "\n" -+ << indent_str(indent) << "compute-capability: " << compute_capability() << "\n"; -+} -+ -+/// Returns the compute capability of the listed device (e.g. 61, 60, 70, 75) -+int Options::Device::compute_capability() const { -+ return properties.major * 10 + properties.minor; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Initialization::Initialization(cutlass::CommandLine const &cmdline) { -+ -+ cmdline.get_cmd_line_argument("initialization-enabled", enabled, true); -+ -+ if (cmdline.check_cmd_line_flag("initialization-provider")) { -+ std::string str; -+ cmdline.get_cmd_line_argument("initialization-provider", str); -+ provider = library::from_string(str); -+ if (provider == library::Provider::kInvalid) { -+ enabled = false; -+ } -+ else if (provider != library::Provider::kReferenceHost && provider != library::Provider::kReferenceDevice) { -+ throw std::runtime_error("Unsupported intialization provider specified."); -+ } -+ } -+ else { -+ provider = library::Provider::kReferenceDevice; -+ } -+ -+ cmdline.get_cmd_line_argument("seed", seed, 2019); -+ -+ if (cmdline.check_cmd_line_flag("dist")) { -+ // user has set the data distribution (fix data distribution once set) -+ fix_data_distribution = true; -+ // set user provided data distribution -+ get_distribution(cmdline, "dist", data_distribution); -+ } -+ else { -+ // profiler choosen data distribution (allowed to change based on numeric types) -+ fix_data_distribution = false; -+ // set uniform data distribution with range [-4, 4] -+ data_distribution.set_uniform(-4, 4, 0); -+ } -+ -+ -+} -+ -+/// Gets the initial distribution -+void Options::Initialization::get_distribution( -+ cutlass::CommandLine const &args, -+ std::string const &arg, -+ cutlass::Distribution &dist) { -+ -+ struct { -+ const char *label; -+ cutlass::Distribution::Kind kind; -+ } distribution_kinds[] = { -+ {"uniform", cutlass::Distribution::Uniform}, -+ {"gaussian", cutlass::Distribution::Gaussian}, -+ {"identity", cutlass::Distribution::Identity}, -+ {"sequential", cutlass::Distribution::Sequential}, -+ {0, cutlass::Distribution::Invalid} -+ }; -+ -+ struct { -+ char const *label; -+ double *member; -+ } members[] = { -+ {"min", &dist.uniform.min}, -+ {"max", &dist.uniform.max}, -+ {"mean", &dist.gaussian.mean}, -+ {"stddev", &dist.gaussian.stddev}, -+ {"start", &dist.sequential.start}, -+ {"delta", &dist.sequential.delta}, -+ {0, 0} -+ }; -+ -+ using KeyValueVector = std::vector >; -+ -+ KeyValueVector values; -+ args.get_cmd_line_argument_pairs(arg.c_str(), values); -+ -+ // The parser expects the first token to be a string identifying the distribution type. -+ auto it = values.begin(); -+ if (it != values.end()) { -+ for (int i = 0; distribution_kinds[i].label; ++i) { -+ if (it->first == distribution_kinds[i].label) { -+ dist.kind = distribution_kinds[i].kind; -+ break; -+ } -+ } -+ ++it; -+ } -+ -+ // Subsequent key-value pairs update the named field of the distribution struct. -+ for (; it != values.end(); ++it) { -+ // Integer scaling factor - if < 0, no integer rounding is performed. -+ if ((it->first.compare("scale") == 0) && !it->second.empty()) { -+ std::stringstream ss; -+ ss << it->second; -+ ss >> dist.int_scale; -+ continue; // next token -+ } -+ -+ // Casts as integer without scaling -+ if (it->first.compare("integer") == 0) { -+ dist.int_scale = 0; -+ continue; // next token -+ } -+ -+ // initialize other members -+ for (int m = 0; members[m].label; ++m) { -+ if (it->first == members[m].label && !it->second.empty()) { -+ std::stringstream ss; -+ ss << it->second; -+ ss >> *(members[m].member); -+ } -+ } -+ } -+} -+ -+void Options::Initialization::print_usage(std::ostream &out) const { -+ -+ out << "Initialization:\n" -+ -+ << " --initialization= " -+ << " Enables initialization (default: true). If false, device memory is" << end_of_line -+ << " not initialized after allocation.\n\n" -+ -+ << " --initialization-provider= " -+ << " Selects initialization provider {host, device*}. (default: '*')\n\n" -+ -+ << " --dist= " -+ << " Data distribution of input tensors {uniform*, gaussian, identity, sequential}" << end_of_line -+ << " --dist=uniform,min:,max:,scale:" << end_of_line -+ << " --dist=gaussian,mean:,stddev:,scale:" << end_of_line -+ << " --dist=sequential,start:,delta:,scale:" << end_of_line -+ << " --dist=identity\n\n" -+ -+ << " --seed= " -+ << " Random number generator seed. Used to enforce deterministic" << end_of_line -+ << " initialization.\n\n"; -+ -+} -+ -+void Options::Initialization::print_options(std::ostream &out, int indent) const { -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Library::Library(cutlass::CommandLine const &cmdline) { -+ -+ algorithm_mode = AlgorithmMode::kDefault; -+ -+ if (cmdline.check_cmd_line_flag("library-algo-mode")) { -+ std::string mode = "default"; -+ cmdline.get_cmd_line_argument("library-algo-mode", mode); -+ algorithm_mode = from_string(mode); -+ } -+ -+ if (cmdline.check_cmd_line_flag("library-algos")) { -+ -+ // If algorithms are specified, override as kBest. -+ algorithm_mode = AlgorithmMode::kBest; -+ -+ std::vector tokens; -+ cmdline.get_cmd_line_arguments("library-algos", tokens); -+ -+ algorithms.reserve(tokens.size()); -+ -+ for (auto const & token : tokens) { -+ if (token.find(":")) { -+ // todo - tokenized range -+ } -+ else { -+ int algo; -+ std::stringstream ss; -+ -+ ss << token; -+ ss >> algo; -+ -+ algorithms.push_back(algo); -+ } -+ } -+ } -+} -+ -+void Options::Library::print_usage(std::ostream &out) const { -+ -+ out << "Library:\n" -+ -+ << " --library-algo-mode= " -+ << " Indicates algorithm mode used to call libraries such as cuBLAS and cuDNN.\n" -+ << " " -+ << " mode={default*,matching,best}\n\n" -+ -+ << " --library-algos= " -+ << " If --algorithm-mode=best, permits specifying a selection of algorithms.\n\n"; -+ -+} -+ -+void Options::Library::print_options(std::ostream &out, int indent) const { -+ -+ out -+ << indent_str(indent) << "library-algo-mode: " << to_string(algorithm_mode) << "\n" -+ << indent_str(indent) << "library-algos: "; -+ -+ int j = 0; -+ for (int x : algorithms) { -+ out << (j++ ? "," : "") << x; -+ } -+ -+ out << "\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Profiling::Profiling(cutlass::CommandLine const &cmdline) { -+ -+ cmdline.get_cmd_line_argument("workspace-count", workspace_count, 0); -+ cmdline.get_cmd_line_argument("warmup-iterations", warmup_iterations, 10); -+ cmdline.get_cmd_line_argument("profiling-iterations", iterations, 100); -+ cmdline.get_cmd_line_argument("sleep-duration", sleep_duration, 50); -+ cmdline.get_cmd_line_argument("profiling-enabled", enabled, true); -+ -+ if (cmdline.check_cmd_line_flag("providers")) { -+ -+ std::vector tokens; -+ cmdline.get_cmd_line_arguments("providers", tokens); -+ -+ providers.clear(); -+ -+ for (auto const &token : tokens) { -+ providers.push_back(library::from_string(token)); -+ } -+ } -+ else { -+ providers.push_back(library::Provider::kCUTLASS); -+ providers.push_back(library::Provider::kCUBLAS); -+ providers.push_back(library::Provider::kCUDNN); -+ } -+} -+ -+void Options::Profiling::print_usage(std::ostream &out) const { -+ -+ out << "Profiling:\n" -+ -+ << " --workspace-count= " -+ << " Number of discrete workspaces maintained to avoid cache-resident " << end_of_line -+ << " If zero (default), the amount is chosen for each workload based on " << end_of_line -+ << " capacity of the last-level cache.\n\n" -+ -+ << " --profiling-iterations= " -+ << " Number of iterations to profile each kernel. If zero, kernels" << end_of_line -+ << " are launched up to the profiling duration.\n\n" -+ -+ << " --warmup-iterations= " -+ << " Number of iterations to execute each kernel prior to profiling.\n\n" -+ -+ << " --sleep-duration= " -+ << " Number of ms to sleep between profiling periods (ms).\n\n" -+ -+ << " --profiling-enabled= " -+ << " If true, profiling is actually conducted.\n\n" -+ -+ ; -+} -+ -+void Options::Profiling::print_options(std::ostream &out, int indent) const { -+ -+ out -+ << indent_str(indent) << "profiling_iterations: " << iterations << "\n" -+ << indent_str(indent) << "sleep_duration: " << sleep_duration << "\n" -+ << indent_str(indent) << "profiling_enabled: " << enabled << "\n" -+ << indent_str(indent) << "providers: ["; -+ -+ int j = 0; -+ for (auto const & provider : providers) { -+ out << (j++ ? ", " : "") << library::to_string(provider); -+ } -+ out << "]\n"; -+} -+ -+/// Returns true if a provider is enabled -+bool Options::Profiling::provider_enabled(library::Provider provider) const { -+ return std::find(providers.begin(), providers.end(), provider) != providers.end(); -+} -+ -+/// Returns the index of a provider if its enabled -+size_t Options::Profiling::index(library::Provider provider) const { -+ size_t idx = 0; -+ for (auto const & x : providers) { -+ if (x == provider) { -+ return idx; -+ } -+ ++idx; -+ } -+ return idx; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Verification::Verification(cutlass::CommandLine const &cmdline) { -+ -+ cmdline.get_cmd_line_argument("verification-enabled", enabled, true); -+ -+ cmdline.get_cmd_line_argument("epsilon", epsilon, 0.05); -+ -+ cmdline.get_cmd_line_argument("nonzero-floor", nonzero_floor, 1.0 / 256.0); -+ -+ if (cmdline.check_cmd_line_flag("save-workspace")) { -+ std::string value; -+ cmdline.get_cmd_line_argument("save-workspace", value); -+ save_workspace = from_string(value); -+ } -+ else { -+ save_workspace = SaveWorkspace::kNever; -+ } -+ -+ if (cmdline.check_cmd_line_flag("verification-providers")) { -+ -+ std::vector tokens; -+ cmdline.get_cmd_line_arguments("verification-providers", tokens); -+ -+ providers.clear(); -+ -+ for (auto const &token : tokens) { -+ library::Provider provider = library::from_string(token); -+ if (provider != library::Provider::kInvalid) { -+ providers.push_back(provider); -+ } -+ } -+ } -+ else { -+ providers.push_back(library::Provider::kCUBLAS); -+ providers.push_back(library::Provider::kReferenceDevice); -+ providers.push_back(library::Provider::kCUDNN); -+ } -+} -+ -+void Options::Verification::print_usage(std::ostream &out) const { -+ -+ out << "Verification:\n" -+ -+ << " --verification-enabled= " -+ << " Whether to perform verification checks.\n\n" -+ -+ << " --epsilon= " -+ << " Error threshold. Setting to zero (default) requires" << end_of_line -+ << " bit-level equivalence.\n\n" -+ -+ << " --nonzero-floor= " -+ << " Results whose absolute value is less than this quantity" << end_of_line -+ << " are treated as zero for comparisons.\n\n" -+ -+ << " --save-workspace= " -+ << " Specifies when to save the GEMM inputs and results to the filesystem." << end_of_line -+ << " --save-workspace=never never save workspace (default)" << end_of_line -+ << " --save-workspace=incorrect save workspace for incorrect results" << end_of_line -+ << " --save-workspace=always always save workspace\n\n" -+ -+ << " --verification-providers= " -+ << " List of providers used to verify result. (default: '*')" << end_of_line -+ << " Gemm verification-providers {cublas*}" << end_of_line -+ << " Conv2d verification-providers {cudnn*, device*, host}" -+ << "\n\n"; -+} -+ -+void Options::Verification::print_options(std::ostream &out, int indent) const { -+ -+ out -+ << indent_str(indent) << "verification_enabled: " << enabled << "\n" -+ << indent_str(indent) << "epsilon: " << epsilon << "\n" -+ << indent_str(indent) << "save_workspace: " << to_string(save_workspace) << "\n" -+ << indent_str(indent) << "verification_providers: ["; -+ -+ int j = 0; -+ for (auto const & provider : providers) { -+ out << (j++ ? ", " : "") << library::to_string(provider); -+ } -+ out << "]\n"; -+} -+ -+/// Returns true if a provider is enabled -+bool Options::Verification::provider_enabled(library::Provider provider) const { -+ return std::find(providers.begin(), providers.end(), provider) != providers.end(); -+} -+ -+/// Returns the index of a provider if its enabled -+size_t Options::Verification::index(library::Provider provider) const { -+ size_t idx = 0; -+ for (auto const & x : providers) { -+ if (x == provider) { -+ return idx; -+ } -+ ++idx; -+ } -+ return idx; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Report::Report(cutlass::CommandLine const &cmdline) { -+ -+ cmdline.get_cmd_line_argument("append", append, false); -+ cmdline.get_cmd_line_argument("output", output_path); -+ cmdline.get_cmd_line_argument("junit-output", junit_output_path); -+ -+ if (cmdline.check_cmd_line_flag("tags")) { -+ cmdline.get_cmd_line_argument_pairs("tags", pivot_tags); -+ } -+ -+ cmdline.get_cmd_line_argument("report-not-run", report_not_run, false); -+ -+ cmdline.get_cmd_line_argument("verbose", verbose, true); -+ -+ cmdline.get_cmd_line_argument("sort-results", sort_results, false); -+} -+ -+void Options::Report::print_usage(std::ostream &out) const { -+ -+ out << "Report:\n" -+ -+ << " --append= " -+ << " If true, result is appended to possibly existing file. Otherwise, " << end_of_line -+ << " any existing file is overwritten.\n\n" -+ -+ << " --output= " -+ << " Path to output file for machine readable results. Operation kind and '.csv' is appended.\n\n" -+ -+ << " --junit-output= " -+ << " Path to junit output file for result reporting. Operation kind and '.junit.xml' is appended.\n\n" -+ -+ << " --report-not-run= " -+ << " If true, reports the status of all kernels including those that" << end_of_line -+ << " do not satisfy the given arguments.\n\n" -+ -+ << " --tags= " -+ << " Inserts leading columns in output table and uniform values for each" << end_of_line -+ << " column. Useful for generating pivot tables.\n\n" -+ -+ << " --verbose= " -+ << " Prints human-readable text to stdout. If false, nothing is written to stdout.\n\n" -+ -+ << " --sort-results= " -+ << " Sorts results (by flops-per-byte).\n\n"; -+} -+ -+void Options::Report::print_options(std::ostream &out, int indent) const { -+ -+ out -+ << indent_str(indent) << "append: " << append << "\n" -+ << indent_str(indent) << "output: " << output_path << "\n" -+ << indent_str(indent) << "junit-output: " << junit_output_path << "\n" -+ << indent_str(indent) << "report_not_run: " << report_not_run << "\n" -+ << indent_str(indent) << "tags:\n"; -+ -+ for (auto const & tag : pivot_tags) { -+ out << indent_str(indent + 1) << tag.first << ": " << tag.second << "\n"; -+ } -+ -+ out -+ << indent_str(indent) << "verbose: " << verbose << "\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::About::About(cutlass::CommandLine const &cmdline) { -+ help = cmdline.check_cmd_line_flag("help"); -+ version = cmdline.check_cmd_line_flag("version"); -+ device_info = cmdline.check_cmd_line_flag("device-info"); -+} -+ -+void Options::About::print_usage(std::ostream &out) const { -+ -+ out << "About:\n" -+ << " --version "; -+ -+ print_version(out); -+ -+ out << "\n"; -+} -+ -+void Options::About::print_version(std::ostream &out) { -+ out << "CUTLASS " << cutlass::getVersionString() -+ << " built on " << __DATE__ << " at " << __TIME__; -+ if (!cutlass::getGitRevision().empty()) out << " with commit " << cutlass::getGitRevision() << ""; -+} -+ -+void Options::About::print_options(std::ostream &out, int indent) const { -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Options(cutlass::CommandLine const &cmdline): -+ cmdline(cmdline), -+ device(cmdline), -+ initialization(cmdline), -+ library(cmdline), -+ profiling(cmdline), -+ verification(cmdline), -+ report(cmdline), -+ about(cmdline) { -+ -+ if (cmdline.check_cmd_line_flag("mode")) { -+ std::string token; -+ cmdline.get_cmd_line_argument("mode", token); -+ execution_mode = from_string(token); -+ } -+ else { -+ execution_mode = ExecutionMode::kProfile; -+ } -+ -+ // Enumerating kernels is equivalent to a dry run. -+ if (execution_mode == ExecutionMode::kEnumerate) { -+ execution_mode = ExecutionMode::kDryRun; -+ } -+ -+ if (cmdline.check_cmd_line_flag("operation")) { -+ std::string str; -+ cmdline.get_cmd_line_argument("operation", str); -+ operation_kind = library::from_string(str); -+ } -+ else if (cmdline.check_cmd_line_flag("function")) { -+ std::string str; -+ cmdline.get_cmd_line_argument("function", str); -+ operation_kind = library::from_string(str); -+ } -+ else { -+ operation_kind = library::OperationKind::kInvalid; -+ } -+ -+ if (cmdline.check_cmd_line_flag("operation_names")) { -+ cmdline.get_cmd_line_arguments("operation_names", operation_names); -+ } -+ else if (cmdline.check_cmd_line_flag("kernels")) { -+ cmdline.get_cmd_line_arguments("kernels", operation_names); -+ } -+ -+ if (cmdline.check_cmd_line_flag("ignore-kernels")) { -+ cmdline.get_cmd_line_arguments("ignore-kernels", excluded_operation_names); -+ } -+ -+ // Prevent launches on the device for anything other than CUTLASS operation -+ if (execution_mode == ExecutionMode::kTrace) { -+ initialization.provider = library::Provider::kReferenceHost; -+ verification.enabled = false; -+ profiling.enabled = false; -+ } -+} -+ -+void Options::print_usage(std::ostream &out) const { -+ -+ out -+ << "CUTLASS Profiler\n" -+ << "usage:\n\n" -+ << " cutlass_profiler [options]\n\n" -+ << " --help\n\n" -+ -+ << " --mode= " -+ << " Cutlass profiler execution mode." << end_of_line -+ << " --mode=profile regular verification and profiling (default)" << end_of_line -+ << " --mode=dry_run no kernels are launched or workspaces allocated" << end_of_line -+ << " --mode=enumerate lists all operation kind and operations" << end_of_line -+ << " --mode=trace executes a single device-side computation with" << end_of_line -+ << " no other kernel launches\n\n" -+ -+ << " --device-info " -+ << " Prints information on all GPUs present in the system\n\n" -+ -+ << " --operation= " -+ << " CUTLASS operation to profile.\n\n" -+ -+ << " --kernels= " -+ << " Filter operations by kernel names. For example, call all kernels with" << end_of_line -+ << " (\"s1688\" and \"nt\") or (\"s844\" and \"tn\" and \"align8\") in their" << end_of_line -+ << " operation name using --kernels=\"s1688*nt, s884*tn*align8\"\n\n" -+ -+ << " --ignore-kernels= " -+ << " Excludes kernels whose names match anything in this list.\n\n" -+ ; -+ -+ // -+ // Detailed options -+ // -+ -+ device.print_usage(out); -+ out << "\n"; -+ -+ initialization.print_usage(out); -+ out << "\n"; -+ -+ library.print_usage(out); -+ out << "\n"; -+ -+ profiling.print_usage(out); -+ out << "\n"; -+ -+ verification.print_usage(out); -+ out << "\n"; -+ -+ report.print_usage(out); -+ out << "\n"; -+ -+ about.print_usage(out); -+ out << "\n"; -+} -+ -+void Options::print_options(std::ostream &out) const { -+ -+ out -+ << "options:\n" -+ << " help: " << about.help << "\n" -+ << " mode: " << to_string(execution_mode) << "\n"; -+ -+ out -+ << " device:\n"; -+ device.print_options(out, 2); -+ -+ out -+ << " initialization:\n"; -+ initialization.print_options(out, 2); -+ -+ out -+ << " profiling:\n"; -+ profiling.print_options(out, 2); -+ -+ out -+ << " verification:\n"; -+ verification.print_options(out, 2); -+ -+ out -+ << " report:\n"; -+ report.print_options(out, 2); -+} -+ -+std::string Options::indent_str(int indent) { -+ return std::string(indent * 2, ' '); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/options.h b/3rdparty/cutlass/tools/profiler/src/options.h -new file mode 100644 -index 0000000..02edd9a ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/options.h -@@ -0,0 +1,323 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Command line options for performance test program -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/library/library.h" -+ -+#include "enumerated_types.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Global options -+class Options { -+public: -+ -+ /// Cublas and cuDNN options -+ struct Library { -+ -+ // -+ // Data members -+ // -+ -+ /// Algorithm mode -+ AlgorithmMode algorithm_mode; -+ -+ /// Algorithm enumerants -+ std::vector algorithms; -+ -+ // -+ // Methods -+ // -+ -+ Library(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ }; -+ -+ /// Options related to the selected device -+ struct Device { -+ -+ /// Device ID -+ int device; -+ -+ /// CUDA Device properties -+ cudaDeviceProp properties; -+ -+ /// Total memory allocation on device -+ size_t maximum_capacity; -+ -+ // -+ // Methods -+ // -+ -+ Device(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ void print_device_info(std::ostream &out) const; -+ -+ /// Returns the compute capability of the listed device (e.g. 61, 60, 70, 75) -+ int compute_capability() const; -+ }; -+ -+ /// Options related to initializing input tensors -+ struct Initialization { -+ -+ /// If true, data is initialized randomly. If false, no initialization is performed after -+ /// allocating tensors. -+ bool enabled; -+ -+ /// If true, data distribution is set by the user and is not allowed to change -+ /// If false, data distribution is allowed to change based on element_type (library::NumericTypeID) -+ bool fix_data_distribution; -+ -+ /// Data distribution for input tensors -+ Distribution data_distribution; -+ -+ /// Source of random tensor elements -+ library::Provider provider; -+ -+ /// Random number generator seed. -+ int seed; -+ -+ // -+ // Methods -+ // -+ -+ Initialization(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ -+ /// Helper to parse a Distribution object from the command line parser -+ static void get_distribution( -+ cutlass::CommandLine const &args, -+ std::string const &arg, -+ cutlass::Distribution &dist); -+ }; -+ -+ /// Options related to verification of the result -+ struct Verification { -+ -+ // -+ // Data members -+ // -+ -+ /// If true, kernels are verified before they are profiled -+ bool enabled; -+ -+ /// Relative error threshold - zero to require bit-level consistency -+ double epsilon; -+ -+ /// Values smaller than this are assumed to be zero -+ double nonzero_floor; -+ -+ /// List of providers used to verify each result -+ ProviderVector providers; -+ -+ /// Indicates when to save the workspace -+ SaveWorkspace save_workspace; -+ -+ // -+ // Methods -+ // -+ -+ Verification(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ -+ /// Returns true if a provider is enabled -+ bool provider_enabled(library::Provider provider) const; -+ -+ /// Returns the index of a provider if its enabled -+ size_t index(library::Provider provider) const; -+ }; -+ -+ /// Options related to profiling -+ struct Profiling { -+ -+ /// Number of workspaces to rotate through to avoid cache-resident working sets -+ int workspace_count; -+ -+ /// Number of iterations to warmup each kernel prior to profiling -+ int warmup_iterations; -+ -+ /// Number of iterations to profile each kernel - if 0, kernels are launched up to the profiling duration -+ int iterations; -+ -+ /// Number of ms to sleep between profiling periods (ms) -+ int sleep_duration; -+ -+ /// If true, profiling is actually conducted. -+ bool enabled; -+ -+ /// List of providers of each functionality to be profiled -+ ProviderVector providers; -+ -+ // -+ // Methods -+ // -+ -+ Profiling(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ -+ /// Returns true if a provider is enabled -+ bool provider_enabled(library::Provider provider) const; -+ -+ /// Returns the index of a provider if its enabled -+ size_t index(library::Provider provider) const; -+ }; -+ -+ /// Options related to reporting -+ struct Report { -+ -+ /// If true, result is appended to possibly existing file -+ bool append; -+ -+ /// Path to a file containing results -+ std::string output_path; -+ -+ /// Path to a file containing junit xml results -+ std::string junit_output_path; -+ -+ /// Sequence of tags to attach to each result -+ std::vector> pivot_tags; -+ -+ /// If true, reports status of all kernels including those that were -+ /// not run for the given argumetns -+ bool report_not_run; -+ -+ /// Prints human-readable text to stdout. If false, nothing is written to stdout -+ bool verbose; -+ -+ /// Sort results by (currently by flops-per-byte) -+ bool sort_results; -+ -+ // -+ // Methods -+ // -+ -+ Report(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ }; -+ -+ /// Options related to printing usage and version information -+ struct About { -+ -+ /// If true, usage is printed and the program ends. -+ bool help; -+ -+ /// Prints version string -+ bool version; -+ -+ /// Print information about devices -+ bool device_info; -+ -+ // -+ // Methods -+ // -+ -+ About(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ -+ static void print_version(std::ostream &out); -+ }; -+ -+public: -+ -+ // -+ // Data members -+ // -+ -+ /// Top-level execution mode -+ ExecutionMode execution_mode; -+ -+ /// Name of math function to profile -+ library::OperationKind operation_kind; -+ -+ /// Vector of operation name substrings -+ std::vector operation_names; -+ -+ /// Vector of operation name substrings -+ std::vector excluded_operation_names; -+ -+ -+ // -+ // Detailed configuration options -+ // -+ -+ /// Configuration -+ CommandLine cmdline; -+ Device device; -+ Initialization initialization; -+ Library library; -+ Verification verification; -+ Profiling profiling; -+ Report report; -+ About about; -+ -+public: -+ -+ Options(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out) const; -+ -+ static std::string indent_str(int indent); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/performance_report.h b/3rdparty/cutlass/tools/profiler/src/performance_report.h -new file mode 100644 -index 0000000..b74d069 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/performance_report.h -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Class performing output during profiling -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+// CUTLASS Profiler includes -+#include "options.h" -+#include "enumerated_types.h" -+#include "performance_result.h" -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+class PerformanceReport { -+private: -+ -+ /// Reference to options -+ Options const &options_; -+ -+ /// Operation kind -+ library::OperationKind op_kind_; -+ -+ /// Operation file name containing performance report of op_kind -+ std::string op_file_name_; -+ -+ /// Output file containing results -+ std::ofstream output_file_; -+ -+ /// Operation file name containing junit performance report of op_kind -+ std::string op_junit_file_name_; -+ -+ /// Output file containing junit results -+ std::ofstream junit_output_file_; -+ -+ /// Flag indicating the performance report is valid -+ bool good_; -+ -+ /// Vector of argument names -+ std::vector argument_names_; -+ -+ /// Counter uniquely identifying problem within the report -+ size_t problem_index_; -+ -+ /// Collection of all results -+ PerformanceResultVector concatenated_results_; -+ -+public: -+ -+ PerformanceReport(Options const &options, std::vector const &argument_names, library::OperationKind const &op_kind); -+ ~PerformanceReport(); -+ -+ bool good() const { return good_; } -+ -+ void next_problem(); -+ void append_result(PerformanceResult result); -+ void sort_results(PerformanceResultVector &results); -+ void append_results(PerformanceResultVector const &results); -+ -+public: -+ -+ /// Prints the CSV header -+ std::ostream & print_csv_header_(std::ostream &out); -+ -+ /// Prints the CSV -+ std::ostream & print_result_csv_(std::ostream &out, PerformanceResult const &result); -+ -+ /// @defgroup jUnit Result Generation -+ /// Functions related to generation of the jUnit results -+ /// @{ -+ -+ std::ostream & print_junit_header_(std::ostream &out); -+ std::ostream & print_junit_result_(std::ostream &out, PerformanceResult const &result); -+ std::ostream & print_junit_footer_(std::ostream &out); -+ -+ /// @} -+ -+ /// Prints the result in human readable form -+ std::ostream & print_result_pretty_( -+ std::ostream &out, -+ PerformanceResult const &result, -+ bool use_shell_coloring = true); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/performance_result.cu b/3rdparty/cutlass/tools/profiler/src/performance_result.cu -new file mode 100644 -index 0000000..810e261 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/performance_result.cu -@@ -0,0 +1,61 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+// CUTLASS Profiler includes -+#include "enumerated_types.h" -+#include "performance_result.h" -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/performance_result.h b/3rdparty/cutlass/tools/profiler/src/performance_result.h -new file mode 100644 -index 0000000..c714e02 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/performance_result.h -@@ -0,0 +1,128 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+// CUTLASS Profiler includes -+#include "enumerated_types.h" -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Performance result object -+struct PerformanceResult { -+ -+ /// Index of problem -+ size_t problem_index; -+ -+ /// library::Provider -+ library::Provider provider; -+ -+ /// Operation kind -+ library::OperationKind op_kind; -+ -+ /// CUTLASS status result from kernels (success or failure) -+ // Status does information on verification -+ Status status; -+ -+ /// Outcome of verification (worst case verification result) -+ Disposition disposition; -+ -+ /// Outcome of verification (all verification results) -+ DispositionMap verification_map; -+ -+ /// Operation name -+ std::string operation_name; -+ -+ /// Stringified vector of argument values -+ std::vector > arguments; -+ -+ /// Number of bytes read or written -+ int64_t bytes; -+ -+ /// Number of DL flops performed by the math function -+ int64_t flops; -+ -+ /// Average runtime in ms -+ double runtime; -+ -+ // -+ // Members -+ // -+ -+ /// Ctor -+ PerformanceResult(): -+ problem_index(0), -+ op_kind(library::OperationKind::kInvalid), -+ provider(library::Provider::kInvalid), -+ disposition(Disposition::kNotRun), -+ status(Status::kInvalid), -+ bytes(0), -+ flops(0), -+ runtime(0) -+ { } -+ -+ /// Returns true if the runtime is valid -+ bool good() const { -+ return runtime > 0; -+ } -+ -+ /// Math throughput in units of GFLOP/s -+ double gflops_per_sec() const { -+ return double(flops) / runtime / 1.0e6; -+ } -+ -+ /// memory bandwidth in units of GiB/s -+ double gbytes_per_sec() const { -+ return double(bytes) / double(1 << 30) / runtime * 1000.0; -+ } -+ -+}; -+ -+using PerformanceResultVector = std::vector; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/problem_space.h b/3rdparty/cutlass/tools/profiler/src/problem_space.h -new file mode 100644 -index 0000000..4e102e6 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/problem_space.h -@@ -0,0 +1,1005 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+ -+ "Any sufficiently complicated C or Fortran program contains an ad-hoc, informally-specified, -+ bug-ridden, slow implementation of half of Common Lisp." -+ -+ - Greenspun's Tenth Rule of Programming -+ -+ -+ cutlass::profiler::ProblemSpace defines a set of data structures which represent the Cartesian -+ product of sequences defined by integer ranges, lists of scalars, and sets of enumerated types. -+ -+ These permit a single invocation of the CUTLASS Profiler to iterate over a large set of problems, -+ verify and profile various operations when they are compatible with the command line, and -+ construct data tables of results that are convenient inputs to post processing in Excel or Pandas. -+ -+ By executing multiple problems per invocation, startup overheads may be amortized across many -+ kernel launches. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Utility includes -+#include "cutlass/util/command_line.h" -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+ -+// Profiler includes -+#include "enumerated_types.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the argument schema -+struct ArgumentDescription { -+ -+ /// Type of argument -+ ArgumentTypeID type; -+ -+ /// Prioritized array of aliases used in command line parsing -+ std::vector aliases; -+ -+ /// Description of argument -+ std::string description; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ ArgumentDescription(): -+ type(ArgumentTypeID::kInvalid) { } -+ -+ /// Constructor with aliases -+ ArgumentDescription( -+ ArgumentTypeID type_, -+ std::vector const &aliases_, -+ std::string const &description_ -+ ): -+ type(type_), aliases(aliases_), description(description_) { } -+}; -+ -+/// Vector of arguments -+using ArgumentDescriptionVector = std::vector; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Base class for kernel arguments -+struct KernelArgument { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Value base class -+ struct Value { -+ -+ KernelArgument const *argument; -+ bool not_null; -+ -+ // -+ // Methods -+ // -+ -+ Value( -+ KernelArgument const *argument_ = nullptr, -+ bool not_null_ = true -+ ): argument(argument_), not_null(not_null_) { } -+ -+ virtual ~Value() { } -+ -+ virtual std::ostream &print(std::ostream &out) const =0; -+ }; -+ -+ /// Abstract base class to iterate over values within arguments -+ struct ValueIterator { -+ -+ /// Indicates type of kernel argument -+ KernelArgument const *argument; -+ -+ /// If the iterator points to an argument that is null, it needs to be distinguished -+ /// from end. -+ bool null_argument; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a value iterator - no methods are valid if argument_ == nullptr -+ ValueIterator( -+ KernelArgument const *argument_ = nullptr, -+ bool null_argument_ = false): -+ argument(argument_), null_argument(null_argument_) { -+ -+ if (!argument_->not_null()) { -+ null_argument = true; -+ } -+ } -+ -+ virtual ~ValueIterator() { } -+ -+ /// Advances to next point in range -+ virtual void operator++() = 0; -+ -+ /// Compares against another value iterator - must be of the same KernelArgument type -+ virtual bool operator==(ValueIterator const &it) const = 0; -+ -+ /// Returns a unique_ptr object pointing to a newly created value object -+ virtual std::unique_ptr at() const = 0; -+ -+ /// Gets the type of the iterator -+ ArgumentTypeID type() const { -+ return argument->description->type; -+ } -+ -+ /// Helper to compute inequality -+ bool operator!=(ValueIterator const &it) const { -+ return !(*this == it); -+ } -+ -+ std::ostream &print(std::ostream &out) const; -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Describes the argument -+ ArgumentDescription const *description; -+ -+ /// Parent node -+ KernelArgument *parent; -+ -+ /// Sequence in which the kernel argument is to be iterated over. -+ /// Smaller means faster changing. -1 is don't care -+ int ordinal; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ KernelArgument( -+ ArgumentDescription const *description_ = nullptr, -+ KernelArgument *parent_ = nullptr, -+ int ordinal_ = -1 -+ ): description(description_), parent(parent_), ordinal(ordinal_) { } -+ -+ virtual ~KernelArgument(); -+ -+ /// Returns true if the kernel argument iself is empty -+ virtual bool not_null() const =0; -+ -+ /// Returns a string name for debugging -+ std::string qualified_name() const { -+ if (description) { -+ if (description->aliases.empty()) { -+ return ""; -+ } -+ return description->aliases.front(); -+ } -+ return ""; -+ } -+ -+ virtual std::unique_ptr begin() const =0; -+ virtual std::unique_ptr end() const =0; -+}; -+ -+using KernelArgumentVector = std::vector>; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a scalar argument type as a string that is lexically cast to the appropriate kernel -+/// type. -+struct ScalarArgument : public KernelArgument { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Value type -+ struct ScalarValue : public KernelArgument::Value { -+ -+ std::string value; -+ -+ // -+ // Methods -+ // -+ -+ ScalarValue( -+ std::string const &value_ = "", -+ ScalarArgument const *argument = nullptr, -+ bool not_null_ = true -+ ); -+ -+ virtual std::ostream &print(std::ostream &out) const; -+ }; -+ -+ using ValueCollection = std::vector; -+ -+ /// Abstract base class to iterate over values within arguments -+ struct ScalarValueIterator : public KernelArgument::ValueIterator { -+ -+ // -+ // Data members -+ // -+ -+ ValueCollection::const_iterator value_it; -+ -+ // -+ // Methods -+ // -+ -+ ScalarValueIterator(ScalarArgument const *argument = nullptr); -+ -+ virtual void operator++(); -+ virtual bool operator==(ValueIterator const &it) const; -+ -+ /// Gets the value pointed to -+ virtual std::unique_ptr at() const; -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Set of posible values -+ ValueCollection values; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ ScalarArgument( -+ ArgumentDescription const *description -+ ): -+ KernelArgument(description) { } -+ -+ virtual bool not_null() const { -+ return !values.empty(); -+ } -+ -+ virtual std::unique_ptr begin() const; -+ virtual std::unique_ptr end() const; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Closed range supporting additive increment -+struct Range { -+ -+ // -+ // Type definitions -+ // -+ -+ enum class Mode { -+ kSequence, -+ kRandom, -+ kRandomLog2, -+ kInvalid -+ }; -+ -+ struct Iterator { -+ -+ int64_t value; -+ int64_t increment; -+ Range const *range; -+ -+ // -+ // Methods -+ // -+ -+ Iterator( -+ int64_t value_ = 0, -+ int64_t increment_ = 1, -+ Range const *range_ = nullptr -+ ): -+ value(value_), increment(increment_), range(range_) { } -+ -+ Iterator & operator++() { -+ value += increment; -+ return *this; -+ } -+ -+ Iterator operator++(int) { -+ Iterator self(*this); -+ ++(*this); -+ return self; -+ } -+ -+ bool operator==(Iterator const &it) const { -+ return value == it.value; -+ } -+ -+ bool operator!=(Iterator const &it) const { -+ return !(*this == it); -+ } -+ -+ static int64_t round(int64_t value, int64_t divisible) { -+ int64_t rem = (value % divisible); -+ -+ // Round either up or down -+ if (rem > divisible / 2) { -+ value += (divisible - rem); -+ } -+ else { -+ value -= rem; -+ } -+ -+ return value; -+ } -+ -+ int64_t at() const { -+ if (!range) { -+ return value; -+ } -+ -+ switch (range->mode) { -+ case Mode::kSequence: return value; -+ -+ case Mode::kRandom: { -+ double rnd = double(range->minimum) + -+ double(std::rand()) / double(RAND_MAX) * (double(range->maximum) - double(range->minimum)); -+ -+ int64_t value = int64_t(rnd); -+ -+ return round(value, range->divisible); -+ } -+ break; -+ -+ case Mode::kRandomLog2: { -+ double lg2_minimum = std::log(double(range->minimum)) / std::log(2.0); -+ double lg2_maximum = std::log(double(range->maximum)) / std::log(2.0); -+ double rnd = lg2_minimum + double(std::rand()) / double(RAND_MAX) * (lg2_maximum - lg2_minimum); -+ -+ int64_t value = int64_t(std::pow(2.0, rnd)); -+ -+ return round(value, range->divisible); -+ } -+ break; -+ default: break; -+ } -+ return value; -+ } -+ -+ int64_t operator*() const { -+ return at(); -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ int64_t first; ///< first element in range -+ int64_t last; ///< last element in range -+ int64_t increment; ///< additive increment between values -+ -+ Mode mode; ///< mode selection enables alternative values -+ int64_t minimum; ///< minimum value to return -+ int64_t maximum; ///< maximum value to return -+ int64_t divisible; ///< rounds value down to an integer multiple of this value -+ -+ // -+ // Methods -+ // -+ -+ /// Default constructor - range acts as a scalar -+ Range(int64_t first_ = 0): first(first_), last(first_), increment(1), mode(Mode::kSequence), minimum(0), maximum(0), divisible(1) { } -+ -+ /// Range acts as a range -+ Range( -+ int64_t first_, -+ int64_t last_, -+ int64_t increment_ = 1, -+ Mode mode_ = Mode::kSequence, -+ int64_t minimum_ = 0, -+ int64_t maximum_ = 0, -+ int64_t divisible_ = 1 -+ ): first(first_), last(last_), increment(increment_), mode(mode_), minimum(minimum_), maximum(maximum_), divisible(divisible_) { -+ -+ // Helpers to avoid constructing invalid ranges -+ if (increment > 0) { -+ if (last < first) { -+ std::swap(last, first); -+ } -+ } -+ else if (increment < 0) { -+ if (first < last) { -+ std::swap(last, first); -+ } -+ } -+ else if (last != first) { -+ last = first; -+ increment = 1; -+ } -+ } -+ -+ /// Helper to construct a sequence range -+ static Range Sequence(int64_t first_, int64_t last_, int64_t increment_ = 1) { -+ return Range(first_, last_, increment_, Mode::kSequence); -+ } -+ -+ /// Helper to construct a range that is a random distribution -+ static Range Random(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { -+ return Range(1, count_, 1, Mode::kRandom, minimum_, maximum_, divisible_); -+ } -+ -+ /// Helper to construct a range that is a random distribution over a log scale -+ static Range RandomLog2(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { -+ return Range(1, count_, 1, Mode::kRandomLog2, minimum_, maximum_, divisible_); -+ } -+ -+ /// Returns an iterator to the first element within the range -+ Iterator begin() const { -+ return Iterator(first, increment, this); -+ } -+ -+ /// Returns an iterator to the first element *after* the range -+ Iterator end() const { -+ return Iterator(first + ((last - first)/increment + 1) * increment, increment, this); -+ } -+}; -+ -+/// Integer-valued argument - represented as a list of integer-valued ranges -+struct IntegerArgument : public KernelArgument { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Value type -+ struct IntegerValue : public KernelArgument::Value { -+ -+ int64_t value; -+ -+ // -+ // Methods -+ // -+ -+ IntegerValue( -+ int64_t value_ = 0, -+ IntegerArgument const *argument_ = nullptr, -+ bool not_null_ = true -+ ); -+ -+ /// Pretty printer for debugging -+ virtual std::ostream &print(std::ostream &out) const; -+ }; -+ -+ /// Collection of ranges represent the IntegerArgument's state -+ using RangeCollection = std::vector; -+ -+ /// Abstract base class to iterate over values within arguments -+ struct IntegerValueIterator : public KernelArgument::ValueIterator { -+ -+ // -+ // Data members -+ // -+ -+ RangeCollection::const_iterator range_it; -+ Range::Iterator value_it; -+ -+ // -+ // Methods -+ // -+ -+ IntegerValueIterator(); -+ IntegerValueIterator(IntegerArgument const *argument); -+ -+ virtual void operator++(); -+ virtual bool operator==(ValueIterator const &it) const; -+ -+ /// Gets the value pointed to -+ virtual std::unique_ptr at() const; -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Set of posible values -+ RangeCollection ranges; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ IntegerArgument( -+ ArgumentDescription const *description -+ ): -+ KernelArgument(description) { } -+ -+ virtual bool not_null() const { -+ bool _not_null = !ranges.empty(); -+ return _not_null; -+ } -+ -+ virtual std::unique_ptr begin() const; -+ virtual std::unique_ptr end() const; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure defining the data type of tensors -+struct TensorArgument : public KernelArgument { -+ -+ // -+ // Type definitions -+ // -+ -+ struct TensorDescription { -+ -+ /// Data type of elements -+ library::NumericTypeID element; -+ -+ /// Layout definition -+ library::LayoutTypeID layout; -+ -+ /// Computed extent -+ std::vector extent; -+ -+ /// Enables directly specifying stride value used to size tensor -+ std::vector stride; -+ -+ // -+ // Methods -+ // -+ -+ TensorDescription( -+ library::NumericTypeID element_ = library::NumericTypeID::kUnknown, -+ library::LayoutTypeID layout_ = library::LayoutTypeID::kUnknown, -+ std::vector extent_ = std::vector(), -+ std::vector stride_ = std::vector() -+ ): -+ element(element_), layout(layout_), extent(extent_), stride(stride_) {} -+ }; -+ -+ using ValueCollection = std::vector; -+ -+ /// Value structure -+ struct TensorValue : public KernelArgument::Value { -+ -+ TensorDescription desc; -+ -+ // -+ // Methods -+ // -+ -+ TensorValue( -+ TensorDescription const &desc_ = TensorDescription(), -+ TensorArgument const *argument_ = nullptr, -+ bool not_null_ = true -+ ); -+ -+ /// Pretty printer for debugging -+ virtual std::ostream &print(std::ostream &out) const; -+ }; -+ -+ /// Abstract base class to iterate over values within arguments -+ struct TensorValueIterator : public KernelArgument::ValueIterator { -+ -+ // -+ // Data members -+ // -+ -+ ValueCollection::const_iterator value_it; -+ -+ // -+ // Methods -+ // -+ -+ TensorValueIterator(TensorArgument const *argument_); -+ -+ virtual void operator++(); -+ virtual bool operator==(ValueIterator const &it) const; -+ -+ /// Gets the value pointed to -+ virtual std::unique_ptr at() const; -+ }; -+ -+ /// Set of possible values -+ ValueCollection values; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ TensorArgument( -+ ArgumentDescription const *description -+ ): -+ KernelArgument(description) { } -+ -+ virtual bool not_null() const { -+ return !values.empty(); -+ } -+ -+ virtual std::unique_ptr begin() const; -+ virtual std::unique_ptr end() const; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Numeric data type -+struct EnumeratedTypeArgument : public KernelArgument { -+ -+ // -+ // Type definitions -+ // -+ -+ struct EnumeratedTypeValue : public KernelArgument::Value { -+ -+ /// Data type of element -+ std::string element; -+ -+ // -+ // Methods -+ // -+ -+ EnumeratedTypeValue( -+ std::string const &element_ = std::string(), -+ EnumeratedTypeArgument const *argument_ = nullptr, -+ bool not_null_ = true -+ ); -+ -+ /// Pretty printer for debugging -+ virtual std::ostream &print(std::ostream &out) const; -+ }; -+ -+ using ValueCollection = std::vector; -+ -+ /// Abstract base class to iterate over values within arguments -+ struct EnumeratedTypeValueIterator : public KernelArgument::ValueIterator { -+ -+ // -+ // Data members -+ // -+ -+ ValueCollection::const_iterator value_it; -+ -+ // -+ // Methods -+ // -+ -+ EnumeratedTypeValueIterator(EnumeratedTypeArgument const *argument_ = nullptr); -+ -+ virtual void operator++(); -+ virtual bool operator==(ValueIterator const &it) const; -+ -+ /// Gets the value pointed to -+ virtual std::unique_ptr at() const; -+ }; -+ -+ // -+ // Data members -+ // -+ -+ ValueCollection values; -+ -+ // -+ // Members -+ // -+ -+ /// Default ctor -+ EnumeratedTypeArgument(ArgumentDescription const *description): -+ KernelArgument(description) {} -+ -+ virtual bool not_null() const { -+ return !values.empty(); -+ } -+ -+ virtual std::unique_ptr begin() const; -+ virtual std::unique_ptr end() const; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Object storing the space argument values -+class ProblemSpace { -+public: -+ -+ /// Tuple of arguments -+ using Problem = std::vector>; -+ -+ /// Type used to iterator over things -+ using IteratorVector = std::vector>; -+ -+ /// Iterates over points in the design space -+ class Iterator { -+ private: -+ -+ /// One iterator per argument -+ IteratorVector iterators; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ explicit Iterator(); -+ Iterator(ProblemSpace const &problem_space); -+ Iterator(Iterator &&it); -+ -+ // Rule of three -+ Iterator(Iterator const &) = delete; -+ Iterator &operator=(Iterator const &it) = delete; -+ ~Iterator() = default; -+ -+ /// Pre-increment - advances to next point in argument range -+ void operator++(); -+ -+ /// Gets the current argument value -+ Problem at() const; -+ -+ /// Moves iterator to end -+ void move_to_end(); -+ -+ /// Equality operator -+ bool operator==(Iterator const &it) const; -+ -+ /// Inequality operator -+ bool operator!=(Iterator const &it) const { -+ return !(*this == it); -+ } -+ -+ /// Helper to call at() method -+ Problem operator*() const { -+ return at(); -+ } -+ -+ /// Helper to print iterator state -+ std::ostream & print(std::ostream &out) const; -+ -+ private: -+ -+ /// Helper for recursively constructing iterators -+ void construct_(KernelArgument const *argument); -+ }; -+ -+public: -+ -+ // -+ // Data members -+ // -+ -+ KernelArgumentVector arguments; -+ -+ /// Map of argument names to their position within the argument vector -+ std::unordered_map argument_index_map; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ ProblemSpace() {} -+ -+ /// Constructs a problem space from a vector of arguments. This vector must outlive -+ /// the ProblemSpace object, which stores pointers to objects within the -+ /// ArgumentDescriptionVector. -+ ProblemSpace(ArgumentDescriptionVector const &schema, CommandLine const &cmdline); -+ -+ Iterator begin() const; // returns an iterator to the first point in the range -+ Iterator end() const; // returns an iterator to the first point after the range -+ -+ /// Returns the index of an argument by name -+ size_t argument_index(char const *name) const; -+ -+ /// Gets all argument names as an ordered vector -+ std::vector argument_names() const; -+ -+ /// Returns the number of dimensions of the problem space -+ size_t rank() const { return arguments.size(); } -+ -+private: -+ -+ /// Helper for recursively cloning -+ void clone_( -+ KernelArgumentVector &kernel_args, -+ ArgumentDescription const *arg_desc); -+ -+ /// Parses command line argument -+ void parse_( -+ KernelArgument *arg, -+ CommandLine const &cmdline); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Lexically casts an argument to an int if it is defined. Returns true if not null. -+bool arg_as_int(int &int_value, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_int(int64_t &int_value, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_int( -+ int &int_value, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_int( -+ int64_t &int_value, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_NumericTypeID(library::NumericTypeID &numeric_type, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_NumericTypeID( -+ library::NumericTypeID &numeric_type, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_LayoutTypeID(library::LayoutTypeID &layout_type, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_LayoutTypeID( -+ library::LayoutTypeID &layout_type, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_OpcodeClassID(library::OpcodeClassID &opcode_class, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_OpcodeClassID( -+ library::OpcodeClassID &opcode_class, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_SplitKModeID(library::SplitKMode &split_k_mode, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_SplitKModeID( -+ library::SplitKMode &split_k_mode, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_ConvModeID(library::ConvModeID &conv_mode, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_ConvModeID( -+ library::ConvModeID &conv_mode, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_IteratorAlgorithmID(library::IteratorAlgorithmID &iterator_algorithm, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_IteratorAlgorithmID( -+ library::IteratorAlgorithmID &iterator_algorithm, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_ProviderID(library::Provider &provider, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_ProviderID( -+ library::Provider &provider, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. -+bool arg_as_scalar( -+ std::vector &bytes, -+ library::NumericTypeID numeric_type, -+ KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. -+bool arg_as_scalar( -+ std::vector &bytes, -+ library::NumericTypeID numeric_type, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Returns true if a tensor description satisfies a `tensor` value -+bool tensor_description_satisfies( -+ library::TensorDescription const &tensor_desc, -+ TensorArgument::TensorValue const *value_ptr); -+ -+/// Returns true if a tensor description satisfies a `tensor` value -+bool tensor_description_satisfies( -+ library::TensorDescription const &tensor_desc, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ -+/// Returns true if a conv kind satisfies the value -+bool conv_kind_satisfies( -+ library::ConvKind const &conv_kind, -+ EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); -+ -+/// Returns true if a conv kind satisfies the value -+bool conv_kind_satisfies( -+ library::ConvKind const &conv_kind, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Returns true if a iterator algorithm satisfies the value -+bool iterator_algorithm_satisfies( -+ library::IteratorAlgorithmID const &iterator_algorithm, -+ EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); -+ -+/// Returns true if a iterator algorithm satisfies the value -+bool iterator_algorithm_satisfies( -+ library::IteratorAlgorithmID const &iterator_algorithm, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/rank_2k_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/rank_2k_operation_profiler.cu -new file mode 100644 -index 0000000..2c2f236 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/rank_2k_operation_profiler.cu -@@ -0,0 +1,727 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+ -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cublas_helpers.h" -+#include "rank_2k_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+Rank2KOperationProfiler::Rank2KOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kRank2K, -+ { -+ {ArgumentTypeID::kEnumerated, {"rank_k_kind"}, "Variant of RankK (universal)"}, -+ {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the RankK problem space"}, -+ {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the RankK problem space"}, -+ {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, -+ {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, -+ {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, -+ {ArgumentTypeID::kEnumerated, {"fill_mode"}, "Fill Mode for RankK kernel (lower or upper)"}, -+ {ArgumentTypeID::kEnumerated, {"blas_mode"}, "Blas Mode for RankK kernel (symmetric or hermitian)"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of RankK computed in one batch"}, -+ }, -+ { library::Provider::kCUBLAS} -+ ) { -+ description_ = " Rank 2k Update. D = alpha * (A*B^T + B*A^T) + beta * C (symmetric) or D = alpha * (A*B^H+B*A^H) + beta * C (hermitian)"; -+} -+ -+/// Destructor -+Rank2KOperationProfiler::~Rank2KOperationProfiler() { -+ -+} -+ -+/// Prints usage statement for the math function -+void Rank2KOperationProfiler::print_usage(std::ostream &out) const { -+ out << "RankK" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void Rank2KOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular problem size Syrk kernel:\n" -+ << " $ cutlass_profiler --operation=rank_2k --blas_mode=symmetric --n=1024 --k=128\n\n" -+ -+ << "Profile a particular problem size Herk kernel:\n" -+ << " $ cutlass_profiler --operation=rank_2k --blas_mode=hermitian --n=1024 --k=128\n\n" -+ -+ << "Schmoo over problem size and beta:\n" -+ << " $ cutlass_profiler --operation=rank_2k --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n" -+ -+ << "Schmoo over accumulator types:\n" -+ << " $ cutlass_profiler --operation=rank_2k --accumulator-type=f16,f32\n\n" -+ -+ << "Schmoo over fill modees:\n" -+ << " $ cutlass_profiler --operation=rank_2k --fill_mode=lower/upper\n\n" -+ -+ << "Run when A is f16 with column-major or A is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" -+ << " $ cutlass_profiler --operation=rank_2k --A=f16:column or --A=*:row\n\n" -+ -+ << "Using various input value distribution:\n" -+ << " $ cutlass_profiler --operation=rank_2k --dist=uniform,min:0,max:3\n" -+ << " $ cutlass_profiler --operation=rank_2k --dist=gaussian,mean:0,stddev:3\n" -+ << " $ cutlass_profiler --operation=rank_2k --dist=sequential,start:0,delta:1\n\n" -+ -+ << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" -+ << " $ cutlass_profiler --operation=rank_2k --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" -+ -+ << "Test your changes to rank_2k kernels with a quick functional test and save results in functional-test.csv:\n" -+ << " $ cutlass_profiler --operation=rank_2k \\ \n" -+ << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" -+ << " --beta=0,1,2 --profiling-iterations=1 \\ \n" -+ << " --providers=cutlass --output=functional-test.csv\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+Status Rank2KOperationProfiler::RankKProblem::parse( -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!arg_as_int(this->n, "n", problem_space, problem)) { -+ // default value -+ this->n = 1024; -+ } -+ -+ if (!arg_as_int(this->k, "k", problem_space, problem)) { -+ // default value -+ this->k = 1024; -+ } -+ -+ if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ this->split_k_slices = 1; -+ } -+ -+ if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { -+ // default value -+ this->batch_count = 1; -+ } -+ -+ if (this->split_k_slices > 1 && this->batch_count > 1) { -+ // At least one of these must be one -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ this->alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ this->beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->n), int(this->k)}).front(); -+ -+ this->ldb = DeviceAllocation::get_packed_layout( -+ operation_desc.B.layout, {int(this->n), int(this->k)}).front(); -+ -+ this->ldc = DeviceAllocation::get_packed_layout( -+ operation_desc.C.layout, {int(this->n), int(this->n)}).front(); -+ -+ return Status::kSuccess; -+} -+ -+/// Total number of bytes loaded -+int64_t Rank2KOperationProfiler::RankKProblem::bytes(library::RankKDescription const &operation_desc) const { -+ // Input bytes read and Output bytes written for the gemm problem -+ int64_t bytes = -+ 2 * int64_t(library::sizeof_bits(operation_desc.A.element) * n / 8) * k + -+ 2 * int64_t(library::sizeof_bits(operation_desc.B.element) * n / 8) * k + -+ // Half matrix including the diagonal will have (N*(N+1))/2 elements -+ int64_t(library::sizeof_bits(operation_desc.C.element) * n / 8) * (n+1) / 2; -+ -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * n / 8) * (n+1) / 2; -+ } -+ -+ bytes *= batch_count; -+ -+ return bytes; -+} -+ -+/// Total number of flops computed -+int64_t Rank2KOperationProfiler::RankKProblem::flops(library::RankKDescription const &operation_desc) const { -+ -+ // FLOPs = 2 * n(n+1)k/2 [mma1] + 2 * n(n+1)k/2 [mma2] + 2 * n(n+1)/2 [epilogue] -+ // FLOPs = n(n+1)(2k + 1) -+ int64_t flops_ = n * (n + 1) * (2*k + 1); -+ -+ // complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddComplexFastF32: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddGaussianComplex: -+ flops_ *= 3; -+ break; -+ -+ default: break; -+ } -+ -+ return flops_; -+} -+ -+/// Initializes a performance result -+void Rank2KOperationProfiler::RankKProblem::initialize_result( -+ PerformanceResult &result, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "rank_k_kind", problem_space, library::to_string(operation_desc.rank_k_kind)); -+ -+ set_argument(result, "A", problem_space, -+ std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); -+ -+ set_argument(result, "B", problem_space, -+ std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); -+ -+ set_argument(result, "C", problem_space, -+ std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); -+ -+ set_argument(result, "fill_mode", problem_space, library::to_string(operation_desc.fill_mode)); -+ -+ set_argument(result, "blas_mode", problem_space, library::to_string(operation_desc.blas_mode)); -+ -+ set_argument(result, "n", problem_space, n); -+ set_argument(result, "k", problem_space, k); -+ -+ set_argument(result, "split_k_slices", problem_space, split_k_slices); -+ set_argument(result, "batch_count", problem_space, batch_count); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(beta, operation_desc.element_epilogue)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status Rank2KOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::RankKDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (operation_desc.rank_k_kind != library::RankKKind::kUniversal) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = problem_.parse(operation_desc, problem_space, problem); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ rank_k_workspace_.configuration.problem_size.m() = int(problem_.n); -+ rank_k_workspace_.configuration.problem_size.n() = int(problem_.n); -+ rank_k_workspace_.configuration.problem_size.k() = int(problem_.k); -+ rank_k_workspace_.configuration.lda = problem_.lda; -+ rank_k_workspace_.configuration.ldb = problem_.ldb; -+ rank_k_workspace_.configuration.ldc = problem_.ldc; -+ rank_k_workspace_.configuration.ldd = problem_.ldc; -+ //rank_k_workspace_.configuration.split_k_slices = int(problem_.split_k_slices); -+ rank_k_workspace_.configuration.batch_count = int(problem_.split_k_slices); -+ -+ rank_k_workspace_.arguments.A = nullptr; -+ rank_k_workspace_.arguments.B = nullptr; -+ rank_k_workspace_.arguments.C = nullptr; -+ rank_k_workspace_.arguments.D = nullptr; -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&rank_k_workspace_.configuration, &rank_k_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void Rank2KOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ problem_.initialize_result(result, operation_desc, problem_space); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ -+ result.bytes = problem_.bytes(operation_desc); -+ result.flops = problem_.flops(operation_desc); -+ result.runtime = 0; -+ -+ -+} -+ -+/// Initializes workspace -+Status Rank2KOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::RankKDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ rank_k_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.n), int(problem_.k)}, -+ {int(problem_.lda)} -+ ); -+ -+ rank_k_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ {int(problem_.n), int(problem_.k)}, -+ {int(problem_.ldb)} -+ ); -+ -+ rank_k_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.ldc)}, -+ 1 // batch_count = 1, default -+ ); -+ -+ rank_k_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ rank_k_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ rank_k_workspace_.Computed->copy_from_device(rank_k_workspace_.C->data()); -+ rank_k_workspace_.Reference->copy_from_device(rank_k_workspace_.C->data()); -+ } -+ -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = operation->get_host_workspace_size(&rank_k_workspace_.configuration); -+ rank_k_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = operation->get_device_workspace_size(&rank_k_workspace_.configuration); -+ rank_k_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = operation->initialize( -+ &rank_k_workspace_.configuration, -+ rank_k_workspace_.host_workspace.data(), -+ rank_k_workspace_.device_workspace.data()); -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kRank2K; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool Rank2KOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ // Initialize structure containing RankK arguments -+ rank_k_workspace_.arguments.A = rank_k_workspace_.A->data(); -+ rank_k_workspace_.arguments.B = rank_k_workspace_.B->data(); -+ rank_k_workspace_.arguments.C = rank_k_workspace_.C->data(); -+ rank_k_workspace_.arguments.D = rank_k_workspace_.Computed->data(); -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Run the CUTLASS operation -+ // -+ -+ results_.back().status = operation->run( -+ &rank_k_workspace_.arguments, -+ rank_k_workspace_.host_workspace.data(), -+ rank_k_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUBLAS -+ if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { -+ -+ // Guard against unsupported cases -+ auto const & rank_k_desc = static_cast(operation->description()); -+ -+ if (cublas_satisfies(rank_k_desc) == Status::kSuccess) { -+ -+ // call cublas verification if supported -+ verify_with_cublas_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else { -+ // set verification map for cublas to not supported -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUBLAS -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool Rank2KOperationProfiler::verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+#if CUTLASS_ENABLE_CUBLAS -+ -+ library::RankKDescription const &rank_k_desc = -+ static_cast(operation->description()); -+ -+ // -+ // Construct cuBLAS operators -+ // -+ -+ CublasCreate handle; -+ cublasStatus_t status = handle.get_cublas_create_status(); -+ -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ try { -+ -+ // -+ // Construct dispatcher to cublasSyr2k() -+ // -+ -+ // Initialize structure containing RankK arguments -+ rank_k_workspace_.arguments.A = rank_k_workspace_.A->data(); -+ rank_k_workspace_.arguments.B = rank_k_workspace_.B->data(); -+ rank_k_workspace_.arguments.C = rank_k_workspace_.Reference->data(); -+ rank_k_workspace_.arguments.D = rank_k_workspace_.Reference->data(); -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ detail::cublasRankKDispatcher rank_k_op( -+ rank_k_desc, -+ rank_k_workspace_.configuration, -+ rank_k_workspace_.arguments -+ ); -+ -+ if (rank_k_op.status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotRun; -+ return true; -+ } -+ -+ results_.back().status = Status::kSuccess; -+ -+ status = rank_k_op(handle); -+ -+ // Handle errors -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( -+ options, -+ *rank_k_workspace_.Computed, -+ *rank_k_workspace_.Reference -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ rank_k_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUBLAS); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ } -+ -+#endif -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Measures performance results -+bool Rank2KOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing RankK arguments -+ rank_k_workspace_.arguments.A = rank_k_workspace_.A->data(); -+ rank_k_workspace_.arguments.B = rank_k_workspace_.B->data(); -+ rank_k_workspace_.arguments.C = rank_k_workspace_.C->data(); -+ rank_k_workspace_.arguments.D = rank_k_workspace_.Computed->data(); -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &rank_k_workspace_.arguments, -+ rank_k_workspace_.host_workspace.data(), -+ rank_k_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/rank_2k_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/rank_2k_operation_profiler.h -new file mode 100644 -index 0000000..6dbfc3f ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/rank_2k_operation_profiler.h -@@ -0,0 +1,229 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+ -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/blas3.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Abstract base class for each math function -+class Rank2KOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct RankKProblem { -+ int64_t n; -+ int64_t k; -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldc; -+ FillMode fill_mode; -+ BlasMode blas_mode; -+ std::vector alpha; -+ std::vector beta; -+ int64_t split_k_slices; -+ int64_t batch_count; -+ -+ // -+ // Methods -+ // -+ -+ RankKProblem(): -+ n(16), k(16), lda(0), ldc(0), -+ fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), -+ split_k_slices(1), batch_count(1) { } -+ -+ /// Parses the problem -+ Status parse( -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Total number of bytes loaded -+ int64_t bytes(library::RankKDescription const &operation_desc) const; -+ -+ /// Total number of flops computed -+ int64_t flops(library::RankKDescription const &operation_desc) const; -+ -+ /// Initializes a performance result -+ void initialize_result( -+ PerformanceResult &result, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ }; -+ -+ /// Workspace used -+ struct RankKWorkspace { -+ -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *C; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ library::RankKConfiguration configuration; -+ library::RankKArguments arguments; -+ -+ /// Buffer used for the operation's host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ // -+ // Methods -+ // -+ -+ RankKWorkspace(): -+ A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// GEMM problem obtained from problem space -+ RankKProblem problem_; -+ -+ /// Device memory allocations -+ RankKWorkspace rank_k_workspace_; -+ -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ Rank2KOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~Rank2KOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against references -+ bool verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/rank_k_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/rank_k_operation_profiler.cu -new file mode 100644 -index 0000000..7e452e7 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/rank_k_operation_profiler.cu -@@ -0,0 +1,715 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+ -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cublas_helpers.h" -+#include "rank_k_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+RankKOperationProfiler::RankKOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kRankK, -+ { -+ {ArgumentTypeID::kEnumerated, {"rank_k_kind"}, "Variant of RankK (universal)"}, -+ {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the RankK problem space"}, -+ {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the RankK problem space"}, -+ {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, -+ {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, -+ {ArgumentTypeID::kEnumerated, {"fill_mode"}, "Fill Mode for RankK kernel (lower or upper)"}, -+ {ArgumentTypeID::kEnumerated, {"blas_mode"}, "Blas Mode for RankK kernel (symmetric or hermitian)"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of RankK computed in one batch"}, -+ }, -+ { library::Provider::kCUBLAS} -+ ) { -+ description_ = " Rank-k Update. D = alpha * A*A^T + beta * C (symmetric) or D = alpha * A*A^H + beta * C (hermitian)"; -+} -+ -+/// Destructor -+RankKOperationProfiler::~RankKOperationProfiler() { -+ -+} -+ -+/// Prints usage statement for the math function -+void RankKOperationProfiler::print_usage(std::ostream &out) const { -+ out << "RankK" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void RankKOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular problem size Syrk kernel:\n" -+ << " $ cutlass_profiler --operation=rank_k --blas_mode=symmetric --n=1024 --k=128\n\n" -+ -+ << "Profile a particular problem size Herk kernel:\n" -+ << " $ cutlass_profiler --operation=rank_k --blas_mode=hermitian --n=1024 --k=128\n\n" -+ -+ << "Schmoo over problem size and beta:\n" -+ << " $ cutlass_profiler --operation=rank_k --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n" -+ -+ << "Schmoo over accumulator types:\n" -+ << " $ cutlass_profiler --operation=rank_k --accumulator-type=f16,f32\n\n" -+ -+ << "Schmoo over fill modees:\n" -+ << " $ cutlass_profiler --operation=rank_k --fill_mode=lower/upper\n\n" -+ -+ << "Run when A is f16 with column-major or A is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" -+ << " $ cutlass_profiler --operation=rank_k --A=f16:column or --A=*:row\n\n" -+ -+ << "Using various input value distribution:\n" -+ << " $ cutlass_profiler --operation=rank_k --dist=uniform,min:0,max:3\n" -+ << " $ cutlass_profiler --operation=rank_k --dist=gaussian,mean:0,stddev:3\n" -+ << " $ cutlass_profiler --operation=rank_k --dist=sequential,start:0,delta:1\n\n" -+ -+ << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" -+ << " $ cutlass_profiler --operation=rank_k --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" -+ -+ << "Test your changes to rank_k kernels with a quick functional test and save results in functional-test.csv:\n" -+ << " $ cutlass_profiler --operation=rank_k \\ \n" -+ << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" -+ << " --beta=0,1,2 --profiling-iterations=1 \\ \n" -+ << " --providers=cutlass --output=functional-test.csv\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+Status RankKOperationProfiler::RankKProblem::parse( -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!arg_as_int(this->n, "n", problem_space, problem)) { -+ // default value -+ this->n = 1024; -+ } -+ -+ if (!arg_as_int(this->k, "k", problem_space, problem)) { -+ // default value -+ this->k = 1024; -+ } -+ -+ if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ this->split_k_slices = 1; -+ } -+ -+ if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { -+ // default value -+ this->batch_count = 1; -+ } -+ -+ if (this->split_k_slices > 1 && this->batch_count > 1) { -+ // At least one of these must be one -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ this->alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ this->beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->n), int(this->k)}).front(); -+ -+ this->ldc = DeviceAllocation::get_packed_layout( -+ operation_desc.C.layout, {int(this->n), int(this->n)}).front(); -+ -+ return Status::kSuccess; -+} -+ -+/// Total number of bytes loaded -+int64_t RankKOperationProfiler::RankKProblem::bytes(library::RankKDescription const &operation_desc) const { -+ // Input bytes read and Output bytes written for the gemm problem -+ int64_t bytes = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * n / 8) * k + -+ int64_t(library::sizeof_bits(operation_desc.A.element) * n / 8) * k + -+ // Half matrix including the diagonal will have (N*(N+1))/2 elements -+ int64_t(library::sizeof_bits(operation_desc.C.element) * n / 8) * (n+1) / 2; -+ -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * n / 8) * (n+1) / 2; -+ } -+ -+ bytes *= batch_count; -+ -+ return bytes; -+} -+ -+/// Total number of flops computed -+int64_t RankKOperationProfiler::RankKProblem::flops(library::RankKDescription const &operation_desc) const { -+ -+ // FLOPs = 2 * n(n+1)k/2 [mma] + 2 * n(n+1)/2 [epilogue] -+ // FLOPs = n(n+1)(k + 1) -+ int64_t flops_ = n * (n + 1) * (k + 1); -+ -+ // complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddComplexFastF32: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddGaussianComplex: -+ flops_ *= 3; -+ break; -+ -+ default: break; -+ } -+ -+ return flops_; -+} -+ -+/// Initializes a performance result -+void RankKOperationProfiler::RankKProblem::initialize_result( -+ PerformanceResult &result, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "rank_k_kind", problem_space, library::to_string(operation_desc.rank_k_kind)); -+ -+ set_argument(result, "A", problem_space, -+ std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); -+ -+ set_argument(result, "C", problem_space, -+ std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); -+ -+ set_argument(result, "fill_mode", problem_space, library::to_string(operation_desc.fill_mode)); -+ -+ set_argument(result, "blas_mode", problem_space, library::to_string(operation_desc.blas_mode)); -+ -+ set_argument(result, "n", problem_space, n); -+ set_argument(result, "k", problem_space, k); -+ -+ set_argument(result, "split_k_slices", problem_space, split_k_slices); -+ set_argument(result, "batch_count", problem_space, batch_count); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(beta, operation_desc.element_epilogue)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status RankKOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::RankKDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (operation_desc.rank_k_kind != library::RankKKind::kUniversal) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = problem_.parse(operation_desc, problem_space, problem); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ rank_k_workspace_.configuration.problem_size.m() = int(problem_.n); -+ rank_k_workspace_.configuration.problem_size.n() = int(problem_.n); -+ rank_k_workspace_.configuration.problem_size.k() = int(problem_.k); -+ rank_k_workspace_.configuration.lda = problem_.lda; -+ rank_k_workspace_.configuration.ldc = problem_.ldc; -+ rank_k_workspace_.configuration.ldd = problem_.ldc; -+ //rank_k_workspace_.configuration.split_k_slices = int(problem_.split_k_slices); -+ rank_k_workspace_.configuration.batch_count = int(problem_.split_k_slices); -+ -+ rank_k_workspace_.arguments.A = nullptr; -+ rank_k_workspace_.arguments.C = nullptr; -+ rank_k_workspace_.arguments.D = nullptr; -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&rank_k_workspace_.configuration, &rank_k_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void RankKOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ problem_.initialize_result(result, operation_desc, problem_space); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ -+ result.bytes = problem_.bytes(operation_desc); -+ result.flops = problem_.flops(operation_desc); -+ -+ result.runtime = 0; -+ -+ // complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ result.flops *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddComplexFastF32: -+ result.flops *= 4; -+ break; -+ -+ default: break; -+ } -+ -+} -+ -+/// Initializes workspace -+Status RankKOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::RankKDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ rank_k_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.n), int(problem_.k)}, -+ {int(problem_.lda)} -+ ); -+ -+ rank_k_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.ldc)}, -+ 1 // batch_count = 1, default -+ ); -+ -+ rank_k_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ rank_k_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ rank_k_workspace_.Computed->copy_from_device(rank_k_workspace_.C->data()); -+ rank_k_workspace_.Reference->copy_from_device(rank_k_workspace_.C->data()); -+ } -+ -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = operation->get_host_workspace_size(&rank_k_workspace_.configuration); -+ rank_k_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = operation->get_device_workspace_size(&rank_k_workspace_.configuration); -+ rank_k_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = operation->initialize( -+ &rank_k_workspace_.configuration, -+ rank_k_workspace_.host_workspace.data(), -+ rank_k_workspace_.device_workspace.data()); -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kRankK; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool RankKOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ // Initialize structure containing RankK arguments -+ rank_k_workspace_.arguments.A = rank_k_workspace_.A->data(); -+ rank_k_workspace_.arguments.C = rank_k_workspace_.C->data(); -+ rank_k_workspace_.arguments.D = rank_k_workspace_.Computed->data(); -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Run the CUTLASS operation -+ // -+ -+ results_.back().status = operation->run( -+ &rank_k_workspace_.arguments, -+ rank_k_workspace_.host_workspace.data(), -+ rank_k_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUBLAS -+ if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { -+ -+ // Guard against unsupported cases -+ auto const & rank_k_desc = static_cast(operation->description()); -+ -+ if (cublas_satisfies(rank_k_desc) == Status::kSuccess) { -+ -+ // call cublas verification if supported -+ verify_with_cublas_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else { -+ // set verification map for cublas to not supported -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUBLAS -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool RankKOperationProfiler::verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+#if CUTLASS_ENABLE_CUBLAS -+ -+ library::RankKDescription const &rank_k_desc = -+ static_cast(operation->description()); -+ -+ // -+ // Construct cuBLAS operators -+ // -+ -+ CublasCreate handle; -+ cublasStatus_t status = handle.get_cublas_create_status(); -+ -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ try { -+ -+ // -+ // Construct dispatcher to cublasSyrk() -+ // -+ -+ // Initialize structure containing RankK arguments -+ rank_k_workspace_.arguments.A = rank_k_workspace_.A->data(); -+ rank_k_workspace_.arguments.C = rank_k_workspace_.Reference->data(); -+ rank_k_workspace_.arguments.D = rank_k_workspace_.Reference->data(); -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ detail::cublasRankKDispatcher rank_k_op( -+ rank_k_desc, -+ rank_k_workspace_.configuration, -+ rank_k_workspace_.arguments -+ ); -+ -+ if (rank_k_op.status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotRun; -+ return true; -+ } -+ -+ results_.back().status = Status::kSuccess; -+ -+ status = rank_k_op(handle); -+ -+ // Handle errors -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( -+ options, -+ *rank_k_workspace_.Computed, -+ *rank_k_workspace_.Reference -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ rank_k_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUBLAS); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ } -+ -+#endif -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Measures performance results -+bool RankKOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing RankK arguments -+ rank_k_workspace_.arguments.A = rank_k_workspace_.A->data(); -+ rank_k_workspace_.arguments.C = rank_k_workspace_.C->data(); -+ rank_k_workspace_.arguments.D = rank_k_workspace_.Computed->data(); -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &rank_k_workspace_.arguments, -+ rank_k_workspace_.host_workspace.data(), -+ rank_k_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/rank_k_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/rank_k_operation_profiler.h -new file mode 100644 -index 0000000..779509a ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/rank_k_operation_profiler.h -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+ -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/blas3.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Abstract base class for each math function -+class RankKOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct RankKProblem { -+ int64_t n; -+ int64_t k; -+ int64_t lda; -+ int64_t ldc; -+ FillMode fill_mode; -+ BlasMode blas_mode; -+ std::vector alpha; -+ std::vector beta; -+ int64_t split_k_slices; -+ int64_t batch_count; -+ -+ // -+ // Methods -+ // -+ -+ RankKProblem(): -+ n(16), k(16), lda(0), ldc(0), -+ fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), -+ split_k_slices(1), batch_count(1) { } -+ -+ /// Parses the problem -+ Status parse( -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Total number of bytes loaded -+ int64_t bytes(library::RankKDescription const &operation_desc) const; -+ -+ /// Total number of flops computed -+ int64_t flops(library::RankKDescription const &operation_desc) const; -+ -+ /// Initializes a performance result -+ void initialize_result( -+ PerformanceResult &result, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ }; -+ -+ /// Workspace used -+ struct RankKWorkspace { -+ -+ DeviceAllocation *A; -+ DeviceAllocation *C; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ library::RankKConfiguration configuration; -+ library::RankKArguments arguments; -+ -+ /// Buffer used for the operation's host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ // -+ // Methods -+ // -+ -+ RankKWorkspace(): -+ A(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// GEMM problem obtained from problem space -+ RankKProblem problem_; -+ -+ /// Device memory allocations -+ RankKWorkspace rank_k_workspace_; -+ -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ RankKOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~RankKOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against references -+ bool verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/reduction_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/reduction_operation_profiler.h -new file mode 100644 -index 0000000..eef7350 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/reduction_operation_profiler.h -@@ -0,0 +1,173 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines profiling functionality for reduction operation -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+#if CUTLASS_ENABLE_CUDNN -+#include "cudnn_helpers.h" -+#endif //#if CUTLASS_ENABLE_CUDNN -+#include "debug.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class ReductionOperationProfiler : public OperationProfiler { -+public: -+ -+ -+ /// Workspace used -+ struct ReductionWorkspace { -+ -+ /// Conv device allocations -+ DeviceAllocation *Workspace; -+ DeviceAllocation *Source; -+ DeviceAllocation *Destination; -+ DeviceAllocation *Reference; -+ -+ /// Library configuration and arguments -+ library::ReductionConfiguration configuration; -+ library::ReductionArguments arguments; -+ -+ /// Buffer used for the cutlass operations' host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the cutlass operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ // -+ // Methods -+ // -+ -+ ReductionWorkspace(): -+ Workspace(nullptr), Source(nullptr), Destination(nullptr), Reference(nullptr) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Reduction problem obtained from problem space -+ MatrixCoord problem_; -+ -+ /// Device memory allocations -+ ReductionWorkspace conv_workspace_; -+ -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ ReductionOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~ReductionOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/sparse_gemm_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/sparse_gemm_operation_profiler.cu -new file mode 100644 -index 0000000..2caf5f0 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/sparse_gemm_operation_profiler.cu -@@ -0,0 +1,569 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cublas_helpers.h" -+#include "sparse_gemm_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+SparseGemmOperationProfiler::SparseGemmOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kSparseGemm, -+ { -+ {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (e.g. gemm, planar complex, batched, ...)"}, -+ {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"}, -+ {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"}, -+ {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"}, -+ {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, -+ {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, -+ {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, -+ {ArgumentTypeID::kTensor, {"E"}, "Tensor storing the E operand"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kInteger, {"batch_count"}, "Number of GEMMs computed in one batch"}, -+ } -+ ) { -+ -+ description_ = " Structured sparse GEMM. D = alpha * A*B + beta * C"; -+} -+ -+/// Destructor -+SparseGemmOperationProfiler::~SparseGemmOperationProfiler() { -+ -+} -+ -+/// Prints usage statement for the math function -+void SparseGemmOperationProfiler::print_usage(std::ostream &out) const { -+ out << "Sparse GEMM" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void SparseGemmOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular problem size:\n" -+ << " $ cutlass_profiler --operation=SparseGemm --m=1024 --n=1024 --k=128\n\n" -+ -+ << "Schmoo over problem size and beta:\n" -+ << " $ cutlass_profiler --operation=SparseGemm --m=1024:4096:256 --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n" -+ -+ << "Schmoo over accumulator types:\n" -+ << " $ cutlass_profiler --operation=SparseGemm --accumulator-type=f16,f32\n\n" -+ -+ << "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" -+ << " $ cutlass_profiler --operation=SparseGemm --A=f16:column --B=*:row\n\n" -+ -+ << "Using various input value distribution:\n" -+ << " $ cutlass_profiler --operation=SparseGemm --dist=uniform,min:0,max:3\n" -+ << " $ cutlass_profiler --operation=SparseGemm --dist=gaussian,mean:0,stddev:3\n" -+ << " $ cutlass_profiler --operation=SparseGemm --dist=sequential,start:0,delta:1\n\n" -+ -+ << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" -+ << " $ cutlass_profiler --operation=SparseGemm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" -+ -+ << "Test your changes to gemm kernels with a quick functional test and save results in functional-test.csv:\n" -+ << " $ cutlass_profiler --operation=SparseGemm \\ \n" -+ << " --m=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" -+ << " --beta=0,1,2 --profiling-iterations=1 \\ \n" -+ << " --providers=cutlass --output=functional-test.csv\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Status SparseGemmOperationProfiler::SparseGemmProblem::parse( -+ library::SparseGemmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!arg_as_int(this->m, "m", problem_space, problem)) { -+ // default value -+ this->m = 1024; -+ } -+ -+ if (!arg_as_int(this->n, "n", problem_space, problem)) { -+ // default value -+ this->n = 1024; -+ } -+ -+ if (!arg_as_int(this->k, "k", problem_space, problem)) { -+ // default value -+ this->k = 1024; -+ } -+ -+ if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ this->split_k_slices = 1; -+ } -+ -+ if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { -+ // default value -+ this->batch_count = 1; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.E, "E", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ this->alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ this->beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ this->elements_per_128b = -+ 128 / library::sizeof_bits(operation_desc.A.element); -+ -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, -+ {int(this->m), int(this->k) / int(this->sparse)}) -+ .front(); -+ -+ this->ldb = DeviceAllocation::get_packed_layout( -+ operation_desc.B.layout, {int(this->k), int(this->n)}).front(); -+ -+ this->ldc = DeviceAllocation::get_packed_layout( -+ operation_desc.C.layout, {int(this->m), int(this->n)}).front(); -+ -+ this->lde = -+ DeviceAllocation::get_packed_layout( -+ operation_desc.E.layout, -+ {int(this->m), int(this->k / this->sparse / this->elements_per_128b)}) -+ .front(); -+ -+ return Status::kSuccess; -+} -+ -+/// Initializes a performance result -+void SparseGemmOperationProfiler::SparseGemmProblem::initialize_result( -+ PerformanceResult &result, -+ library::SparseGemmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind)); -+ -+ set_argument(result, "A", problem_space, -+ std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); -+ -+ set_argument(result, "B", problem_space, -+ std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); -+ -+ set_argument(result, "C", problem_space, -+ std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); -+ -+ set_argument(result, "E", problem_space, -+ std::string(library::to_string(operation_desc.E.element)) + ":" + library::to_string(operation_desc.E.layout)); -+ -+ set_argument(result, "m", problem_space, m); -+ set_argument(result, "n", problem_space, n); -+ set_argument(result, "k", problem_space, k); -+ -+ set_argument(result, "split_k_slices", problem_space, split_k_slices); -+ set_argument(result, "batch_count", problem_space, batch_count); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(beta, operation_desc.element_epilogue)); -+} -+ -+/// Extracts the problem dimensions -+Status SparseGemmOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::SparseGemmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (operation_desc.gemm_kind != library::GemmKind::kSparse) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = problem_.parse(operation_desc, problem_space, problem); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ gemm_workspace_.configuration.problem_size.m() = int(problem_.m); -+ gemm_workspace_.configuration.problem_size.n() = int(problem_.n); -+ gemm_workspace_.configuration.problem_size.k() = int(problem_.k); -+ gemm_workspace_.configuration.lda = problem_.lda; -+ gemm_workspace_.configuration.ldb = problem_.ldb; -+ gemm_workspace_.configuration.ldc = problem_.ldc; -+ gemm_workspace_.configuration.ldd = problem_.ldc; -+ gemm_workspace_.configuration.lde = problem_.lde; -+ -+ gemm_workspace_.arguments.A = nullptr; -+ gemm_workspace_.arguments.B = nullptr; -+ gemm_workspace_.arguments.C = nullptr; -+ gemm_workspace_.arguments.D = nullptr; -+ gemm_workspace_.arguments.E = nullptr; -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void SparseGemmOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::SparseGemmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ problem_.initialize_result(result, operation_desc, problem_space); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ // Input bytes read and Output bytes written for the gemm problem -+ result.bytes = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * problem_.m / 8) * -+ problem_.k / problem_.sparse + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * problem_.n / 8) * -+ problem_.k + -+ int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * -+ problem_.n + -+ int64_t(library::sizeof_bits(operation_desc.E.element) * problem_.m / 8) * -+ problem_.k / problem_.sparse / problem_.elements_per_128b; -+ -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(problem_.beta.begin(), problem_.beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ result.bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * problem_.n; -+ } -+ -+ result.flops = 2 * (problem_.m * problem_.n * problem_.k + problem_.m * problem_.n); -+ result.runtime = 0; -+ -+} -+ -+/// Initializes workspace -+Status SparseGemmOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::SparseGemmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ gemm_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.m), int(problem_.k) / int(problem_.sparse)}, -+ {int(problem_.lda)} -+ ); -+ -+ gemm_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ {int(problem_.k), int(problem_.n)}, -+ {int(problem_.ldb)} -+ ); -+ -+ gemm_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ gemm_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ gemm_workspace_.E = device_context.allocate_sparsemeta_tensor( -+ options, -+ "E", -+ operation_desc.E.element, -+ operation_desc.E.layout, -+ operation_desc.A.element, -+ {int(problem_.m), int(problem_.k) / int(problem_.sparse) / int(problem_.elements_per_128b)}, -+ {int(problem_.lde)} -+ ); -+ -+ gemm_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ gemm_workspace_.Reference->copy_from_device(gemm_workspace_.C->data()); -+ } -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = operation->get_host_workspace_size(&gemm_workspace_.configuration); -+ gemm_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = operation->get_device_workspace_size(&gemm_workspace_.configuration); -+ gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = operation->initialize( -+ &gemm_workspace_.configuration, -+ gemm_workspace_.host_workspace.data(), -+ gemm_workspace_.device_workspace.data()); -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kSparseGemm; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto &verification_provider : options.verification.providers) { -+ results_.back().verification_map[verification_provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool SparseGemmOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ // Initialize structure containing GEMM arguments -+ gemm_workspace_.arguments.A = gemm_workspace_.A->data(); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->data(); -+ gemm_workspace_.arguments.C = gemm_workspace_.C->data(); -+ gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); -+ gemm_workspace_.arguments.E = gemm_workspace_.E->data(); -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Run the CUTLASS operation -+ // -+ -+ results_.back().status = operation->run( -+ &gemm_workspace_.arguments, -+ gemm_workspace_.host_workspace.data(), -+ gemm_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Measures performance results -+bool SparseGemmOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing GEMM arguments -+ gemm_workspace_.arguments.A = gemm_workspace_.A->data(); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->data(); -+ gemm_workspace_.arguments.C = gemm_workspace_.C->data(); -+ gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); -+ gemm_workspace_.arguments.E = gemm_workspace_.E->data(); -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &gemm_workspace_.arguments, -+ gemm_workspace_.host_workspace.data(), -+ gemm_workspace_.device_workspace.data() -+ ); -+ } -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/sparse_gemm_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/sparse_gemm_operation_profiler.h -new file mode 100644 -index 0000000..c1f11c9 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/sparse_gemm_operation_profiler.h -@@ -0,0 +1,214 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+#include "gemm_operation_profiler.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class SparseGemmOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct SparseGemmProblem { -+ int64_t m; -+ int64_t n; -+ int64_t k; -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldc; -+ int64_t lde; -+ std::vector alpha; -+ std::vector beta; -+ int64_t split_k_slices; -+ int64_t batch_count; -+ static int const sparse = 2; -+ // every 128b ElementA uses one elementE -+ int elements_per_128b; -+ -+ // -+ // Methods -+ // -+ -+ SparseGemmProblem(): -+ m(16), n(16), k(16), lda(0), ldb(0), ldc(0), lde(0), split_k_slices(1), batch_count(1) { } -+ -+ /// Parses the problem -+ Status parse( -+ library::SparseGemmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes a performance result -+ void initialize_result( -+ PerformanceResult &result, -+ library::SparseGemmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ }; -+ -+ /// Workspace used -+ struct SparseGemmWorkspace { -+ -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *C; -+ DeviceAllocation *E; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ library::SparseGemmConfiguration configuration; -+ library::SparseGemmArguments arguments; -+ -+ /// Buffer used for the operation's host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ // -+ // Methods -+ // -+ -+ SparseGemmWorkspace(): -+ A(nullptr), B(nullptr), C(nullptr), E(nullptr), Computed(nullptr), Reference(nullptr) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ // GEMM problem -+ SparseGemmProblem problem_; -+ -+ /// Device memory allocations -+ SparseGemmWorkspace gemm_workspace_; -+ -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ SparseGemmOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~SparseGemmOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::SparseGemmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/symm_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/symm_operation_profiler.cu -new file mode 100644 -index 0000000..97cb34a ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/symm_operation_profiler.cu -@@ -0,0 +1,764 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+ -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cublas_helpers.h" -+#include "symm_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+SymmOperationProfiler::SymmOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kSymm, -+ { -+ {ArgumentTypeID::kEnumerated, {"symm_kind"}, "Variant of Symm (universal)"}, -+ {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the Symm problem space"}, -+ {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the Symm problem space"}, -+ {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, -+ {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, -+ {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, -+ {ArgumentTypeID::kEnumerated, {"side_mode"}, "Side Mode for Symm kernel (left or right)"}, -+ {ArgumentTypeID::kEnumerated, {"fill_mode"}, "Fill Mode for Symm kernel (lower or upper)"}, -+ {ArgumentTypeID::kEnumerated, {"blas_mode"}, "Blas Mode for Symm kernel (symmetric or hermitian)"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of Symm computed in one batch"}, -+ }, -+ { library::Provider::kCUBLAS } -+ ) { -+ description_ = " Symmetric Matrix-Matrix Multiplication. D = alpha * A * B OR alpha * B * A + beta * C (where A is symmetric/hermitian)"; -+} -+ -+/// Destructor -+SymmOperationProfiler::~SymmOperationProfiler() { -+ -+} -+ -+/// Prints usage statement for the math function -+void SymmOperationProfiler::print_usage(std::ostream &out) const { -+ out << "Symm" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void SymmOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular problem size SYMM kernel:\n" -+ << " $ cutlass_profiler --operation=Symm --blas_mode=symmetric --m=1024 --n=128\n\n" -+ -+ << "Profile a particular problem size HEMM kernel:\n" -+ << " $ cutlass_profiler --operation=Symm --blas_mode=hermitian --m=1024 --n=128\n\n" -+ -+ << "Schmoo over problem size and beta:\n" -+ << " $ cutlass_profiler --operation=Symm --m=1024:4096:256 --n=128:8192:128 --beta=0,1,2.5\n\n" -+ -+ << "Schmoo over accumulator types:\n" -+ << " $ cutlass_profiler --operation=Symm --accumulator-type=f16,f32\n\n" -+ -+ << "Schmoo over side modees:\n" -+ << " $ cutlass_profiler --operation=Symm --side_mode=left/right\n\n" -+ -+ << "Schmoo over fill modees:\n" -+ << " $ cutlass_profiler --operation=Symm --fill_mode=lower/upper\n\n" -+ -+ << "Run when A is f16 with column-major or A is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" -+ << " $ cutlass_profiler --operation=Symm --A=f16:column or --A=*:row\n\n" -+ -+ << "Using various input value distribution:\n" -+ << " $ cutlass_profiler --operation=Symm --dist=uniform,min:0,max:3\n" -+ << " $ cutlass_profiler --operation=Symm --dist=gaussian,mean:0,stddev:3\n" -+ << " $ cutlass_profiler --operation=Symm --dist=sequential,start:0,delta:1\n\n" -+ -+ << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" -+ << " $ cutlass_profiler --operation=Symm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" -+ -+ << "Test your changes to symm kernels with a quick functional test and save results in functional-test.csv:\n" -+ << " $ cutlass_profiler --operation=Symm \\ \n" -+ << " --m=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --n=8,16,32,64,128,256,288,384,504,512,520 \\ \n" -+ << " --beta=0,1,2 --profiling-iterations=1 \\ \n" -+ << " --providers=cutlass --output=functional-test.csv\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+Status SymmOperationProfiler::SymmProblem::parse( -+ library::SymmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!arg_as_int(this->m, "m", problem_space, problem)) { -+ // default value -+ this->m = 1024; -+ } -+ -+ if (!arg_as_int(this->n, "n", problem_space, problem)) { -+ // default value -+ this->n = 1024; -+ } -+ -+ if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ this->split_k_slices = 1; -+ } -+ -+ if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { -+ // default value -+ this->batch_count = 1; -+ } -+ -+ if (this->split_k_slices > 1 && this->batch_count > 1) { -+ // At least one of these must be one -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ this->alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ this->beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (operation_desc.side_mode == SideMode::kLeft) { -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->m), int(this->m)}).front(); -+ } -+ else if (operation_desc.side_mode == SideMode::kRight) { -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->n), int(this->n)}).front(); -+ } -+ -+ this->ldb = DeviceAllocation::get_packed_layout( -+ operation_desc.B.layout, {int(this->m), int(this->n)}).front(); -+ -+ this->ldc = DeviceAllocation::get_packed_layout( -+ operation_desc.C.layout, {int(this->m), int(this->n)}).front(); -+ -+ return Status::kSuccess; -+} -+ -+/// Total number of bytes loaded -+int64_t SymmOperationProfiler::SymmProblem::bytes(library::SymmDescription const &operation_desc) const { -+ int64_t bytes; -+ // Input bytes read and Output bytes written for the gemm problem -+ // Half matrix including the diagonal will have (X*(X+1))/2 elements -+ if (operation_desc.side_mode == SideMode::kLeft) { -+ bytes = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * m / 8) * (m + 1) / 2 + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * m / 8) * n + -+ int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; -+ } else if (operation_desc.side_mode == SideMode::kRight) { -+ bytes = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * n / 8) * (n + 1) / 2 + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * m / 8) * n + -+ int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; -+ } -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; -+ } -+ -+ bytes *= batch_count; -+ -+ return bytes; -+} -+ -+/// Total number of flops computed -+int64_t SymmOperationProfiler::SymmProblem::flops(library::SymmDescription const &operation_desc) const { -+ -+ // FLOPs for first TRMM kernel (with diagonal) = 2 * [ ( M * (M+1)/2 * N ) ] // Beta is zero -+ // FLOPs for second TRMM kernel (with diagonal) = 2 * [ ( M * (M-1)/2 * N ) ] // Beta is zero -+ // FLOPs = m*(m+1)*n [mma1] + m*(m-1)*n [mma2] + 2*m*n [epilogue] -+ // FLOPs = 2*m*n(m+1) for left side mode -+ // FLOPs can also be calculated to be same as GEMM with correct value for 'k' as below. -+ int64_t k = (operation_desc.side_mode == SideMode::kLeft) ? int64_t(m) : int64_t(n); -+ int64_t flops_ = (int64_t(m) * n * k + m * n) * 2; -+ -+ // complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddComplexFastF32: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddGaussianComplex: -+ flops_ *= 3; -+ break; -+ -+ default: break; -+ } -+ -+ return flops_; -+} -+ -+/// Initializes a performance result -+void SymmOperationProfiler::SymmProblem::initialize_result( -+ PerformanceResult &result, -+ library::SymmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "symm_kind", problem_space, library::to_string(operation_desc.symm_kind)); -+ -+ set_argument(result, "A", problem_space, -+ std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); -+ -+ set_argument(result, "B", problem_space, -+ std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); -+ -+ set_argument(result, "C", problem_space, -+ std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); -+ -+ set_argument(result, "side_mode", problem_space, library::to_string(operation_desc.side_mode)); -+ -+ set_argument(result, "fill_mode", problem_space, library::to_string(operation_desc.fill_mode)); -+ -+ set_argument(result, "blas_mode", problem_space, library::to_string(operation_desc.blas_mode)); -+ -+ set_argument(result, "m", problem_space, m); -+ set_argument(result, "n", problem_space, n); -+ -+ set_argument(result, "split_k_slices", problem_space, split_k_slices); -+ set_argument(result, "batch_count", problem_space, batch_count); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(beta, operation_desc.element_epilogue)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status SymmOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::SymmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (operation_desc.symm_kind != library::SymmKind::kUniversal) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = problem_.parse(operation_desc, problem_space, problem); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ symm_workspace_.configuration.problem_size.m() = int(problem_.m); -+ symm_workspace_.configuration.problem_size.n() = int(problem_.n); -+ symm_workspace_.configuration.problem_size.k() = (operation_desc.side_mode == SideMode::kLeft) -+ ? int(problem_.m) : int(problem_.n); -+ symm_workspace_.configuration.lda = problem_.lda; -+ symm_workspace_.configuration.ldb = problem_.ldb; -+ symm_workspace_.configuration.ldc = problem_.ldc; -+ symm_workspace_.configuration.ldd = problem_.ldc; -+ //symm_workspace_.configuration.split_k_slices = int(problem_.split_k_slices); -+ symm_workspace_.configuration.batch_count = int(problem_.split_k_slices); -+ -+ symm_workspace_.arguments.A = nullptr; -+ symm_workspace_.arguments.B = nullptr; -+ symm_workspace_.arguments.C = nullptr; -+ symm_workspace_.arguments.D = nullptr; -+ symm_workspace_.arguments.alpha = problem_.alpha.data(); -+ symm_workspace_.arguments.beta = problem_.beta.data(); -+ symm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&symm_workspace_.configuration, &symm_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void SymmOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::SymmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ problem_.initialize_result(result, operation_desc, problem_space); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ -+ result.bytes = problem_.bytes(operation_desc); -+ result.flops = problem_.flops(operation_desc); -+ result.runtime = 0; -+ -+ -+} -+ -+/// Initializes workspace -+Status SymmOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::SymmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ if (operation_desc.side_mode == SideMode::kLeft) { -+ symm_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.m), int(problem_.m)}, -+ {int(problem_.lda)}, -+ 1 // batch_count = 1, default -+ ); -+ } else if (operation_desc.side_mode == SideMode::kRight) { -+ symm_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.lda)}, -+ 1 // batch_count = 1, default -+ ); -+ } -+ -+ symm_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldb)} -+ ); -+ -+ symm_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)}, -+ 1 // batch_count = 1, default -+ ); -+ -+ symm_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ symm_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ symm_workspace_.Computed->copy_from_device(symm_workspace_.C->data()); -+ symm_workspace_.Reference->copy_from_device(symm_workspace_.C->data()); -+ } -+ -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = operation->get_host_workspace_size(&symm_workspace_.configuration); -+ symm_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = operation->get_device_workspace_size(&symm_workspace_.configuration); -+ symm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = operation->initialize( -+ &symm_workspace_.configuration, -+ symm_workspace_.host_workspace.data(), -+ symm_workspace_.device_workspace.data()); -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kSymm; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool SymmOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ // Initialize structure containing Symm arguments -+ symm_workspace_.arguments.A = symm_workspace_.A->data(); -+ symm_workspace_.arguments.B = symm_workspace_.B->data(); -+ symm_workspace_.arguments.C = symm_workspace_.C->data(); -+ symm_workspace_.arguments.D = symm_workspace_.Computed->data(); -+ symm_workspace_.arguments.alpha = problem_.alpha.data(); -+ symm_workspace_.arguments.beta = problem_.beta.data(); -+ symm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Run the CUTLASS operation -+ // -+ -+ results_.back().status = operation->run( -+ &symm_workspace_.arguments, -+ symm_workspace_.host_workspace.data(), -+ symm_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUBLAS -+ if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { -+ -+ // Guard against unsupported cases -+ auto const & symm_desc = static_cast(operation->description()); -+ -+ if (cublas_satisfies(symm_desc) == Status::kSuccess) { -+ -+ // call cublas verification if supported -+ verify_with_cublas_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else { -+ // set verification map for cublas to not supported -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUBLAS -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool SymmOperationProfiler::verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+#if CUTLASS_ENABLE_CUBLAS -+ -+ library::SymmDescription const &symm_desc = -+ static_cast(operation->description()); -+ -+ // -+ // Construct cuBLAS operators -+ // -+ -+ CublasCreate handle; -+ cublasStatus_t status = handle.get_cublas_create_status(); -+ -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ try { -+ -+ // -+ // Construct dispatcher to cublasSymm() -+ // -+ -+ // Initialize structure containing Symm arguments -+ symm_workspace_.arguments.A = symm_workspace_.A->data(); -+ symm_workspace_.arguments.B = symm_workspace_.B->data(); -+ symm_workspace_.arguments.C = symm_workspace_.Reference->data(); -+ symm_workspace_.arguments.D = symm_workspace_.Reference->data(); -+ symm_workspace_.arguments.alpha = problem_.alpha.data(); -+ symm_workspace_.arguments.beta = problem_.beta.data(); -+ symm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ detail::cublasSymmDispatcher symm_op( -+ symm_desc, -+ symm_workspace_.configuration, -+ symm_workspace_.arguments -+ ); -+ -+ if (symm_op.status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotRun; -+ return true; -+ } -+ -+ results_.back().status = Status::kSuccess; -+ -+ status = symm_op(handle); -+ -+ // Handle errors -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( -+ options, -+ *symm_workspace_.Computed, -+ *symm_workspace_.Reference -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ symm_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUBLAS); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ } -+ -+#endif -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Measures performance results -+bool SymmOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing Symm arguments -+ symm_workspace_.arguments.A = symm_workspace_.A->data(); -+ symm_workspace_.arguments.B = symm_workspace_.B->data(); -+ symm_workspace_.arguments.C = symm_workspace_.C->data(); -+ symm_workspace_.arguments.D = symm_workspace_.Computed->data(); -+ symm_workspace_.arguments.alpha = problem_.alpha.data(); -+ symm_workspace_.arguments.beta = problem_.beta.data(); -+ symm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &symm_workspace_.arguments, -+ symm_workspace_.host_workspace.data(), -+ symm_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/symm_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/symm_operation_profiler.h -new file mode 100644 -index 0000000..a0162b4 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/symm_operation_profiler.h -@@ -0,0 +1,230 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+ -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/blas3.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Abstract base class for each math function -+class SymmOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct SymmProblem { -+ int64_t m; -+ int64_t n; -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldc; -+ SideMode side_mode; -+ FillMode fill_mode; -+ BlasMode blas_mode; -+ std::vector alpha; -+ std::vector beta; -+ int64_t split_k_slices; -+ int64_t batch_count; -+ -+ // -+ // Methods -+ // -+ -+ SymmProblem(): -+ m(16), n(16), lda(0), ldb(0), ldc(0), -+ side_mode(SideMode::kInvalid), fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), -+ split_k_slices(1), batch_count(1) { } -+ -+ /// Parses the problem -+ Status parse( -+ library::SymmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Total number of bytes loaded -+ int64_t bytes(library::SymmDescription const &operation_desc) const; -+ -+ /// Total number of flops computed -+ int64_t flops(library::SymmDescription const &operation_desc) const; -+ -+ /// Initializes a performance result -+ void initialize_result( -+ PerformanceResult &result, -+ library::SymmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ }; -+ -+ /// Workspace used -+ struct SymmWorkspace { -+ -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *C; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ library::SymmConfiguration configuration; -+ library::SymmArguments arguments; -+ -+ /// Buffer used for the operation's host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ // -+ // Methods -+ // -+ -+ SymmWorkspace(): -+ A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// GEMM problem obtained from problem space -+ SymmProblem problem_; -+ -+ /// Device memory allocations -+ SymmWorkspace symm_workspace_; -+ -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ SymmOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~SymmOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::SymmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against references -+ bool verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/trmm_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/trmm_operation_profiler.cu -new file mode 100644 -index 0000000..19014d0 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/trmm_operation_profiler.cu -@@ -0,0 +1,704 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+ -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cublas_helpers.h" -+#include "trmm_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+TrmmOperationProfiler::TrmmOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kTrmm, -+ { -+ {ArgumentTypeID::kEnumerated, {"trmm_kind"}, "Variant of TRMM (universal)"}, -+ {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the TRMM problem space"}, -+ {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the TRMM problem space"}, -+ {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, -+ {ArgumentTypeID::kEnumerated, {"side_mode"}, "Side Mode for TRMM (left, right)"}, -+ {ArgumentTypeID::kEnumerated, {"fill_mode"}, "Fill Mode for TRMM (lower, upper)"}, -+ {ArgumentTypeID::kEnumerated, {"diag_type"}, "Diag Type for TRMM (nonunit, unit)"}, -+ {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, -+ {ArgumentTypeID::kTensor, {"D"}, "Tensor storing the D operand"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of TRMMs computed in one batch"}, -+ }, -+ { library::Provider::kCUBLAS} -+ ) { -+ description_ = " Triangular Matrix-Multiplication. D = alpha * A * B or alpha * B * A"; -+} -+ -+/// Destructor -+TrmmOperationProfiler::~TrmmOperationProfiler() { -+ -+} -+ -+/// Prints usage statement for the math function -+void TrmmOperationProfiler::print_usage(std::ostream &out) const { -+ out << "TRMM" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void TrmmOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular problem size:\n" -+ << " $ cutlass_profiler --operation=Trmm --n=1024 --m=128\n\n" -+ -+ << "Schmoo over problem size and beta:\n" -+ << " $ cutlass_profiler --operation=Trmm --n=1024:4096:256 --m=128:8192:128 --beta=0,1,2.5\n\n" -+ -+ << "Schmoo over accumulator types:\n" -+ << " $ cutlass_profiler --operation=Trmm --accumulator-type=f16,f32\n\n" -+ -+ << "Run when A is f16 with column-major or A is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" -+ << " $ cutlass_profiler --operation=Trmm --A=f16:column or --A=*:row\n\n" -+ -+ << "Using various input value distribution:\n" -+ << " $ cutlass_profiler --operation=Trmm --dist=uniform,min:0,max:3\n" -+ << " $ cutlass_profiler --operation=Trmm --dist=gaussian,mean:0,stddev:3\n" -+ << " $ cutlass_profiler --operation=Trmm --dist=sequential,start:0,delta:1\n\n" -+ -+ << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" -+ << " $ cutlass_profiler --operation=Trmm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" -+ -+ << "Test your changes to trmm kernels with a quick functional test and save results in functional-test.csv:\n" -+ << " $ cutlass_profiler --operation=Trmm \\ \n" -+ << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" -+ << " --beta=0,1,2 --profiling-iterations=1 \\ \n" -+ << " --providers=cutlass --output=functional-test.csv\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+Status TrmmOperationProfiler::TrmmProblem::parse( -+ library::TrmmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!arg_as_int(this->m, "m", problem_space, problem)) { -+ // default value -+ this->m = 1024; -+ } -+ -+ if (!arg_as_int(this->n, "n", problem_space, problem)) { -+ // default value -+ this->n = 1024; -+ } -+ -+ if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ this->split_k_slices = 1; -+ } -+ -+ if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { -+ // default value -+ this->batch_count = 1; -+ } -+ -+ if (this->split_k_slices > 1 && this->batch_count > 1) { -+ // At least one of these must be one -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.D, "D", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ this->alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ this->beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (operation_desc.side_mode == SideMode::kLeft) { -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->m), int(this->m)}).front(); -+ } -+ else if (operation_desc.side_mode == SideMode::kRight) { -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->n), int(this->n)}).front(); -+ } -+ -+ this->ldb = DeviceAllocation::get_packed_layout( -+ operation_desc.B.layout, {int(this->m), int(this->n)}).front(); -+ -+ this->ldd = DeviceAllocation::get_packed_layout( -+ operation_desc.D.layout, {int(this->m), int(this->n)}).front(); -+ -+ return Status::kSuccess; -+} -+ -+/// Initializes a performance result -+void TrmmOperationProfiler::TrmmProblem::initialize_result( -+ PerformanceResult &result, -+ library::TrmmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "trmm_kind", problem_space, library::to_string(operation_desc.trmm_kind)); -+ -+ set_argument(result, "A", problem_space, -+ std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); -+ -+ set_argument(result, "side_mode", problem_space, library::to_string(operation_desc.side_mode)); -+ -+ set_argument(result, "fill_mode", problem_space, library::to_string(operation_desc.fill_mode)); -+ -+ set_argument(result, "diag_type", problem_space, library::to_string(operation_desc.diag_type)); -+ -+ set_argument(result, "B", problem_space, -+ std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); -+ -+ set_argument(result, "D", problem_space, -+ std::string(library::to_string(operation_desc.D.element)) + ":" + library::to_string(operation_desc.D.layout)); -+ -+ set_argument(result, "m", problem_space, m); -+ set_argument(result, "n", problem_space, n); -+ -+ set_argument(result, "split_k_slices", problem_space, split_k_slices); -+ set_argument(result, "batch_count", problem_space, batch_count); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(beta, operation_desc.element_epilogue)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status TrmmOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::TrmmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (operation_desc.trmm_kind != library::TrmmKind::kUniversal) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = problem_.parse(operation_desc, problem_space, problem); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ trmm_workspace_.configuration.problem_size.m() = int(problem_.m); -+ trmm_workspace_.configuration.problem_size.n() = int(problem_.n); -+ trmm_workspace_.configuration.problem_size.k() = (operation_desc.side_mode == SideMode::kLeft) -+ ? int(problem_.m) : int(problem_.n); -+ trmm_workspace_.configuration.lda = problem_.lda; -+ trmm_workspace_.configuration.ldb = problem_.ldb; -+ trmm_workspace_.configuration.ldd = problem_.ldd; -+ //trmm_workspace_.configuration.split_k_slices = int(problem_.split_k_slices); -+ trmm_workspace_.configuration.batch_count = int(problem_.split_k_slices); -+ -+ trmm_workspace_.arguments.A = nullptr; -+ trmm_workspace_.arguments.B = nullptr; -+ trmm_workspace_.arguments.D = nullptr; -+ trmm_workspace_.arguments.alpha = problem_.alpha.data(); -+ trmm_workspace_.arguments.beta = problem_.beta.data(); -+ trmm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&trmm_workspace_.configuration, &trmm_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void TrmmOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::TrmmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ problem_.initialize_result(result, operation_desc, problem_space); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ if (operation_desc.side_mode == SideMode::kLeft) { -+ // Input bytes read and Output bytes written for the trmm problem -+ result.bytes = -+ // Half matrix including the diagonal will have (M*(M+1))/2 elements -+ int64_t(library::sizeof_bits(operation_desc.A.element) * problem_.m / 8) * (problem_.m + 1) / 2 + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * problem_.m / 8) * problem_.n + -+ int64_t(library::sizeof_bits(operation_desc.D.element) * problem_.m / 8) * problem_.n; -+ } else if (operation_desc.side_mode == SideMode::kRight) { -+ // Input bytes read and Output bytes written for the trmm problem -+ result.bytes = -+ // Half matrix including the diagonal will have (N*(N+1))/2 elements -+ int64_t(library::sizeof_bits(operation_desc.A.element) * problem_.n / 8) * (problem_.n + 1) / 2 + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * problem_.m / 8) * problem_.n + -+ int64_t(library::sizeof_bits(operation_desc.D.element) * problem_.m / 8) * problem_.n; -+ } -+ -+ // FLOPs = 2 * [ ( M * (M+1)/2 * N ) ] // Beta is zero -+ result.flops = problem_.m * (problem_.m + 1) * problem_.n; -+ -+ result.runtime = 0; -+ -+ // complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ result.flops *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddComplexFastF32: -+ result.flops *= 4; -+ break; -+ -+ default: break; -+ } -+ -+} -+ -+/// Initializes workspace -+Status TrmmOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::TrmmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ if (operation_desc.side_mode == SideMode::kLeft) { -+ trmm_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.m), int(problem_.m)}, -+ {int(problem_.lda)}, -+ 1 // batch_count = 1, default -+ ); -+ } else if (operation_desc.side_mode == SideMode::kRight) { -+ trmm_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.lda)}, -+ 1 // batch_count = 1, default -+ ); -+ } -+ -+ trmm_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldb)} -+ ); -+ -+ trmm_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.D.element, -+ operation_desc.D.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldd)} -+ ); -+ -+ trmm_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.D.element, -+ operation_desc.D.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldd)} -+ ); -+ -+ } -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = operation->get_host_workspace_size(&trmm_workspace_.configuration); -+ trmm_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = operation->get_device_workspace_size(&trmm_workspace_.configuration); -+ trmm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = operation->initialize( -+ &trmm_workspace_.configuration, -+ trmm_workspace_.host_workspace.data(), -+ trmm_workspace_.device_workspace.data()); -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kTrmm; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool TrmmOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ // Initialize structure containing TRMM arguments -+ trmm_workspace_.arguments.A = trmm_workspace_.A->data(); -+ trmm_workspace_.arguments.B = trmm_workspace_.B->data(); -+ trmm_workspace_.arguments.D = trmm_workspace_.Computed->data(); -+ trmm_workspace_.arguments.alpha = problem_.alpha.data(); -+ trmm_workspace_.arguments.beta = problem_.beta.data(); -+ trmm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Run the CUTLASS operation -+ // -+ -+ results_.back().status = operation->run( -+ &trmm_workspace_.arguments, -+ trmm_workspace_.host_workspace.data(), -+ trmm_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUBLAS -+ if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { -+ -+ // Guard against unsupported cases -+ auto const & trmm_desc = static_cast(operation->description()); -+ -+ if (cublas_satisfies(trmm_desc) == Status::kSuccess) { -+ -+ // call cublas verification if supported -+ verify_with_cublas_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else { -+ // set verification map for cublas to not supported -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUBLAS -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool TrmmOperationProfiler::verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+#if CUTLASS_ENABLE_CUBLAS -+ -+ library::TrmmDescription const &trmm_desc = -+ static_cast(operation->description()); -+ -+ // -+ // Construct cuBLAS operators -+ // -+ -+ CublasCreate handle; -+ cublasStatus_t status = handle.get_cublas_create_status(); -+ -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ try { -+ -+ // -+ // Construct dispatcher to cublasTrmm() -+ // -+ -+ // Initialize structure containing TRMM arguments -+ trmm_workspace_.arguments.A = trmm_workspace_.A->data(); -+ trmm_workspace_.arguments.B = trmm_workspace_.B->data(); -+ trmm_workspace_.arguments.D = trmm_workspace_.Reference->data(); -+ trmm_workspace_.arguments.alpha = problem_.alpha.data(); -+ trmm_workspace_.arguments.beta = problem_.beta.data(); -+ trmm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ detail::cublasTrmmDispatcher trmm_op( -+ trmm_desc, -+ trmm_workspace_.configuration, -+ trmm_workspace_.arguments -+ ); -+ -+ if (trmm_op.status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotRun; -+ return true; -+ } -+ -+ results_.back().status = Status::kSuccess; -+ -+ status = trmm_op(handle); -+ -+ // Handle errors -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( -+ options, -+ *trmm_workspace_.Computed, -+ *trmm_workspace_.Reference -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ trmm_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUBLAS); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ } -+ -+#endif -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Measures performance results -+bool TrmmOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing TRMM arguments -+ trmm_workspace_.arguments.A = trmm_workspace_.A->data(); -+ trmm_workspace_.arguments.B = trmm_workspace_.B->data(); -+ trmm_workspace_.arguments.D = trmm_workspace_.Computed->data(); -+ trmm_workspace_.arguments.alpha = problem_.alpha.data(); -+ trmm_workspace_.arguments.beta = problem_.beta.data(); -+ trmm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &trmm_workspace_.arguments, -+ trmm_workspace_.host_workspace.data(), -+ trmm_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/trmm_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/trmm_operation_profiler.h -new file mode 100644 -index 0000000..32ebcda ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/trmm_operation_profiler.h -@@ -0,0 +1,222 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+ -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/blas3.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class TrmmOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct TrmmProblem { -+ int64_t m; -+ int64_t n; -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldd; -+ SideMode side_mode; -+ FillMode fill_mode; -+ DiagType diag_type; -+ std::vector alpha; -+ std::vector beta; -+ int64_t split_k_slices; -+ int64_t batch_count; -+ -+ // -+ // Methods -+ // -+ -+ TrmmProblem(): -+ m(16), n(16), lda(0), ldb(0), ldd(0), split_k_slices(1), batch_count(1) { } -+ -+ /// Parses the problem -+ Status parse( -+ library::TrmmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes a performance result -+ void initialize_result( -+ PerformanceResult &result, -+ library::TrmmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ }; -+ -+ /// Workspace used -+ struct TrmmWorkspace { -+ -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *D; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ library::TrmmConfiguration configuration; -+ library::TrmmArguments arguments; -+ -+ /// Buffer used for the operation's host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ // -+ // Methods -+ // -+ -+ TrmmWorkspace(): -+ A(nullptr), B(nullptr), D(nullptr), Computed(nullptr), Reference(nullptr) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// GEMM problem obtained from problem space -+ TrmmProblem problem_; -+ -+ /// Device memory allocations -+ TrmmWorkspace trmm_workspace_; -+ -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ TrmmOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~TrmmOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::TrmmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against references -+ bool verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp -new file mode 100644 -index 0000000..5f2dd4b ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp -@@ -0,0 +1,67 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+ -+struct GPU_Clock -+{ -+ GPU_Clock() { -+ cudaEventCreate(&start_); -+ cudaEventCreate(&stop_); -+ cudaEventRecord(start_); -+ } -+ -+ ~GPU_Clock() { -+ cudaEventDestroy(start_); -+ cudaEventDestroy(stop_); -+ } -+ -+ void start() { -+ cudaEventRecord(start_); -+ } -+ -+ float milliseconds() { -+ cudaEventRecord(stop_); -+ cudaEventSynchronize(stop_); -+ float time; -+ cudaEventElapsedTime(&time, start_, stop_); -+ return time; -+ } -+ -+ float seconds() { -+ return milliseconds() * float(1e-3); -+ } -+ -+ private: -+ cudaEvent_t start_, stop_; -+}; -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/command_line.h b/3rdparty/cutlass/tools/util/include/cutlass/util/command_line.h -new file mode 100644 -index 0000000..65cf9a1 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/command_line.h -@@ -0,0 +1,313 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * Utility for parsing command line arguments -+ */ -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+ -+/****************************************************************************** -+ * command_line -+ ******************************************************************************/ -+ -+/** -+ * Utility for parsing command line arguments -+ */ -+struct CommandLine { -+ std::vector keys; -+ std::vector values; -+ std::vector args; -+ -+ /** -+ * Constructor -+ */ -+ CommandLine(int argc, const char** argv) { -+ using namespace std; -+ -+ for (int i = 1; i < argc; i++) { -+ string arg = argv[i]; -+ -+ if ((arg[0] != '-') || (arg[1] != '-')) { -+ args.push_back(arg); -+ continue; -+ } -+ -+ string::size_type pos; -+ string key, val; -+ if ((pos = arg.find('=')) == string::npos) { -+ key = string(arg, 2, arg.length() - 2); -+ val = ""; -+ } else { -+ key = string(arg, 2, pos - 2); -+ val = string(arg, pos + 1, arg.length() - 1); -+ } -+ -+ keys.push_back(key); -+ values.push_back(val); -+ } -+ } -+ -+ /** -+ * Checks whether a flag "--" is present in the commandline -+ */ -+ bool check_cmd_line_flag(const char* arg_name) const { -+ using namespace std; -+ -+ for (int i = 0; i < int(keys.size()); ++i) { -+ if (keys[i] == string(arg_name)) return true; -+ } -+ return false; -+ } -+ -+ /** -+ * Returns number of naked (non-flag and non-key-value) commandline parameters -+ */ -+ size_t num_naked_args() const { -+ return args.size(); -+ } -+ -+ /** -+ * Print naked (non-flag and non-key-value) commandline parameters -+ */ -+ void print_naked_args(std::ostream &out) const { -+ for (auto arg : args) { -+ out << " " << arg <<"\n"; -+ } -+ } -+ -+ /** -+ * Returns the commandline parameter for a given index (not including flags) -+ */ -+ template -+ void get_cmd_line_argument(int index, value_t& val) const { -+ using namespace std; -+ if (index < args.size()) { -+ istringstream str_stream(args[index]); -+ str_stream >> val; -+ } -+ } -+ -+ /** -+ * Obtains the boolean value specified for a given commandline parameter --= -+ */ -+ void get_cmd_line_argument(const char* arg_name, bool& val, bool _default) const { -+ val = _default; -+ if (check_cmd_line_flag(arg_name)) { -+ std::string value; -+ get_cmd_line_argument(arg_name, value); -+ -+ val = !(value == "0" || value == "false"); -+ } -+ } -+ -+ /** -+ * Obtains the value specified for a given commandline parameter --= -+ */ -+ template -+ void get_cmd_line_argument(const char* arg_name, -+ value_t& val) const { -+ -+ get_cmd_line_argument(arg_name, val, val); -+ } -+ -+ /** -+ * Obtains the value specified for a given commandline parameter --= -+ */ -+ template -+ void get_cmd_line_argument(const char* arg_name, -+ value_t& val, -+ value_t const& _default) const { -+ using namespace std; -+ -+ val = _default; -+ -+ for (int i = 0; i < int(keys.size()); ++i) { -+ if (keys[i] == string(arg_name)) { -+ istringstream str_stream(values[i]); -+ str_stream >> val; -+ } -+ } -+ } -+ -+ /** -+ * Returns the values specified for a given commandline parameter --=,* -+ */ -+ template -+ void get_cmd_line_arguments(const char* arg_name, -+ std::vector& vals, -+ char sep = ',') const { -+ using namespace std; -+ -+ if (check_cmd_line_flag(arg_name)) { -+ // Clear any default values -+ vals.clear(); -+ -+ // Recover from multi-value string -+ for (int i = 0; i < keys.size(); ++i) { -+ if (keys[i] == string(arg_name)) { -+ string val_string(values[i]); -+ seperate_string(val_string, vals, sep); -+ } -+ } -+ } -+ } -+ -+ /** -+ * Returns the values specified for a given commandline parameter -+ * --=,* -+ */ -+ void get_cmd_line_argument_pairs(const char* arg_name, -+ std::vector >& tokens, -+ char delim = ',', -+ char sep = ':') const { -+ if (check_cmd_line_flag(arg_name)) { -+ std::string value; -+ get_cmd_line_argument(arg_name, value); -+ -+ tokenize(tokens, value, delim, sep); -+ } -+ } -+ -+ /** -+ * Returns a list of ranges specified for a given commandline parameter -+ * --=,* -+ */ -+ void get_cmd_line_argument_ranges(const char* arg_name, -+ std::vector >& vals, -+ char delim = ',', -+ char sep = ':') const { -+ std::vector ranges; -+ get_cmd_line_arguments(arg_name, ranges, delim); -+ -+ for (std::vector::const_iterator range = ranges.begin(); -+ range != ranges.end(); ++range) { -+ -+ std::vector range_vals; -+ seperate_string(*range, range_vals, sep); -+ vals.push_back(range_vals); -+ } -+ } -+ -+ /** -+ * The number of pairs parsed -+ */ -+ int parsed_argc() const { return (int)keys.size(); } -+ -+ //------------------------------------------------------------------------- -+ // Utility functions -+ //------------------------------------------------------------------------- -+ -+ /// Tokenizes a comma-delimited list of string pairs delimited by ':' -+ static void tokenize(std::vector >& tokens, -+ std::string const& str, -+ char delim = ',', -+ char sep = ':') { -+ // Home-built to avoid Boost dependency -+ size_t s_idx = 0; -+ size_t d_idx = std::string::npos; -+ while (s_idx < str.size()) { -+ d_idx = str.find_first_of(delim, s_idx); -+ -+ size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size()); -+ size_t sep_idx = str.find_first_of(sep, s_idx); -+ size_t offset = 1; -+ if (sep_idx == std::string::npos || sep_idx >= end_idx) { -+ sep_idx = end_idx; -+ offset = 0; -+ } -+ -+ std::pair item( -+ str.substr(s_idx, sep_idx - s_idx), -+ str.substr(sep_idx + offset, end_idx - sep_idx - offset)); -+ -+ tokens.push_back(item); -+ s_idx = end_idx + 1; -+ } -+ } -+ -+ /// Tokenizes a comma-delimited list of string pairs delimited by ':' -+ static void tokenize(std::vector& tokens, -+ std::string const& str, -+ char delim = ',', -+ char sep = ':') { -+ typedef std::vector > TokenVector; -+ typedef TokenVector::const_iterator token_iterator; -+ -+ std::vector > token_pairs; -+ tokenize(token_pairs, str, delim, sep); -+ for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) { -+ tokens.push_back(tok->first); -+ } -+ } -+ -+ template -+ static void seperate_string(std::string const& str, -+ std::vector& vals, -+ char sep = ',') { -+ std::istringstream str_stream(str); -+ std::string::size_type old_pos = 0; -+ std::string::size_type new_pos = 0; -+ -+ // Iterate -delimited values -+ value_t val; -+ while ((new_pos = str.find(sep, old_pos)) != std::string::npos) { -+ if (new_pos != old_pos) { -+ str_stream.width(new_pos - old_pos); -+ str_stream >> val; -+ vals.push_back(val); -+ } -+ -+ // skip over delimiter -+ str_stream.ignore(1); -+ old_pos = new_pos + 1; -+ } -+ -+ // Read last value -+ str_stream >> val; -+ vals.push_back(val); -+ } -+}; -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp -new file mode 100644 -index 0000000..82d56fa ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp -@@ -0,0 +1,526 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+ -+//-- BLAM_DEBUG_OUT --------------------------------------------------------- -+#ifdef BLAM_DEBUG -+# include -+# ifndef BLAM_DEBUG_OUT -+# define BLAM_DEBUG_OUT(msg) std::cerr << "BLAM: " << msg << std::endl -+# define BLAM_DEBUG_OUT_2(msg) std::cerr << msg << std::endl -+# endif // BLAM_DEBUG_OUT -+#else -+# ifndef BLAM_DEBUG_OUT -+# define BLAM_DEBUG_OUT(msg) -+# define BLAM_DEBUG_OUT_2(msg) -+# endif // BLAM_DEBUG_OUT -+#endif // BLAM_DEBUG -+ -+// User could potentially define ComplexFloat/ComplexDouble instead of std:: -+#ifndef BLAM_COMPLEX_TYPES -+#define BLAM_COMPLEX_TYPES 1 -+#include -+namespace blam { -+template -+using Complex = cuda::std::complex; -+using ComplexFloat = cuda::std::complex; -+using ComplexDouble = cuda::std::complex; -+} -+#endif // BLAM_COMPLEX_TYPES -+ -+// User could potentially define Half instead of cute:: -+#ifndef BLAM_HALF_TYPE -+#define BLAM_HALF_TYPE 1 -+#include -+namespace blam { -+using Half = cute::half_t; -+} -+#endif // BLAM_HALF_TYPE -+ -+namespace blam -+{ -+namespace cublas -+{ -+ -+inline const char* -+cublas_get_error(cublasStatus_t status) -+{ -+ switch (status) { -+ case CUBLAS_STATUS_SUCCESS: -+ return "CUBLAS_STATUS_SUCCESS"; -+ case CUBLAS_STATUS_NOT_INITIALIZED: -+ return "CUBLAS_STATUS_NOT_INITIALIZED -- The cuBLAS library was not initialized."; -+ case CUBLAS_STATUS_ALLOC_FAILED: -+ return "CUBLAS_STATUS_ALLOC_FAILED -- Resource allocation failed inside the cuBLAS library."; -+ case CUBLAS_STATUS_INVALID_VALUE: -+ return "CUBLAS_STATUS_INVALID_VALUE -- An unsupported value or parameter was passed to the function."; -+ case CUBLAS_STATUS_ARCH_MISMATCH: -+ return "CUBLAS_STATUS_ARCH_MISMATCH -- The function requires a feature absent from the device architecture."; -+ case CUBLAS_STATUS_MAPPING_ERROR: -+ return "CUBLAS_STATUS_MAPPING_ERROR -- An access to GPU memory space failed."; -+ case CUBLAS_STATUS_EXECUTION_FAILED: -+ return "CUBLAS_STATUS_EXECUTION_FAILED -- The GPU program failed to execute."; -+ case CUBLAS_STATUS_INTERNAL_ERROR: -+ return "CUBLAS_STATUS_INTERNAL_ERROR -- An internal cuBLAS operation failed."; -+ case CUBLAS_STATUS_NOT_SUPPORTED: -+ return "CUBLAS_STATUS_NOT_SUPPORTED -- The functionality requested is not supported."; -+ case CUBLAS_STATUS_LICENSE_ERROR: -+ return "CUBLAS_STATUS_LICENSE_ERROR -- An error was detected when checking the current licensing."; -+ default: -+ return "CUBLAS_ERROR -- "; -+ } -+} -+ -+inline bool -+cublas_is_error(cublasStatus_t status) -+{ -+ return status != CUBLAS_STATUS_SUCCESS; -+} -+ -+ -+// hgemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const Half* alpha, -+ const Half* A, int ldA, -+ const Half* B, int ldB, -+ const Half* beta, -+ Half* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasHgemm"); -+ -+ return cublasGemmEx(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(A), CUDA_R_16F, ldA, -+ reinterpret_cast(B), CUDA_R_16F, ldB, -+ reinterpret_cast(beta), -+ reinterpret_cast< __half*>(C), CUDA_R_16F, ldC, -+ CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+} -+ -+// mixed hf gemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const float* alpha, -+ const Half* A, int ldA, -+ const Half* B, int ldB, -+ const float* beta, -+ float* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasGemmEx mixed half-float"); -+ -+ return cublasGemmEx(handle, transA, transB, -+ m, n, k, -+ alpha, -+ reinterpret_cast(A), CUDA_R_16F, ldA, -+ reinterpret_cast(B), CUDA_R_16F, ldB, -+ beta, -+ C, CUDA_R_32F, ldC, -+ CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+} -+ -+// igemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const int32_t* alpha, -+ const int8_t* A, int ldA, -+ const int8_t* B, int ldB, -+ const int32_t* beta, -+ int32_t* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasIgemm"); -+ -+ return cublasGemmEx(handle, transA, transB, -+ m, n, k, -+ alpha, -+ A, CUDA_R_8I, ldA, -+ B, CUDA_R_8I, ldB, -+ beta, -+ C, CUDA_R_32I, ldC, -+ CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+} -+ -+// sgemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const float* alpha, -+ const float* A, int ldA, -+ const float* B, int ldB, -+ const float* beta, -+ float* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasSgemm"); -+ -+ return cublasSgemm(handle, transA, transB, -+ m, n, k, -+ alpha, -+ A, ldA, -+ B, ldB, -+ beta, -+ C, ldC); -+} -+ -+// dgemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const double* alpha, -+ const double* A, int ldA, -+ const double* B, int ldB, -+ const double* beta, -+ double* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasDgemm"); -+ -+ return cublasDgemm(handle, transA, transB, -+ m, n, k, -+ alpha, -+ A, ldA, -+ B, ldB, -+ beta, -+ C, ldC); -+} -+ -+// cgemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const ComplexFloat* alpha, -+ const ComplexFloat* A, int ldA, -+ const ComplexFloat* B, int ldB, -+ const ComplexFloat* beta, -+ ComplexFloat* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasCgemm"); -+ -+ return cublasCgemm(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(A), ldA, -+ reinterpret_cast(B), ldB, -+ reinterpret_cast(beta), -+ reinterpret_cast(C), ldC); -+} -+ -+// zgemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const ComplexDouble* alpha, -+ const ComplexDouble* A, int ldA, -+ const ComplexDouble* B, int ldB, -+ const ComplexDouble* beta, -+ ComplexDouble* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasZgemm"); -+ -+ return cublasZgemm(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(A), ldA, -+ reinterpret_cast(B), ldB, -+ reinterpret_cast(beta), -+ reinterpret_cast(C), ldC); -+} -+ -+// hgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const Half* alpha, -+ const Half* A, int ldA, int loA, -+ const Half* B, int ldB, int loB, -+ const Half* beta, -+ Half* C, int ldC, int loC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasHgemmStridedBatched"); -+ -+ return cublasHgemmStridedBatched(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(A), ldA, loA, -+ reinterpret_cast(B), ldB, loB, -+ reinterpret_cast(beta), -+ reinterpret_cast<__half*>(C), ldC, loC, -+ batch_size); -+} -+ -+// sgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const float* alpha, -+ const float* A, int ldA, int loA, -+ const float* B, int ldB, int loB, -+ const float* beta, -+ float* C, int ldC, int loC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasSgemmStridedBatched"); -+ -+ return cublasSgemmStridedBatched(handle, transA, transB, -+ m, n, k, -+ alpha, -+ A, ldA, loA, -+ B, ldB, loB, -+ beta, -+ C, ldC, loC, -+ batch_size); -+} -+ -+// dgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const double* alpha, -+ const double* A, int ldA, int loA, -+ const double* B, int ldB, int loB, -+ const double* beta, -+ double* C, int ldC, int loC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasDgemmStridedBatched"); -+ -+ return cublasDgemmStridedBatched(handle, transA, transB, -+ m, n, k, -+ alpha, -+ A, ldA, loA, -+ B, ldB, loB, -+ beta, -+ C, ldC, loC, -+ batch_size); -+} -+ -+// cgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const ComplexFloat* alpha, -+ const ComplexFloat* A, int ldA, int loA, -+ const ComplexFloat* B, int ldB, int loB, -+ const ComplexFloat* beta, -+ ComplexFloat* C, int ldC, int loC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasCgemmStridedBatched"); -+ -+ return cublasCgemmStridedBatched(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(A), ldA, loA, -+ reinterpret_cast(B), ldB, loB, -+ reinterpret_cast(beta), -+ reinterpret_cast(C), ldC, loC, -+ batch_size); -+} -+ -+// zgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const ComplexDouble* alpha, -+ const ComplexDouble* A, int ldA, int loA, -+ const ComplexDouble* B, int ldB, int loB, -+ const ComplexDouble* beta, -+ ComplexDouble* C, int ldC, int loC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasZgemmStridedBatched"); -+ -+ return cublasZgemmStridedBatched(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(A), ldA, loA, -+ reinterpret_cast(B), ldB, loB, -+ reinterpret_cast(beta), -+ reinterpret_cast(C), ldC, loC, -+ batch_size); -+} -+ -+// hgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const Half* alpha, -+ const Half* const A[], int ldA, -+ const Half* const B[], int ldB, -+ const Half* beta, -+ Half* const C[], int ldC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasHgemmBatched"); -+ -+ return cublasHgemmBatched(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(const_cast(A)), ldA, -+ // A, ldA, // cuBLAS 9.2 -+ reinterpret_cast(const_cast(B)), ldB, -+ // B, ldB, // cuBLAS 9.2 -+ reinterpret_cast(beta), -+ reinterpret_cast<__half**>(const_cast(C)), ldC, -+ // C, ldC, // cuBLAS 9.2 -+ batch_size); -+} -+ -+// sgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const float* alpha, -+ const float* const A[], int ldA, -+ const float* const B[], int ldB, -+ const float* beta, -+ float* const C[], int ldC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasSgemmBatched"); -+ -+ return cublasSgemmBatched(handle, transA, transB, -+ m, n, k, -+ alpha, -+ const_cast(A), ldA, -+ // A, ldA, // cuBLAS 9.2 -+ const_cast(B), ldB, -+ // B, ldB, // cuBLAS 9.2 -+ beta, -+ const_cast(C), ldC, -+ // C, ldC, // cuBLAS 9.2 -+ batch_size); -+} -+ -+// dgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const double* alpha, -+ const double* const A[], int ldA, -+ const double* const B[], int ldB, -+ const double* beta, -+ double* const C[], int ldC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasDgemmBatched"); -+ -+ return cublasDgemmBatched(handle, transA, transB, -+ m, n, k, -+ alpha, -+ const_cast(A), ldA, -+ // A, ldA, // cuBLAS 9.2 -+ const_cast(B), ldB, -+ // B, ldB, // cuBLAS 9.2 -+ beta, -+ const_cast(C), ldC, -+ // C, ldC, // cuBLAS 9.2 -+ batch_size); -+} -+ -+// cgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const ComplexFloat* alpha, -+ const ComplexFloat* const A[], int ldA, -+ const ComplexFloat* const B[], int ldB, -+ const ComplexFloat* beta, -+ ComplexFloat* const C[], int ldC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasCgemmBatched"); -+ -+ return cublasCgemmBatched(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ const_cast(reinterpret_cast(A)), ldA, -+ //reinterpret_cast(A), ldA, // cuBLAS 9.2 -+ const_cast(reinterpret_cast(B)), ldB, -+ //reinterpret_cast(B), ldB, // cuBLAS 9.2 -+ reinterpret_cast(beta), -+ const_cast(reinterpret_cast(C)), ldC, -+ //reinterpret_cast(C), ldC, // cuBLAS 9.2 -+ batch_size); -+} -+ -+// zgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const ComplexDouble* alpha, -+ const ComplexDouble* const A[], int ldA, -+ const ComplexDouble* const B[], int ldB, -+ const ComplexDouble* beta, -+ ComplexDouble* const C[], int ldC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasZgemmBatched"); -+ -+ return cublasZgemmBatched(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ const_cast(reinterpret_cast(A)), ldA, -+ //reinterpret_cast(A), ldA, // cuBLAS 9.2 -+ const_cast(reinterpret_cast(B)), ldB, -+ //reinterpret_cast(B), ldB, // cuBLAS 9.2 -+ reinterpret_cast(beta), -+ const_cast(reinterpret_cast(C)), ldC, -+ //reinterpret_cast(C), ldC, // cuBLAS 9.2 -+ batch_size); -+} -+ -+} // end namespace cublas -+} // end namespace blam -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/debug.h b/3rdparty/cutlass/tools/util/include/cutlass/util/debug.h -new file mode 100644 -index 0000000..3a2480c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/debug.h -@@ -0,0 +1,143 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Contains code for debugging cutlass code -+*/ -+ -+#pragma once -+ -+#include "device_dump.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/****************************************************************************** -+ * Debug and logging macros -+ ******************************************************************************/ -+ -+/** -+ * Formats and prints the given message to stdout -+ */ -+#if !defined(CUDA_LOG) -+#if !defined(__CUDA_ARCH__) -+#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__) -+#else -+#define CUDA_LOG(format, ...) \ -+ printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ -+ blockIdx.x, \ -+ blockIdx.y, \ -+ blockIdx.z, \ -+ threadIdx.x, \ -+ threadIdx.y, \ -+ threadIdx.z, \ -+ __VA_ARGS__); -+#endif -+#endif -+ -+/** -+ * Formats and prints the given message to stdout only if DEBUG is defined -+ */ -+#if !defined(CUDA_LOG_DEBUG) -+#ifdef DEBUG -+#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__) -+#else -+#define CUDA_LOG_DEBUG(format, ...) -+#endif -+#endif -+ -+/** -+ * \brief The corresponding error message is printed to \p stderr (or \p stdout in device code) -+ * along with the supplied source context. -+ * -+ * \return The CUDA error. -+ */ -+__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error, -+ const char* expression, -+ const char* filename, -+ int line) { -+ (void)filename; -+ (void)line; -+ if (error) { -+#if !defined(__CUDA_ARCH__) -+ fprintf( -+ stderr, "CUDA error %d [%s, %d] in expression '%s': %s\n", error, filename, line, expression, cudaGetErrorString(error)); -+ fflush(stderr); -+#else -+ printf("CUDA error %d [%s, %d] in expression '%s'\n", error, filename, line, expression); -+#endif -+ } -+ return error; -+} -+ -+/** -+ * \brief Perror macro -+ */ -+#ifndef CUDA_PERROR -+#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__) -+#endif -+ -+/** -+ * \brief Perror macro with exit -+ */ -+#ifndef CUDA_PERROR_EXIT -+#define CUDA_PERROR_EXIT(e) \ -+ do { if (cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)) { \ -+ exit(1); \ -+ } } while (0) -+#endif -+ -+/** -+ * \brief Perror macro only if DEBUG is defined -+ */ -+#ifndef CUDA_PERROR_DEBUG -+#ifdef DEBUG -+#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e) -+#else -+#define CUDA_PERROR_DEBUG(e) (e) -+#endif -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// A small helper class to dump a type at compile time -+// Usage:: DumpType::Class -+template -+struct DebugType {}; -+ -+template -+void DebugTypeFunc(T const& t) { -+ T::t; -+} -+ -+// A small helper class to dump a compile time constant at compile time -+// Usage: DumpValue::kConstant -+template -+struct DebugValue {}; -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_dump.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_dump.h -new file mode 100644 -index 0000000..7a3270d ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_dump.h -@@ -0,0 +1,187 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+ -+/** -+ * \file -+ * \brief C++ interface to dump fragments and shared memory contents for -+ * debugging. -+ */ -+ -+namespace cutlass { -+namespace debug { -+ -+/****************************************************************************** -+ * Dump the fragments -+ ******************************************************************************/ -+ -+/// The first N threads dump the first M elements from their fragments with a -+/// stride of S elements. If N is not specified, dump the data of all the -+/// threads. If M is not specified, dump all the elements of the fragment. -+template -+CUTLASS_DEVICE void dump_fragment(Fragment const& frag, int N = 0, int M = 0, -+ int S = 1) { -+ int total_threads = blockDim.x * blockDim.y * blockDim.z; -+ int block_id = -+ blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; -+ int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + -+ (threadIdx.y * blockDim.x) + threadIdx.x; -+ -+ if (N < 0 || N > total_threads) { -+ if (thread_id == 0 && block_id == 0) -+ printf("Thread number N = %d should between [1, %d].\n", N, -+ total_threads); -+ -+ __syncthreads(); -+ -+ return; -+ } -+ -+ int total_elements = frag.size(); -+ -+ if (M < 0 || M > total_elements) { -+ if (thread_id == 0 && block_id == 0) -+ printf("Element number M = %d should between [1, %d].\n", M, -+ total_elements); -+ -+ __syncthreads(); -+ -+ return; -+ } -+ -+ if (N == 0) N = total_threads; -+ -+ if (M == 0) M = total_elements; -+ -+ if (S < 1 || S > M) { -+ if (thread_id == 0 && block_id == 0) -+ printf("Stride S = %d should between [1, %d].\n", S, M); -+ -+ __syncthreads(); -+ -+ return; -+ } -+ -+ if (thread_id == 0 && block_id == 0) -+ printf("\n*******************Dumping the fragments*******************\n\n"); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int tid = 0; tid < N; ++tid) { -+ if (tid == thread_id) { -+ printf("TB%d W%d T%d: ", block_id, tid / 32, tid & 31); -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int i = 0; i < M; i += S) { -+ printf("%.0f ", float(typename Fragment::value_type(frag[i]))); -+ } -+ printf("\n"); -+ } -+ -+ __syncthreads(); -+ } -+ -+ if (thread_id == 0 && block_id == 0) -+ printf("\n***********************************************************\n\n"); -+ -+ __syncthreads(); -+ -+ return; -+} -+ -+/****************************************************************************** -+ * Dump the shared memory -+ ******************************************************************************/ -+ -+#define SHMEM_ROW_SIZE 128 -+ -+/// Dump the shared memory contents. ptr is the begin address, size specifies -+/// the number of elements that need to be dumped, and S specifies the stride. -+template -+CUTLASS_DEVICE void dump_shmem(Element const* ptr, size_t size, int S = 1) { -+ int block_id = -+ blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; -+ int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + -+ (threadIdx.y * blockDim.x) + threadIdx.x; -+ -+ if (ptr == nullptr) { -+ if (thread_id == 0 && block_id == 0) printf("ptr is null.\n"); -+ -+ __syncthreads(); -+ return; -+ } -+ -+ if (size < 1) { -+ if (thread_id == 0 && block_id == 0) -+ printf("Element size is less than 1\n"); -+ -+ __syncthreads(); -+ -+ return; -+ } -+ -+ int row_elements = SHMEM_ROW_SIZE / sizeof(Element); -+ -+ if (S < 1 || S > row_elements) { -+ if (thread_id == 0 && block_id == 0) -+ printf("Stride S = %d should between [1, %d].\n", S, row_elements); -+ -+ __syncthreads(); -+ -+ return; -+ } -+ -+ __syncthreads(); -+ -+ if (thread_id == 0) -+ printf("\n********Dumping the shared memory of TB %d*******\n\n", block_id); -+ -+ if (thread_id == 0) { -+ for (int i = 0; i < size; i += row_elements) { -+ for (int j = 0; j < row_elements; j += S) { -+ printf("%.0f ", float(ptr[i + j])); -+ } -+ -+ printf("\n"); -+ } -+ } -+ -+ if (thread_id == 0) -+ printf("\n***********************************************************\n\n"); -+ -+ __syncthreads(); -+ -+ return; -+} -+} // namespace debug -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_groupnorm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_groupnorm.h -new file mode 100644 -index 0000000..aaa19b2 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_groupnorm.h -@@ -0,0 +1,402 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief cuda kernels to do group norm on a device memory tensor with NHWC layout. The tensor will be divided into [N, H, W, G, C'] and then we do normalization on [H, W, C']. -+ */ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/tensor_ref.h" -+#include "device_utils.h" -+#include -+ -+namespace cutlass { -+ -+/** \brief interface to do group norm on a device memory tensor with NHWC layout. -+ * \tparam T: data type -+ */ -+template -+void groupnorm(cutlass::Tensor4DCoord input_size, -+ const int num_groups, -+ const float eps, -+ TensorRef ref_output, -+ TensorRef ref_input, -+ TensorRef ref_gamma, -+ TensorRef ref_beta, -+ cudaStream_t stream); -+ -+extern __shared__ char groupnorm_shm[]; -+ -+// For small prod_dim1_to_last_dim/num_groups, to avoid multiple loads from global memory, -+// we store the input in the shared memory. -+// grid(num_groups, dim0) -+// block(BLOCKSIZE) -+// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group -+template -+__global__ void groupnorm_twopass_store_locally(T* output, -+ const T* input, -+ const T* gamma, -+ const T* beta, -+ int num_groups, -+ int prod_dim1_to_last_dim, -+ int last_dim, -+ const float eps, -+ const int TVecs_PER_THREAD) -+{ -+ const int bid = blockIdx.y; // index of batch -+ const int gid = blockIdx.x; // index of group -+ const int tid = threadIdx.x; // index of thread -+ const int bdimx = blockDim.x; -+ const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; -+ const int v_reduce_elements = s_reduce_elements / T_PER_TVec; -+ const int s_group_stride = last_dim / num_groups; -+ const int v_group_stride = s_group_stride / T_PER_TVec; -+ const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; -+ const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; -+ TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; -+ T* local_val = ((T*)groupnorm_shm) + TVecs_PER_THREAD * T_PER_TVec * tid; -+ float local_sum[1] = {0.0f}; -+ -+// load from global memory into shared memory -+#pragma unroll -+ for (int i = 0; i < TVecs_PER_THREAD; i += 1) { -+ const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; -+ const int offset_in_group = -+ ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) -+ / T_PER_TVec; -+ if (current_load_start_idx < s_reduce_elements) { -+ TVec tmp_vec = input_TVec_ptr[offset_in_group]; -+ T* tmp_vec_ptr = (T*)(&tmp_vec); -+ const int local_val_offset = i * T_PER_TVec; -+#pragma unroll -+ for (int j = 0; j < T_PER_TVec; j++) { -+ float tmp = static_cast(tmp_vec_ptr[j]); -+ local_sum[0] += tmp; -+ local_val[local_val_offset + j] = tmp_vec_ptr[j]; -+ } -+ } -+ } -+ __shared__ float s_mean, s_variance; -+ -+ // reduction for mean -+ if (bdimx <= 32) { -+ warpReduceSum(local_sum); -+ } -+ else { -+ blockReduceSum(local_sum); -+ } -+ if (tid == 0) { -+ s_mean = local_sum[0] / s_reduce_elements; -+ } -+ __syncthreads(); -+ -+ // reduction for std -+ local_sum[0] = 0.0f; -+#pragma unroll -+ for (int i = 0; i < TVecs_PER_THREAD; i += 1) { -+ const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; -+ if (current_load_start_idx < s_reduce_elements) { -+ const int local_val_offset = i * T_PER_TVec; -+#pragma unroll -+ for (int j = 0; j < T_PER_TVec; j++) { -+ float tmp = static_cast(local_val[local_val_offset + j]); -+ tmp -= s_mean; -+ local_sum[0] += tmp * tmp; -+ } -+ } -+ } -+ if (bdimx <= 32) { -+ warpReduceSum(local_sum); -+ } -+ else { -+ blockReduceSum(local_sum); -+ } -+ if (tid == 0) { -+ s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); -+ } -+ __syncthreads(); -+ -+ // normalize -+ const int gamma_offset_of_group = gid * v_group_stride; -+ const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; -+ const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; -+#pragma unroll -+ for (int i = 0; i < TVecs_PER_THREAD; i += 1) { -+ const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; -+ const int offset_in_group = -+ ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) -+ / T_PER_TVec; -+ const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; -+ const int local_val_offset = i * T_PER_TVec; -+ if (current_load_start_idx < s_reduce_elements) { -+ TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; -+ TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; -+ T* gamma_val_ptr = (T*)(&gamma_val); -+ T* beta_val_ptr = (T*)(&beta_val); -+ TVec tmp_vec; -+ T* tmp_vec_ptr = (T*)(&tmp_vec); -+#pragma unroll -+ for (int j = 0; j < T_PER_TVec; j++) { -+ float tmp = (static_cast(local_val[local_val_offset + j]) - s_mean) * s_variance -+ * static_cast(gamma_val_ptr[j]) -+ + static_cast(beta_val_ptr[j]); -+ if (sizeof(T) == sizeof(half)) { -+ tmp_vec_ptr[j] = T(__float2half_rn(tmp)); -+ } -+ else { -+ tmp_vec_ptr[j] = T(tmp); -+ } -+ } -+ output_TVec_ptr[offset_in_group] = tmp_vec; -+ } -+ } -+} -+ -+// For large prod_dim1_to_last_dim/num_groups, -+// in which the data cannot be stored locally, -+// we will load from global memory multiple times, -+// grid(num_groups, dim0) -+// block(BLOCKSIZE) -+// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group -+template -+__global__ void groupnorm_twopass_multiple_load(T* output, -+ const T* input, -+ const T* gamma, -+ const T* beta, -+ int num_groups, -+ int prod_dim1_to_last_dim, -+ int last_dim, -+ const float eps, -+ const int TVecs_PER_THREAD) -+{ -+ const int bid = blockIdx.y; // index of batch -+ const int gid = blockIdx.x; // index of group -+ const int tid = threadIdx.x; // index of thread -+ const int bdimx = blockDim.x; -+ const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; -+ const int v_reduce_elements = s_reduce_elements / T_PER_TVec; -+ const int s_group_stride = last_dim / num_groups; -+ const int v_group_stride = s_group_stride / T_PER_TVec; -+ const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; -+ const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; -+ TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; -+ float local_sum[1] = {0.0f}; -+ -+#pragma unroll -+ for (int i = 0; i < TVecs_PER_THREAD; i += 1) { -+ const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; -+ if (current_load_start_idx < s_reduce_elements) { -+ const int offset_in_group = -+ ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) -+ / T_PER_TVec; -+ TVec tmp_vec = input_TVec_ptr[offset_in_group]; -+ T* tmp_vec_ptr = (T*)(&tmp_vec); -+#pragma unroll -+ for (int j = 0; j < T_PER_TVec; j++) { -+ float tmp = static_cast(tmp_vec_ptr[j]); -+ local_sum[0] += tmp; -+ } -+ } -+ } -+ __shared__ float s_mean, s_variance; -+ -+ // reduction for mean -+ if (bdimx <= 32) { -+ warpReduceSum(local_sum); -+ } -+ else { -+ blockReduceSum(local_sum); -+ } -+ if (tid == 0) { -+ s_mean = local_sum[0] / s_reduce_elements; -+ } -+ __syncthreads(); -+ -+ // reduction for std -+ local_sum[0] = 0.0f; -+#pragma unroll -+ for (int i = 0; i < TVecs_PER_THREAD; i += 1) { -+ const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; -+ if (current_load_start_idx < s_reduce_elements) { -+ const int offset_in_group = -+ ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) -+ / T_PER_TVec; -+ TVec tmp_vec = input_TVec_ptr[offset_in_group]; -+ T* tmp_vec_ptr = (T*)(&tmp_vec); -+#pragma unroll -+ for (int j = 0; j < T_PER_TVec; j++) { -+ float tmp = static_cast(tmp_vec_ptr[j]); -+ tmp -= s_mean; -+ local_sum[0] += tmp * tmp; -+ } -+ } -+ } -+ if (bdimx <= 32) { -+ warpReduceSum(local_sum); -+ } -+ else { -+ blockReduceSum(local_sum); -+ } -+ if (tid == 0) { -+ s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); -+ } -+ __syncthreads(); -+ -+ // normalize -+ const int gamma_offset_of_group = gid * v_group_stride; -+ const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; -+ const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; -+#pragma unroll -+ for (int i = 0; i < TVecs_PER_THREAD; i += 1) { -+ const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; -+ if (current_load_start_idx < s_reduce_elements) { -+ const int offset_in_group = -+ ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) -+ / T_PER_TVec; -+ const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; -+ TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; -+ TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; -+ T* gamma_val_ptr = (T*)(&gamma_val); -+ T* beta_val_ptr = (T*)(&beta_val); -+ TVec tmp_vec = input_TVec_ptr[offset_in_group]; -+ T* tmp_vec_ptr = (T*)(&tmp_vec); -+ TVec output_tmp_vec; -+ T* output_tmp_vec_ptr = (T*)(&output_tmp_vec); -+#pragma unroll -+ for (int j = 0; j < T_PER_TVec; j++) { -+ float tmp = -+ (static_cast(tmp_vec_ptr[j]) - s_mean) * s_variance * static_cast(gamma_val_ptr[j]) -+ + static_cast(beta_val_ptr[j]); -+ if (sizeof(T) == sizeof(half)) { -+ output_tmp_vec_ptr[j] = T(__float2half_rn(tmp)); -+ } -+ else { -+ output_tmp_vec_ptr[j] = T(tmp); -+ } -+ } -+ output_TVec_ptr[offset_in_group] = output_tmp_vec; -+ } -+ } -+} -+ -+//ref_input & ref_output should be [N, H, W, C] -+//ref_gamma & ref_beta shoud be [1, 1, 1, C] -+template -+void groupnorm(cutlass::Tensor4DCoord input_size, -+ const int num_groups, -+ const float eps, -+ TensorRef ref_output, -+ TensorRef ref_input, -+ TensorRef ref_gamma, -+ TensorRef ref_beta, -+ cudaStream_t stream){ -+ const int N = input_size.n(); -+ const int H = input_size.h(); -+ const int W = input_size.w(); -+ const int C = input_size.c(); -+ if (C % num_groups != 0){ -+ printf("[ERROR] C should be a multiple of num_groups.\n"); -+ } -+ T* output = ref_output.data(); -+ const T* input = ref_input.data(); -+ const T* gamma = ref_gamma.data(); -+ const T* beta = ref_beta.data(); -+ -+ const int dim0 = N; -+ const int last_dim = C; -+ const int prod_dim1_to_last_dim = H*W*C; -+ const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; -+ const int s_group_stride = last_dim / num_groups; -+ dim3 grid(num_groups, dim0); -+ int threadblock_size = 32; -+ if (s_group_stride % 2 == 0) { -+ const int T_PER_TVec = 2; -+ while (threadblock_size < 1024) { -+ if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) -+ break; -+ threadblock_size *= 2; -+ } -+ dim3 block(threadblock_size); -+ const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; -+ const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); -+ // for small s_reduce_elements, specific case for H=W=22, C=1280, num_groups=32; -+ // the size of grid & block may have better choice for different cases. -+ // ensure shared memory is smaller than 48KB -+ if (std::is_same::value){ -+ if (shm_size < 48 * 1024) { -+ groupnorm_twopass_store_locally<<>>( -+ output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); -+ } -+ else { -+ groupnorm_twopass_multiple_load<<>>( -+ output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); -+ } -+ } -+ else{ -+ if (shm_size < 48 * 1024) { -+ groupnorm_twopass_store_locally<<>>( -+ output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); -+ } -+ else { -+ groupnorm_twopass_multiple_load<<>>( -+ output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); -+ } -+ } -+ } -+ else { -+ const int T_PER_TVec = 1; -+ while (threadblock_size < 1024) { -+ if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) -+ break; -+ threadblock_size *= 2; -+ } -+ dim3 block(threadblock_size); -+ const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; -+ const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); -+ if (shm_size < 48 * 1024) { -+ groupnorm_twopass_store_locally<<>>( -+ output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); -+ } -+ else { -+ groupnorm_twopass_multiple_load<<>>( -+ output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); -+ } -+ } -+ -+} -+ -+} //namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_layernorm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_layernorm.h -new file mode 100644 -index 0000000..c4ec925 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_layernorm.h -@@ -0,0 +1,644 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief cuda kernels to do layernorm on a device memory tensor with RowMajor layout. -+ */ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/tensor_ref.h" -+#include "device_utils.h" -+#include -+ -+namespace cutlass { -+ -+/** \brief interface to do layernorm on a device memory tensor with RowMajor layout. -+ * \tparam T: data type -+ */ -+template -+void layernorm(cutlass::MatrixCoord tensor_size, -+ TensorRef ref_output, -+ TensorRef ref_input, -+ TensorRef ref_gamma, -+ TensorRef ref_beta, -+ cudaStream_t stream); -+ -+/** -+ * output [m, n] row-major -+ * input [m, n] row-major -+ * gamma [n] -+ * beta [n] -+ * grid(m) -+ * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements -+*/ -+template -+__global__ void layernorm_twoPassAlgo_stored_locally_e1(T* output, -+ const T* input, -+ const T* gamma, -+ const T* beta, -+ const int m, -+ const int n) -+{ -+ const int m_idx = blockIdx.x; -+ const int tid = threadIdx.x; -+ const int bdimx = blockDim.x; -+ __shared__ float s_mean, s_variance; -+ T local_val[ITEM_PER_THREAD]; -+ float local_sums[1] = {0.0f}; -+ int offset = m_idx * n; -+ input += offset; -+ output += offset; -+ -+ const T zero = T(0.0f); -+ #pragma unroll -+ for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ -+ int index = tid + i*bdimx; -+ local_val[i] = index < n ? input[index] : zero; -+ local_sums[0] += static_cast(local_val[i]); -+ } -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_mean = local_sums[0] / n; -+ } -+ __syncthreads(); -+ -+ local_sums[0] = 0.0f; -+ #pragma unroll -+ for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ -+ int index = tid + i*bdimx; -+ if (index < n){ -+ const float tmp = static_cast(local_val[i]) - s_mean; -+ local_sums[0] += tmp * tmp; -+ } -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(local_sums[0] / n + 1e-5); -+ } -+ __syncthreads(); -+ -+ #pragma unroll -+ for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ -+ int index = tid + i*bdimx; -+ if (index < n) { -+ const T gamma_val = gamma[index]; -+ const T beta_val = beta[index]; -+ output[index] = T((static_cast(local_val[i]) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); -+ } -+ } -+} -+ -+/** -+ * output [m, n] row-major -+ * input [m, n] row-major -+ * gamma [n] -+ * beta [n] -+ * grid(m) -+ * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; -+*/ -+template -+__global__ void layernorm_twoPassAlgo_stored_locally_e2(T2* output, -+ const T2* input, -+ const T2* gamma, -+ const T2* beta, -+ const int m, -+ const int n) -+{ -+ const int m_idx = blockIdx.x; -+ const int tid = threadIdx.x; -+ const int bdimx = blockDim.x; -+ __shared__ float s_mean, s_variance; -+ float local_sums[1] = {0.0f}; -+ T2 local_val[ITEM_PER_THREAD]; -+ const int n_2 = n / 2; -+ int offset = m_idx * n_2; -+ input += offset; -+ output += offset; -+ -+ const T2 zero = {T(0.0f), T(0.0f)}; -+ #pragma UNROLL -+ for (int i = 0; i < ITEM_PER_THREAD; i += 1) { -+ const int index = i*bdimx + tid; -+ local_val[i] = index < n_2 ? input[index] : zero; -+ local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y); -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_mean = local_sums[0] / n; -+ } -+ __syncthreads(); -+ -+ local_sums[0] = 0.0f; -+ #pragma UNROLL -+ for (int i = 0; i < ITEM_PER_THREAD; i += 1) { -+ const int index = i*bdimx + tid; -+ if (index < n_2){ -+ const float2 tmp = {static_cast(local_val[i].x) - s_mean, -+ static_cast(local_val[i].y) - s_mean}; -+ local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; -+ } -+ } -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(local_sums[0] / n + 1e-5); -+ } -+ __syncthreads(); -+ -+ #pragma UNROLL -+ for (int i = 0; i < ITEM_PER_THREAD; i += 1) { -+ const int index = i*bdimx + tid; -+ if (index < n_2){ -+ const T2 gamma_val = gamma[index]; -+ const T2 beta_val = beta[index]; -+ T2 tmp; -+ tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); -+ tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); -+ output[index] = tmp; -+ } -+ } -+} -+ -+/** -+ * output [m, n] row-major -+ * input [m, n] row-major -+ * gamma [n] -+ * beta [n] -+ * grid(m) -+ * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*4 elements; -+*/ -+template -+__global__ void layernorm_twoPassAlgo_stored_locally_e4(T4* output, -+ const T4* input, -+ const T4* gamma, -+ const T4* beta, -+ const int m, -+ const int n) -+{ -+ const int m_idx = blockIdx.x; -+ const int tid = threadIdx.x; -+ const int bdimx = blockDim.x; -+ __shared__ float s_mean, s_variance; -+ float local_sums[1] = {0.0f}; -+ T4 local_val[ITEM_PER_THREAD]; -+ const int n_4 = n / 4; -+ int offset = m_idx * n_4; -+ input += offset; -+ output += offset; -+ -+ const T4 zero = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)}; -+ #pragma UNROLL -+ for (int i = 0; i < ITEM_PER_THREAD; i += 1) { -+ const int index = i*bdimx + tid; -+ local_val[i] = index < n_4 ? input[index] : zero; -+ local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y) + -+ static_cast(local_val[i].z) + static_cast(local_val[i].w); -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_mean = local_sums[0] / n; -+ } -+ __syncthreads(); -+ -+ local_sums[0] = 0.0f; -+ #pragma UNROLL -+ for (int i = 0; i < ITEM_PER_THREAD; i += 1) { -+ const int index = i*bdimx + tid; -+ if (index < n_4){ -+ const float4 tmp = {static_cast(local_val[i].x) - s_mean, -+ static_cast(local_val[i].y) - s_mean, -+ static_cast(local_val[i].z) - s_mean, -+ static_cast(local_val[i].w) - s_mean}; -+ local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y + tmp.z * tmp.z + tmp.w * tmp.w; -+ } -+ } -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(local_sums[0] / n + 1e-5); -+ } -+ __syncthreads(); -+ -+ #pragma UNROLL -+ for (int i = 0; i < ITEM_PER_THREAD; i += 1) { -+ const int index = i*bdimx + tid; -+ if (index < n_4){ -+ const T4 gamma_val = gamma[index]; -+ const T4 beta_val = beta[index]; -+ T4 tmp; -+ tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); -+ tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); -+ tmp.z = T((static_cast(local_val[i].z) - s_mean)*s_variance*static_cast(gamma_val.z) + static_cast(beta_val.z)); -+ tmp.w = T((static_cast(local_val[i].w) - s_mean)*s_variance*static_cast(gamma_val.w) + static_cast(beta_val.w)); -+ output[index] = tmp; -+ } -+ } -+} -+ -+/** -+ * output [m, n] row-major -+ * input [m, n] row-major -+ * gamma [n] -+ * beta [n] -+ * grid(m) -+ * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements -+*/ -+template -+__global__ void layernorm_twoPassAlgo_e1(T* output, -+ const T* input, -+ const T* gamma, -+ const T* beta, -+ const int m, -+ const int n) -+{ -+ const int m_idx = blockIdx.x; -+ const int tid = threadIdx.x; -+ const int bdimx = blockDim.x; -+ __shared__ float s_mean, s_variance; -+ float local_sums[1] = {0.0f}; -+ int offset = m_idx * n; -+ input += offset; -+ output += offset; -+ -+ for (int index = tid ; index < n ; index += bdimx){ -+ float local_val = static_cast(input[index]); -+ local_sums[0] += local_val; -+ } -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_mean = local_sums[0] / n; -+ } -+ __syncthreads(); -+ -+ local_sums[0] = 0.0f; -+ for (int index = tid ; index < n ; index += bdimx){ -+ float local_val = static_cast(input[index]); -+ local_val = local_val - s_mean; -+ local_sums[0] += local_val * local_val; -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(local_sums[0] / n + 1e-5); -+ } -+ __syncthreads(); -+ -+ for (int index = tid ; index < n ; index += bdimx){ -+ const T gamma_val = gamma[index]; -+ const T beta_val = beta[index]; -+ const T local_val = input[index]; -+ output[index] = T((static_cast(local_val) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); -+ } -+} -+ -+/** -+ * output [m, n] row-major -+ * input [m, n] row-major -+ * gamma [n] -+ * beta [n] -+ * grid(m) -+ * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; -+*/ -+template -+__global__ void layernorm_twoPassAlgo_e2(T2* output, -+ const T2* input, -+ const T2* gamma, -+ const T2* beta, -+ const int m, -+ const int n) -+{ -+ const int m_idx = blockIdx.x; -+ const int tid = threadIdx.x; -+ const int bdimx = blockDim.x; -+ __shared__ float s_mean, s_variance; -+ float local_sums[1] = {0.0f}; -+ const int n_2 = n / 2; -+ int offset = m_idx * n_2; -+ input += offset; -+ output += offset; -+ -+ for (int index = tid; index < n_2; index += bdimx) { -+ const T2 local_val = input[index]; -+ local_sums[0] += static_cast(local_val.x) + static_cast(local_val.y); -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_mean = local_sums[0] / n; -+ } -+ __syncthreads(); -+ -+ local_sums[0] = 0.0f; -+ for (int index = tid; index < n_2; index += bdimx) { -+ const T2 local_val = input[index]; -+ const float2 tmp = {static_cast(local_val.x) - s_mean, -+ static_cast(local_val.y) - s_mean}; -+ local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; -+ } -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(local_sums[0] / n + 1e-5); -+ } -+ __syncthreads(); -+ -+ for (int index = tid; index < n_2; index += bdimx) { -+ const T2 local_val = input[index]; -+ const T2 gamma_val = gamma[index]; -+ const T2 beta_val = beta[index]; -+ T2 tmp; -+ tmp.x = T((static_cast(local_val.x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); -+ tmp.y = T((static_cast(local_val.y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); -+ output[index] = tmp; -+ } -+} -+ -+template -+void layernorm(cutlass::MatrixCoord tensor_size, -+ TensorRef ref_output, -+ TensorRef ref_input, -+ TensorRef ref_gamma, -+ TensorRef ref_beta, -+ cudaStream_t stream){ -+ const int m = tensor_size.row(); -+ const int n = tensor_size.column(); -+ T* output = ref_output.data(); -+ const T* input = ref_input.data(); -+ const T* gamma = ref_gamma.data(); -+ const T* beta = ref_beta.data(); -+ dim3 grid(m); -+ dim3 block((n + 31)/32*32); -+ if (block.x > 1024){ -+ block.x = 1024; -+ } -+ // TODO : There should be better configs for different cases, we only use several samples to show how to use here -+ // TODO : using registers to store values locally can reduce the loads from global memory and speedup the kernels. -+ if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) { -+ block.x = (n/4 + 31)/32*32; -+ if (std::is_same::value) { -+ layernorm_twoPassAlgo_stored_locally_e4<<>>( -+ (float4*)output, -+ (const float4*)input, -+ (const float4*)gamma, -+ (const float4*)beta, -+ m, -+ n); -+ } // if (std::is_same::value) -+ else { -+ layernorm_twoPassAlgo_stored_locally_e4<<>>( -+ (half4*)output, -+ (const half4*)input, -+ (const half4*)gamma, -+ (const half4*)beta, -+ m, -+ n); -+ } -+ } //if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) -+ else if (n % 2 == 0) { -+ if (n / 2 <= 1024) { -+ block.x = (n/2 + 31)/32*32; -+ if (std::is_same::value) { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (float2*)output, -+ (const float2*)input, -+ (const float2*)gamma, -+ (const float2*)beta, -+ m, -+ n); -+ } //if (std::is_same::value) -+ else { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (half2*)output, -+ (const half2*)input, -+ (const half2*)gamma, -+ (const half2*)beta, -+ m, -+ n); -+ } -+ } // if (n / 2 <= 1024) -+ else if (n <= 8192) { -+ block.x = ((n + 7)/8 + 31)/32*32; -+ if (std::is_same::value) { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (float2*)output, -+ (const float2*)input, -+ (const float2*)gamma, -+ (const float2*)beta, -+ m, -+ n); -+ } // if (std::is_same::value) -+ else { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (half2*)output, -+ (const half2*)input, -+ (const half2*)gamma, -+ (const half2*)beta, -+ m, -+ n); -+ } -+ } // if (n <= 8192) -+ else if (n <= 16384) { -+ block.x = ((n + 15)/ 16 + 31)/32*32; -+ if (std::is_same::value) { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (float2*)output, -+ (const float2*)input, -+ (const float2*)gamma, -+ (const float2*)beta, -+ m, -+ n); -+ } // if (std::is_same::value) -+ else { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (half2*)output, -+ (const half2*)input, -+ (const half2*)gamma, -+ (const half2*)beta, -+ m, -+ n); -+ } -+ } // if (n <= 16384) -+ else if (n <= 32768) { -+ block.x = ((n + 31)/32 + 31)/32*32; -+ if (std::is_same::value) { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (float2*)output, -+ (const float2*)input, -+ (const float2*)gamma, -+ (const float2*)beta, -+ m, -+ n); -+ } // if (std::is_same::value) -+ else { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (half2*)output, -+ (const half2*)input, -+ (const half2*)gamma, -+ (const half2*)beta, -+ m, -+ n); -+ } -+ } // if (n <= 32768) -+ else { -+ if (block.x > 512) -+ block.x = 512; -+ if (std::is_same::value) { -+ layernorm_twoPassAlgo_e2<<>>( -+ (float2 *)output, -+ (const float2 *)input, -+ (const float2 *)gamma, -+ (const float2 *)beta, -+ m, -+ n); -+ } // if (std::is_same::value) -+ else { -+ layernorm_twoPassAlgo_e2<<>>( -+ (half2 *)output, -+ (const half2 *)input, -+ (const half2 *)gamma, -+ (const half2 *)beta, -+ m, -+ n); -+ } -+ } -+ } // if (n % 2 == 0) -+ else { -+ if (n <= 1024) { -+ layernorm_twoPassAlgo_stored_locally_e1<<>>( -+ output, -+ input, -+ gamma, -+ beta, -+ m, -+ n); -+ } // if (n <= 1024) -+ else if (n <= 8192) { -+ block.x = ((n + 7)/8 + 31)/32*32; -+ layernorm_twoPassAlgo_stored_locally_e1<<>>( -+ output, -+ input, -+ gamma, -+ beta, -+ m, -+ n); -+ } // if (n <= 8192) -+ else if (n <= 16384) { -+ block.x = ((n + 15)/16 + 32)/32*32; -+ layernorm_twoPassAlgo_stored_locally_e1<<>>( -+ output, -+ input, -+ gamma, -+ beta, -+ m, -+ n); -+ } // if (n <= 16384) -+ else if (n <= 32768) { -+ block.x = ((n + 31)/32 + 31)/32*32; -+ layernorm_twoPassAlgo_stored_locally_e1<<>>( -+ output, -+ input, -+ gamma, -+ beta, -+ m, -+ n); -+ } // if (n <= 32768) -+ else{ -+ if (block.x > 512) { -+ block.x = 512; -+ } -+ layernorm_twoPassAlgo_e1<<>>( -+ output, -+ input, -+ gamma, -+ beta, -+ m, -+ n); -+ } -+ } -+} -+ -+} //namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_memory.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_memory.h -new file mode 100644 -index 0000000..67dfff5 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_memory.h -@@ -0,0 +1,338 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief C++ interface to CUDA device memory management functions. -+ */ -+ -+#include -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/numeric_types.h" -+#include "exceptions.h" -+ -+namespace cutlass { -+namespace device_memory { -+ -+/****************************************************************************** -+ * Allocation lifetime -+ ******************************************************************************/ -+ -+/// Allocate a buffer of \p count elements of type \p T on the current CUDA device -+template -+T* allocate(size_t count = 1) { -+ -+ T* ptr = 0; -+ size_t bytes = 0; -+ -+ bytes = count * sizeof(T); -+ -+ cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes); -+ -+ if (cuda_error != cudaSuccess) { -+ throw cuda_exception("Failed to allocate memory", cuda_error); -+ } -+ -+ return ptr; -+} -+ -+/// Free the buffer pointed to by \p ptr -+template -+void free(T* ptr) { -+ if (ptr) { -+ cudaError_t cuda_error = (cudaFree(ptr)); -+ if (cuda_error != cudaSuccess) { -+ throw cuda_exception("Failed to free device memory", cuda_error); -+ } -+ } -+} -+ -+/****************************************************************************** -+ * Data movement -+ ******************************************************************************/ -+ -+template -+void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) { -+ size_t bytes = count * sizeof_bits::value / 8; -+ if (bytes == 0 && count > 0) -+ bytes = 1; -+ cudaError_t cuda_error = (cudaMemcpy(dst, src, bytes, kind)); -+ if (cuda_error != cudaSuccess) { -+ throw cuda_exception("cudaMemcpy() failed", cuda_error); -+ } -+} -+ -+template -+void copy_to_device(T* dst, T const* src, size_t count = 1) { -+ copy(dst, src, count, cudaMemcpyHostToDevice); -+} -+ -+template -+void copy_to_host(T* dst, T const* src, size_t count = 1) { -+ copy(dst, src, count, cudaMemcpyDeviceToHost); -+} -+ -+template -+void copy_device_to_device(T* dst, T const* src, size_t count = 1) { -+ copy(dst, src, count, cudaMemcpyDeviceToDevice); -+} -+ -+template -+void copy_host_to_host(T* dst, T const* src, size_t count = 1) { -+ copy(dst, src, count, cudaMemcpyHostToHost); -+} -+ -+/// Copies elements from device memory to host-side range -+template -+void insert_to_host(OutputIterator begin, OutputIterator end, T const* device_begin) { -+ size_t elements = end - begin; -+ copy_to_host(&*begin, device_begin, elements); -+} -+ -+/// Copies elements to device memory from host-side range -+template -+void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) { -+ size_t elements = end - begin; -+ copy_to_device(device_begin, &*begin, elements); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device_memory -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class DeviceAllocation { -+public: -+ -+ /// Delete functor for CUDA device memory -+ struct deleter { -+ void operator()(T* ptr) { -+ cudaError_t cuda_error = (cudaFree(ptr)); -+ if (cuda_error != cudaSuccess) { -+ // noexcept -+ // throw cuda_exception("cudaFree() failed", cuda_error); -+ return; -+ } -+ } -+ }; -+ -+public: -+ // -+ // Data members -+ // -+ -+ /// Number of elements of T allocated on the current CUDA device -+ size_t capacity; -+ -+ /// Smart pointer -+ platform::unique_ptr smart_ptr; -+ -+public: -+ -+ // -+ // Static methods -+ // -+ -+ /// Static member to compute the number of bytes needed for a given number of elements -+ static size_t bytes(size_t elements) { -+ if (sizeof_bits::value < 8) { -+ size_t const kElementsPerByte = 8 / sizeof_bits::value; -+ return elements / kElementsPerByte; -+ } -+ else { -+ size_t const kBytesPerElement = sizeof_bits::value / 8; -+ return elements * kBytesPerElement; -+ } -+ } -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor: allocates no memory -+ DeviceAllocation() : capacity(0) {} -+ -+ /// Constructor: allocates \p capacity elements on the current CUDA device -+ DeviceAllocation(size_t _capacity) : -+ smart_ptr(device_memory::allocate(_capacity)), capacity(_capacity) {} -+ -+ /// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation -+ DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {} -+ -+ /// Copy constructor -+ DeviceAllocation(DeviceAllocation const &p): -+ smart_ptr(device_memory::allocate(p.capacity)), capacity(p.capacity) { -+ -+ device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); -+ } -+ -+ /// Move constructor -+ DeviceAllocation(DeviceAllocation &&p): capacity(0) { -+ std::swap(smart_ptr, p.smart_ptr); -+ std::swap(capacity, p.capacity); -+ } -+ -+ /// Destructor -+ ~DeviceAllocation() { reset(); } -+ -+ /// Returns a pointer to the managed object -+ T* get() const { return smart_ptr.get(); } -+ -+ /// Releases the ownership of the managed object (without deleting) and resets capacity to zero -+ T* release() { -+ capacity = 0; -+ return smart_ptr.release(); -+ } -+ -+ /// Deletes the managed object and resets capacity to zero -+ void reset() { -+ capacity = 0; -+ smart_ptr.reset(); -+ } -+ -+ /// Deletes managed object, if owned, and allocates a new object -+ void reset(size_t _capacity) { -+ reset(device_memory::allocate(_capacity), _capacity); -+ } -+ -+ /// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity -+ void reset(T* _ptr, size_t _capacity) { -+ smart_ptr.reset(_ptr); -+ capacity = _capacity; -+ } -+ -+ /// Allocates a new buffer and copies the old buffer into it. The old buffer is then released. -+ void reallocate(size_t new_capacity) { -+ -+ platform::unique_ptr new_allocation(device_memory::allocate(new_capacity)); -+ -+ device_memory::copy_device_to_device( -+ new_allocation.get(), -+ smart_ptr.get(), -+ std::min(new_capacity, capacity)); -+ -+ std::swap(smart_ptr, new_allocation); -+ std::swap(new_capacity, capacity); -+ } -+ -+ /// Returns the number of elements -+ size_t size() const { -+ return capacity; -+ } -+ -+ /// Returns the number of bytes needed to store the allocation -+ size_t bytes() const { -+ return bytes(capacity); -+ } -+ -+ /// Returns a pointer to the object owned by *this -+ T* operator->() const { return smart_ptr.get(); } -+ -+ /// Returns the deleter object which would be used for destruction of the managed object. -+ deleter& get_deleter() { return smart_ptr.get_deleter(); } -+ -+ /// Returns the deleter object which would be used for destruction of the managed object (const) -+ const deleter& get_deleter() const { return smart_ptr.get_deleter(); } -+ -+ /// Copies a device-side memory allocation -+ DeviceAllocation & operator=(DeviceAllocation const &p) { -+ if (capacity != p.capacity) { -+ smart_ptr.reset(device_memory::allocate(p.capacity)); -+ capacity = p.capacity; -+ } -+ device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); -+ return *this; -+ } -+ -+ /// Move assignment -+ DeviceAllocation & operator=(DeviceAllocation && p) { -+ std::swap(smart_ptr, p.smart_ptr); -+ std::swap(capacity, p.capacity); -+ return *this; -+ } -+ -+ /// Copies the entire allocation from another location in device memory. -+ void copy_from_device(T const *ptr) const { -+ copy_from_device(ptr, capacity); -+ } -+ -+ /// Copies a given number of elements from device memory -+ void copy_from_device(T const *ptr, size_t elements) const { -+ device_memory::copy_device_to_device(get(), ptr, elements); -+ } -+ -+ void copy_to_device(T *ptr) const { -+ copy_to_device(ptr, capacity); -+ } -+ -+ void copy_to_device(T *ptr, size_t elements) const { -+ device_memory::copy_device_to_device(ptr, get(), elements); -+ } -+ -+ void copy_from_host(T const *ptr) const { -+ copy_from_host(ptr, capacity); -+ } -+ -+ void copy_from_host(T const *ptr, size_t elements) const { -+ device_memory::copy_to_device(get(), ptr, elements); -+ } -+ -+ void copy_to_host(T *ptr) const { -+ copy_to_host(ptr, capacity); -+ } -+ -+ void copy_to_host(T *ptr, size_t elements) const { -+ device_memory::copy_to_host(ptr, get(), elements); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace device_memory { -+ -+/// Device allocation abstraction that tracks size and capacity -+template -+using allocation = cutlass::DeviceAllocation; -+ -+} // namespace device_memory -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h -new file mode 100644 -index 0000000..8628c7a ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h -@@ -0,0 +1,141 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief cuda kernels to transform a device memory tensor from NCHW layout to NHWC layout. -+ */ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/tensor_ref.h" -+ -+namespace cutlass { -+ -+/** \brief interface to transform a device memory tensor from NCHW layout to NHWC layout. -+ * \tparam T: data type -+ */ -+template -+void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ cudaStream_t stream); -+ -+template -+__global__ void nchw_to_nhwc_kernel(T *output, -+ const T *input, -+ const int n, -+ const int h, -+ const int w, -+ const int c) { -+ const int hw = h*w; -+ const int chw = c*hw; -+ __shared__ T shbuf[32 * (32 + 1)]; -+ const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; -+ const int32_t wid = tid / 32; -+ const int32_t lid = tid % 32; -+ const int32_t ni = blockIdx.z; -+ const int32_t ci0 = blockIdx.y * 32; -+ const int32_t hwi0 = blockIdx.x * 32; -+ -+ const size_t input_idx = ni * chw + (ci0 + wid) * hw + hwi0; -+ const T *A = input + input_idx; -+ if (hwi0 + lid < hw) { -+ const int lid_x_33 = lid * 33; -+ if ((ci0 + 32) <= c) { -+ int ci = wid; // between 0 and 7 -+ CUTLASS_PRAGMA_UNROLL -+ for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { -+ shbuf[lid_x_33 + ci] = A[lid]; -+ A = &A[8 * hw]; -+ ci += 8; -+ } -+ } else { -+ for (int ci = wid; ci < 32; ci += 8) { -+ if ((ci + ci0) < c) { -+ shbuf[lid_x_33 + ci] = A[lid]; -+ } -+ A = &A[8 * hw]; -+ } -+ } -+ } -+ __syncthreads(); -+ -+ const int32_t ciOut = ci0 + lid; -+ output = &output[ni * chw + ciOut]; -+ if (ciOut < c) { -+ if (hwi0 + 32 < hw) { -+ int hwI = wid; -+ CUTLASS_PRAGMA_UNROLL -+ for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { -+ output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; -+ hwI += 8; -+ } -+ } else { -+ for (int hwI = wid; hwI < 32; hwI += 8) { -+ if (hwi0 + hwI < hw) { -+ output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; -+ } -+ } -+ } -+ } -+} -+ -+template -+void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ cudaStream_t stream) { -+ -+ assert( -+ input_tensor_size.n() == output_tensor_size.n() && -+ input_tensor_size.c() == output_tensor_size.h() && -+ input_tensor_size.h() == output_tensor_size.w() && -+ input_tensor_size.w() == output_tensor_size.c()); -+ -+ int n = output_tensor_size.n(); -+ int h = output_tensor_size.h(); -+ int w = output_tensor_size.w(); -+ int c = output_tensor_size.c(); -+ -+ dim3 grid((h*w + 31)/32, (c + 31)/32, n); -+ dim3 block(32, 8); -+ nchw_to_nhwc_kernel<<>>(ref_output.data(), ref_input.data(), -+ n, h, w, c); -+} -+ -+} //namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h -new file mode 100644 -index 0000000..86e5fa7 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h -@@ -0,0 +1,276 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief cuda kernels for padding in device memory with NHWC layout. -+ */ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/tensor_ref.h" -+ -+namespace cutlass { -+ -+/** \brief interface for padding in a device memory tensor with NHWC layout -+ * \tparam T: data type -+ */ -+template -+void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ cudaStream_t stream); -+ -+ -+template -+__global__ void nhwc_padding_kernel(const int32_t n, -+ const int32_t h, -+ const int32_t w, -+ const int32_t c_in, -+ const int32_t c_out, -+ const T zero, -+ const T *input, -+ T *output){ -+ -+ const int32_t idx_jump = blockDim.x * gridDim.x; -+ const int32_t total_elements = n * h * w * c_out; -+ -+ int32_t c_idx, w_idx, h_idx, n_idx, resudial; -+ -+ T value; -+ for (int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += idx_jump) { -+ -+ c_idx = idx%c_out; -+ if (c_idx >= c_in){ -+ value = zero; -+ } -+ else{ -+ resudial = idx/c_out; -+ w_idx = resudial%w; -+ resudial = resudial/w; -+ h_idx = resudial%h; -+ n_idx = resudial/h; -+ resudial = ((n_idx * h + h_idx) * w + w_idx) * c_in + c_idx; -+ value = input[resudial]; -+ } -+ output[idx] = value; -+ } -+} -+ -+ -+// fast kernel for c_in = 3 & c_out = 4 -+template -+__global__ void nhwc_padding_channel_3To4_kernel(const int32_t n, -+ const int32_t h, -+ const int32_t w, -+ const Tio *input, -+ Tio *output, -+ const int32_t max_output_element, -+ const int32_t max_input_element, -+ const Tio zero_io, -+ const Telement zero_element){ -+ __shared__ Tio shm[192]; -+ const int tidx = blockIdx.x * 192 + threadIdx.x; -+ const int threadidx = threadIdx.x; -+ -+ shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; -+ __syncthreads(); -+ -+ const int ouput_offset = blockIdx.x * 256; -+ const int lower_bound = max_output_element < ouput_offset + 256 ? max_output_element : ouput_offset + 256; -+ for (int i = ouput_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) -+ { -+ const Telement* shm_element = (const Telement*)shm + j*3*element_in_Tio/4; -+ Telement array[element_in_Tio]; -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0 ; k < element_in_Tio ; k++) -+ array[k] = ((k+1)%4 == 0) ? zero_element : shm_element[(k > 3) ? (k - 1) : k]; -+ output[i] = *((const Tio *)array); -+ } -+} -+ -+// fast kernel for c_in = 3 & c_out = 8 -+template -+__global__ void nhwc_padding_channel_3To8_kernel(const int32_t n, -+ const int32_t h, -+ const int32_t w, -+ const Tio *input, -+ Tio *output, -+ const int32_t max_output_element, -+ const int32_t max_input_element, -+ const Tio zero_io, -+ const Telement zero_element){ -+ __shared__ Tio shm[192]; -+ const int tidx = blockIdx.x * 192 + threadIdx.x; -+ const int threadidx = threadIdx.x; -+ -+ shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; -+ __syncthreads(); -+ -+ const int ouput_offset = blockIdx.x * 512; -+ const int lower_bound = max_output_element < ouput_offset + 512 ? max_output_element : ouput_offset + 512; -+ for (int i = ouput_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) -+ { -+ const Telement* shm_element = (const Telement*)shm + (element_in_Tio == 4 ? j/2 : j)*3; -+ Telement array[element_in_Tio]; -+ //float -+ if (element_in_Tio == 4){ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0 ; k < element_in_Tio ; k++) -+ array[k] = ((j % 2) == 1) ? zero_element : ((k >= 3) ? zero_element : shm_element[k]); -+ } -+ //half -+ else{ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0 ; k < element_in_Tio ; k++) -+ array[k] = (k >= 3) ? zero_element : shm_element[k]; -+ } -+ output[i] = *((const Tio *)array); -+ } -+} -+ -+template -+void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ cudaStream_t stream){ -+ assert( -+ input_tensor_size.n() == output_tensor_size.n() && -+ input_tensor_size.h() == output_tensor_size.h() && -+ input_tensor_size.w() == output_tensor_size.w() && -+ input_tensor_size.c() <= output_tensor_size.c()); -+ -+ int n = input_tensor_size.n(); -+ int h = input_tensor_size.h(); -+ int w = input_tensor_size.w(); -+ int c_in = input_tensor_size.c(); -+ int c_out = output_tensor_size.c(); -+ -+ //case 1 : channel == 3 padding to 4 or 8 -+ if ((c_out == 4 || c_out == 8) && c_in == 3 && (n*h*w % 8 == 0)){ -+ dim3 block(192); -+ const int nhw = n*h*w; -+ const int nhwc = nhw*c_in; -+ //for half_t -+ if (cutlass::sizeof_bits::value == 16){ -+ const int element_in_Tio = 8; -+ const int max_input_element = nhwc/element_in_Tio; -+ const int max_output_element = nhw*c_out/element_in_Tio; -+ const int4 zero_io = {0, 0, 0, 0}; -+ const half_t zero_element = static_cast(0.0f); -+ dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); -+ if (c_out == 4){ -+ nhwc_padding_channel_3To4_kernel<<>> -+ (n, h, w, -+ (const int4 *)ref_input.data(), -+ (int4 *)ref_output.data(), -+ max_output_element, -+ max_input_element, -+ zero_io, -+ zero_element); -+ } -+ else if (c_out == 8){ -+ nhwc_padding_channel_3To8_kernel<<>> -+ (n, h, w, -+ (const int4 *)ref_input.data(), -+ (int4 *)ref_output.data(), -+ max_output_element, -+ max_input_element, -+ zero_io, -+ zero_element); -+ } -+ } -+ //for float -+ else{ -+ const int element_in_Tio = 4; -+ const int max_input_element = nhwc/element_in_Tio; -+ const int max_output_element = nhw*c_out/element_in_Tio; -+ const float4 zero_io = {0.0f, 0.0f, 0.0f, 0.0f}; -+ const float zero_element = 0.0f; -+ dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); -+ if (c_out == 4){ -+ nhwc_padding_channel_3To4_kernel<<>> -+ (n, h, w, -+ (const float4 *)ref_input.data(), -+ (float4 *)ref_output.data(), -+ max_output_element, -+ max_input_element, -+ zero_io, -+ zero_element); -+ } -+ else if (c_out == 8){ -+ nhwc_padding_channel_3To8_kernel<<>> -+ (n, h, w, -+ (const float4 *)ref_input.data(), -+ (float4 *)ref_output.data(), -+ max_output_element, -+ max_input_element, -+ zero_io, -+ zero_element); -+ } -+ } -+ } -+ //case 2 : even channel -+ else if ((c_out % 2) == 0 && (c_in % 2) == 0){ -+ int32_t total_elements = n * h * w * c_out / 2; -+ int block_size = 256; -+ dim3 grid((total_elements + 255)/256); -+ dim3 block(block_size); -+ //for half_t -+ if (cutlass::sizeof_bits::value == 16){ -+ const __half2 zero = {0.0f, 0.0f}; -+ nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const __half2*)ref_input.data(), (__half2*)ref_output.data()); -+ } -+ //for float -+ else{ -+ const float2 zero = {0.0f, 0.0f}; -+ nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const float2*)ref_input.data(), (float2*)ref_output.data()); -+ } -+ } -+ //case 3 : odd channel -+ else{ -+ int32_t total_elements = n * h * w * c_out; -+ int block_size = 256; -+ dim3 grid((total_elements + 255)/256); -+ dim3 block(block_size); -+ const T zero = static_cast(0.0f); -+ nhwc_padding_kernel<<>>(n, h, w, c_in, c_out, zero, ref_input.data(), ref_output.data()); -+ } -+} -+ -+ -+} //namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h -new file mode 100644 -index 0000000..6bdf866 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h -@@ -0,0 +1,576 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief cuda kernels to do avg/max pooling on a device memory tensor with NHWC layout. -+ */ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/tensor_ref.h" -+#include "device_utils.h" -+#include -+ -+namespace cutlass { -+ -+/** \brief interface to do avg/max pooling on a device memory tensor with NHWC layout. -+ * \tparam T: data type -+ */ -+template -+void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord filter_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ cutlass::MatrixCoord padding, -+ cutlass::MatrixCoord stride, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ int poolingType, //0 for avg pooling ; 1 for max pooling -+ cudaStream_t stream); -+ -+/** get the output size of pooling -+ */ -+inline int getOutputSize(int H_W, int padding, int kernel_size, int stride) -+{ -+ return (H_W + 2 * padding - kernel_size) / stride + 1; -+} -+ -+/** -+ * input is [N, H, W, C] -+ * assume stride == kernel_size -+ * output_h = (H + 2*padding_H - kernel_H)/stride_H -+ * output_w = (W + 2*padding_W - kernel_W)/stride_W -+ * output is [N, output_h, output_w, C] -+ * grid(N, output_h, output_w) -+ * block(min(C, 256)) : -+ * each block deals with C elements of output when each thread deals with ((C + 255)/256 element of output) -+*/ -+template -+__global__ void pooling_nhwc_element1_kernel(T* output, -+ const T* input, -+ const int N, -+ const int H, -+ const int W, -+ const int C, -+ const int output_H, -+ const int output_W, -+ const int kernel_H, -+ const int kernel_W, -+ const int stride_H, -+ const int stride_W, -+ const int padding_H, -+ const int padding_W) -+{ -+ const int tid = threadIdx.x; -+ const int n_idx = blockIdx.x; -+ const int output_h_idx = blockIdx.y; -+ const int output_w_idx = blockIdx.z; -+ -+ int h_start_idx = output_h_idx * stride_H - padding_H; -+ int h_end_idx = h_start_idx + kernel_H; -+ h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; -+ h_end_idx = h_end_idx > H ? H : h_end_idx; -+ -+ int w_start_idx = output_w_idx * stride_W - padding_W; -+ int w_end_idx = w_start_idx + kernel_W; -+ w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; -+ w_end_idx = w_end_idx > W ? W : w_end_idx; -+ -+ input += n_idx * H * W * C; -+ output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; -+ const int kernel_size2 = kernel_H * kernel_W; -+ for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { -+ float pooling; -+ if (IS_AVG_POOLING){ -+ pooling = 0.0f; -+ } -+ else{ -+ pooling = -FLT_MAX; -+ } -+ for (int h = h_start_idx; h < h_end_idx; h++) { -+ for (int w = w_start_idx; w < w_end_idx; w++) { -+ const int idx = (h * W + w) * C; -+ const float tmp = static_cast(input[idx + c_idx]); -+ if (IS_AVG_POOLING){ -+ pooling = pooling + tmp; -+ } -+ else{ -+ pooling = pooling > tmp ? pooling : tmp; -+ } -+ } -+ } -+ -+ T output_val; -+ if (IS_AVG_POOLING){ -+ output_val = T(pooling/kernel_size2); -+ } -+ else{ -+ output_val = T(pooling); -+ } -+ output[c_idx] = output_val; -+ } -+} -+ -+template -+__global__ void pooling_nhwc_element2_kernel(T2* output, -+ const T2* input, -+ const int N, -+ const int H, -+ const int W, -+ const int C, -+ const int output_H, -+ const int output_W, -+ const int kernel_H, -+ const int kernel_W, -+ const int stride_H, -+ const int stride_W, -+ const int padding_H, -+ const int padding_W) -+{ -+ const int tid = threadIdx.x; -+ const int n_idx = blockIdx.x; -+ const int output_h_idx = blockIdx.y; -+ const int output_w_idx = blockIdx.z; -+ -+ int h_start_idx = output_h_idx * stride_H - padding_H; -+ int h_end_idx = h_start_idx + kernel_H; -+ h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; -+ h_end_idx = h_end_idx > H ? H : h_end_idx; -+ -+ int w_start_idx = output_w_idx * stride_W - padding_W; -+ int w_end_idx = w_start_idx + kernel_W; -+ w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; -+ w_end_idx = w_end_idx > W ? W : w_end_idx; -+ -+ input += n_idx * H * W * C; -+ output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; -+ const int kernel_size2 = kernel_H * kernel_W; -+ for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { -+ float2 pooling; -+ if (IS_AVG_POOLING) { -+ pooling = {0.0f, 0.0f}; -+ } -+ else { -+ pooling = {-FLT_MAX, -FLT_MAX}; -+ } -+ for (int h = h_start_idx; h < h_end_idx; h++) { -+ for (int w = w_start_idx; w < w_end_idx; w++) { -+ const int idx = (h * W + w) * C; -+ const T2 tmp = input[idx + c_idx]; -+ const float2 tmp_flt2 = {static_cast(tmp.x), static_cast(tmp.y)}; -+ if (IS_AVG_POOLING) { -+ pooling.x += tmp_flt2.x; -+ pooling.y += tmp_flt2.y; -+ } -+ else { -+ pooling.x = pooling.x > tmp_flt2.x ? pooling.x : tmp_flt2.x; -+ pooling.y = pooling.y > tmp_flt2.y ? pooling.y : tmp_flt2.y; -+ } -+ } -+ } -+ -+ T2 output_val; -+ if (IS_AVG_POOLING) { -+ output_val.x = T(pooling.x/kernel_size2); -+ output_val.y = T(pooling.y/kernel_size2); -+ } -+ else { -+ output_val.x = T(pooling.x); -+ output_val.y = T(pooling.y); -+ } -+ output[c_idx] = output_val; -+ } -+} -+ -+/** -+ * output [N, 1, 1, C] -+ * input [N, H, W, C] -+ * grid(C, N) -+ * block(block_size) -- each block deals with H*W/block_size elements; -+*/ -+template -+__global__ void pooling_nxhTo1x1_element1_kernel( -+ T* output, const T* input, const int N, const int HW, const int C) -+{ -+ const int c_idx = blockIdx.x; -+ const int n_idx = blockIdx.y; -+ float pooling[1]; -+ if (IS_AVG_POOLING) { -+ pooling[0] = 0.0f; -+ } -+ else { -+ pooling[0] = -FLT_MAX; -+ } -+ const size_t input_offset = n_idx * HW * C + c_idx; -+ input += input_offset; -+ const size_t output_offset = n_idx * C + c_idx; -+ output += output_offset; -+ int tid = threadIdx.x; -+ -+ for (int index = tid; index < HW; index += blockDim.x) { -+ float val = static_cast(input[index * C]); -+ if (IS_AVG_POOLING) { -+ pooling[0] += val; -+ } -+ else { -+ pooling[0] = pooling[0] > val ? pooling[0] : val; -+ } -+ } -+ if (blockDim.x <= 32) { -+ if (IS_AVG_POOLING) { -+ warpReduceSum(pooling); -+ } -+ else { -+ warpReduceMax(pooling); -+ } -+ } -+ else { -+ if (IS_AVG_POOLING) { -+ blockReduceSum(pooling); -+ } -+ else { -+ blockReduceMax(pooling); -+ } -+ } -+ __syncthreads(); -+ if (threadIdx.x == 0) { -+ T output_val; -+ if (IS_AVG_POOLING) { -+ output_val = T(pooling[0] / HW); -+ } -+ else { -+ output_val = T(pooling[0]); -+ } -+ output[0] = output_val; -+ } -+} -+ -+ -+/** -+ * output [N, 1, 1, C] -+ * input [N, H, W, C] -+ * grid(C/2, N) -+ * block(block_size) -- each thread deals with H*W/block_size * 2 elements; -+*/ -+template -+__global__ void pooling_nxhTo1x1_element2_kernel( -+ T2* output, const T2* input, const int N, const int HW, const int C) -+{ -+ const int c_idx = blockIdx.x; -+ const int n_idx = blockIdx.y; -+ float pooling[2]; -+ if (IS_AVG_POOLING) { -+ pooling[0] = pooling[1] = 0.0f; -+ } -+ else { -+ pooling[0] = pooling[1] = -FLT_MAX; -+ } -+ const int C_2 = C / 2; -+ const size_t input_offset = n_idx * HW * C_2 + c_idx; -+ input += input_offset; -+ const size_t output_offset = n_idx * C_2 + c_idx; -+ output += output_offset; -+ int tid = threadIdx.x; -+ -+ for (int index = tid; index < HW; index += blockDim.x) { -+ T2 val = input[index * C_2]; -+ float2 val_flt2 = {static_cast(val.x), static_cast(val.y)}; -+ if (IS_AVG_POOLING) { -+ pooling[0] += val_flt2.x; -+ pooling[1] += val_flt2.y; -+ } -+ else { -+ pooling[0] = pooling[0] > val_flt2.x ? pooling[0] : val_flt2.x; -+ pooling[1] = pooling[1] > val_flt2.y ? pooling[1] : val_flt2.y; -+ } -+ } -+ if (blockDim.x <= 32) { -+ if (IS_AVG_POOLING) { -+ warpReduceSum(pooling); -+ } -+ else { -+ warpReduceMax(pooling); -+ } -+ } -+ else { -+ if (IS_AVG_POOLING) { -+ blockReduceSum(pooling); -+ } -+ else { -+ blockReduceMax(pooling); -+ } -+ } -+ __syncthreads(); -+ if (threadIdx.x == 0) { -+ T2 output_val; -+ if (IS_AVG_POOLING) { -+ output_val.x = T(pooling[0] / HW); -+ output_val.y = T(pooling[1] / HW); -+ } -+ else { -+ output_val.x = T(pooling[0]); -+ output_val.y = T(pooling[1]); -+ } -+ output[0] = output_val; -+ } -+} -+ -+template -+void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord filter_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ cutlass::Tensor4DCoord padding, -+ cutlass::MatrixCoord stride, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ int poolingType, //0 for avg pooling ; 1 for max pooling -+ cudaStream_t stream) { -+ -+ assert(input_tensor_size.n() == output_tensor_size.n() && -+ input_tensor_size.c() == output_tensor_size.c()); -+ -+ assert(filter_tensor_size.h() == stride.row() && -+ filter_tensor_size.w() == stride.column()); -+ -+ const int N = input_tensor_size.n(); -+ const int H = input_tensor_size.h(); -+ const int W = input_tensor_size.w(); -+ const int C = input_tensor_size.c(); -+ const int padding_H = padding.h(); -+ const int padding_W = padding.w(); -+ const int kernel_H = filter_tensor_size.h(); -+ const int kernel_W = filter_tensor_size.w(); -+ const int stride_H = stride.row(); -+ const int stride_W = stride.column(); -+ -+ const int output_H = getOutputSize(H, padding_H, kernel_H, stride_H); -+ const int output_W = getOutputSize(W, padding_W, kernel_W, stride_W); -+ -+ assert(output_tensor_size.h() == output_H && -+ output_tensor_size.w() == output_W); -+ -+ if (C % 2 != 0) { -+ if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { -+ dim3 grid(C, N); -+ dim3 block(256); -+ if (H*W < block.x){ -+ block.x = (H*W + 31)/32*32; -+ } -+ if (poolingType == 0) { -+ pooling_nxhTo1x1_element1_kernel<<>>( -+ ref_output.data(), -+ ref_input.data(), -+ N, -+ H*W, -+ C); -+ } // if (poolingType == 0) -+ else { -+ pooling_nxhTo1x1_element1_kernel<<>>( -+ ref_output.data(), -+ ref_input.data(), -+ N, -+ H*W, -+ C); -+ } -+ } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) -+ else { -+ dim3 grid(N, output_H, output_W); -+ dim3 block(256); -+ if (C < block.x) { -+ block.x = C; -+ } -+ if (poolingType == 0) { -+ pooling_nhwc_element1_kernel<<>>( -+ ref_output.data(), -+ ref_input.data(), -+ N, -+ H, -+ W, -+ C, -+ output_H, -+ output_W, -+ kernel_H, -+ kernel_W, -+ stride_H, -+ stride_W, -+ padding_H, -+ padding_W); -+ } // if (poolingType == 0) -+ else { -+ pooling_nhwc_element1_kernel<<>>( -+ ref_output.data(), -+ ref_input.data(), -+ N, -+ H, -+ W, -+ C, -+ output_H, -+ output_W, -+ kernel_H, -+ kernel_W, -+ stride_H, -+ stride_W, -+ padding_H, -+ padding_W); -+ } -+ } -+ } // if (C % 2 != 0)) -+ else { -+ if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { -+ dim3 grid(C/2, N); -+ dim3 block(256); -+ if (H*W < block.x){ -+ block.x = (H*W + 31)/32*32; -+ } -+ if (poolingType == 0) { -+ if (std::is_same::value) { -+ pooling_nxhTo1x1_element2_kernel<<>>( -+ (float2*)(ref_output.data()), -+ (const float2*)(ref_input.data()), -+ N, -+ H*W, -+ C); -+ } // if (std::is_same::value) -+ else { -+ pooling_nxhTo1x1_element2_kernel<<>>( -+ (half2*)(ref_output.data()), -+ (const half2*)(ref_input.data()), -+ N, -+ H*W, -+ C); -+ } -+ } // if (poolingType == 0) -+ else { -+ if (std::is_same::value) { -+ pooling_nxhTo1x1_element2_kernel<<>>( -+ (float2*)(ref_output.data()), -+ (const float2*)(ref_input.data()), -+ N, -+ H*W, -+ C); -+ } // if (std::is_same::value) -+ else { -+ pooling_nxhTo1x1_element2_kernel<<>>( -+ (half2*)(ref_output.data()), -+ (const half2*)(ref_input.data()), -+ N, -+ H*W, -+ C); -+ } -+ } -+ } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) -+ else { -+ dim3 grid(N, output_H, output_W); -+ dim3 block(256); -+ if (C/2 < block.x) { -+ block.x = C/2; -+ } -+ if (poolingType == 0) { -+ if (std::is_same::value) { -+ pooling_nhwc_element2_kernel<<>>( -+ (float2*)(ref_output.data()), -+ (const float2*)(ref_input.data()), -+ N, -+ H, -+ W, -+ C/2, -+ output_H, -+ output_W, -+ kernel_H, -+ kernel_W, -+ stride_H, -+ stride_W, -+ padding_H, -+ padding_W); -+ } // if (std::is_same::value) -+ else { -+ pooling_nhwc_element2_kernel<<>>( -+ (half2*)(ref_output.data()), -+ (const half2*)(ref_input.data()), -+ N, -+ H, -+ W, -+ C/2, -+ output_H, -+ output_W, -+ kernel_H, -+ kernel_W, -+ stride_H, -+ stride_W, -+ padding_H, -+ padding_W); -+ } -+ } // if (poolingType == 0) -+ else { -+ if (std::is_same::value) { -+ pooling_nhwc_element2_kernel<<>>( -+ (float2*)(ref_output.data()), -+ (const float2*)(ref_input.data()), -+ N, -+ H, -+ W, -+ C/2, -+ output_H, -+ output_W, -+ kernel_H, -+ kernel_W, -+ stride_H, -+ stride_W, -+ padding_H, -+ padding_W); -+ } // if (std::is_same::value) -+ else { -+ pooling_nhwc_element2_kernel<<>>( -+ (half2*)(ref_output.data()), -+ (const half2*)(ref_input.data()), -+ N, -+ H, -+ W, -+ C/2, -+ output_H, -+ output_W, -+ kernel_H, -+ kernel_W, -+ stride_H, -+ stride_W, -+ padding_H, -+ padding_W); -+ } -+ } -+ } -+ } -+} -+ -+} //namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h -new file mode 100644 -index 0000000..d71fd1e ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h -@@ -0,0 +1,144 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief cuda kernels to transform a device memory tensor from NHWC layout to NCHW layout. -+ */ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/tensor_ref.h" -+ -+namespace cutlass { -+ -+/** \brief interface to transform a device memory tensor from NHWC layout to NCHW layout. -+ * \tparam T: data type -+ */ -+template -+void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ cudaStream_t stream); -+ -+ -+template -+__global__ void nhwc_to_nchw_kernel(T *output, -+ const T *input, -+ const int n, -+ const int h, -+ const int w, -+ const int c) { -+ -+ const int hw = h*w; -+ const int hwc = hw*c; -+ __shared__ T shbuf[32 * (32 + 1)]; -+ const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; -+ const int32_t wid = tid / 32; -+ const int32_t lid = tid % 32; -+ const int32_t ni = blockIdx.z; -+ const int32_t hwi0 = blockIdx.y * 32; -+ const int32_t ci0 = blockIdx.x * 32; -+ -+ const size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0; -+ const T *A = input + input_idx; -+ if (ci0 + lid < c) { -+ const int lid_x_33 = lid * 33; -+ if ((hwi0 + 32) <= hw) { -+ int hwi = wid; // between 0 and 7 -+ CUTLASS_PRAGMA_UNROLL -+ for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { -+ shbuf[lid_x_33 + hwi] = A[lid]; -+ A = &A[8 * c]; -+ hwi += 8; -+ } -+ } else { -+ for (int hwi = wid; hwi < 32; hwi += 8) { -+ if ((hwi + hwi0) < hw) { -+ shbuf[lid_x_33 + hwi] = A[lid]; -+ } -+ A = &A[8 * c]; -+ } -+ } -+ } -+ __syncthreads(); -+ -+ const int32_t hwiOut = hwi0 + lid; -+ output = &output[ni * hwc + hwiOut]; -+ if (hwiOut < hw) { -+ if (ci0 + 32 < c) { -+ int cI = wid; -+ CUTLASS_PRAGMA_UNROLL -+ for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { -+ output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; -+ cI += 8; -+ } -+ } else { -+ for (int cI = wid; cI < 32; cI += 8) { -+ if (ci0 + cI < c) { -+ output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; -+ } -+ } -+ } -+ } -+} -+ -+template -+void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ cudaStream_t stream) { -+ -+ assert( -+ input_tensor_size.n() == output_tensor_size.n() && -+ input_tensor_size.h() == output_tensor_size.c() && -+ input_tensor_size.w() == output_tensor_size.h() && -+ input_tensor_size.c() == output_tensor_size.w()); -+ -+ int n = input_tensor_size.n(); -+ int h = input_tensor_size.h(); -+ int w = input_tensor_size.w(); -+ int c = input_tensor_size.c(); -+ -+ dim3 grid((c + 31)/32, (h*w + 31)/32, n); -+ dim3 block(32, 8); -+ nhwc_to_nchw_kernel<<>>(ref_output.data(), ref_input.data(), -+ n, h, w, c); -+ -+} -+ -+} //namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_utils.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_utils.h -new file mode 100644 -index 0000000..00414a5 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_utils.h -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief utils code for device cutlass code -+*/ -+ -+#pragma once -+ -+#include -+#include -+#define FINAL_MASK 0xffffffff -+ -+struct half4 { -+ half x, y, z, w; -+}; -+ -+template -+__inline__ __device__ T warpReduceSum(T* val) -+{ -+#pragma unroll -+ for (int i = 0; i < NUM; i++) { -+#pragma unroll -+ for (int mask = 16; mask > 0; mask >>= 1) -+ val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); -+ } -+ return (T)(0.0f); -+} -+ -+template -+__inline__ __device__ T blockReduceSum(T* val) -+{ -+ __shared__ T shared[NUM][33]; -+ int lane = threadIdx.x & 0x1f; -+ int wid = threadIdx.x >> 5; -+ -+ warpReduceSum(val); -+ -+ if (lane == 0) { -+#pragma unroll -+ for (int i = 0; i < NUM; i++) { -+ shared[i][wid] = val[i]; -+ } -+ } -+ -+ __syncthreads(); -+ -+ bool is_mask = threadIdx.x < (blockDim.x / 32.f); -+#pragma unroll -+ for (int i = 0; i < NUM; i++) { -+ val[i] = is_mask ? shared[i][lane] : (T)(0.0f); -+ } -+ warpReduceSum(val); -+ return (T)0.0f; -+} -+ -+template -+__inline__ __device__ T warpReduceMax(T* val) -+{ -+#pragma unroll -+ for (int i = 0; i < NUM; i++) { -+#pragma unroll -+ for (int mask = 16; mask > 0; mask >>= 1) -+ val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); -+ } -+ return (T)(0.0f); -+} -+ -+template -+__inline__ __device__ T blockReduceMax(T* val) -+{ -+ static __shared__ T shared[32][NUM]; -+ int lane = threadIdx.x & 0x1f; // in-warp idx -+ int wid = threadIdx.x >> 5; // warp idx -+ -+ warpReduceMax(val); // get maxx in each warp -+ -+ if (lane == 0) // record in-warp maxx by warp Idx -+ { -+#pragma unroll -+ for (int i = 0; i < NUM; i++) { -+ shared[wid][i] = val[i]; -+ } -+ } -+ -+ __syncthreads(); -+ -+ // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent -+ // blockDim.x is not divided by 32 -+ bool is_mask = threadIdx.x < (blockDim.x / 32.f); -+#pragma unroll -+ for (int i = 0; i < NUM; i++) { -+ val[i] = is_mask ? shared[lane][i] : (T)(-FLT_MAX); -+ } -+ warpReduceMax(val); -+ -+ return (T)0.0f; -+} -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/distribution.h b/3rdparty/cutlass/tools/util/include/cutlass/util/distribution.h -new file mode 100644 -index 0000000..7fee888 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/distribution.h -@@ -0,0 +1,143 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+/*! \file -+ \brief This header contains a class to parametrize a statistical distribution function. -+*/ -+ -+#include -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Distribution type -+struct Distribution { -+ /// Variant types -+ enum Kind { Invalid, Uniform, Gaussian, Identity, Sequential, AllZeros, AllOnes }; -+ -+ /// Distribution state -+ union { -+ /// Uniform distribution -+ struct { -+ double min; -+ double max; -+ } uniform; -+ -+ /// Gaussian distribution -+ struct { -+ double mean; -+ double stddev; -+ } gaussian; -+ -+ /// Elements are linear combination of row and column index -+ struct { -+ double start; -+ double delta; -+ } sequential; -+ }; -+ -+ /// Active variant kind -+ Kind kind; -+ -+ /// Random values are cast to integer after scaling by this power of two -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ Distribution() : kind(Invalid), int_scale(0) {} -+ -+ /// Configures distribution as uniform random -+ Distribution &set_uniform(double _min, double _max, int _int_scale = 0) { -+ kind = Uniform; -+ uniform.min = _min; -+ uniform.max = _max; -+ int_scale = _int_scale; -+ return *this; -+ } -+ -+ /// Configures distribution as Gaussian distribution -+ Distribution &set_gaussian(double _mean, double _stddev, int _int_scale = 0) { -+ kind = Gaussian; -+ gaussian.mean = _mean; -+ gaussian.stddev = _stddev; -+ int_scale = _int_scale; -+ return *this; -+ } -+ -+ /// Sets identity -+ Distribution &set_identity() { -+ kind = Identity; -+ return *this; -+ } -+ -+ /// Sets sequential -+ Distribution &set_sequential(double start, double delta, int _int_scale = 0) { -+ kind = Sequential; -+ sequential.start = start; -+ sequential.delta = delta; -+ int_scale = _int_scale; -+ return *this; -+ } -+}; -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Prints a Distribution to ostream -+inline std::ostream &operator<<(std::ostream &out, cutlass::Distribution const &dist) { -+ switch (dist.kind) { -+ case cutlass::Distribution::Uniform: -+ out << "uniform, min: " << dist.uniform.min << ", max: " << dist.uniform.max; -+ break; -+ case cutlass::Distribution::Gaussian: -+ out << "gaussian, mean: " << dist.gaussian.mean << ", stddev: " << dist.gaussian.stddev; -+ break; -+ case cutlass::Distribution::Identity: -+ out << "identity"; -+ break; -+ case cutlass::Distribution::Sequential: -+ out << "sequential"; -+ break; -+ default: -+ out << "unknown"; -+ } -+ -+ out << ", int_scale: " << dist.int_scale; -+ -+ return out; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/exceptions.h b/3rdparty/cutlass/tools/util/include/cutlass/util/exceptions.h -new file mode 100644 -index 0000000..a349d49 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/exceptions.h -@@ -0,0 +1,69 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief C++ exception semantics for CUDA error codes -+ */ -+ -+#include -+#include -+#include -+ -+#include "cutlass/platform/platform.h" -+ -+namespace cutlass { -+ -+/// C++ exception wrapper for CUDA \p cudaError_t -+class cuda_exception : public std::exception { -+ public: -+ /// Constructor -+ cuda_exception(const char* msg = "", cudaError_t err = cudaErrorUnknown) : msg(msg), err(err) {} -+ -+ /// Returns the underlying CUDA \p cudaError_t -+ cudaError_t cudaError() const { return err; } -+ -+ protected: -+ /// Explanatory string -+ const char* msg; -+ -+ /// Underlying CUDA \p cudaError_t -+ cudaError_t err; -+}; -+ -+/// Writes a cuda_exception instance to an output stream -+inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) { -+ return out << e.what() << ": " << cudaGetErrorString(e.cudaError()); -+} -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp -new file mode 100644 -index 0000000..15e0bc8 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp -@@ -0,0 +1,116 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+void -+device_init(int device_id, bool quiet = false) -+{ -+ cudaDeviceProp device_prop; -+ std::size_t device_free_physmem; -+ std::size_t device_total_physmem; -+ -+ CUTE_CHECK_ERROR(cudaSetDevice(device_id)); -+ CUTE_CHECK_ERROR(cudaMemGetInfo(&device_free_physmem, &device_total_physmem)); -+ CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); -+ -+ if (device_prop.major < 1) { -+ fprintf(stderr, "Device does not support CUDA.\n"); -+ exit(1); -+ } -+ -+ //float device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000; -+ -+ if (!quiet) { -+ printf("Using device %d: %s (SM%d, %d SMs)\n", -+ device_id, device_prop.name, -+ device_prop.major * 10 + device_prop.minor, -+ device_prop.multiProcessorCount); -+ fflush(stdout); -+ } -+} -+ -+/** -+ * Convert the SM version (e.g. v7.0, v7.5) to the physical number of cores. -+ */ -+inline int -+_ConvertSMVer2Cores(int major, int minor) -+{ -+ // Defines for GPU Architecture types (using the SM version to determine -+ // the # of cores per SM -+ typedef struct { -+ int SM; // 0xMm (hexidecimal notation), M = SM Major version, -+ // and m = SM minor version -+ int Cores; -+ } sSMtoCores; -+ -+ sSMtoCores nGpuArchCoresPerSM[] = { -+ {0x30, 192}, -+ {0x32, 192}, -+ {0x35, 192}, -+ {0x37, 192}, -+ {0x50, 128}, -+ {0x52, 128}, -+ {0x53, 128}, -+ {0x60, 64}, -+ {0x61, 128}, -+ {0x62, 128}, -+ {0x70, 64}, -+ {0x72, 64}, -+ {0x75, 64}, -+ {-1, -1}}; -+ -+ int index = 0; -+ -+ while (nGpuArchCoresPerSM[index].SM != -1) { -+ if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { -+ return nGpuArchCoresPerSM[index].Cores; -+ } -+ index++; -+ } -+ -+ // If we don't find the values, we default use the previous one -+ // to run properly -+ printf("MapSMtoCores for SM %d.%d is undefined." -+ " Default to use %d Cores/SM\n", -+ major, minor, nGpuArchCoresPerSM[index - 1].Cores); -+ -+ return nGpuArchCoresPerSM[index - 1].Cores; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/host_reorder.h b/3rdparty/cutlass/tools/util/include/cutlass/util/host_reorder.h -new file mode 100644 -index 0000000..c17c0a2 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/host_reorder.h -@@ -0,0 +1,111 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief reorder data from the host side -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace cutlass { -+ -+/// This is needed for the interleaved integer tensor core kernels. The purpose -+/// is to use skip the shared memory part in the epilogue. -+template -+void reorder_column(TensorRef dest, -+ TensorRef src, -+ cutlass::gemm::GemmCoord problem_size) { -+ const int InstructionShapeCol = 8; -+ // 4 threads per Quad -+ const int ElementsPerThread = InstructionShapeCol / 4; -+ // 4 threads per Quad -+ const int ReorderedElementsPerThread = -+ Interleaved / 4; -+ -+ for (int n = 0; n < problem_size.n(); n++) { -+ for (int k = 0; k < problem_size.k(); k++) { -+ dest.at({k, (n / Interleaved) * Interleaved + -+ ((n % ReorderedElementsPerThread) / ElementsPerThread) * -+ InstructionShapeCol + -+ ((n % Interleaved) / ReorderedElementsPerThread) * -+ ElementsPerThread + -+ (n % ElementsPerThread)}) = src.at({k, n}); -+ } -+ } -+} -+ -+template -+void reorder_convK(TensorRef dest, -+ TensorRef src, -+ cutlass::gemm::GemmCoord problem_size) { -+ -+ TensorRef> mappedDest(dest.data(), dest.stride(0)); -+ TensorRef> mappedSrc(src.data(), src.stride(0)); -+ -+ reorder_column( -+ mappedDest, mappedSrc, problem_size); -+} -+ -+/// This is needed for the sparse tensor core kernels. The purpose -+/// is to use ldmatrix to load from shared memory to the register file. -+template -+void reorder_meta(TensorRef dest, -+ TensorRef src, -+ cutlass::gemm::GemmCoord problem_size) { -+ for (int m = 0; m < problem_size.m(); m++) { -+ for (int k = 0; k < problem_size.k(); k++) { -+ // First reorder the rows. -+ int group = (sizeof(Element) == 2) ? 32 : 16; -+ int interweave = (sizeof(Element) == 2) ? 4 : 2; -+ -+ int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8; -+ int dest_col = k; -+ -+ // Next swizzle the 2x2 blocks from Z to N. -+ if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) { -+ ++dest_row; -+ --dest_col; -+ } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) { -+ --dest_row; -+ ++dest_col; -+ } -+ -+ dest.at({dest_row, dest_col}) = src.at({m, k}); -+ } -+ } -+} -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/host_tensor.h b/3rdparty/cutlass/tools/util/include/cutlass/util/host_tensor.h -new file mode 100644 -index 0000000..9909ee9 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/host_tensor.h -@@ -0,0 +1,507 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+/*! \file -+ \brief HostTensor contributes management for both host and device memory. -+ -+ HostTensor allocates host and device memory upon construction. Basic element-wise operations on -+ host memory synchronize device memory automatically. Explicit copy operations provide abstractions -+ for CUDA memcpy operations. -+ -+ Call {host, device}_{data, ref, view}() for accessing host or device memory. -+ -+ See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+#include "device_memory.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Host tensor -+template < -+ /// Data type of element stored within tensor (concept: NumericType) -+ typename Element_, -+ /// Defines a mapping from logical coordinate to linear memory (concept: Layout) -+ typename Layout_ -+> -+class HostTensor { -+public: -+ -+ /// Data type of individual access -+ using Element = Element_; -+ -+ /// Mapping function from logical coordinate to linear memory -+ using Layout = Layout_; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = Layout::kRank; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Layout's stride vector -+ using Stride = typename Layout::Stride; -+ -+ /// Tensor reference to device memory -+ using TensorRef = TensorRef; -+ -+ /// Tensor reference to constant device memory -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ /// Tensor reference to device memory -+ using TensorView = TensorView; -+ -+ /// Tensor reference to constant device memory -+ using ConstTensorView = typename TensorView::ConstTensorView; -+ -+ /// Reference to element in tensor -+ using Reference = typename TensorRef::Reference; -+ -+ /// Constant reference to element in tensor -+ using ConstReference = typename ConstTensorRef::Reference; -+ -+ /// Used to handle packing of subbyte elements -+ static int const kElementsPerStoredItem = (sizeof_bits::value < 8 ? (8 / sizeof_bits::value) : 1); -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Extent of tensor in logical dimensions -+ TensorCoord extent_; -+ -+ /// Layout object -+ Layout layout_; -+ -+ /// Host-side memory allocation -+ std::vector host_; -+ -+ /// Device-side memory -+ device_memory::allocation device_; -+ -+ public: -+ // -+ // Device and Host Methods -+ // -+ -+ /// Default constructor -+ HostTensor() {} -+ -+ /// Constructs a tensor given an extent. Assumes a packed layout -+ HostTensor( -+ TensorCoord const &extent, -+ bool device_backed = true -+ ) { -+ -+ this->reset(extent, Layout::packed(extent), device_backed); -+ } -+ -+ /// Constructs a tensor given an extent and layout -+ HostTensor( -+ TensorCoord const &extent, -+ Layout const &layout, -+ bool device_backed = true -+ ) { -+ -+ this->reset(extent, layout, device_backed); -+ } -+ -+ ~HostTensor() { } -+ -+ /// Clears the HostTensor allocation to size/capacity = 0 -+ void reset() { -+ extent_ = TensorCoord(); -+ layout_ = Layout::packed(extent_); -+ -+ host_.clear(); -+ device_.reset(); -+ } -+ -+ /// Resizes internal memory allocations without affecting layout or extent -+ void reserve( -+ size_t count, ///< size of tensor in elements -+ bool device_backed_ = true) { ///< if true, device memory is also allocated -+ -+ device_.reset(); -+ host_.clear(); -+ -+ count /= kElementsPerStoredItem; -+ -+ host_.resize(count); -+ -+ // Allocate memory -+ Element* device_memory = nullptr; -+ if (device_backed_) { -+ device_memory = device_memory::allocate(count); -+ } -+ device_.reset(device_memory, device_backed_ ? count : 0); -+ } -+ -+ /// Updates the extent and layout of the HostTensor. Allocates memory according to the new -+ /// extent and layout. -+ void reset( -+ TensorCoord const &extent, ///< extent of logical tensor -+ Layout const &layout, ///< layout object of tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ extent_ = extent; -+ layout_ = layout; -+ -+ reserve(size_t(layout_.capacity(extent_)), device_backed_); -+ } -+ -+ /// Updates the extent and layout of the HostTensor. Allocates memory according to the new -+ /// extent and layout. Assumes a packed tensor configuration. -+ void reset( -+ TensorCoord const &extent, ///< extent of logical tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ reset(extent, Layout::packed(extent), device_backed_); -+ } -+ -+ /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. -+ /// To force allocation, call reset(). -+ void resize( -+ TensorCoord const &extent, ///< extent of logical tensor -+ Layout const &layout, ///< layout object of tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ extent_ = extent; -+ layout_ = layout; -+ -+ LongIndex new_size = size_t(layout_.capacity(extent_)); -+ -+ if (static_cast(new_size) > host_.size()) { -+ reserve(new_size); -+ } -+ } -+ -+ /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. -+ /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. -+ void resize( -+ TensorCoord const &extent, ///< extent of logical tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ resize(extent, Layout::packed(extent), device_backed_); -+ } -+ -+ /// Returns the number of elements stored in the host tensor -+ size_t size() const { -+ return host_.size() * kElementsPerStoredItem; -+ } -+ -+ /// Returns the logical capacity based on extent and layout. May differ from size(). -+ LongIndex capacity() const { -+ return layout_.capacity(extent_); -+ } -+ -+ /// Gets pointer to host data -+ Element * host_data() { return host_.data(); } -+ -+ /// Gets pointer to host data with a pointer offset -+ Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(host_.data(), ptr_element_offset); } -+ -+ /// Gets a reference to an element in host memory -+ Reference host_data(LongIndex idx) { -+ return ReferenceFactory::get(host_data(), idx); -+ } -+ -+ /// Gets pointer to host data -+ Element const * host_data() const { return host_.data(); } -+ -+ /// Gets a constant reference to an element in host memory -+ ConstReference host_data(LongIndex idx) const { -+ return ReferenceFactory::get(host_data(), idx); -+ } -+ -+ /// Gets pointer to device data -+ Element * device_data() { return device_.get(); } -+ -+ /// Gets pointer to device data with a pointer offset -+ Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(device_data(), ptr_element_offset); } -+ -+ /// Gets pointer to device data -+ Element const * device_data() const { return device_.get(); } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorRef device_ref(LongIndex ptr_element_offset=0) { -+ return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { -+ return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorView host_view(LongIndex ptr_element_offset=0) { -+ return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorView host_view(LongIndex ptr_element_offset=0) const { -+ return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorView device_view(LongIndex ptr_element_offset=0) { -+ return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorView device_view(LongIndex ptr_element_offset=0) const { -+ return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); -+ } -+ -+ /// Returns true if device memory is allocated -+ bool device_backed() const { -+ return (device_.get() == nullptr) ? false : true; -+ } -+ -+ -+ /// Returns the layout object -+ Layout & layout() { -+ return layout_; -+ } -+ -+ /// Returns the layout object -+ Layout layout() const { -+ return layout_; -+ } -+ -+ /// Returns the layout object's stride vector -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride vector -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ LongIndex stride(int dim) const { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ LongIndex & stride(int dim) { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Computes the offset of an index from the origin of the tensor -+ LongIndex offset(TensorCoord const& coord) const { -+ return layout_(coord); -+ } -+ -+ /// Returns a reference to the element at the logical Coord in host memory -+ Reference at(TensorCoord const& coord) { -+ return host_data(offset(coord)); -+ } -+ -+ /// Returns a const reference to the element at the logical Coord in host memory -+ ConstReference at(TensorCoord const& coord) const { -+ return host_data(offset(coord)); -+ } -+ -+ /// Returns the extent of the tensor -+ TensorCoord extent() const { -+ return extent_; -+ } -+ -+ /// Returns the extent of the tensor -+ TensorCoord & extent() { -+ return extent_; -+ } -+ -+ /// Copies data from device to host -+ void sync_host() { -+ if (device_backed()) { -+ device_memory::copy_to_host( -+ host_data(), device_data(), size()); -+ } -+ } -+ -+ /// Copies data from host to device -+ void sync_device() { -+ if (device_backed()) { -+ device_memory::copy_to_device( -+ device_data(), host_data(), size()); -+ } -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_device_to_host( -+ Element const* ptr_device, ///< source device memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_to_host( -+ host_data(), ptr_device, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_device_to_device( -+ Element const* ptr_device, ///< source device memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_device_to_device( -+ device_data(), ptr_device, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_host_to_device( -+ Element const* ptr_host, ///< source host memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_to_device( -+ device_data(), ptr_host, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_host_to_host( -+ Element const* ptr_host, ///< source host memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_host_to_host( -+ host_data(), ptr_host, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_device_to_host( -+ Element * ptr_host, ///< source device memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_to_host( -+ ptr_host, device_data(), count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_device_to_device( -+ Element * ptr_device, ///< source device memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_device_to_device( -+ ptr_device, device_data(), count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_host_to_device( -+ Element * ptr_device, ///< source host memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_to_device( -+ ptr_device, host_data(), count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_host_to_host( -+ Element * ptr_host, ///< source host memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_host_to_host( -+ ptr_host, host_data(), count); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h -new file mode 100644 -index 0000000..c548d9c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h -@@ -0,0 +1,591 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+/*! \file -+ \brief HostTensor contributes management for both host and device memory. -+ -+ HostTensor allocates host and device memory upon construction. Basic element-wise operations on -+ host memory synchronize device memory automatically. Explicit copy operations provide abstractions -+ for CUDA memcpy operations. -+ -+ Call {host, device}_{data, ref, view}() for accessing host or device memory. -+ -+ See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/tensor_ref_planar_complex.h" -+#include "cutlass/tensor_view_planar_complex.h" -+ -+#include "device_memory.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Host tensor -+template < -+ /// Data type of element stored within tensor (concept: NumericType) -+ typename Element_, -+ /// Defines a mapping from logical coordinate to linear memory (concept: Layout) -+ typename Layout_ -+> -+class HostTensorPlanarComplex { -+public: -+ -+ /// Data type of individual access -+ using Element = Element_; -+ -+ /// Mapping function from logical coordinate to linear memory -+ using Layout = Layout_; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = Layout::kRank; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Layout's stride vector -+ using Stride = typename Layout::Stride; -+ -+ /// Tensor reference to device memory -+ using TensorRef = TensorRefPlanarComplex; -+ -+ /// Tensor reference to constant device memory -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ /// Tensor reference to device memory -+ using TensorView = TensorViewPlanarComplex; -+ -+ /// Tensor reference to constant device memory -+ using ConstTensorView = typename TensorView::ConstTensorView; -+ -+ /// Reference to element in tensor -+ using Reference = typename TensorRef::Reference; -+ -+ /// Constant reference to element in tensor -+ using ConstReference = typename ConstTensorRef::Reference; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Extent of tensor in logical dimensions -+ TensorCoord extent_; -+ -+ /// Layout object -+ Layout layout_; -+ -+ /// Host-side memory allocation -+ std::vector host_; -+ -+ /// Device-side memory -+ device_memory::allocation device_; -+ -+ public: -+ // -+ // Device and Host Methods -+ // -+ -+ /// Default constructor -+ HostTensorPlanarComplex() {} -+ -+ /// Constructs a tensor given an extent. Assumes a packed layout -+ HostTensorPlanarComplex( -+ TensorCoord const &extent, -+ bool device_backed = true -+ ) { -+ -+ this->reset(extent, Layout::packed(extent), device_backed); -+ } -+ -+ /// Constructs a tensor given an extent and layout -+ HostTensorPlanarComplex( -+ TensorCoord const &extent, -+ Layout const &layout, -+ bool device_backed = true -+ ) { -+ -+ this->reset(extent, layout, device_backed); -+ } -+ -+ ~HostTensorPlanarComplex() { } -+ -+ /// Clears the HostTensor allocation to size/capacity = 0 -+ void reset() { -+ extent_ = TensorCoord(); -+ layout_ = Layout::packed(extent_); -+ -+ host_.clear(); -+ device_.reset(); -+ } -+ -+ /// Resizes internal memory allocations without affecting layout or extent -+ void reserve( -+ size_t count, ///< size of tensor in elements -+ bool device_backed_ = true) { ///< if true, device memory is also allocated -+ -+ device_.reset(); -+ host_.clear(); -+ -+ host_.resize(count * 2); -+ -+ // Allocate memory -+ Element* device_memory = nullptr; -+ if (device_backed_) { -+ device_memory = device_memory::allocate(count * 2); -+ } -+ device_.reset(device_memory, device_backed_ ? count * 2 : 0); -+ } -+ -+ /// Updates the extent and layout of the HostTensor. Allocates memory according to the new -+ /// extent and layout. -+ void reset( -+ TensorCoord const &extent, ///< extent of logical tensor -+ Layout const &layout, ///< layout object of tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ extent_ = extent; -+ layout_ = layout; -+ -+ reserve(size_t(layout_.capacity(extent_)), device_backed_); -+ } -+ -+ /// Updates the extent and layout of the HostTensor. Allocates memory according to the new -+ /// extent and layout. Assumes a packed tensor configuration. -+ void reset( -+ TensorCoord const &extent, ///< extent of logical tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ reset(extent, Layout::packed(extent), device_backed_); -+ } -+ -+ /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. -+ /// To force allocation, call reset(). -+ void resize( -+ TensorCoord const &extent, ///< extent of logical tensor -+ Layout const &layout, ///< layout object of tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ extent_ = extent; -+ layout_ = layout; -+ -+ LongIndex new_size = size_t(layout_.capacity(extent_)); -+ -+ if (static_cast(new_size * 2) > host_.size()) { -+ reserve(new_size); -+ } -+ } -+ -+ /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. -+ /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. -+ void resize( -+ TensorCoord const &extent, ///< extent of logical tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ resize(extent, Layout::packed(extent), device_backed_); -+ } -+ -+ /// Returns the number of elements stored in the host tensor -+ size_t size() const { -+ return host_.size() / 2; -+ } -+ -+ /// Returns the logical capacity based on extent and layout. May differ from size(). -+ LongIndex capacity() const { -+ return layout_.capacity(extent_); -+ } -+ -+ /// Stride between real and imaginary parts -+ LongIndex imaginary_stride() const { -+ return host_.size() / 2; -+ } -+ -+ /// Gets pointer to host data -+ Element * host_data() { return host_.data(); } -+ -+ /// Gets pointer to host data imaginary part -+ Element * host_data_imag() { return host_.data() + imaginary_stride(); } -+ -+ /// Gets pointer to host data with a pointer offset -+ Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; } -+ -+ /// Gets pointer to host data with a pointer offset -+ Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; } -+ -+ /// Gets a reference to an element in host memory -+ Reference host_data(LongIndex idx) { -+ return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); -+ } -+ -+ /// Gets pointer to host data -+ Element const * host_data() const { return host_.data(); } -+ -+ /// Gets pointer to host data imaginary part -+ Element const * host_data_imag() const { return host_.data() + imaginary_stride(); } -+ -+ /// Gets a constant reference to an element in host memory -+ ConstReference host_data(LongIndex idx) const { -+ return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); -+ } -+ -+ /// Gets pointer to device data -+ Element * device_data() { return device_.get(); } -+ -+ /// Gets pointer to device data with a pointer offset -+ Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; } -+ -+ /// Gets pointer to device data -+ Element const * device_data() const { return device_.get(); } -+ -+ /// Gets pointer to device data with a pointer offset -+ Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; } -+ -+ /// Gets a pointer to the device data imaginary part -+ Element * device_data_imag() { return device_.get() + imaginary_stride(); } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorRef host_ref(LongIndex ptr_element_offset=0) { -+ return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); -+ } -+ -+ /// Returns a tensor reference to the real part of the tensor -+ cutlass::TensorRef host_ref_real() { -+ return cutlass::TensorRef(host_data(), layout_); -+ } -+ -+ /// Returns a tensor reference to the real part of the tensor -+ cutlass::TensorRef host_ref_imag() { -+ return cutlass::TensorRef(host_data_ptr_offset(imaginary_stride()), layout_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { -+ return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorRef device_ref(LongIndex ptr_element_offset=0) { -+ return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { -+ return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); -+ } -+ -+ /// Returns a tensor reference to the real part of the tensor -+ cutlass::TensorRef device_ref_real() { -+ return cutlass::TensorRef(device_data(), layout_); -+ } -+ -+ /// Returns a tensor reference to the real part of the tensor -+ cutlass::TensorRef device_ref_imag() { -+ return cutlass::TensorRef(device_data_ptr_offset(imaginary_stride()), layout_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorView host_view(LongIndex ptr_element_offset=0) { -+ return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorView host_view(LongIndex ptr_element_offset=0) const { -+ return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ cutlass::TensorView host_view_real() { -+ return cutlass::TensorView(host_data(), layout_, extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ cutlass::TensorView host_view_imag() { -+ return cutlass::TensorView(host_data_ptr_offset(imaginary_stride()), layout_, extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorView device_view(LongIndex ptr_element_offset=0) { -+ return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorView device_view(LongIndex ptr_element_offset=0) const { -+ return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ cutlass::TensorView device_view_real() { -+ return cutlass::TensorView(device_data(), layout_, extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ cutlass::TensorView device_view_imag() { -+ return cutlass::TensorView(device_data_ptr_offset(imaginary_stride()), layout_, extent_); -+ } -+ -+ /// Returns true if device memory is allocated -+ bool device_backed() const { -+ return (device_.get() == nullptr) ? false : true; -+ } -+ -+ /// Returns the layout object -+ Layout layout() const { -+ return layout_; -+ } -+ -+ /// Returns the layout object's stride vector -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ Index stride(int dim) const { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Computes the offset of an index from the origin of the tensor -+ LongIndex offset(TensorCoord const& coord) const { -+ return layout_(coord); -+ } -+ -+ /// Returns a reference to the element at the logical Coord in host memory -+ Reference at(TensorCoord const& coord) { -+ return host_data(offset(coord)); -+ } -+ -+ /// Returns a const reference to the element at the logical Coord in host memory -+ ConstReference at(TensorCoord const& coord) const { -+ return host_data(offset(coord)); -+ } -+ -+ /// Returns the extent of the tensor -+ TensorCoord extent() const { -+ return extent_; -+ } -+ -+ /// Returns the extent of the tensor -+ TensorCoord & extent() { -+ return extent_; -+ } -+ -+ /// Copies data from device to host -+ void sync_host() { -+ if (device_backed()) { -+ device_memory::copy_to_host( -+ host_data(), device_data(), imaginary_stride() * 2); -+ } -+ } -+ -+ /// Copies data from host to device -+ void sync_device() { -+ if (device_backed()) { -+ device_memory::copy_to_device( -+ device_data(), host_data(), imaginary_stride() * 2); -+ } -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_device_to_host( -+ Element const* ptr_device_real, ///< source device memory -+ Element const* ptr_device_imag, ///< source device memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_to_host( -+ host_data(), ptr_device_real, count); -+ -+ device_memory::copy_to_host( -+ host_data_imag(), ptr_device_imag, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_device_to_device( -+ Element const* ptr_device_real, ///< source device memory -+ Element const* ptr_device_imag, ///< source device memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_device_to_device( -+ device_data(), ptr_device_real, count); -+ -+ device_memory::copy_device_to_device( -+ device_data_imag(), ptr_device_imag, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_host_to_device( -+ Element const* ptr_host_real, ///< source host memory -+ Element const* ptr_host_imag, ///< source host memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_to_device( -+ device_data(), ptr_host_real, count); -+ -+ device_memory::copy_to_device( -+ device_data_imag(), ptr_host_imag, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_host_to_host( -+ Element const* ptr_host_real, ///< source host memory -+ Element const* ptr_host_imag, ///< source host memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_host_to_host( -+ host_data(), ptr_host_real, count); -+ -+ device_memory::copy_host_to_host( -+ host_data_imag(), ptr_host_imag, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_device_to_host( -+ Element * ptr_host_real, ///< source device memory -+ Element * ptr_host_imag, ///< source device memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_to_host( -+ ptr_host_real, device_data(), count); -+ -+ device_memory::copy_to_host( -+ ptr_host_imag, device_data_imag(), count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_device_to_device( -+ Element * ptr_device_real, ///< source device memory -+ Element * ptr_device_imag, ///< source device memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_device_to_device( -+ ptr_device_real, device_data(), count); -+ -+ device_memory::copy_device_to_device( -+ ptr_device_imag, device_data_imag(), count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_host_to_device( -+ Element * ptr_device_real, ///< source device memory -+ Element * ptr_device_imag, ///< source device memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_to_device( -+ ptr_device_real, host_data(), count); -+ -+ device_memory::copy_to_device( -+ ptr_device_imag, host_data_imag(), count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_host_to_host( -+ Element * ptr_host_real, ///< source host memory -+ Element * ptr_host_imag, ///< source host memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_host_to_host( -+ ptr_host_real, host_data(), count); -+ -+ device_memory::copy_host_to_host( -+ ptr_host_imag, host_data_imag(), count); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/host_uncompress.h b/3rdparty/cutlass/tools/util/include/cutlass/util/host_uncompress.h -new file mode 100644 -index 0000000..7028bf7 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/host_uncompress.h -@@ -0,0 +1,157 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief uncompress sparse matrix from the host side -+*/ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace cutlass { -+ -+// uncompress sparse tensor core A matrix -+template -+void uncompress(TensorRef uncompressed_tensor_a, -+ TensorRef tensor_a, -+ TensorRef tensor_e, int row, int col) { -+ // How many uncompressed data we can get with ElementE meta data -+ int DecompressedElementsPerElementE = -+ 256 / cutlass::sizeof_bits::value; -+ -+ // Process 4bit meta data a time -+ int step; -+ -+ // 1:2 or 2:4 or 4:8 -+ int a, b; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ step = 8; -+ a = 4; -+ b = 8; -+ } else if (cutlass::sizeof_bits::value == 8) { -+ step = 4; -+ a = 2; -+ b = 4; -+ } else if (cutlass::sizeof_bits::value == 16) { -+ step = 4; -+ a = 2; -+ b = 4; -+ } else if (cutlass::sizeof_bits::value == 32) { -+ step = 2; -+ a = 1; -+ b = 2; -+ } -+ -+ int ElementsPerE = (cutlass::sizeof_bits::value == 4) ? 2 : 1; -+ -+ for (int r = 0; r < row; ++r) { -+ for (int c = 0; c < (col / DecompressedElementsPerElementE); ++c) { -+ -+ ElementE meta = tensor_e.at(MatrixCoord(r, c)); -+ -+ for (int i = 0; i < DecompressedElementsPerElementE; i += step) { -+ int e = (meta >> (i / step * 4)) & 0xf; -+ int idx0 = e & 0x3; -+ int idx1 = e >> 2; -+ -+ if (a == 1) idx0 = idx0 / 2; -+ -+ for (int ii = 0; ii < step; ii += ElementsPerE) { -+ int real_col = -+ c * DecompressedElementsPerElementE + i + ii; -+ int compressed_col = (real_col / b) * a; -+ -+ if (ii == (idx0 * ElementsPerE)) { -+ uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = -+ tensor_a.at(MatrixCoord(r, compressed_col)); -+ if (ElementsPerE == 2) -+ uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = -+ tensor_a.at(MatrixCoord(r, compressed_col + 1)); -+ } else if ((ii == (idx1 * ElementsPerE)) && (a != 1)) { -+ uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = -+ tensor_a.at(MatrixCoord(r, compressed_col + ElementsPerE)); -+ if (ElementsPerE == 2) -+ uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = -+ tensor_a.at( -+ MatrixCoord(r, compressed_col + ElementsPerE + 1)); -+ } else { -+ uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = -+ ElementA(0); -+ if (ElementsPerE == 2) -+ uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = -+ ElementA(0); -+ } -+ } -+ } -+ } -+ } -+} -+ -+// uncompress ELL block sparse matrix -+template -+void uncompress_ell_block_sparse( -+ TensorRef uncompressed_tensor_a, -+ TensorRef tensor_a, -+ TensorRef ell_idx, -+ int rows, int cols, -+ int ell_num_cols, int ell_blocksize) { -+ -+ for (int r = 0; r < rows / ell_blocksize; ++r) { -+ for (int c = 0; c < ell_num_cols / ell_blocksize; ++c) { -+ -+ ElementE idx = ell_idx.at(MatrixCoord(r, c)); -+ -+ if (idx != -1) { -+ int row_begin = r * ell_blocksize; -+ int col_begin_real = idx * ell_blocksize; -+ int col_begin = c * ell_blocksize; -+ -+ for (int i = 0; i < ell_blocksize; ++i) { -+ for (int j = 0; j < ell_blocksize; ++j) { -+ uncompressed_tensor_a.at(MatrixCoord(row_begin + i, col_begin_real + j)) = -+ tensor_a.at( -+ MatrixCoord(row_begin + i, col_begin +j)); -+ } -+ } -+ } -+ } -+ } -+} -+ -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/index_sequence.h b/3rdparty/cutlass/tools/util/include/cutlass/util/index_sequence.h -new file mode 100644 -index 0000000..846e02c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/index_sequence.h -@@ -0,0 +1,38 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+ -+// integer_sequence moved to cutlass/numeric_types.h -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/packed_stride.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/packed_stride.hpp -new file mode 100644 -index 0000000..7ecffaf ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/packed_stride.hpp -@@ -0,0 +1,101 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Utilities for packing a rank-X shape into a rank-(X-1) stride in CuTe. -+*/ -+ -+#pragma once -+ -+#include "cute/stride.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Strides without batch mode -+ -+template -+cute::Stride> -+make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { -+ static_assert(std::is_integral_v, -+ "Stride must have an integral type so it can be set dynamically. Static strides not supported."); -+ auto s_copy = s; -+ cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); -+ return s_copy; -+} -+ -+template -+cute::Stride, StrideIntT> -+make_cute_packed_stride(cute::Stride, StrideIntT> s, cute::Shape shape_MKL) { -+ static_assert(std::is_integral_v, -+ "Stride must have an integral type so it can be set dynamically. Static strides not supported."); -+ auto s_copy = s; -+ cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); -+ return s_copy; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Strides with batch mode -+ -+template -+cute::Stride, int64_t> -+make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { -+ static_assert(std::is_integral_v, -+ "Stride must have an integral type so it can be set dynamically. Static strides not supported."); -+ auto s_copy = s; -+ cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); -+ int batch_count = cute::get<2>(shape_MKL); -+ if (batch_count > 1) { -+ cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); -+ } -+ else { -+ cute::get<2>(s_copy) = static_cast(0); -+ } -+ return s_copy; -+} -+ -+template -+cute::Stride, StrideIntT, int64_t> -+make_cute_packed_stride(cute::Stride, StrideIntT, int64_t> s, cute::Shape shape_MKL) { -+ static_assert(std::is_integral_v, -+ "Stride must have an integral type so it can be set dynamically. Static strides not supported."); -+ auto s_copy = s; -+ cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); -+ int batch_count = cute::get<2>(shape_MKL); -+ if (batch_count > 1) { -+ cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); -+ } -+ else { -+ cute::get<2>(s_copy) = static_cast(0); -+ } -+ return s_copy; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/print_error.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/print_error.hpp -new file mode 100644 -index 0000000..f867f88 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/print_error.hpp -@@ -0,0 +1,235 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+#include -+ -+// The computed infinity norm does not include -+// any NaN column absolute-value sums. -+struct matrix_inf_norm_result { -+ // Accumulate errors in double, as this is generally -+ // the highest precision that the examples use. -+ double inf_norm = 0.0; -+ bool found_nan = false; -+}; -+ -+// In theory, cute::Tensor, T> could be treated as a view type, -+// and thus passed by value (as std::span or std::string_view would be). -+// However, generic cute::Tensor are more like containers -+// and thus are best passed by reference or const reference. -+template -+matrix_inf_norm_result -+matrix_inf_norm(const cute::Tensor& host_matrix) -+{ -+ using std::abs; -+ using error_type = decltype(std::declval().inf_norm); -+ -+ error_type inf_norm = 0.0; -+ bool found_nan = false; -+ -+ const auto shape = host_matrix.shape(); -+ using index_type = std::decay_t(shape))>; -+ // Computing the infinity norm requires that we be able -+ // to treat the input as a matrix, with rows and columns. -+ static_assert(std::is_integral_v); -+ const index_type num_rows = cute::get<0>(shape); -+ const index_type num_cols = cute::get<1>(shape); -+ -+ for(index_type i = 0; i < num_rows; ++i) { -+ error_type row_abs_sum = 0.0; -+ for(index_type j = 0; j < num_cols; ++j) { -+ row_abs_sum += abs(host_matrix(i, j)); -+ } -+ if(std::isnan(row_abs_sum)) { -+ found_nan = true; -+ } else { -+ inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; -+ } -+ } -+ -+ return {inf_norm, found_nan}; -+} -+ -+// Infinity norm of (X - Y). -+template -+matrix_inf_norm_result -+matrix_diff_inf_norm(const cute::Tensor& X, -+ const cute::Tensor& Y) -+{ -+ using std::abs; -+ using error_type = decltype(std::declval().inf_norm); -+ -+ const auto X_shape = X.shape(); -+ const auto Y_shape = Y.shape(); -+ -+ using index_type = std::decay_t(X_shape))>; -+ // Computing the infinity norm requires that we be able -+ // to treat the input as a matrix, with rows and columns. -+ static_assert(std::is_integral_v); -+ const index_type num_rows = cute::get<0>(X_shape); -+ const index_type num_cols = cute::get<1>(X_shape); -+ -+ assert(num_rows == cute::get<0>(Y_shape)); -+ assert(num_cols == cute::get<1>(Y_shape)); -+ -+ auto matrix_ij = [&](const auto& A, std::size_t i, std::size_t j) { -+ return A(i, j); -+ }; -+ auto diff_ij = [&](std::size_t i, std::size_t j) { -+ return matrix_ij(X, i, j) - matrix_ij(Y, i, j); -+ }; -+ -+ error_type inf_norm = 0.0; -+ bool found_nan = false; -+ -+ for(index_type i = 0; i < num_rows; ++i) { -+ error_type row_abs_sum = 0.0; -+ for(index_type j = 0; j < num_cols; ++j) { -+ row_abs_sum += abs(diff_ij(i, j)); -+ } -+ if(std::isnan(row_abs_sum)) { -+ found_nan = true; -+ } else { -+ inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; -+ } -+ } -+ -+ return {inf_norm, found_nan}; -+} -+ -+template -+void -+print_matrix_multiply_mollified_relative_error( -+ const char A_value_type_name[], -+ const cute::Tensor& A, -+ const char B_value_type_name[], -+ const cute::Tensor& B, -+ const char C_value_type_name[], -+ const cute::Tensor& C_computed, -+ const cute::Tensor& C_expected) -+{ -+ const auto [A_norm, A_has_nan] = matrix_inf_norm(A); -+ const auto [B_norm, B_has_nan] = matrix_inf_norm(B); -+ const auto [C_norm, C_has_nan] = matrix_inf_norm(C_expected); -+ const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C_computed, C_expected); -+ -+ const auto A_norm_times_B_norm = A_norm * B_norm; -+ const auto relative_error = A_norm_times_B_norm == 0.0 ? -+ diff_norm : (diff_norm / A_norm_times_B_norm); -+ -+ // For expected error bounds, please refer to the LAPACK Users' Guide, -+ // in particular https://netlib.org/lapack/lug/node108.html . -+ // Printing the infinity norm of C is a way to check -+ // that both the function being tested (C_computed) -+ // and the reference implementation (C_expected) -+ // don't just do nothing (or fill with zeros). -+ using std::cout; -+ cout << "Value type of A: " << A_value_type_name << '\n' -+ << std::scientific -+ << "Infinity norm of A: " << A_norm << '\n' -+ << "Value type of B: " << B_value_type_name << '\n' -+ << "Infinity norm of B: " << B_norm << '\n' -+ << "Value type of C: " << C_value_type_name << '\n' -+ << "Infinity norm of C_expected: " << C_norm << '\n' -+ << "Infinity norm of (C_computed - C_expected): " << diff_norm << '\n'; -+ -+ if(A_norm_times_B_norm == 0.0) { -+ cout << "Mollified relative error: " << relative_error << '\n'; -+ } else { -+ cout << "Relative error: " << relative_error << '\n'; -+ } -+ -+ cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' -+ << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' -+ << "Did we encounter NaN in C_expected? " << (C_has_nan ? "yes" : "no") << '\n' -+ << "Did we encounter NaN in (C_computed - C_expected)? " -+ << (diff_has_nan ? "yes" : "no") << '\n'; -+} -+ -+template -+void -+print_matrix_multiply_mollified_relative_error( -+ const char value_type_name[], -+ const cute::Tensor& A, -+ const cute::Tensor& B, -+ const cute::Tensor& C_computed, -+ const cute::Tensor& C_expected) -+{ -+ print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B, -+ value_type_name, C_computed, C_expected); -+} -+ -+// Take a CUTLASS HostTensor (or the like) as input, -+// and return a const CuTe Tensor. -+// This is useful for use with the above error printing functions. -+// This implicitly "transposes" if the layout is RowMajor. -+// Note that the HostTensor must be captured by nonconst reference -+// in order for X.host_ref().data() to compile. -+// (CUTLASS is a bit more container-y than CuTe.) -+template -+auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X) -+{ -+ // The tensors were created with post-transposed extents. -+ const auto extents = X.extent(); -+ const auto shape = cute::Shape{extents[0], extents[1]}; -+ // Both RowMajor and ColumnMajor only store one stride. -+ const int LDX = X.stride(0); -+ const auto strides = [&]() { -+ using input_layout_type = typename std::decay_t::Layout; -+ if constexpr (std::is_same_v) { -+ return cute::Stride{1, LDX}; -+ } -+ else { -+ static_assert(std::is_same_v); -+ return cute::Stride{LDX, 1}; -+ } -+ }(); -+ const auto layout = cute::make_layout(shape, strides); -+ auto X_data = X.host_ref().data(); -+ auto X_data_const = const_cast >(X_data); -+ return cute::make_tensor(X_data_const, layout); -+}; -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h -new file mode 100644 -index 0000000..b4bffa3 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h -@@ -0,0 +1,135 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GEMM in host-side code. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+namespace cutlass { -+namespace reference { -+namespace detail { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template function to compute an inner product. -+#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a -+ // host-only type -+template -+CUTLASS_HOST_DEVICE -+Ctype inner_product(Atype a, Btype b, Ctype c) { -+ return Ctype(a) * Ctype(b) + c; -+} -+ -+/// Specialization for matrix multiplication with binary operands -+template <> -+CUTLASS_HOST_DEVICE -+int inner_product, Array, int>( -+ Array a, -+ Array b, -+ int c) { -+ -+ int accum = 0; -+ for (int bit = 0; bit < 32; bit++) { -+ accum += a[bit] ^ b[bit]; -+ } -+ return accum + c; -+} -+ -+/* -+/// Specialization for matrix multiplication with signed 4-bit integer operands -+template <> -+CUTLASS_HOST_DEVICE -+int inner_product, Array, int>( -+ Array a, -+ Array b, -+ int c) { -+ -+ int accum = 0; -+ for (int k = 0; k < 8; k++) { -+ accum += a[k] * b[k]; -+ } -+ return accum + c; -+} -+ -+/// Specialization for matrix multiplication with unsigned 4-bit integer operands -+template <> -+CUTLASS_HOST_DEVICE -+int inner_product, Array, int>( -+ Array a, -+ Array b, -+ int c) { -+ -+ int accum = 0; -+ for (int k = 0; k < 8; k++) { -+ accum += a[k] * b[k]; -+ } -+ return accum + c; -+} -+*/ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct Cast { -+ // Default behavior: convert to the destination type -+#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a -+ // host-only type -+ CUTLASS_HOST_DEVICE -+ static DstType apply(SrcType src) { return static_cast(src); }; -+}; -+ -+template <> -+struct Cast { -+ CUTLASS_HOST_DEVICE -+ static int8_t apply(float src) { -+ // Clamp to the range of signed 8-bit integers. -+ return static_cast(fmaxf(-128.f, fminf(127.f, src))); -+ }; -+}; -+ -+template <> -+struct Cast { -+ CUTLASS_HOST_DEVICE -+ static uint8_t apply(float src) { -+ // Clamp to the range of signed 8-bit integers. -+ return static_cast(fmaxf(0.f, fminf(255.f, src))); -+ }; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace detail -+} // namespace reference -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h -new file mode 100644 -index 0000000..ac22699 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h -@@ -0,0 +1,94 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GEMM in host-side code. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace detail { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct LinearToCoordinateHelper { -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(Coord &coord, int64_t idx, Coord const &extent) const { -+ -+ int64_t prod = 1; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = Rank - Index; i < Rank; ++i) { -+ prod *= int64_t(extent[i]); -+ } -+ -+ coord[Rank - Index - 1] = int(idx / prod); -+ -+ int64_t residual = idx % prod; -+ LinearToCoordinateHelper()(coord, residual, extent); -+ } -+}; -+ -+template -+struct LinearToCoordinateHelper { -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(Coord &coord, int64_t idx, Coord const &extent) const { -+ coord[Rank - 1] = int(idx); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct LinearToCoordinate { -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(Coord &coord, int64_t idx, Coord const &extent) const { -+ LinearToCoordinateHelper()(coord, idx, extent); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace detail -+} // namespace reference -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h -new file mode 100644 -index 0000000..fec0587 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h -@@ -0,0 +1,1549 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Reference implementation for convolution in device-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/functional.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Conv2d device reference kernel -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2d Fprop kernel - y = fprop(x, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 16, // shape of a threadblock in units of threads -+ int kCtaShapeN = 8 // shape of a threadblock in units of threads -+> -+__global__ void Conv2dFprop( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_x, -+ TensorRef tensor_w, -+ TensorRef tensor_y_in, -+ TensorRef tensor_y_out, -+ ElementCompute alpha, -+ ElementCompute beta -+ ) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ ElementAccumulator element_A[kThreadM]; -+ ElementAccumulator element_B[kThreadN]; -+ ElementAccumulator accum[kThreadM][kThreadN]; -+ -+ int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_n[kThreadM]; -+ int thread_p[kThreadM]; -+ int thread_q[kThreadM]; -+ -+ // Compute N, P, Q coordinates for each row of a thread's tile -+ int64_t PQ = int64_t(problem_size.P) * problem_size.Q; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int64_t npq = npq_start + m; -+ -+ thread_n[m] = int(npq / PQ); -+ -+ int64_t residual = npq % PQ; -+ thread_p[m] = int(residual / problem_size.Q); -+ thread_q[m] = int(residual % problem_size.Q); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = ElementAccumulator(); -+ } -+ } -+ -+ int c_per_group = problem_size.C / problem_size.groups; -+ int k_per_group = problem_size.K / problem_size.groups; -+ -+ // Compute convolution -+ for (int R = 0; R < problem_size.R; ++R) { -+ for (int S = 0; S < problem_size.S; ++S) { -+ for (int C = 0; C < problem_size.C; ++C) { -+ -+ // Get group id of currnet channel -+ int c_group_idx = C / c_per_group; -+ -+ // Load from activations tensor -+ int filter_r = R; -+ int filter_s = S; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_r = problem_size.R - 1 - R; -+ filter_s = problem_size.S - 1 - S; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; -+ int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; -+ -+ if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { -+ element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C})); -+ } -+ else { -+ element_A[m] = ElementAccumulator(); -+ } -+ } -+ -+ // Load from filters tensor -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_k = k_start + n; -+ int k_group_idx = thread_k / k_per_group; -+ -+ if (thread_k < problem_size.K && k_group_idx == c_group_idx) { -+ element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group})); -+ } -+ else { -+ element_B[n] = ElementAccumulator(); -+ } -+ } -+ -+ // Accumulate matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); -+ } -+ } -+ } -+ } -+ } -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_k = k_start + n; -+ if (thread_k < problem_size.K) { -+ -+ ElementCompute c_ref = ElementCompute(); -+ if (beta != ElementCompute()) { -+ c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})); -+ } -+ -+ tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( -+ alpha * ElementCompute(accum[m][n]) + beta * c_ref); -+ } -+ } -+ } -+ } -+} -+ -+// Conv3d Fprop kernel - y = fprop(x, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 16, // shape of a threadblock in units of threads -+ int kCtaShapeN = 8 // shape of a threadblock in units of threads -+> -+__global__ void Conv3dFprop( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_x, -+ TensorRef tensor_w, -+ TensorRef tensor_y_in, -+ TensorRef tensor_y_out, -+ ElementCompute alpha, -+ ElementCompute beta -+ ) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ ElementAccumulator element_A[kThreadM]; -+ ElementAccumulator element_B[kThreadN]; -+ ElementAccumulator accum[kThreadM][kThreadN]; -+ -+ int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_n[kThreadM]; -+ int thread_z[kThreadM]; -+ int thread_p[kThreadM]; -+ int thread_q[kThreadM]; -+ -+ // Compute N, Z, P, Q coordinates for each row of a thread's tile -+ int64_t PQ = int64_t(problem_size.P) * problem_size.Q; -+ int64_t ZPQ = PQ * problem_size.Z; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int64_t nzpq = nzpq_start + m; -+ -+ thread_n[m] = int(nzpq / ZPQ); -+ -+ int64_t residual = nzpq % ZPQ; -+ thread_z[m] = int(residual / PQ); -+ -+ residual = residual % PQ; -+ thread_p[m] = int(residual / problem_size.Q); -+ thread_q[m] = int(residual % problem_size.Q); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = ElementAccumulator(); -+ } -+ } -+ -+ // Compute convolution -+ for (int T = 0; T < problem_size.T; ++T) { -+ for (int R = 0; R < problem_size.R; ++R) { -+ for (int S = 0; S < problem_size.S; ++S) { -+ for (int C = 0; C < problem_size.C; ++C) { -+ -+ // Load from activations tensor -+ int filter_t = T; -+ int filter_r = R; -+ int filter_s = S; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_t = problem_size.T - 1 - T; -+ filter_r = problem_size.R - 1 - R; -+ filter_s = problem_size.S - 1 - S; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; -+ int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; -+ int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; -+ -+ if (thread_n[m] < problem_size.N && -+ d >= 0 && d < problem_size.D && -+ h >= 0 && h < problem_size.H && -+ w >= 0 && w < problem_size.W) { -+ -+ element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C})); -+ } -+ else { -+ element_A[m] = ElementAccumulator(); -+ } -+ } -+ -+ // Load from filters tensor -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_k = k_start + n; -+ -+ if (thread_k < problem_size.K) { -+ element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C})); -+ } -+ else { -+ element_B[n] = ElementAccumulator(); -+ } -+ } -+ -+ // Accumulate matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); -+ } -+ } -+ -+ } // for (C) -+ } // for (S) -+ } // for (R) -+ } // for (T) -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ if (thread_n[m] < problem_size.N && -+ thread_z[m] < problem_size.Z && -+ thread_p[m] < problem_size.P && -+ thread_q[m] < problem_size.Q) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_k = k_start + n; -+ if (thread_k < problem_size.K) { -+ -+ ElementCompute c_ref = ElementCompute(); -+ if (beta != ElementCompute()) { -+ c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k})); -+ } -+ -+ tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op( -+ alpha * ElementCompute(accum[m][n]) + beta * c_ref); -+ } -+ } // for (n) -+ -+ } -+ } // for (m) -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2d dgrad kernel - dx = dgrad(dy, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 16, // shape of a threadblock in units of threads -+ int kCtaShapeN = 8 // shape of a threadblock in units of threads -+> -+__global__ void Conv2dDgrad( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_w, -+ TensorRef tensor_dx_in, -+ TensorRef tensor_dx_out, -+ ElementCompute alpha, -+ ElementCompute beta -+ ) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ ElementAccumulator element_A[kThreadM]; -+ ElementAccumulator element_B[kThreadN]; -+ ElementAccumulator accum[kThreadM][kThreadN]; -+ -+ int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_n[kThreadM]; -+ int thread_h[kThreadM]; -+ int thread_w[kThreadM]; -+ -+ // Compute N, H, W coordinates for each row of a thread's tile -+ int64_t HW = int64_t(problem_size.H) * problem_size.W; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int64_t nhw = nhw_start + m; -+ -+ thread_n[m] = int(nhw / HW); -+ -+ int64_t residual = nhw % HW; -+ thread_h[m] = int(residual / problem_size.W); -+ thread_w[m] = int(residual % problem_size.W); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = ElementAccumulator(); -+ } -+ } -+ -+ // Compute convolution -+ for (int R = 0; R < problem_size.R; ++R) { -+ for (int S = 0; S < problem_size.S; ++S) { -+ for (int K = 0; K < problem_size.K; ++K) { -+ -+ // Load from activations tensor -+ int filter_r = R; -+ int filter_s = S; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_r = problem_size.R - 1 - R; -+ filter_s = problem_size.S - 1 - S; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; -+ int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; -+ -+ element_A[m] = ElementAccumulator(); -+ -+ if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) { -+ -+ p = p / problem_size.stride_h; -+ q = q / problem_size.stride_w; -+ -+ if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) { -+ element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K})); -+ } -+ } -+ } -+ -+ // Load from filters tensor -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_c = c_start + n; -+ -+ if (thread_c < problem_size.C) { -+ element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c})); -+ } -+ else { -+ element_B[n] = ElementAccumulator(); -+ } -+ } -+ -+ // Accumulate matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); -+ } -+ } -+ } -+ } -+ } -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_c = c_start + n; -+ if (thread_c < problem_size.C) { -+ -+ ElementCompute c_ref = ElementCompute(); -+ if (beta != ElementCompute()) { -+ c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c})); -+ } -+ -+ tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op( -+ alpha * ElementCompute(accum[m][n]) + beta * c_ref); -+ } -+ } -+ } -+ } -+} -+ -+// Conv3d dgrad kernel - dx = dgrad(dy, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 16, // shape of a threadblock in units of threads -+ int kCtaShapeN = 8 // shape of a threadblock in units of threads -+> -+__global__ void Conv3dDgrad( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_w, -+ TensorRef tensor_dx_in, -+ TensorRef tensor_dx_out, -+ ElementCompute alpha, -+ ElementCompute beta -+ ) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ ElementAccumulator element_A[kThreadM]; -+ ElementAccumulator element_B[kThreadN]; -+ ElementAccumulator accum[kThreadM][kThreadN]; -+ -+ int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_n[kThreadM]; -+ int thread_d[kThreadM]; -+ int thread_h[kThreadM]; -+ int thread_w[kThreadM]; -+ -+ // Compute N, H, W coordinates for each row of a thread's tile -+ int64_t HW = int64_t(problem_size.H) * problem_size.W; -+ int64_t DHW = HW * problem_size.D; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int64_t ndhw = ndhw_start + m; -+ -+ thread_n[m] = int(ndhw / DHW); -+ -+ int64_t residual = ndhw % DHW; -+ thread_d[m] = int(residual / HW); -+ -+ residual = residual % HW; -+ thread_h[m] = int(residual / problem_size.W); -+ thread_w[m] = int(residual % problem_size.W); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = ElementAccumulator(); -+ } -+ } -+ -+ // Compute convolution -+ for (int T = 0; T < problem_size.T; ++T) { -+ for (int R = 0; R < problem_size.R; ++R) { -+ for (int S = 0; S < problem_size.S; ++S) { -+ for (int K = 0; K < problem_size.K; ++K) { -+ -+ // Load from activations tensor -+ int filter_t = T; -+ int filter_r = R; -+ int filter_s = S; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_t = problem_size.T - 1 - T; -+ filter_r = problem_size.R - 1 - R; -+ filter_s = problem_size.S - 1 - S; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d; -+ int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; -+ int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; -+ -+ element_A[m] = ElementAccumulator(); -+ -+ if (z >= 0 && !(z % problem_size.stride_d) && -+ p >= 0 && !(p % problem_size.stride_h) && -+ q >= 0 && !(q % problem_size.stride_w)) { -+ -+ z = z / problem_size.stride_d; -+ p = p / problem_size.stride_h; -+ q = q / problem_size.stride_w; -+ -+ if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { -+ element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K})); -+ } -+ } -+ } -+ -+ // Load from filters tensor -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_c = c_start + n; -+ -+ if (thread_c < problem_size.C) { -+ element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c})); -+ } -+ else { -+ element_B[n] = ElementAccumulator(); -+ } -+ } -+ -+ // Accumulate matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); -+ } -+ } -+ -+ } // for (C) -+ } // for (S) -+ } // for (R) -+ } // for (T) -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ if (thread_n[m] < problem_size.N && -+ thread_d[m] < problem_size.D && -+ thread_h[m] < problem_size.H && -+ thread_w[m] < problem_size.W) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_c = c_start + n; -+ if (thread_c < problem_size.C) { -+ -+ ElementCompute c_ref = ElementCompute(); -+ if (beta != ElementCompute()) { -+ c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c})); -+ } -+ -+ tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op( -+ alpha * ElementCompute(accum[m][n]) + beta * c_ref); -+ } -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2d wgrad kernel - dw = wgrad(dy, x) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 8, // shape of a threadblock in units of threads -+ int kCtaShapeN = 16 // shape of a threadblock in units of threads -+> -+__global__ void Conv2dWgrad( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_x, -+ TensorRef tensor_dw_in, -+ TensorRef tensor_dw_out, -+ ElementCompute alpha, -+ ElementCompute beta -+ ) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ ElementAccumulator element_A[kThreadM]; -+ ElementAccumulator element_B[kThreadN]; -+ ElementAccumulator accum[kThreadM][kThreadN]; -+ -+ int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_r[kThreadN]; -+ int thread_s[kThreadN]; -+ int thread_c[kThreadN]; -+ -+ // Compute R, S, C coordinates for each row of a thread's tile -+ int64_t SC = int64_t(problem_size.S) * problem_size.C; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ -+ int64_t rsc = rsc_start + n; -+ int64_t residual = rsc % SC; -+ -+ thread_r[n] = int(rsc / SC); -+ thread_s[n] = int(residual / problem_size.C); -+ thread_c[n] = int(residual % problem_size.C); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = ElementAccumulator(); -+ } -+ } -+ -+ // Compute convolution -+ for (int N = 0; N < problem_size.N; ++N) { -+ for (int P = 0; P < problem_size.P; ++P) { -+ for (int Q = 0; Q < problem_size.Q; ++Q) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ int thread_k = k_start + m; -+ -+ element_A[m] = ElementAccumulator(); -+ -+ if (thread_k < problem_size.K) { -+ element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k})); -+ } -+ } -+ -+ // Load from filters tensor -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ -+ // Load from activations tensor -+ int filter_r = thread_r[n]; -+ int filter_s = thread_s[n]; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_r = problem_size.R - 1 - filter_r; -+ filter_s = problem_size.S - 1 - filter_s; -+ } -+ -+ int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; -+ int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; -+ -+ element_B[n] = ElementAccumulator(); -+ -+ if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) { -+ element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]})); -+ } -+ } -+ -+ // Accumulate matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); -+ } -+ } -+ } -+ } -+ } -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ int thread_k = k_start + m; -+ -+ if (thread_k < problem_size.K) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ -+ if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) { -+ -+ ElementCompute c_ref = ElementCompute(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]})); -+ } -+ -+ tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op( -+ alpha * ElementCompute(accum[m][n]) + beta * c_ref); -+ } -+ } -+ } -+ } -+} -+ -+// Conv3d wgrad kernel - dw = wgrad(dy, x) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 8, // shape of a threadblock in units of threads -+ int kCtaShapeN = 16 // shape of a threadblock in units of threads -+> -+__global__ void Conv3dWgrad( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_x, -+ TensorRef tensor_dw_in, -+ TensorRef tensor_dw_out, -+ ElementCompute alpha, -+ ElementCompute beta -+ ) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ ElementAccumulator element_A[kThreadM]; -+ ElementAccumulator element_B[kThreadN]; -+ ElementAccumulator accum[kThreadM][kThreadN]; -+ -+ int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_t[kThreadN]; -+ int thread_r[kThreadN]; -+ int thread_s[kThreadN]; -+ int thread_c[kThreadN]; -+ -+ // Compute R, S, C coordinates for each row of a thread's tile -+ int64_t SC = int64_t(problem_size.S) * problem_size.C; -+ int64_t RSC = SC * problem_size.R; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ -+ int64_t trsc = trsc_start + n; -+ -+ thread_t[n] = int(trsc / RSC); -+ -+ int64_t residual = trsc % RSC; -+ thread_r[n] = int(residual / SC); -+ -+ residual = residual % SC; -+ thread_s[n] = int(residual / problem_size.C); -+ thread_c[n] = int(residual % problem_size.C); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = ElementAccumulator(); -+ } -+ } -+ -+ // Compute convolution -+ for (int N = 0; N < problem_size.N; ++N) { -+ for (int Z = 0; Z < problem_size.Z; ++Z) { -+ for (int P = 0; P < problem_size.P; ++P) { -+ for (int Q = 0; Q < problem_size.Q; ++Q) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ int thread_k = k_start + m; -+ -+ element_A[m] = ElementAccumulator(); -+ -+ if (thread_k < problem_size.K) { -+ element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k})); -+ } -+ } -+ -+ // Load from filters tensor -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ -+ // Load from activations tensor -+ int filter_t = thread_t[n]; -+ int filter_r = thread_r[n]; -+ int filter_s = thread_s[n]; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_t = problem_size.T - 1 - filter_t; -+ filter_r = problem_size.R - 1 - filter_r; -+ filter_s = problem_size.S - 1 - filter_s; -+ } -+ -+ int d = Z * problem_size.stride_d - problem_size.pad_w + filter_t * problem_size.dilation_d; -+ int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; -+ int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; -+ -+ element_B[n] = ElementAccumulator(); -+ -+ if (d >= 0 && d < problem_size.D && -+ h >= 0 && h < problem_size.H && -+ w >= 0 && w < problem_size.W && -+ thread_c[n] < problem_size.C) { -+ -+ element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]})); -+ } -+ } -+ -+ // Accumulate matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); -+ } -+ } -+ -+ } // for (Q) -+ } // for (P) -+ } // for (Z) -+ } // for (N) -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ int thread_k = k_start + m; -+ -+ if (thread_k < problem_size.K) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ -+ if (thread_t[n] < problem_size.T && -+ thread_r[n] < problem_size.R && -+ thread_s[n] < problem_size.S && -+ thread_c[n] < problem_size.C) { -+ -+ ElementCompute c_ref = ElementCompute(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]})); -+ } -+ -+ tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op( -+ alpha * ElementCompute(accum[m][n]) + beta * c_ref); -+ } -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Conv2d Fprop dispatcher - y = fprop(x, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv2dFprop( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_x, -+ TensorRef tensor_w, -+ TensorRef tensor_y_in, -+ TensorRef tensor_y_out, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ // -+ // Blocking factors improve performance of reference implementation -+ // -+ -+ int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads -+ -+ int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; -+ int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); -+ -+ kernel::Conv2dFprop< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block, 0, stream >>>( -+ problem_size, -+ tensor_x, -+ tensor_w, -+ tensor_y_in, -+ tensor_y_out, -+ alpha, -+ beta -+ ); -+ -+ cudaError_t result = cudaPeekAtLastError(); -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+} -+ -+/// Conv3d Fprop dispatcher - y = fprop(x, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv3dFprop( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_x, -+ TensorRef tensor_w, -+ TensorRef tensor_y_in, -+ TensorRef tensor_y_out, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ // -+ // Blocking factors improve performance of reference implementation -+ // -+ -+ int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads -+ -+ int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q; -+ int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); -+ -+ kernel::Conv3dFprop< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block, 0, stream >>>( -+ problem_size, -+ tensor_x, -+ tensor_w, -+ tensor_y_in, -+ tensor_y_out, -+ alpha, -+ beta -+ ); -+ -+ cudaError_t result = cudaPeekAtLastError(); -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+} -+ -+/// Conv2d Dgrad dispatcher - dx = dgrad(dy, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv2dDgrad( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_w, -+ TensorRef tensor_dx_in, -+ TensorRef tensor_dx_out, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ // -+ // Blocking factors improve performance of reference implementation -+ // -+ -+ int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads -+ -+ int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W; -+ int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); -+ -+ kernel::Conv2dDgrad< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block, 0, stream >>>( -+ problem_size, -+ tensor_dy, -+ tensor_w, -+ tensor_dx_in, -+ tensor_dx_out, -+ alpha, -+ beta -+ ); -+ -+ cudaError_t result = cudaPeekAtLastError(); -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+} -+ -+/// Conv3d Dgrad dispatcher - dx = dgrad(dy, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv3dDgrad( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_w, -+ TensorRef tensor_dx_in, -+ TensorRef tensor_dx_out, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ // -+ // Blocking factors improve performance of reference implementation -+ // -+ -+ int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads -+ -+ int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W; -+ int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); -+ -+ kernel::Conv3dDgrad< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block, 0, stream >>>( -+ problem_size, -+ tensor_dy, -+ tensor_w, -+ tensor_dx_in, -+ tensor_dx_out, -+ alpha, -+ beta -+ ); -+ -+ cudaError_t result = cudaPeekAtLastError(); -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+} -+ -+/// Conv2d Wgrad dispatcher - dw = wgrad(dy, x) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv2dWgrad( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_x, -+ TensorRef tensor_dw_in, -+ TensorRef tensor_dw_out, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ // -+ // Blocking factors improve performance of reference implementation -+ // -+ -+ int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 8; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 16; // shape of a threadblock in units of threads -+ -+ int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C; -+ int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); -+ -+ kernel::Conv2dWgrad< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block, 0, stream >>>( -+ problem_size, -+ tensor_dy, -+ tensor_x, -+ tensor_dw_in, -+ tensor_dw_out, -+ alpha, -+ beta -+ ); -+ -+ cudaError_t result = cudaPeekAtLastError(); -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+} -+ -+/// Conv3d Wgrad dispatcher - dw = wgrad(dy, x) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv3dWgrad( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_x, -+ TensorRef tensor_dw_in, -+ TensorRef tensor_dw_out, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ // -+ // Blocking factors improve performance of reference implementation -+ // -+ -+ int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 8; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 16; // shape of a threadblock in units of threads -+ -+ int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C; -+ int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); -+ -+ kernel::Conv3dWgrad< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block, 0, stream >>>( -+ problem_size, -+ tensor_dy, -+ tensor_x, -+ tensor_dw_in, -+ tensor_dw_out, -+ alpha, -+ beta -+ ); -+ -+ cudaError_t result = cudaPeekAtLastError(); -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv2d( -+ conv::Operator convolutional_operator, -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_A, -+ TensorRef tensor_B, -+ TensorRef tensor_C, -+ TensorRef tensor_D, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ switch (convolutional_operator) { -+ case conv::Operator::kFprop: -+ return Conv2dFprop< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); -+ break; -+ -+ case conv::Operator::kDgrad: -+ return Conv2dDgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); -+ break; -+ -+ case conv::Operator::kWgrad: -+ return Conv2dWgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); -+ break; -+ -+ default: break; -+ } -+ -+ return Status::kErrorNotSupported; -+} -+ -+/// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv3d( -+ conv::Operator convolutional_operator, -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_A, -+ TensorRef tensor_B, -+ TensorRef tensor_C, -+ TensorRef tensor_D, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ switch (convolutional_operator) { -+ case conv::Operator::kFprop: -+ return Conv3dFprop< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); -+ -+ case conv::Operator::kDgrad: -+ return Conv3dDgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); -+ -+ case conv::Operator::kWgrad: -+ return Conv3dWgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); -+ -+ default: break; -+ } -+ -+ return Status::kErrorNotSupported; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h -new file mode 100644 -index 0000000..1850c2f ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h -@@ -0,0 +1,385 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GEMM in device-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/util/reference/device/kernel/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_gemm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ AccumulatorType initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ // Blocking structure potentially improves performance of reference implementation -+ // with a minor increase in complexity. -+ // -+ // Note, this reference implementation is NOT expected to approach peak performance. -+ using OutputTile = MatrixShape<4, 4>; -+ -+ dim3 block(16, 8); -+ -+ dim3 grid( -+ (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), -+ (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) -+ ); -+ -+ // Launch a GEMM kernel -+ kernel::Gemm< -+ TensorRef, -+ TensorRef, -+ TensorRef, -+ ScalarType, -+ AccumulatorType, -+ OutputTile, -+ InnerProductOp, -+ ConvertOp -+ ><<< grid, block >>>( -+ problem_size, -+ alpha, -+ tensor_a, -+ tensor_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ initial_accum -+ ); -+} -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_gemm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ AccumulatorType initial_accum) { -+ -+ compute_gemm( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, -+ initial_accum); -+} -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd -+> -+struct Gemm; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ AccumulatorType initial_accum = AccumulatorType(0)) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ AccumulatorType initial_accum = AccumulatorType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add-saturate -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ AccumulatorType initial_accum = AccumulatorType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm, -+ NumericConverterClamp>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ AccumulatorType initial_accum = AccumulatorType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm, -+ NumericConverterClamp>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for XOR-popc -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ AccumulatorType initial_accum = AccumulatorType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ AccumulatorType initial_accum = AccumulatorType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Batched GEMM -+// -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a batch of GEMMs over a set of matrices of common dimension. -+// -+// TensorRefCollection* is a type satisfying the TensorRefCollection concept. -+// -+template < -+ typename TensorRefCollectionA, -+ typename TensorRefCollectionB, -+ typename TensorRefCollectionC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename InnerProductOp, -+ typename ConvertOp -+> -+void BatchedGemm( -+ gemm::GemmCoord problem_size, -+ int batch_count, -+ ScalarType alpha, -+ TensorRefCollectionA const& tensor_a, -+ TensorRefCollectionB const& tensor_b, -+ ScalarType beta, -+ TensorRefCollectionC &tensor_c, -+ AccumulatorType initial_accum) { -+ -+ static_assert( -+ TensorRefCollectionA::kRank == 2 && -+ TensorRefCollectionB::kRank == 2 && -+ TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2"); -+ -+ // Blocking structure potentially improves performance of reference implementation -+ // with a minor increase in complexity. -+ // -+ // Note, this reference implementation is NOT expected to approach peak performance. -+ using OutputTile = MatrixShape<4, 4>; -+ -+ dim3 block(16, 8); -+ dim3 grid( -+ (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), -+ (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn), -+ batch_count -+ ); -+ -+ // Launch a GEMM kernel -+ kernel::BatchedGemm< -+ TensorRefCollectionA, -+ TensorRefCollectionB, -+ TensorRefCollectionC, -+ ScalarType, -+ AccumulatorType, -+ OutputTile, -+ InnerProductOp, -+ ConvertOp -+ ><<< grid, block >>>( -+ problem_size, -+ alpha, -+ tensor_a, -+ tensor_b, -+ beta, -+ tensor_c, -+ initial_accum -+ ); -+} -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+// -+// TensorRefCollection* is a type satisfying the TensorRefCollection concept. -+// -+template < -+ typename TensorRefCollectionA, -+ typename TensorRefCollectionB, -+ typename TensorRefCollectionC, -+ typename ScalarType, -+ typename AccumulatorType -+> -+void BatchedGemm( -+ gemm::GemmCoord problem_size, -+ int batch_count, -+ ScalarType alpha, -+ TensorRefCollectionA const& tensor_a, -+ TensorRefCollectionB const& tensor_b, -+ ScalarType beta, -+ TensorRefCollectionC &tensor_c) { -+ -+ BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h -new file mode 100644 -index 0000000..0f3977b ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h -@@ -0,0 +1,345 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued GEMM in device-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kMblock = 4, -+ int kNblock = 4 -+> -+__global__ void GemmComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; -+ int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; -+ int batch_idx = blockIdx.z; -+ -+ tensor_a.add_pointer_offset(batch_idx * batch_stride_A); -+ tensor_b.add_pointer_offset(batch_idx * batch_stride_B); -+ tensor_c.add_pointer_offset(batch_idx * batch_stride_C); -+ tensor_d.add_pointer_offset(batch_idx * batch_stride_D); -+ -+ for (; batch_idx < batch_count; batch_idx += gridDim.z) { -+ -+ // Compute matrix product using blocks -+ ComputeType accum[kMblock][kNblock]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementB b = tensor_b.at(MatrixCoord(k_block, col)); -+ -+ ComputeType a_ik = ComputeType(a); -+ ComputeType b_kj = ComputeType(b); -+ -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_ik = conj(a_ik); -+ } -+ -+ if (transform_b == ComplexTransform::kConjugate) { -+ b_kj = conj(b_kj); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * ScalarType(tensor_c.at(coord))); -+ } -+ } -+ } -+ -+ tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); -+ tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); -+ tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); -+ tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); -+ -+ } // for (batch_idx) -+} -+ -+} // namespace kernel -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void GemmComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ int const kMblock = 4; -+ int const kNblock = 4; -+ -+ dim3 block(16, 8); -+ dim3 grid( -+ (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), -+ (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), -+ batch_count % std::numeric_limits::max() -+ ); -+ -+ if (grid.y <= std::numeric_limits::max()) { -+ kernel::GemmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ScalarType, -+ ComputeType, -+ ConvertOp, -+ InnerProductOp, -+ kMblock, -+ kNblock -+ ><<< grid, block >>>( -+ problem_size, -+ alpha, -+ tensor_a, -+ transform_a, -+ tensor_b, -+ transform_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ initial_accum, -+ batch_count, -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C, -+ batch_stride_D -+ ); -+ } else { -+ // Using bigger thread tile size -+ int const kBigMblock = 4; -+ int const kBigNblock = 16; -+ -+ dim3 Bigblock(16, 8); -+ dim3 Biggrid( -+ (problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock), -+ (problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock), -+ batch_count % std::numeric_limits::max() -+ ); -+ -+ kernel::GemmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ScalarType, -+ ComputeType, -+ ConvertOp, -+ InnerProductOp, -+ kBigMblock, -+ kBigNblock -+ ><<< Biggrid, Bigblock >>>( -+ problem_size, -+ alpha, -+ tensor_a, -+ transform_a, -+ tensor_b, -+ transform_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ initial_accum, -+ batch_count, -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C, -+ batch_stride_D -+ ); -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void GemmComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d) { -+ -+ GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h -new file mode 100644 -index 0000000..baab696 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h -@@ -0,0 +1,311 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued GEMM in device code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_ref_planar_complex.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static int const kGemmPlanarComplexBlockSize = 4; -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add> -+> -+__global__ void GemmPlanarComplex( -+ gemm::GemmCoord problem_size, -+ complex alpha, -+ TensorRefPlanarComplex tensor_a, -+ ComplexTransform transform_a, -+ TensorRefPlanarComplex tensor_b, -+ ComplexTransform transform_b, -+ complex beta, -+ TensorRefPlanarComplex tensor_c, -+ TensorRefPlanarComplex tensor_d, -+ complex initial_accum) { -+ -+ int const kMblock = kGemmPlanarComplexBlockSize; -+ int const kNblock = kGemmPlanarComplexBlockSize; -+ -+ using ComplexA = typename TensorRefPlanarComplex::ComplexElement; -+ using ComplexB = typename TensorRefPlanarComplex::ComplexElement; -+ using ComplexC = typename TensorRefPlanarComplex::ComplexElement; -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ complex accum[kMblock][kNblock]; -+ -+ int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; -+ int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int k_block = 0; k_block < K; ++k_block) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ -+ ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); -+ ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); -+ -+ complex a = complex{ -+ ComputeType(a_ik.real()), -+ ComputeType(a_ik.imag()) -+ }; -+ -+ complex b = complex{ -+ ComputeType(b_kj.real()), -+ ComputeType(b_kj.imag()) -+ }; -+ -+ if (transform_a == ComplexTransform::kConjugate) { -+ a = conj(a); -+ } -+ -+ if (transform_b == ComplexTransform::kConjugate) { -+ b = conj(b); -+ } -+ -+ accum[i][j] = inner_product_op(a, b, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ -+ complex acc{ -+ ScalarType(accum[i][j].real()), -+ ScalarType(accum[i][j].imag()) -+ }; -+ -+ ComplexC c_ij = ComplexC(); -+ -+ if (beta.real() != ScalarType() || beta.imag() != ScalarType()) { -+ c_ij = tensor_c.at(coord); -+ } -+ -+ complex src{ -+ ScalarType(c_ij.real()), -+ ScalarType(c_ij.imag()) -+ }; -+ -+ complex result = alpha * acc + beta * src; -+ -+ ComplexC d_ij; -+ -+ d_ij.real() = convert_op(result.real()); -+ d_ij.imag() = convert_op(result.imag()); -+ -+ tensor_d.at(coord) = d_ij; -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add> -+> -+void GemmPlanarComplex( -+ gemm::GemmCoord problem_size, -+ complex alpha, -+ TensorRefPlanarComplex tensor_a, -+ ComplexTransform transform_a, -+ TensorRefPlanarComplex tensor_b, -+ ComplexTransform transform_b, -+ complex beta, -+ TensorRefPlanarComplex tensor_c, -+ TensorRefPlanarComplex tensor_d, -+ complex initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ int const kMblock = kernel::kGemmPlanarComplexBlockSize; -+ int const kNblock = kernel::kGemmPlanarComplexBlockSize; -+ -+ dim3 block(16, 8); -+ -+ dim3 grid( -+ (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), -+ (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), -+ 1); -+ -+ kernel::GemmPlanarComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ScalarType, -+ ComputeType, -+ ConvertOp, -+ InnerProductOp -+ ><<< grid, block >>>( -+ problem_size, -+ alpha, -+ tensor_a, -+ transform_a, -+ tensor_b, -+ transform_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ initial_accum -+ ); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void GemmPlanarComplex( -+ gemm::GemmCoord problem_size, -+ complex alpha, -+ TensorRefPlanarComplex tensor_a, -+ ComplexTransform transform_a, -+ TensorRefPlanarComplex tensor_b, -+ ComplexTransform transform_b, -+ complex beta, -+ TensorRefPlanarComplex tensor_c, -+ TensorRefPlanarComplex tensor_d) { -+ -+ GemmPlanarComplex( -+ problem_size, -+ alpha, -+ tensor_a, transform_a, -+ tensor_b, transform_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ complex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h -new file mode 100644 -index 0000000..e917765 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h -@@ -0,0 +1,162 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GEMM in host-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/util/reference/device/thread/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename TensorRefA, -+ typename TensorRefB, -+ typename TensorRefC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename OutputTile, -+ typename InnerProductOp, -+ typename ConvertOp -+> -+__global__ void Gemm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRefA tensor_a, -+ TensorRefB tensor_b, -+ ScalarType beta, -+ TensorRefC tensor_c, -+ TensorRefC tensor_d, -+ AccumulatorType initial_accum) { -+ -+ // Map each thread to a unique tile of the output matrix -+ MatrixCoord output_coord( -+ MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), -+ MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) -+ ); -+ -+ // Compute the general matrix product -+ thread::Gemm< -+ TensorRefA, -+ TensorRefB, -+ TensorRefC, -+ ScalarType, -+ AccumulatorType, -+ OutputTile, -+ InnerProductOp, -+ ConvertOp -+ > gemm(initial_accum); -+ -+ gemm.multiply_add( -+ problem_size, -+ tensor_a, -+ tensor_b, -+ output_coord); -+ -+ gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename TensorRefCollectionA, -+ typename TensorRefCollectionB, -+ typename TensorRefCollectionC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename OutputTile, -+ typename InnerProductOp, -+ typename ConvertOp -+> -+__global__ void BatchedGemm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRefCollectionA tensor_collection_a, -+ TensorRefCollectionB tensor_collection_b, -+ ScalarType beta, -+ TensorRefCollectionC tensor_collection_c, -+ AccumulatorType initial_accum) { -+ -+ // Obtain batch ID -+ int batch_id = blockIdx.z; -+ -+ // Dereference based on batch_id -+ typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id); -+ typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id); -+ typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id); -+ -+ // Map each thread to a unique tile of the output matrix -+ MatrixCoord output_coord( -+ (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn, -+ (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow -+ ); -+ -+ // Compute the general matrix product -+ thread::Gemm< -+ typename TensorRefCollectionA::TensorRef, -+ typename TensorRefCollectionB::TensorRef, -+ typename TensorRefCollectionC::TensorRef, -+ ScalarType, -+ AccumulatorType, -+ OutputTile, -+ InnerProductOp, -+ ConvertOp -+ > gemm(initial_accum); -+ -+ gemm.multiply_add( -+ problem_size, -+ tensor_a, -+ tensor_b, -+ output_coord); -+ -+ gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h -new file mode 100644 -index 0000000..4850b98 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h -@@ -0,0 +1,168 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to initialize tensor to uniform random distribution -+template -+__global__ void TensorInitializeUniform( -+ Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { -+ __shared__ curandState_t rng_state[1024]; -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; -+ -+ curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); -+ -+ int c_idx = blockIdx.x * blockDim.x + threadIdx.x; -+ int s_idx = blockIdx.y * blockDim.x; -+ -+ tensor += s_idx * ldm + c_idx; -+ -+ for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { -+ if (s_idx < dim_strided && c_idx < dim_contiguous) { -+ double range = dist.uniform.max - dist.uniform.min; -+ -+ double rnd = curand_uniform(&rng_state[threadIdx.x]); -+ -+ rnd = dist.uniform.min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ if (dist.int_scale >= 0) { -+ rnd = double(int(rnd * double(1 << dist.int_scale))); -+ *tensor = T(rnd / double(1 << dist.int_scale)); -+ } else { -+ *tensor = T(rnd); -+ } -+ -+ tensor += ldm; -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to initialize tensor to uniform distribution -+template -+__global__ void TensorInitializeGaussian( -+ Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { -+ __shared__ curandState_t rng_state[1024]; -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; -+ -+ curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); -+ -+ int c_idx = blockIdx.x * blockDim.x + threadIdx.x; -+ int s_idx = blockIdx.y * blockDim.x; -+ -+ tensor += s_idx * ldm + c_idx; -+ -+ for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { -+ if (s_idx < dim_strided && c_idx < dim_contiguous) { -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ -+ double rnd = curand_normal(&rng_state[threadIdx.x]); -+ -+ rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd; -+ -+ if (dist.int_scale >= 0) { -+ rnd = double(int(rnd * double(1 << dist.int_scale))); -+ *tensor = T(rnd / double(1 << dist.int_scale)); -+ } else { -+ *tensor = T(rnd); -+ } -+ } -+ } -+} -+ -+/// Kernel to initialize tensor to an identity matrix -+template -+__global__ void TensorInitializeLinear( -+ Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { -+ __shared__ curandState_t rng_state[1024]; -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; -+ -+ curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); -+ -+ int c_idx = blockIdx.x * blockDim.x + threadIdx.x; -+ int s_idx = blockIdx.y * blockDim.x; -+ -+ tensor += s_idx * ldm + c_idx; -+ -+ for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { -+ if (s_idx < dim_strided && c_idx < dim_contiguous) { -+ *tensor = -+ dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx; -+ } -+ } -+} -+ -+/// Kernel to initialize tensor to an identity matrix -+template -+__global__ void TensorInitializeIdentity( -+ Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { -+ __shared__ curandState_t rng_state[1024]; -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; -+ -+ curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); -+ -+ int c_idx = blockIdx.x * blockDim.x + threadIdx.x; -+ int s_idx = blockIdx.y * blockDim.x; -+ -+ tensor += s_idx * ldm + c_idx; -+ -+ for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { -+ if (s_idx < dim_strided && c_idx < dim_contiguous) { -+ *tensor = (c_idx == s_idx ? T(1) : T(0)); -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h -new file mode 100644 -index 0000000..ea5359f ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h -@@ -0,0 +1,159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+#include "cutlass/subbyte_reference.h" -+#include "cutlass/fast_math.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+namespace kernel { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines several helpers -+namespace detail { -+ -+/// Helper to perform for-each operation -+template -+struct TensorForEachHelper { -+ -+ /// Constructor for general rank -+ __inline__ __device__ -+ TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { -+ -+ int64_t product = 1; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = Rank - RankRemaining; i < Rank; ++i) { -+ product *= size[i]; -+ } -+ -+ coord[Rank - 1 - RankRemaining] = index / product; -+ int64_t remaining = index % product; -+ -+ TensorForEachHelper(func, size, coord, remaining); -+ } -+}; -+ -+/// Helper to perform for-each operation -+template -+struct TensorForEachHelper { -+ -+ /// Constructor for fastest chaning rank -+ __inline__ __device__ -+ TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { -+ -+ coord[Rank - 1] = index; -+ -+ if (coord < size) { -+ func(coord); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel calls a functor for each element in a tensor's index space -+template -+__global__ void TensorForEach(Coord size, Params params = Params()) { -+ -+ Func func(params); -+ -+ int64_t index = threadIdx.x + blockIdx.x * blockDim.x; -+ int64_t max_index = 1; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank; ++i) { -+ max_index *= size[i]; -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ while (index < max_index) { -+ Coord coord; -+ -+ detail::TensorForEachHelper(func, size, coord, index); -+ index += blockDim.x * gridDim.x; -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel calls a functor for each element along a tensor's diagonal -+template -+__global__ void TensorDiagonalForEach(Coord size, Params params, int start, int end) { -+ -+ Func func(params); -+ -+ int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start; -+ -+ if (index < end) { -+ Coord coord; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank; ++i) { -+ coord[i] = index; -+ } -+ -+ func(coord); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void BlockForEach( -+ Element *ptr, -+ size_t capacity, -+ typename Func::Params params) { -+ -+ Func func(params); -+ -+ size_t index = threadIdx.x + blockIdx.x * blockDim.x; -+ -+ for (; index < capacity; index += blockDim.x * gridDim.x) { -+ ReferenceFactory::get(ptr, index) = func(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace device -+} // namespace reference -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h -new file mode 100644 -index 0000000..357ca3c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h -@@ -0,0 +1,355 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued GEMM in device-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kMblock = 4, -+ int kNblock = 4 -+> -+__global__ void Rank2KComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ FillMode fill_mode_c, -+ BlasMode blas_mode, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ assert(M=N); -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; -+ int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; -+ int batch_idx = blockIdx.z; -+ -+ tensor_a.add_pointer_offset(batch_idx * batch_stride_A); -+ tensor_b.add_pointer_offset(batch_idx * batch_stride_B); -+ tensor_c.add_pointer_offset(batch_idx * batch_stride_C); -+ tensor_d.add_pointer_offset(batch_idx * batch_stride_D); -+ -+ for (; batch_idx < batch_count; batch_idx += gridDim.z) { -+ -+ // Compute matrix product using blocks -+ ComputeType accum[kMblock][kNblock]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N && -+ ( (fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col) ) -+ ) { -+ -+ // A x B^T (Symmetric) or A x B^H (Hermitian) -+ // complex conjugation on operandB (b_t) is function of blas3 computation -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementB b_t = (blas_mode == BlasMode::kHermitian) ? -+ conj(tensor_b.at(MatrixCoord(col, k_block))) : -+ tensor_b.at(MatrixCoord(col, k_block)); -+ -+ ComputeType a_ik = ComputeType(a); -+ ComputeType b_jk = ComputeType(b_t); -+ -+ // complex conjugation is a function of operand layouts -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_ik = conj(a_ik); -+ } -+ // complex conjugation is a function of operand layouts -+ if (transform_b == ComplexTransform::kConjugate) { -+ b_jk = conj(b_jk); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); -+ -+ // B x A^T (Symmetric) or B x A^H (Hermitian) -+ // complex conjugation on operandB (a_t) is function of blas3 computation -+ ElementB b = tensor_b.at(MatrixCoord(row, k_block)); -+ ElementA a_t = (blas_mode == BlasMode::kHermitian) ? -+ conj(tensor_a.at(MatrixCoord(col, k_block))): -+ tensor_a.at(MatrixCoord(col, k_block)); -+ -+ ComputeType b_ik = ComputeType(b); -+ ComputeType a_jk = ComputeType(a_t); -+ -+ // complex conjugation here is a function of operand layouts -+ if (transform_b == ComplexTransform::kConjugate) { -+ b_ik = conj(b_ik); -+ } -+ // complex conjugation here is a function of operand layouts -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_jk = conj(a_jk); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N && -+ ((fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col)) -+ ) { -+ -+ ScalarType c = tensor_c.at(coord); -+ // The imaginary parts of the diagonal elements of -+ // a complex data type are assumed and set to zero -+ if (blas_mode == BlasMode::kHermitian) { -+ c = (row == col) ? real(c) : c; -+ } -+ -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * c); -+ } -+ } -+ } -+ -+ tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); -+ tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); -+ tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); -+ tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); -+ -+ } // for (batch_idx) -+} -+ -+} // namespace kernel -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Rank2KComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ FillMode fill_mode_c, -+ BlasMode blas_mode, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ int const kMblock = 4; -+ int const kNblock = 4; -+ -+ dim3 block(16, 8); -+ dim3 grid( -+ (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), -+ (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), -+ batch_count % std::numeric_limits::max() -+ ); -+ -+ kernel::Rank2KComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ScalarType, -+ ComputeType, -+ ConvertOp, -+ InnerProductOp, -+ kMblock, -+ kNblock -+ ><<< grid, block >>>( -+ problem_size, -+ alpha, -+ tensor_a, -+ transform_a, -+ tensor_b, -+ transform_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ initial_accum, -+ fill_mode_c, -+ blas_mode, -+ batch_count, -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C, -+ batch_stride_D -+ ); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void Rank2KComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ FillMode fill_mode_c, -+ BlasMode blas_mode) { -+ -+ Rank2KComplex( -+ problem_size, alpha, -+ tensor_a, transform_a, -+ tensor_b, transform_b, -+ beta, tensor_c, tensor_d, -+ ScalarType(0), -+ fill_mode_c, -+ blas_mode); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h -new file mode 100644 -index 0000000..e29ad69 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h -@@ -0,0 +1,246 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines host-side elementwise operations on TensorView. -+*/ -+ -+#pragma once -+// Standard Library includes -+#include -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/relatively_equal.h" -+ -+#include "cutlass/util/distribution.h" -+ -+#include "tensor_foreach.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+template -+__global__ void BlockCompareEqual( -+ int *equal, -+ Element const *ptr_A, -+ Element const *ptr_B, -+ size_t capacity) { -+ -+ size_t idx = threadIdx.x + blockDim.x * blockIdx.x; -+ -+ for (; idx < capacity; idx += gridDim.x * blockDim.x) { -+ -+ Element a = cutlass::ReferenceFactory::get(ptr_A, idx); -+ Element b = cutlass::ReferenceFactory::get(ptr_B, idx); -+ -+ if (a != b) { -+ *equal = 0; -+ -+ return; -+ } -+ } -+} -+ -+template -+__global__ void BlockCompareRelativelyEqual( -+ int *equal, -+ Element const *ptr_A, -+ Element const *ptr_B, -+ size_t capacity, -+ Element epsilon, -+ Element nonzero_floor) { -+ -+ size_t idx = threadIdx.x + blockDim.x * blockIdx.x; -+ -+ for (; idx < capacity; idx += gridDim.x * blockDim.x) { -+ -+ Element a = cutlass::ReferenceFactory::get(ptr_A, idx); -+ Element b = cutlass::ReferenceFactory::get(ptr_B, idx); -+ -+ if (!relatively_equal(a, b, epsilon, nonzero_floor)) { -+ *equal = 0; -+ return; -+ } -+ } -+} -+ -+} // namespace kernel -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Performs a bit-level equality check between two blocks -+template -+bool BlockCompareEqual( -+ Element const *ptr_A, -+ Element const *ptr_B, -+ size_t capacity, -+ int grid_size = 0, -+ int block_size = 0) { -+ -+ int equal_flag = 1; -+ int *device_equal_flag = nullptr; -+ -+ if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { -+ throw std::runtime_error("Failed to allocate device flag."); -+ } -+ -+ if (cudaMemcpy( -+ device_equal_flag, -+ &equal_flag, -+ sizeof(int), -+ cudaMemcpyHostToDevice) != cudaSuccess) { -+ -+ throw std::runtime_error("Failed to copy equality flag to device."); -+ } -+ -+ if (!grid_size || !block_size) { -+ -+ // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API -+ cudaError_t result = cudaOccupancyMaxPotentialBlockSize( -+ &grid_size, -+ &block_size, -+ reinterpret_cast(kernel::BlockCompareEqual)); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed to query occupancy."); -+ } -+ -+ // Limit block size. This has the effect of increasing the number of items processed by a -+ // single thread and reduces the impact of initialization overhead. -+ block_size = (block_size < 128 ? block_size : 128); -+ } -+ -+ dim3 grid(grid_size, 1, 1); -+ dim3 block(block_size, 1, 1); -+ -+ kernel::BlockCompareEqual<<< grid, block >>>(device_equal_flag, ptr_A, ptr_B, capacity); -+ -+ if (cudaMemcpy( -+ &equal_flag, -+ device_equal_flag, -+ sizeof(int), -+ cudaMemcpyDeviceToHost) != cudaSuccess) { -+ -+ cudaFree(device_equal_flag); -+ -+ throw std::runtime_error("Failed to copy equality flag from device."); -+ } -+ -+ cudaFree(device_equal_flag); -+ -+ return equal_flag; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Performs a bit-level equality check between two blocks -+template -+bool BlockCompareRelativelyEqual( -+ Element const *ptr_A, -+ Element const *ptr_B, -+ size_t capacity, -+ Element epsilon, -+ Element nonzero_floor, -+ int grid_size = 0, -+ int block_size = 0) { -+ -+ int equal_flag = 1; -+ int *device_equal_flag = nullptr; -+ -+ if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { -+ throw std::runtime_error("Failed to allocate device flag."); -+ } -+ -+ if (cudaMemcpy( -+ device_equal_flag, -+ &equal_flag, -+ sizeof(int), -+ cudaMemcpyHostToDevice) != cudaSuccess) { -+ -+ throw std::runtime_error("Failed to copy equality flag to device."); -+ } -+ -+ if (!grid_size || !block_size) { -+ -+ // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API -+ cudaError_t result = cudaOccupancyMaxPotentialBlockSize( -+ &grid_size, -+ &block_size, -+ reinterpret_cast(kernel::BlockCompareRelativelyEqual)); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed to query occupancy."); -+ } -+ -+ // Limit block size. This has the effect of increasing the number of items processed by a -+ // single thread and reduces the impact of initialization overhead. -+ block_size = (block_size < 128 ? block_size : 128); -+ } -+ -+ dim3 grid(grid_size, 1, 1); -+ dim3 block(block_size, 1, 1); -+ -+ kernel::BlockCompareRelativelyEqual<<< grid, block >>>( -+ device_equal_flag, -+ ptr_A, -+ ptr_B, -+ capacity, -+ epsilon, -+ nonzero_floor -+ ); -+ -+ if (cudaMemcpy( -+ &equal_flag, -+ device_equal_flag, -+ sizeof(int), -+ cudaMemcpyDeviceToHost) != cudaSuccess) { -+ -+ cudaFree(device_equal_flag); -+ -+ throw std::runtime_error("Failed to copy equality flag from device."); -+ } -+ -+ cudaFree(device_equal_flag); -+ -+ return equal_flag; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // device -+} // reference -+} // cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h -new file mode 100644 -index 0000000..8568e47 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h -@@ -0,0 +1,1898 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines device-side elementwise operations on TensorView. Note, the operations defined -+ in this header are not specialized for any particular data layout and are therefore not -+ intended to offer the best possible performance. Rather, they are intended to be generic -+ reference implementations to support the CUTLASS unit tests. -+*/ -+ -+#pragma once -+ -+#if !defined(__CUDACC_RTC__) -+ -+// Standard Library includes -+#include -+#include -+#include -+#include -+#include -+ -+#endif -+ -+// CUDA includes -+#include -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/complex.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/util/reference/device/tensor_foreach.h" -+#include "cutlass/util/distribution.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+CUTLASS_DEVICE -+FloatType random_normal_float(curandState_t *state) { -+ return curand_normal(state); -+} -+ -+template <> -+CUTLASS_DEVICE -+double random_normal_float(curandState_t *state) { -+ return curand_normal_double(state); -+} -+ -+template -+CUTLASS_DEVICE -+FloatType random_uniform_float(curandState_t *state) { -+ return curand_uniform(state); -+} -+ -+template <> -+CUTLASS_DEVICE -+double random_uniform_float(curandState_t *state) { -+ return curand_uniform_double(state); -+} -+ -+template -+struct RandomGaussianFunc { -+ -+ using FloatType = typename std::conditional<(sizeof(Element) > 4), double, float>::type; -+ using IntType = typename std::conditional<(sizeof(Element) > 4), int64_t, int>::type; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ uint64_t seed; -+ FloatType mean; -+ FloatType stddev; -+ int int_scale; -+ FloatType float_scale_up; -+ FloatType float_scale_down; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ uint64_t seed_ = 0, -+ Element mean_ = 0, -+ Element stddev_ = 1, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), -+ mean(static_cast(mean_)), -+ stddev(static_cast(stddev_)), -+ int_scale(int_scale_) { -+ -+ float_scale_up = FloatType(IntType(1) << int_scale); -+ float_scale_up += FloatType(0.5) * float_scale_up; -+ float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ /// RNG state object -+ curandState_t rng_state; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ RandomGaussianFunc(Params const ¶ms): params(params) { -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; -+ -+ curand_init(params.seed, gtid, 0, &rng_state); -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ Element operator()() { -+ -+ FloatType rnd = random_normal_float(&rng_state); -+ rnd = params.mean + params.stddev * rnd; -+ -+ Element result; -+ if (params.int_scale >= 0) { -+ rnd = FloatType(IntType(rnd * params.float_scale_up)); -+ result = Element(rnd * params.float_scale_down); -+ } -+ else { -+ result = Element(rnd); -+ } -+ -+ return result; -+ } -+}; -+ -+ -+template -+struct RandomGaussianFunc> { -+ -+ using Element = complex; -+ using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type; -+ using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ uint64_t seed; -+ FloatType mean; -+ FloatType stddev; -+ int int_scale; -+ FloatType float_scale_up; -+ FloatType float_scale_down; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ uint64_t seed_ = 0, -+ Real mean_ = 0, -+ Real stddev_ = 1, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), -+ mean(static_cast(mean_)), -+ stddev(static_cast(stddev_)), -+ int_scale(int_scale_) { -+ -+ float_scale_up = FloatType(IntType(1) << int_scale); -+ float_scale_up += FloatType(0.5) * float_scale_up; -+ float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ /// RNG state object -+ curandState_t rng_state; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ RandomGaussianFunc(Params const ¶ms): params(params) { -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; -+ -+ curand_init(params.seed, gtid, 0, &rng_state); -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ Element operator()() { -+ -+ FloatType rnd_r = random_normal_float(&rng_state); -+ FloatType rnd_i = random_normal_float(&rng_state); -+ rnd_r = params.mean + params.stddev * rnd_r; -+ rnd_i = params.mean + params.stddev * rnd_i; -+ -+ Element result; -+ if (params.int_scale >= 0) { -+ rnd_r = FloatType(IntType(rnd_r * params.float_scale_up)); -+ rnd_i = FloatType(IntType(rnd_i * params.float_scale_down)); -+ -+ result = { -+ Real(rnd_r * params.float_scale_down), -+ Real(rnd_i * params.float_scale_down) -+ }; -+ } -+ else { -+ result = Element(Real(rnd_r), Real(rnd_i)); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillRandomGaussianFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ using RandomFunc = RandomGaussianFunc; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ typename RandomFunc::Params random; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_ = TensorView(), -+ typename RandomFunc::Params random_ = typename RandomFunc::Params() -+ ): -+ view(view_), random(random_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ Params params; -+ RandomFunc random; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorFillRandomGaussianFunc(Params const ¶ms): params(params), random(params.random) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ params.view.at(coord) = random(); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a Gaussian distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomGaussian( -+ TensorView view, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ Element mean = Element(0), ///< Gaussian distribution's mean -+ Element stddev = Element(1), ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ using RandomFunc = detail::RandomGaussianFunc; -+ using Func = detail::TensorFillRandomGaussianFunc; -+ using Params = typename Func::Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, typename RandomFunc::Params(seed, mean, stddev, bits)) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a Gaussian distribution. -+template ///< Element type -+void BlockFillRandomGaussian( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, ///< seed for RNG -+ typename RealType::Type mean, ///< Gaussian distribution's mean -+ typename RealType::Type stddev, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ using RandomFunc = detail::RandomGaussianFunc; -+ -+ typename RandomFunc::Params params(seed, mean, stddev, bits); -+ -+ BlockForEach(ptr, capacity, params); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Computes a random Gaussian distribution -+template ///< Element type -+struct RandomUniformFunc { -+ -+ using FloatType = typename std::conditional< -+ (sizeof(Element) > 4), -+ double, -+ float>::type; -+ -+ using IntType = typename std::conditional< -+ (sizeof(Element) > 4), -+ int64_t, -+ int>::type; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ uint64_t seed; -+ FloatType range; -+ FloatType max; -+ int int_scale; -+ FloatType float_scale_up; -+ FloatType float_scale_down; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ uint64_t seed_ = 0, -+ Element max_ = 1, -+ Element min = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), -+ range(static_cast(max_ - min)), -+ max(static_cast(max_)), -+ int_scale(int_scale_) { -+ -+ float_scale_up = FloatType(IntType(1) << int_scale); -+ float_scale_up += FloatType(0.5) * float_scale_up; -+ float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ /// RNG state object -+ curandState_t rng_state; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ RandomUniformFunc(Params const ¶ms): params(params) { -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; -+ -+ curand_init(params.seed, gtid, 0, &rng_state); -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ Element operator()() { -+ -+ FloatType rnd = random_uniform_float(&rng_state); -+ rnd = params.max - params.range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ Element result; -+ -+ if (params.int_scale >= 0) { -+ rnd = FloatType(IntType(rnd * params.float_scale_up)); -+ result = Element(rnd * params.float_scale_down); -+ } -+ else { -+ result = Element(rnd); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Computes a random Gaussian distribution -+template -+struct RandomUniformFunc> { -+ -+ using Element = complex; -+ -+ using FloatType = typename std::conditional< -+ (sizeof(Real) > 4), -+ double, -+ float>::type; -+ -+ using IntType = typename std::conditional< -+ (sizeof(Real) > 4), -+ int64_t, -+ int>::type; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ uint64_t seed; -+ FloatType range; -+ FloatType min; -+ int int_scale; -+ FloatType float_scale_up; -+ FloatType float_scale_down; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ uint64_t seed_ = 0, -+ FloatType max = 1, -+ FloatType min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), -+ range(static_cast(max - min_)), -+ min(static_cast(min_)), -+ int_scale(int_scale_) { -+ -+ float_scale_up = FloatType(IntType(1) << int_scale); -+ float_scale_up += FloatType(0.5) * float_scale_up; -+ float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ /// RNG state object -+ curandState_t rng_state; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ RandomUniformFunc(Params const ¶ms): params(params) { -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; -+ -+ curand_init(params.seed, gtid, 0, &rng_state); -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ Element operator()() { -+ -+ FloatType rnd_r = random_uniform_float(&rng_state); -+ FloatType rnd_i = random_uniform_float(&rng_state); -+ -+ rnd_r = params.min + params.range * rnd_r; -+ rnd_i = params.min + params.range * rnd_i; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ Element result; -+ -+ if (params.int_scale >= 0) { -+ rnd_r = FloatType(IntType(rnd_r * params.float_scale_up)); -+ rnd_i = FloatType(IntType(rnd_i * params.float_scale_up)); -+ -+ result = { -+ Real(rnd_r * params.float_scale_down), -+ Real(rnd_i * params.float_scale_down) -+ }; -+ } -+ else { -+ result = Element(Real(rnd_r), Real(rnd_i)); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillRandomUniformFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ using RandomFunc = RandomUniformFunc; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ typename RandomFunc::Params random; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_ = TensorView(), -+ typename RandomFunc::Params random_ = RandomFunc::Params() -+ ): -+ view(view_), random(random_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ Params params; -+ RandomFunc random; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorFillRandomUniformFunc(Params const ¶ms): params(params), random(params.random) { -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ params.view.at(coord) = random(); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomUniform( -+ TensorView view, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ Element max = Element(1), ///< upper bound of distribution -+ Element min = Element(0), ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ using RandomFunc = detail::RandomUniformFunc; -+ using Func = detail::TensorFillRandomUniformFunc; -+ using Params = typename Func::Params; -+ -+ typename RandomFunc::Params random(seed, max, min, bits); -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, random) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template -+void BlockFillRandomUniform( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, ///< seed for RNG -+ typename RealType::Type max, ///< upper bound of distribution -+ typename RealType::Type min, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ using RandomFunc = detail::RandomUniformFunc; -+ -+ typename RandomFunc::Params params(seed, max, min, bits); -+ -+ BlockForEach(ptr, capacity, params); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Computes a random sparse meta -+template ///< Element type -+struct RandomSparseMetaFunc { -+ -+ using FloatType = float; -+ -+ using IntType = int32_t; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ uint64_t seed; -+ FloatType range; -+ int MetaSizeInBits; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ uint64_t seed_ = 0, -+ int MetaSizeInBits_ = 2 -+ ): -+ seed(seed_), -+ MetaSizeInBits(MetaSizeInBits_) { -+ if (MetaSizeInBits_ == 2) { -+ range = 6; -+ } else if (MetaSizeInBits_ == 4) { -+ range = 2; -+ } -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ /// RNG state object -+ curandState_t rng_state; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ RandomSparseMetaFunc(Params const ¶ms): params(params) { -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; -+ -+ curand_init(params.seed, gtid, 0, &rng_state); -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ Element operator()() { -+ Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; -+ Element TwoToOneMeta[2] = {0x4, 0xe}; -+ -+ Element *MetaArray = -+ (params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; -+ -+ Element result = 0x0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { -+ FloatType rnd = random_uniform_float(&rng_state); -+ rnd = params.range * rnd; -+ Element meta = MetaArray[(int)rnd]; -+ -+ result = (Element)(result | ((Element)(meta << (i * 4)))); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillRandomSparseMetaFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ using RandomFunc = RandomSparseMetaFunc; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ typename RandomFunc::Params random; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_ = TensorView(), -+ typename RandomFunc::Params random_ = RandomFunc::Params() -+ ): -+ view(view_), random(random_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ Params params; -+ RandomFunc random; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorFillRandomSparseMetaFunc(Params const ¶ms): params(params), random(params.random) { -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ params.view.at(coord) = random(); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomSparseMeta( -+ TensorView view, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ int MetaSizeInBits = 2) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ using RandomFunc = detail::RandomSparseMetaFunc; -+ using Func = detail::TensorFillRandomUniformFunc; -+ using Params = typename Func::Params; -+ -+ typename RandomFunc::Params random(seed, MetaSizeInBits); -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, random) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template -+void BlockFillRandomSparseMeta( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, ///< seed for RNG -+ int MetaSizeInBits = 2) { ///< meta data size -+ -+ using RandomFunc = detail::RandomSparseMetaFunc; -+ -+ typename RandomFunc::Params params(seed, MetaSizeInBits); -+ -+ BlockForEach(ptr, capacity, params); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillDiagonalFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element diag; -+ Element other; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ Params( -+ TensorView view_ = TensorView(), -+ Element diag_ = Element(1), -+ Element other_ = Element(0) -+ ): -+ view(view_), diag(diag_), other(other_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorFillDiagonalFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Updates the tensor -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ bool is_diag = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[i - 1]) { -+ is_diag = false; -+ break; -+ } -+ } -+ -+ params.view.at(coord) = (is_diag ? params.diag : params.other); -+ } -+}; -+ -+// Overwrites the elements of a tensor with a uniform value depending on fill mode -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillPartialFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element element; -+ FillMode fill_mode; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params(): fill_mode(FillMode::kNone) { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_, -+ Element element_, -+ FillMode fill_mode_ -+ ): -+ view(view_), element(element_), fill_mode(fill_mode_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ TensorFillPartialFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Overwrites the element if it is within the covered region. -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ bool predicate = true; -+ -+ switch (params.fill_mode) { -+ case FillMode::kFull: -+ predicate = true; -+ break; -+ -+ case FillMode::kLower: -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i - 1] < coord[i]) { -+ predicate = false; -+ break; -+ } -+ } -+ break; -+ -+ case FillMode::kUpper: -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i - 1] > coord[i]) { -+ predicate = false; -+ break; -+ } -+ } -+ break; -+ -+ case FillMode::kDiagonal: -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i - 1] != coord[i]) { -+ predicate = false; -+ break; -+ } -+ } -+ break; -+ -+ case FillMode::kNone: // fall-through -+ -+ default: -+ predicate = false; -+ break; -+ } -+ -+ if (predicate) { -+ params.view.at(coord) = params.element; -+ } -+ } -+}; -+ -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorClearPartialFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// -+ static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices"); -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element element; -+ FillMode fill_mode; -+ int alignment; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params(): fill_mode(FillMode::kNone) { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_, -+ Element element_, -+ FillMode fill_mode_, -+ int alignment_ -+ ): -+ view(view_), element(element_), fill_mode(fill_mode_), alignment(alignment_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ TensorClearPartialFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Overwrites the element if it is within the covered region. -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ bool predicate = true; -+ -+ switch (params.fill_mode) { -+ -+ case FillMode::kLower: -+ if ((coord[0] >= coord[1]) || -+ ((coord[1] - coord[0]) >= params.alignment)) { -+ predicate = false; -+ break; -+ } -+ break; -+ -+ case FillMode::kUpper: -+ if ((coord[0] <= coord[1]) || -+ ((coord[0] - coord[1]) >= params.alignment)) { -+ predicate = false; -+ break; -+ } -+ break; -+ -+ case FillMode::kNone: // fall-through -+ -+ default: -+ predicate = false; -+ break; -+ } -+ -+ if (predicate) { -+ params.view.at(coord) = params.element; -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor everywhere with a unique value for its diagonal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillDiagonal( -+ TensorView view, ///< destination tensor -+ Element diag = Element(1), ///< value to write in the diagonal -+ Element other = Element(0)) { ///< value to write off the diagonal -+ -+ typedef detail::TensorFillDiagonalFunc Func; -+ typedef typename Func::Params Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, diag, other) -+ ); -+} -+ -+/// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are -+/// not written. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillPartial( -+ TensorView view, ///< destination tensor -+ Element element, -+ FillMode fill_mode) { -+ -+ typedef detail::TensorFillPartialFunc Func; -+ typedef typename Func::Params Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, element, fill_mode) -+ ); -+} -+ -+/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side -+/// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros) -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorClearPartial( -+ TensorView view, ///< destination tensor -+ Element element, -+ FillMode fill_mode, -+ int alignment) { -+ -+ typedef detail::TensorClearPartialFunc Func; -+ typedef typename Func::Params Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, element, fill_mode, alignment) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with a uniform value -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFill( -+ TensorView view, ///< destination tensor -+ Element val = Element(0)) { ///< value to uniformly fill it with -+ -+ TensorFillDiagonal(view, val, val); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor's digonal with 1 and 0 everywhere else. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillIdentity( -+ TensorView view) { ///< destination tensor -+ -+ TensorFillDiagonal(view, Element(1), Element(0)); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorUpdateDiagonalFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element diag; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_ = TensorView(), -+ Element diag_ = Element(1) -+ ): -+ view(view_), diag(diag_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorUpdateDiagonalFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ bool is_diag = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[i - 1]) { -+ is_diag = false; -+ break; -+ } -+ } -+ -+ if (is_diag) { -+ params.view.at(coord) = params.diag; -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorUpdateDiagonal( -+ TensorView view, ///< destination tensor -+ Element diag = Element(1)) { -+ -+ typedef detail::TensorUpdateDiagonalFunc Func; -+ typedef typename Func::Params Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, diag) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorUpdateOffDiagonalFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element other; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_ = TensorView(), -+ Element other_ = Element(0) -+ ): -+ view(view_), other(other_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorUpdateOffDiagonalFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ bool is_diag = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[i - 1]) { -+ is_diag = false; -+ break; -+ } -+ } -+ -+ if (!is_diag) { -+ params.view.at(coord) = params.other; -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorUpdateOffDiagonal( -+ TensorView view, ///< destination tensor -+ Element other = Element(1)) { -+ -+ typedef detail::TensorUpdateOffDiagonalFunc Func; -+ typedef typename Func::Params Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, other) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillLinearFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Array v; -+ Element s; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_, ///< destination tensor -+ Array const & v_, -+ Element s_ = Element(0) -+ ): -+ view(view_), v(v_), s(s_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorFillLinearFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ Element sum = params.s; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank; ++i) { -+ sum += params.v[i] * Element(coord[i]); -+ } -+ -+ params.view.at(coord) = sum; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills tensor with a linear combination of its coordinate and another vector -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillLinear( -+ TensorView view, ///< destination tensor -+ Array const & v, -+ Element s = Element(0)) { -+ -+ using Func = detail::TensorFillLinearFunc; -+ using Params = typename Func::Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, v, s) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillSequential( -+ Element *ptr, -+ int64_t capacity, -+ Element v = Element(1), -+ Element s = Element(0)) { -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillRandom( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, -+ Distribution dist) { -+ -+ using Real = typename RealType::Type; -+ -+ if (dist.kind == Distribution::Gaussian) { -+ BlockFillRandomGaussian( -+ ptr, -+ capacity, -+ seed, -+ static_cast(dist.gaussian.mean), -+ static_cast(dist.gaussian.stddev), -+ dist.int_scale); -+ } -+ else if (dist.kind == Distribution::Uniform) { -+ BlockFillRandomUniform( -+ ptr, -+ capacity, -+ seed, -+ static_cast(dist.uniform.max), -+ static_cast(dist.uniform.min), -+ dist.int_scale); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorCopyDiagonalInFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element const *ptr; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_, ///< destination tensor -+ Element const *ptr_ -+ ): -+ view(view_), ptr(ptr_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorCopyDiagonalInFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Only update the diagonal element -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ bool is_diagonal = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[0]) { -+ is_diagonal = false; -+ } -+ } -+ if (is_diagonal) { -+ params.view.at(coord) = params.ptr[coord[0]]; -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies a diagonal in from host memory without modifying off-diagonal elements. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorCopyDiagonalIn( -+ TensorView view, ///< destination tensor -+ Element const *ptr) { ///< dense buffer of elements -+ -+ using Func = detail::TensorCopyDiagonalInFunc; -+ using Params = typename Func::Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, ptr) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+namespace detail { -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorCopyDiagonalOutFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element *ptr; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_, ///< destination tensor -+ Element *ptr_ -+ ): -+ view(view_), ptr(ptr_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorCopyDiagonalOutFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ bool is_diagonal = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[0]) { -+ is_diagonal = false; -+ } -+ } -+ if (is_diagonal) { -+ params.ptr[coord[0]] = params.view.at(coord); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies the diagonal of a tensor into a dense buffer in host memory. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorCopyDiagonalOut( -+ Element *ptr, ///< dense buffer of elements -+ TensorView view) { ///< source tensor -+ -+ using Func = detail::TensorCopyDiagonalOutFunc; -+ using Params = typename Func::Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, ptr) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h -new file mode 100644 -index 0000000..cac558d ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/util/reference/device/kernel/tensor_foreach.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Launches a kernel calling a functor for each element in a tensor's index space. -+template -+struct TensorForEach { -+ -+ /// Constructor performs the operation. -+ TensorForEach(Coord size, Params params = Params(), int grid_size = 0, int block_size = 0) { -+ -+ if (!grid_size || !block_size) { -+ -+ // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API -+ cudaError_t result = cudaOccupancyMaxPotentialBlockSize( -+ &grid_size, -+ &block_size, -+ reinterpret_cast(kernel::TensorForEach)); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed to query occupancy."); -+ } -+ -+ // Limit block size. This has the effect of increasing the number of items processed by a -+ // single thread and reduces the impact of initialization overhead. -+ block_size = (block_size < 128 ? block_size : 128); -+ } -+ -+ dim3 grid(grid_size, 1, 1); -+ dim3 block(block_size, 1, 1); -+ -+ kernel::TensorForEach<<< grid, block >>>(size, params); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Launches a kernel calling a functor for each element along a tensor's diagonal -+template -+struct TensorDiagonalForEach { -+ -+ /// Constructor performs the operation -+ TensorDiagonalForEach(Coord size, Params params = Params(), int start = 0, int end = -1, int block_size = 128) { -+ -+ if (end < 0) { -+ end = size.min(); -+ } -+ -+ dim3 block(block_size, 1, 1); -+ dim3 grid((end - start + block_size - 1) / block_size, 1, 1); -+ -+ kernel::TensorDiagonalForEach<<< grid, block >>>(size, params, start, end); -+ } -+}; -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct BlockForEach { -+ -+ /// Constructor performs the operation. -+ BlockForEach( -+ Element *ptr, -+ size_t capacity, -+ typename Func::Params params = typename Func::Params(), -+ int grid_size = 0, -+ int block_size = 0) { -+ -+ if (!grid_size || !block_size) { -+ -+ // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API -+ cudaError_t result = cudaOccupancyMaxPotentialBlockSize( -+ &grid_size, -+ &block_size, -+ reinterpret_cast(kernel::BlockForEach)); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed to query occupancy."); -+ } -+ -+ // Limit block size. This has the effect of increasing the number of items processed by a -+ // single thread and reduces the impact of initialization overhead. -+ block_size = (block_size < 128 ? block_size : 128); -+ } -+ -+ dim3 grid(grid_size, 1, 1); -+ dim3 block(block_size, 1, 1); -+ -+ kernel::BlockForEach<<< grid, block >>>(ptr, capacity, params); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namesace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h -new file mode 100644 -index 0000000..09c11db ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h -@@ -0,0 +1,510 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/reference/detail/linear_to_coordinate.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp, -+ int kBlockSize = 128 -+> -+__global__ void TensorTransformReducePartial( -+ TensorView view, /// View of the tensor to reduce over -+ ComputeType identity, /// Identity element of the reduction operation -+ ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType -+ TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType -+ ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] -+ -+ int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; -+ int64_t size = view.size(); -+ -+ __shared__ ComputeType scratchpad[kBlockSize]; -+ -+ for (; idx < size; idx += blockDim.x * gridDim.x) { -+ -+ // Map linear thread ID onto tensor coordinate -+ typename Layout::TensorCoord coord; -+ -+ cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); -+ -+ if (view.contains(coord)) { -+ -+ // Fetch element -+ Element x = view.at(coord); -+ -+ // Transform -+ identity = reduce(identity, transform(x)); -+ } -+ } -+ -+ scratchpad[threadIdx.x] = identity; -+ -+ __syncthreads(); -+ -+ // One thread performs the final reduction and stores out. This could be enhanced via -+ // a tree reduction and pipelining. -+ if (threadIdx.x == 0) { -+ -+ for (int i = 1; i < kBlockSize; ++i) { -+ identity = reduce(identity, scratchpad[i]); -+ } -+ -+ workspace[blockIdx.x] = identity; -+ } -+} -+ -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp, -+ int kBlockSize = 128 -+> -+__global__ void TensorTransformReducePartial( -+ TensorView view_A, /// View of the tensor to reduce over -+ TensorView view_B, /// View of the tensor to reduce over -+ ComputeType identity, /// Identity element of the reduction operation -+ ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType -+ TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType -+ ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] -+ -+ int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; -+ int64_t size = view_A.size(); -+ -+ __shared__ ComputeType scratchpad[kBlockSize]; -+ -+ for (; idx < size; idx += blockDim.x * gridDim.x) { -+ -+ // Map linear thread ID onto tensor coordinate -+ typename Layout::TensorCoord coord; -+ -+ cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); -+ -+ if (view_A.contains(coord)) { -+ -+ // Fetch element -+ Element a = view_A.at(coord); -+ Element b = view_B.at(coord); -+ -+ // Transform -+ identity = reduce(identity, transform(a, b)); -+ } -+ } -+ -+ scratchpad[threadIdx.x] = identity; -+ -+ __syncthreads(); -+ -+ // One thread performs the final reduction and stores out. This could be enhanced via -+ // a tree reduction and pipelining. -+ if (threadIdx.x == 0) { -+ -+ for (int i = 1; i < kBlockSize; ++i) { -+ identity = reduce(identity, scratchpad[i]); -+ } -+ -+ workspace[blockIdx.x] = identity; -+ } -+} -+ -+ -+template < -+ typename ComputeType, -+ typename ReduceOp, -+ int kBlockSize = 32 -+> -+__global__ void TensorTransformReduceFinalize( -+ ComputeType *workspace, -+ ComputeType identity, -+ int workspace_size, -+ ReduceOp reduce) { -+ -+ __shared__ ComputeType scratchpad[kBlockSize]; -+ -+ for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) { -+ identity = reduce(identity, workspace[idx]); -+ } -+ -+ scratchpad[threadIdx.x] = identity; -+ -+ __syncthreads(); -+ -+ if (threadIdx.x == 0) { -+ -+ for (int i = 1; i < kBlockSize; ++i) { -+ identity = reduce(identity, scratchpad[i]); -+ } -+ -+ workspace[0] = identity; -+ } -+} -+ -+} // namespace kernel -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Transform-reduce operation over the elements of a tensor -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorView view, /// View of the tensor to reduce over -+ ComputeType identity, /// Identity element of the reduction operation -+ ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType -+ TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType -+ ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] -+ int workspace_size, /// Number of elements in workspace -+ cudaStream_t stream = nullptr, /// CUDA stream to launch into -+ bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. -+) { -+ -+ int const kBlockSize = 128; -+ -+ dim3 block(kBlockSize, 1); -+ dim3 grid(workspace_size, 1); -+ -+ kernel::TensorTransformReducePartial< -+ Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize -+ ><<< grid, block, 0, stream >>>( -+ view, identity, reduce, transform, workspace -+ ); -+ -+ int const kFinalizeBlockSize = 32; -+ -+ kernel::TensorTransformReduceFinalize< -+ ComputeType, ReduceOp, kFinalizeBlockSize -+ ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( -+ workspace, identity, workspace_size, reduce -+ ); -+ -+ if (copy_out) { -+ cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaMemcpy() failed"); -+ } -+ } -+ -+ return identity; -+} -+ -+/// Transform-reduce operation over the elements of two tensors, zipped together -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorView view_A, /// View of the tensor to reduce over -+ TensorView view_B, /// View of the tensor to reduce over -+ ComputeType identity, /// Identity element of the reduction operation -+ ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType -+ TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType -+ ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] -+ int workspace_size, /// Number of elements in workspace -+ cudaStream_t stream = nullptr, /// CUDA stream to launch into -+ bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. -+) { -+ -+ if (view_A.extent() != view_B.extent()) { -+ throw std::runtime_error("Extents must be equal."); -+ } -+ -+ int const kBlockSize = 128; -+ -+ dim3 block(kBlockSize, 1); -+ dim3 grid(workspace_size, 1); -+ -+ kernel::TensorTransformReducePartial< -+ Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize -+ ><<< grid, block, 0, stream >>>( -+ view_A, view_B, identity, reduce, transform, workspace -+ ); -+ -+ int const kFinalizeBlockSize = 32; -+ -+ kernel::TensorTransformReduceFinalize< -+ ComputeType, ReduceOp, kFinalizeBlockSize -+ ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( -+ workspace, identity, workspace_size, reduce -+ ); -+ -+ if (copy_out) { -+ cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaMemcpy() failed"); -+ } -+ } -+ -+ return identity; -+} -+ -+/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -+/// workspace -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorView view, -+ ComputeType identity, -+ ReduceOp reduce, -+ TransformOp transform, -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ // Optionally query for the SM count to size the workspace. -+ if (!workspace_size) { -+ -+ int device_idx = 0; -+ cudaDeviceProp prop; -+ -+ cudaError_t result = cudaGetDevice(&device_idx); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() failed"); -+ } -+ -+ result = cudaGetDeviceProperties(&prop, device_idx); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProp() failed"); -+ } -+ -+ workspace_size = int(prop.multiProcessorCount); -+ } -+ -+ DeviceAllocation workspace(workspace_size); -+ -+ ComputeType output = TensorTransformReduce( -+ view, -+ identity, -+ reduce, -+ transform, -+ workspace.get(), -+ workspace_size, -+ stream, -+ true); -+ -+ return output; -+} -+ -+ -+/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -+/// workspace -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorView view_A, -+ TensorView view_B, -+ ComputeType identity, -+ ReduceOp reduce, -+ TransformOp transform, -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ // Optionally query for the SM count to size the workspace. -+ if (!workspace_size) { -+ -+ int device_idx = 0; -+ cudaDeviceProp prop; -+ -+ cudaError_t result = cudaGetDevice(&device_idx); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() failed"); -+ } -+ -+ result = cudaGetDeviceProperties(&prop, device_idx); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProp() failed"); -+ } -+ -+ workspace_size = int(prop.multiProcessorCount); -+ } -+ -+ DeviceAllocation workspace(workspace_size); -+ -+ ComputeType output = TensorTransformReduce( -+ view_A, -+ view_B, -+ identity, -+ reduce, -+ transform, -+ workspace.get(), -+ workspace_size, -+ stream, -+ true); -+ -+ return output; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to compute the sum of the elements of a tensor -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = Element -+> -+ComputeType TensorSum( -+ TensorView view, -+ ComputeType identity = ComputeType(), -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ plus reduce; -+ NumericConverter transform; -+ -+ return TensorTransformReduce( -+ view, identity, reduce, transform, stream, workspace_size); -+} -+ -+/// Helper to compute the sum of the squares of the elements of a tensor -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = Element -+> -+ComputeType TensorSumSq( -+ TensorView view, -+ ComputeType identity = ComputeType(), -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ plus reduce; -+ magnitude_squared transform; -+ -+ return TensorTransformReduce( -+ view, identity, reduce, transform, stream, workspace_size); -+} -+ -+/// Helper to compute the norm of the elements of a tensor. -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorNorm( -+ TensorView view, -+ ComputeType identity = ComputeType(), -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ return std::sqrt(TensorSumSq(view, identity, stream, workspace_size)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to compute the sum of the squares of the differences of two tensors -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorSumSqDiff( -+ TensorView view_A, -+ TensorView view_B, -+ ComputeType identity = ComputeType(), -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ plus reduce; -+ magnitude_squared_difference transform; -+ -+ return TensorTransformReduce( -+ view_A, view_B, identity, reduce, transform, stream, workspace_size); -+} -+ -+ -+/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorNormDiff( -+ TensorView view_A, -+ TensorView view_B, -+ ComputeType identity = ComputeType(), -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h -new file mode 100644 -index 0000000..c78f1dc ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h -@@ -0,0 +1,141 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines device-side elementwise operations on TensorView. Note, the operations defined -+ in this header are not specialized for any particular data layout and are therefore not -+ intended to offer the best possible performance. Rather, they are intended to be generic -+ reference implementations to support the CUTLASS unit tests. -+*/ -+ -+#pragma once -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_view.h" -+ -+#include "cutlass/util/reference/device/tensor_foreach.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorReLuFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Coordinate in tensor's index space -+ using TensorCoord = typename TensorView::TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element threshold; -+ -+ -+ // -+ // Methods -+ // -+ -+ Params( -+ TensorView view_ = TensorView(), -+ Element threshold_ = Element(0) -+ ): -+ view(view_), threshold(threshold_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ TensorReLuFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ Element const & value = params.view.at(coord); -+ params.view.at(coord) = (value < params.threshold) ? params.threshold : value; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Apply ReLu on a tensor -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorReLu( -+ TensorView view, ///< destination tensor -+ Element threshold = Element(0)) { ///< ReLu threshold -+ -+ using Func = detail::TensorReLuFunc; -+ using Params = typename Func::Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, threshold) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h -new file mode 100644 -index 0000000..094f716 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h -@@ -0,0 +1,186 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GEMM in host-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+namespace thread { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Thread-level blocked general matrix product. -+// -+// Note, this is a reference implementation. Performance is not expected to approach peak. -+// -+template < -+ typename TensorRefA, -+ typename TensorRefB, -+ typename TensorRefC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename OutputTile, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+struct Gemm { -+ -+ using ElementA = typename TensorRefA::Element; -+ using ElementB = typename TensorRefB::Element; -+ using ElementC = typename TensorRefC::Element; -+ -+ // -+ // Data members -+ // -+ -+ /// Tile for A operand -+ ElementA A_tile[OutputTile::kColumn]; -+ -+ /// Tile for B operand -+ ElementB B_tile[OutputTile::kRow]; -+ -+ /// Tile for Accumulator -+ AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow]; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Gemm(AccumulatorType initial_accum = AccumulatorType(0)) { -+ -+ // Clear fetch registers -+ for (int i = 0; i < OutputTile::kColumn; ++i) { -+ A_tile[i] = ElementA(0); -+ } -+ -+ for (int j = 0; j < OutputTile::kColumn; ++j) { -+ B_tile[j] = ElementB(0); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < OutputTile::kColumn; ++j) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < OutputTile::kRow; ++i) { -+ accum[j][i] = initial_accum; -+ } -+ } -+ } -+ -+ /// Computes a matrix product -+ CUTLASS_HOST_DEVICE -+ Gemm & multiply_add( -+ gemm::GemmCoord problem_size, -+ TensorRefA tensor_a, -+ TensorRefB tensor_b, -+ MatrixCoord output_coord = MatrixCoord()) { -+ -+ InnerProductOp inner_product_op; -+ -+ // Loop over the GEMM K dimension -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int k = 0; k < problem_size.k(); ++k) { -+ -+ // Fetch a slice of the A matrix -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < OutputTile::kColumn; ++i) { -+ if (output_coord.row() + i < problem_size.m()) { -+ A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k)); -+ } -+ } -+ -+ // Fetch a slice of the B matrix -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < OutputTile::kRow; ++j) { -+ if (output_coord.column() + j < problem_size.n()) { -+ B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j)); -+ } -+ } -+ -+ // Compute an accumulated matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < OutputTile::kRow; ++j) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < OutputTile::kColumn; ++i) { -+ accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]); -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Performs linear scaling of matrix product and updates output tensor -+ CUTLASS_HOST_DEVICE -+ Gemm & epilogue( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ ScalarType beta, -+ TensorRefC tensor_c, -+ TensorRefC tensor_d, -+ MatrixCoord output_coord = MatrixCoord()) { -+ -+ ConvertOp convert_op; -+ -+ // Update the output tensor -+ for (int j = 0; j < OutputTile::kRow; ++j) { -+ for (int i = 0; i < OutputTile::kColumn; ++i) { -+ MatrixCoord coord = output_coord + MatrixCoord(i, j); -+ if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { -+ -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[j][i]) + -+ beta * ScalarType(tensor_c.at(coord)) -+ ); -+ } -+ } -+ } -+ -+ return *this; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h -new file mode 100644 -index 0000000..4d8a7fc ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h -@@ -0,0 +1,789 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Reference implementation for convolution in host-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/functional.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Forward propagation -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// y = conv2d(x, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv2dFprop( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_x, -+ TensorRef tensor_w, -+ TensorRef tensor_y_in, -+ TensorRef tensor_y_out, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ for (int k = 0; k < problem_size.K; ++k) { -+ -+ int group_idx = k / (problem_size.K / problem_size.groups); -+ int channels_per_group = problem_size.C / problem_size.groups; -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int c = 0; c < channels_per_group; ++c) { -+ -+ int filter_r = r; -+ int filter_s = s; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_r = problem_size.R - 1 - r; -+ filter_s = problem_size.S - 1 - s; -+ } -+ -+ int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; -+ int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; -+ -+ if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { -+ -+ ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group}); -+ ElementB b = tensor_w.at({k, r, s, c}); -+ -+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); -+ -+ } -+ } -+ } -+ } -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_y_in.at(cutlass::make_Coord(n, p, q, k)); -+ } -+ -+ tensor_y_out.at(cutlass::make_Coord(n, p, q, k)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ } -+ } -+ } -+ } -+} -+ -+/// Depthwise-separable convolution -+template , -+ typename InnerProductOp = multiply_add > -+void Depsep_Fprop(cutlass::TensorView tensor_A, -+ cutlass::TensorView tensor_B, -+ cutlass::TensorView tensor_C, -+ cutlass::TensorView tensor_D, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(), -+ cutlass::Coord<2> conv_stride = cutlass::Coord<2>(), -+ cutlass::Coord<2> dilation = cutlass::Coord<2>(), -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int n = 0; n < tensor_C.extent().n(); ++n) { -+ for (int p = 0; p < tensor_C.extent().h(); ++p) { -+ for (int q = 0; q < tensor_C.extent().w(); ++q) { -+ for (int g = 0; g < tensor_C.extent().c(); ++g) { -+ ElementAccumulator acc = ElementAccumulator(); -+ for (int r = 0; r < tensor_B.extent().h(); ++r) { -+ for (int s = 0; s < tensor_B.extent().w(); ++s) { -+ -+ // input activation H and W -+ int h = p * conv_stride[0] - padding[0] + r * dilation[0]; -+ int w = q * conv_stride[1] - padding[2] + s * dilation[1]; -+ -+ if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) { -+ ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g)); -+ -+ ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation) -+ ? tensor_B.at(cutlass::make_Coord(g, r, s, 0)) -+ : tensor_B.at(cutlass::make_Coord( -+ g, tensor_B.extent().h() - r - 1, tensor_B.extent().w() - s - 1, 0)); -+ -+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); -+ } -+ } -+ } -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g)); -+ tensor_D.at(cutlass::make_Coord(n, p, q, g)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Dgrad -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// dx = dgrad(dy, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv2dDgrad( -+ cutlass::conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_w, -+ TensorRef tensor_dx_in, -+ TensorRef tensor_dx_out, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int h = 0; h < problem_size.H; ++h) { -+ for (int w = 0; w < problem_size.W; ++w) { -+ for (int c = 0; c < problem_size.C; ++c) { -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int k = 0; k < problem_size.K; ++k) { -+ -+ int filter_r = r; -+ int filter_s = s; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_r = problem_size.R - 1 - r; -+ filter_s = problem_size.S - 1 - s; -+ } -+ -+ int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; -+ int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; -+ -+ if (p >= 0 && (p % problem_size.stride_h) == 0 && -+ q >= 0 && (q % problem_size.stride_w) == 0) { -+ -+ p = p / problem_size.stride_h; -+ q = q / problem_size.stride_w; -+#if 0 -+ std::cout << "row:" -+ << n * problem_size.H * problem_size.W + -+ h * problem_size.W + -+ w << " " -+ << "n, p, q: (" -+ << n << ", " -+ << p << ", " -+ << q << ") * " -+ << "r, s: (" -+ << r << ", " -+ << s << ") [" -+ << ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]" -+ << std::endl; -+#endif -+ if (p < problem_size.P && q < problem_size.Q) { -+ -+ ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k)); -+ ElementB b = tensor_w.at(cutlass::make_Coord(k, r, s, c)); -+ -+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); -+ } -+ } -+ -+ } // for (K) -+ } // for (S) -+ } // for (R) -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_dx_in.at(cutlass::make_Coord(n, h, w, c)); -+ } -+ -+ tensor_dx_out.at(cutlass::make_Coord(n, h, w, c)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ -+ } // for (C) -+ } // for (W) -+ } // for (H) -+ } // for (N) -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Wgrad -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// dw = wgrad(dy, x) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv2dWgrad( -+ cutlass::conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_x, -+ TensorRef tensor_dw_in, -+ TensorRef tensor_dw_out, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ InnerProductOp inner_product_op; -+ ConvertOp convert_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int k = 0; k < problem_size.K; ++k) { -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int c = 0; c < problem_size.C; ++c) { -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ -+ cutlass::Tensor4DCoord b_coord; -+ -+ int filter_r = r; -+ int filter_s = s; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_r = problem_size.R - 1 - r; -+ filter_s = problem_size.S - 1 - s; -+ } -+ -+ b_coord = make_Coord( -+ n, -+ p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, -+ q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, -+ c); -+ -+ if (b_coord.h() < problem_size.H && b_coord.h() >= 0 && -+ b_coord.w() < problem_size.W && b_coord.w() >= 0) { -+ -+ ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, p, q, k))); -+ ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); -+ acc = inner_product_op(a, b, acc); -+ } -+ } -+ } -+ } -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_dw_in.at(cutlass::make_Coord(k, r, s, c)); -+ } -+ -+ tensor_dw_out.at(cutlass::make_Coord(k, r, s, c)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ -+ } // for (C) -+ } // for (S) -+ } // for (R) -+ } // for (K) -+} -+ -+/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv2d( -+ conv::Operator convolutional_operator, -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_A, -+ TensorRef tensor_B, -+ TensorRef tensor_C, -+ TensorRef tensor_D, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ switch (convolutional_operator) { -+ case conv::Operator::kFprop: -+ Conv2dFprop< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); -+ break; -+ -+ case conv::Operator::kDgrad: -+ Conv2dDgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); -+ break; -+ -+ case conv::Operator::kWgrad: -+ Conv2dWgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); -+ break; -+ -+ default: -+ break; -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// 3D convolution -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// y = conv3d(x, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv3dFprop( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_x, -+ TensorRef tensor_w, -+ TensorRef tensor_y_in, -+ TensorRef tensor_y_out, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int z = 0; z < problem_size.Z; ++z) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ for (int k = 0; k < problem_size.K; ++k) { -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int t = 0; t < problem_size.T; ++t) { -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int c = 0; c < problem_size.C; ++c) { -+ -+ int filter_t = t; -+ int filter_r = r; -+ int filter_s = s; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_t = problem_size.T - 1 - t; -+ filter_r = problem_size.R - 1 - r; -+ filter_s = problem_size.S - 1 - s; -+ } -+ -+ int d = z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; -+ int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; -+ int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; -+ -+ if (d >= 0 && d < problem_size.D && -+ h >=0 && h < problem_size.H && -+ w >= 0 && w < problem_size.W) { -+ -+ ElementA a = tensor_x.at({n, d, h, w, c}); -+ ElementB b = tensor_w.at({k, t, r, s, c}); -+ -+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); -+ } -+ } -+ } -+ } -+ } -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_y_in.at(cutlass::make_Coord(n, z, p, q, k)); -+ } -+ -+ tensor_y_out.at(cutlass::make_Coord(n, z, p, q, k)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Dgrad -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// dx = dgrad(dy, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv3dDgrad( -+ cutlass::conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_w, -+ TensorRef tensor_dx_in, -+ TensorRef tensor_dx_out, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int d = 0; d < problem_size.D; ++d) { -+ for (int h = 0; h < problem_size.H; ++h) { -+ for (int w = 0; w < problem_size.W; ++w) { -+ for (int c = 0; c < problem_size.C; ++c) { -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int t = 0; t < problem_size.T; ++t) { -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int k = 0; k < problem_size.K; ++k) { -+ -+ int filter_t = t; -+ int filter_r = r; -+ int filter_s = s; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_t = problem_size.T - 1 - t; -+ filter_r = problem_size.R - 1 - r; -+ filter_s = problem_size.S - 1 - s; -+ } -+ -+ int z = d + problem_size.pad_d - filter_t * problem_size.dilation_d; -+ int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; -+ int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; -+ -+ if (z >= 0 && (z % problem_size.stride_d) == 0 && -+ p >= 0 && (p % problem_size.stride_h) == 0 && -+ q >= 0 && (q % problem_size.stride_w) == 0) { -+ -+ z = z / problem_size.stride_d; -+ p = p / problem_size.stride_h; -+ q = q / problem_size.stride_w; -+ -+ if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { -+ -+ ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)); -+ ElementB b = tensor_w.at(cutlass::make_Coord(k, t, r, s, c)); -+ -+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); -+ } -+ } -+ -+ } // for (K) -+ } // for (S) -+ } // for (R) -+ } // for (T) -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_dx_in.at(cutlass::make_Coord(n, d, h, w, c)); -+ } -+ -+ tensor_dx_out.at(cutlass::make_Coord(n, d, h, w, c)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ -+ } // for (C) -+ } // for (W) -+ } // for (H) -+ } // for (D) -+ } // for (N) -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Wgrad -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// dw = wgrad(dy, x) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv3dWgrad( -+ cutlass::conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_x, -+ TensorRef tensor_dw_in, -+ TensorRef tensor_dw_out, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ InnerProductOp inner_product_op; -+ ConvertOp convert_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int k = 0; k < problem_size.K; ++k) { -+ for (int t = 0; t < problem_size.T; ++t) { -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int c = 0; c < problem_size.C; ++c) { -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int z = 0; z < problem_size.Z; ++z) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ -+ int filter_t = t; -+ int filter_r = r; -+ int filter_s = s; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_t = problem_size.T - 1 - t; -+ filter_r = problem_size.R - 1 - r; -+ filter_s = problem_size.S - 1 - s; -+ } -+ -+ Tensor5DCoord b_coord = make_Coord( -+ n, -+ z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d, -+ p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, -+ q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, -+ c); -+ -+ if (b_coord.d() < problem_size.D && b_coord.d() >= 0 && -+ b_coord.h() < problem_size.H && b_coord.h() >= 0 && -+ b_coord.w() < problem_size.W && b_coord.w() >= 0) { -+ -+ ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, z, p, q, k))); -+ ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); -+ -+ acc = inner_product_op(a, b, acc); -+ } -+ } -+ } -+ } -+ } -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_dw_in.at(cutlass::make_Coord(k, t, r, s, c)); -+ } -+ -+ tensor_dw_out.at(cutlass::make_Coord(k, t, r, s, c)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ -+ } // for (C) -+ } // for (S) -+ } // for (R) -+ } // for (T) -+ } // for (K) -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Generic 3D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv3d( -+ conv::Operator convolutional_operator, -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_A, -+ TensorRef tensor_B, -+ TensorRef tensor_C, -+ TensorRef tensor_D, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ switch (convolutional_operator) { -+ case conv::Operator::kFprop: -+ Conv3dFprop< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); -+ break; -+ -+ case conv::Operator::kDgrad: -+ Conv3dDgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); -+ break; -+ -+ case conv::Operator::kWgrad: -+ Conv3dWgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); -+ break; -+ -+ default: -+ break; -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h -new file mode 100644 -index 0000000..0b4285c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h -@@ -0,0 +1,66 @@ -+ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/util/reference/host/tensor_reduce.h" -+#include "cutlass/core_io.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorRelativeErrorMetric( -+ TensorView view_A_computed, -+ TensorView view_B_reference, -+ ComputeType identity = ComputeType() -+) { -+ -+ return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) / -+ cutlass::reference::host::TensorNorm(view_B_reference, identity); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h -new file mode 100644 -index 0000000..cd87e6f ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h -@@ -0,0 +1,453 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GEMM in host-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/util/host_tensor.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+template -+struct CastIfScalar { -+ static Out cast(In in) { -+ return Out(in); -+ } -+}; -+ -+template -+struct CastIfScalar, In> { -+ typedef cutlass::complex Out; -+ static Out cast(In in) { -+ return Out(static_cast(in)); -+ } -+}; -+ -+template -+struct CastIfScalar, cutlass::complex> { -+ typedef cutlass::complex Out; -+ typedef cutlass::complex In; -+ static Out cast(In in) { -+ return Out(in); -+ } -+}; -+ -+template -+Out cast_if_scalar(In in) { -+ return CastIfScalar::cast(in); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_gemm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementB b = tensor_b.at(MatrixCoord(k_block, col)); -+ -+ ComputeType compute_a(cast_if_scalar(a)); -+ ComputeType compute_b(cast_if_scalar(b)); -+ -+ accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * ScalarType(tensor_c.at(coord))); -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_gemm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum) { -+ compute_gemm( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, -+ initial_accum); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd -+> -+struct Gemm; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add-saturate -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm, -+ NumericConverterClamp>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm, -+ NumericConverterClamp>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for XOR-popc -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Batched GEMM -+// -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a batch of GEMMs over a set of matrices of common dimension. -+// -+// TensorRefCollection* is a type satisfying the TensorRefCollection concept. -+// -+template < -+ typename TensorRefCollectionA, -+ typename TensorRefCollectionB, -+ typename TensorRefCollectionC, -+ typename ScalarType, -+ typename AccumulatorType -+> -+void BatchedGemm( -+ gemm::GemmCoord problem_size, -+ int batch_count, -+ ScalarType alpha, -+ TensorRefCollectionA const& tensor_a, -+ TensorRefCollectionB const& tensor_b, -+ ScalarType beta, -+ TensorRefCollectionC &tensor_c, -+ AccumulatorType initial_accum) { -+ -+ typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin(); -+ typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin(); -+ typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin(); -+ -+ for (int batch = 0; -+ batch < batch_count; -+ ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) { -+ -+ Gemm -+ gemm; -+ -+ gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it, -+ initial_accum); -+ } -+} -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+// -+// TensorRefCollection* is a type satisfying the TensorRefCollection concept. -+// -+template < -+ typename TensorRefCollectionA, -+ typename TensorRefCollectionB, -+ typename TensorRefCollectionC, -+ typename ScalarType, -+ typename AccumulatorType -+> -+void BatchedGemm( -+ gemm::GemmCoord problem_size, -+ int batch_count, -+ ScalarType alpha, -+ TensorRefCollectionA const& tensor_a, -+ TensorRefCollectionB const& tensor_b, -+ ScalarType beta, -+ TensorRefCollectionC &tensor_c) { -+ -+ BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h -new file mode 100644 -index 0000000..f16e19c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h -@@ -0,0 +1,208 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued GEMM in host-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/matrix_coord.h" -+ -+#include "cutlass/tensor_view.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void GemmComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { -+ -+ // Compute matrix product using blocks -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementB b = tensor_b.at(MatrixCoord(k_block, col)); -+ -+ ComputeType a_ik = ComputeType(a); -+ ComputeType b_kj = ComputeType(b); -+ -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_ik = conj(a_ik); -+ } -+ -+ if (transform_b == ComplexTransform::kConjugate) { -+ b_kj = conj(b_kj); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * ScalarType(tensor_c.at(coord))); -+ } -+ } -+ } -+ -+ } // for (col_block) -+ } // for (row_block) -+ -+ tensor_a.add_pointer_offset(batch_stride_A); -+ tensor_b.add_pointer_offset(batch_stride_B); -+ tensor_c.add_pointer_offset(batch_stride_C); -+ tensor_d.add_pointer_offset(batch_stride_D); -+ -+ } // for (batch_idx) -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void GemmComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d) { -+ -+ GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h -new file mode 100644 -index 0000000..7e94210 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h -@@ -0,0 +1,228 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued GEMM in host-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_ref_planar_complex.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add> -+> -+void GemmPlanarComplex( -+ gemm::GemmCoord problem_size, -+ complex alpha, -+ TensorRefPlanarComplex tensor_a, -+ ComplexTransform transform_a, -+ TensorRefPlanarComplex tensor_b, -+ ComplexTransform transform_b, -+ complex beta, -+ TensorRefPlanarComplex tensor_c, -+ TensorRefPlanarComplex tensor_d, -+ complex initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ using ComplexA = typename TensorRefPlanarComplex::ComplexElement; -+ using ComplexB = typename TensorRefPlanarComplex::ComplexElement; -+ using ComplexC = typename TensorRefPlanarComplex::ComplexElement; -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ complex accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ -+ ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); -+ ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); -+ -+ complex a = complex{ -+ ComputeType(a_ik.real()), -+ ComputeType(a_ik.imag()) -+ }; -+ -+ complex b = complex{ -+ ComputeType(b_kj.real()), -+ ComputeType(b_kj.imag()) -+ }; -+ -+ if (transform_a == ComplexTransform::kConjugate) { -+ a = conj(a); -+ } -+ -+ if (transform_b == ComplexTransform::kConjugate) { -+ b = conj(b); -+ } -+ -+ accum[i][j] = inner_product_op(a, b, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ -+ complex acc{ -+ ScalarType(accum[i][j].real()), -+ ScalarType(accum[i][j].imag()) -+ }; -+ -+ ComplexC d_ij = tensor_c.at(coord); -+ -+ complex src{ -+ ScalarType(d_ij.real()), -+ ScalarType(d_ij.imag()) -+ }; -+ -+ complex result = alpha * acc + beta * src; -+ -+ d_ij.real() = convert_op(result.real()); -+ d_ij.imag() = convert_op(result.imag()); -+ -+ tensor_d.at(coord) = d_ij; -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void GemmPlanarComplex( -+ gemm::GemmCoord problem_size, -+ complex alpha, -+ TensorRefPlanarComplex tensor_a, -+ ComplexTransform transform_a, -+ TensorRefPlanarComplex tensor_b, -+ ComplexTransform transform_b, -+ complex beta, -+ TensorRefPlanarComplex tensor_c, -+ TensorRefPlanarComplex tensor_d) { -+ -+ GemmPlanarComplex( -+ problem_size, -+ alpha, -+ tensor_a, transform_a, -+ tensor_b, transform_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ complex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp -new file mode 100644 -index 0000000..64a0600 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp -@@ -0,0 +1,311 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GETT in host-side code. -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/complex.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cute/tensor.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::reference::host { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template< -+ class ElementAccumulator_, -+ class TensorA_, // (M, K, L) -+ class TensorB_ // (N, K, L) -+> -+struct GettMainloopParams { -+ using ElementAccumulator = ElementAccumulator_; -+ using TensorA = TensorA_; -+ using TensorB = TensorB_; -+ using EngineA = typename TensorA::engine_type; -+ using LayoutA = typename TensorA::layout_type; -+ using EngineB = typename TensorB::engine_type; -+ using LayoutB = typename TensorB::layout_type; -+ -+ TensorA A{}; -+ TensorB B{}; -+ -+ ComplexTransform transform_A = ComplexTransform::kNone; -+ ComplexTransform transform_B = ComplexTransform::kNone; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template< -+ class ElementScalar_, -+ class ElementAccumulator_, -+ class ElementCompute_, -+ class TensorC_, // (M, N, L) -+ class TensorD_ // (M, N, L) -+> -+struct GettEpilogueParams { -+ using ElementScalar = ElementScalar_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ using TensorC = TensorC_; -+ using TensorD = TensorD_; -+ using EngineC = typename TensorC::engine_type; -+ using LayoutC = typename TensorC::layout_type; -+ using EngineD = typename TensorD::engine_type; -+ using LayoutD = typename TensorD::layout_type; -+ ElementScalar alpha = ElementScalar(1); -+ ElementScalar beta = ElementScalar(0); -+ -+ TensorC C{}; -+ TensorD D{}; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GETT - General Tensor-Tensor contraction reference kernel -+template < -+ class MainloopParams, -+ class EpilogueParams -+> -+void Gett( -+ MainloopParams const& mainloop_params, -+ EpilogueParams const& epilogue_params) -+{ -+ -+ static int constexpr kBlockM = 64; -+ static int constexpr kBlockN = 64; -+ -+ #pragma omp parallel for collapse(3) -+ for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { -+ for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { -+ for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { -+ typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; -+ gett_mainloop(mainloop_params, m, n, l, acc); -+ gett_epilogue(epilogue_params, m, n, l, acc); -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GETT - Mainloop -+template -+void gett_mainloop( -+ MainloopParams const& mainloop_params, -+ int64_t m, -+ int64_t n, -+ int64_t l, -+ ElementAccumulator (&acc)[kBlockM][kBlockN]) -+{ -+ -+ static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); -+ static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); -+ -+ using ElementA = typename MainloopParams::EngineA::value_type; -+ using ElementB = typename MainloopParams::EngineB::value_type; -+ -+ using RingOp = multiply_add; -+ RingOp fma_op; -+ -+ // Zero out accumulators -+ for (int m_b = 0; m_b < kBlockM; ++m_b) { -+ for (int n_b = 0; n_b < kBlockN; ++n_b) { -+ acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity -+ } -+ } -+ -+ // Compute on this k-block -+ for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { -+ // Load A -+ ElementAccumulator a_frag[kBlockM]; -+ for (int m_b = 0; m_b < kBlockM; ++m_b) { -+ if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { -+ a_frag[m_b] = static_cast(mainloop_params.A(m + m_b, k, l)); -+ if (mainloop_params.transform_A == ComplexTransform::kConjugate) { -+ a_frag[m_b] = conj(a_frag[m_b]); -+ } -+ } else { -+ a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity -+ } -+ } -+ -+ // Load B -+ ElementAccumulator b_frag[kBlockN]; -+ for (int n_b = 0; n_b < kBlockN; ++n_b) { -+ if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { -+ b_frag[n_b] = static_cast(mainloop_params.B(n + n_b, k, l)); -+ if (mainloop_params.transform_B == ComplexTransform::kConjugate) { -+ b_frag[n_b] = conj(b_frag[n_b]); -+ } -+ } else { -+ b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity -+ } -+ } -+ -+ // do compute -+ for (int m_b = 0; m_b < kBlockM; ++m_b) { -+ for (int n_b = 0; n_b < kBlockN; ++n_b) { -+ acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]); -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GETT - Epilogue -+template -+void gett_epilogue( -+ EpilogueParams const& epilogue_params, -+ int64_t m, -+ int64_t n, -+ int64_t l, -+ ElementAccumulator (&acc)[kBlockM][kBlockN]) -+{ -+ static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); -+ static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); -+ -+ using ElementCompute = typename EpilogueParams::ElementCompute; -+ using ElementC = typename EpilogueParams::EngineC::value_type; -+ -+ using ElementD = typename EpilogueParams::EngineD::value_type; -+ using ElementScalar = typename EpilogueParams::ElementScalar; -+ // Input related converter -+ NumericConverter accumulator_converter; -+ NumericConverter source_converter; -+ -+ // Scale related converter -+ NumericConverter scale_converter; -+ // Output related converter -+ NumericConverter destination_converter; -+ // Epilogue operations -+ multiply_add epilogue_fma; -+ multiplies mul; -+ -+ // Do conversion -+ ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); -+ ElementCompute converted_beta = scale_converter(epilogue_params.beta); -+ for (int n_b = 0; n_b < kBlockN; ++n_b) { -+ for (int m_b = 0; m_b < kBlockM; ++m_b) { -+ if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { -+ // Convert every type to ElementCompute first, do compute, convert to output type, write it out -+ ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); -+ ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); -+ -+ ElementScalar output = epilogue_fma(converted_alpha, converted_acc, ElementCompute(0)); -+ output = epilogue_fma(converted_beta, converted_src, output); -+ -+ epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(output); -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GEMM - General Matrix-Matrix contraction without conjugation options -+template < -+ class MainloopParams, -+ class EpilogueParams -+> -+void Gemm3x( -+ MainloopParams const& mainloop_params, -+ EpilogueParams const& epilogue_params) -+{ -+ using namespace cute; -+ -+ static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename MainloopParams::LayoutB{})); -+ static_assert(rank(typename EpilogueParams::LayoutC{}) == rank(typename EpilogueParams::LayoutD{})); -+ static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename EpilogueParams::LayoutC{})); -+ -+ if constexpr (rank(typename MainloopParams::LayoutA{}) == 2) { -+ // append a batch mode of size 1 if we do not have tensors that are rank 3 -+ Layout layout_A = make_layout( -+ make_shape(get<0>(mainloop_params.A.shape()), get<1>(mainloop_params.A.shape()), Int<1>{}), -+ make_stride(get<0>(mainloop_params.A.stride()), get<1>(mainloop_params.A.stride()), int64_t(cosize(mainloop_params.A.layout())))); -+ -+ Layout layout_B = make_layout( -+ make_shape(get<0>(mainloop_params.B.shape()), get<1>(mainloop_params.B.shape()), Int<1>{}), -+ make_stride(get<0>(mainloop_params.B.stride()), get<1>(mainloop_params.B.stride()), int64_t(cosize(mainloop_params.B.layout())))); -+ -+ Layout layout_C = make_layout( -+ make_shape(get<0>(epilogue_params.C.shape()), get<1>(epilogue_params.C.shape()), Int<1>{}), -+ make_stride(get<0>(epilogue_params.C.stride()), get<1>(epilogue_params.C.stride()), int64_t(cosize(epilogue_params.C.layout())))); -+ -+ Layout layout_D = make_layout( -+ make_shape(get<0>(epilogue_params.D.shape()), get<1>(epilogue_params.D.shape()), Int<1>{}), -+ make_stride(get<0>(epilogue_params.D.stride()), get<1>(epilogue_params.D.stride()), int64_t(cosize(epilogue_params.D.layout())))); -+ auto TensorA = make_tensor(mainloop_params.A.data(), layout_A); -+ auto TensorB = make_tensor(mainloop_params.B.data(), layout_B); -+ auto TensorC = make_tensor(epilogue_params.C.data(), layout_C); -+ auto TensorD = make_tensor(epilogue_params.D.data(), layout_D); -+ // Reconstruct mainloop params -+ GettMainloopParams -+ mainloop_params_converted{TensorA, -+ TensorB, -+ mainloop_params.transform_A, -+ mainloop_params.transform_B}; -+ -+ // Reconstruct epilogue params -+ GettEpilogueParams -+ epilogue_params_converted{epilogue_params.alpha, -+ epilogue_params.beta, -+ TensorC, -+ TensorD -+ }; -+ -+ Gett(mainloop_params_converted, epilogue_params_converted); -+ } -+ else { -+ // if we already have a batch mode, just pass it through -+ Gett(mainloop_params, epilogue_params); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // cutlass::reference::host -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h -new file mode 100644 -index 0000000..5b34260 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h -@@ -0,0 +1,261 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for Rank 2k update in host-side code. -+ -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ FillMode FillModeC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_rank2k( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ static_assert( -+ FillModeC == FillMode::kLower || -+ FillModeC == FillMode::kUpper, -+ "Fill Mode can either be Lower or Upper."); -+ -+ using CompareOp = typename platform::conditional<(FillModeC == FillMode::kLower), -+ std::greater_equal, -+ std::less_equal>::type; -+ -+ // Note: batch is ignored. -+ // Note: M is same as N for Rank 2k update -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ CompareOp compare_op; -+ -+ for (int row_block = 0; row_block < N; row_block += Nblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Nblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Nblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Nblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < N && col < N && compare_op(row, col)) -+ { -+ -+ // A x B^T -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementB b_t = tensor_b.at(MatrixCoord(col, k_block)); -+ -+ ComputeType compute_a(cast_if_scalar(a)); -+ ComputeType compute_b_t(cast_if_scalar(b_t)); -+ -+ accum[i][j] = inner_product_op(compute_a, compute_b_t, accum[i][j]); -+ -+ // B x A^T -+ ElementB b = tensor_b.at(MatrixCoord(row, k_block)); -+ ElementA a_t = tensor_a.at(MatrixCoord(col, k_block)); -+ -+ ComputeType compute_b(cast_if_scalar(b)); -+ ComputeType compute_a_t(cast_if_scalar(a_t)); -+ -+ accum[i][j] = inner_product_op(compute_b, compute_a_t, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Nblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < N && col < N && -+ ( (FillModeC == FillMode::kLower && row >= col) || -+ (FillModeC == FillMode::kUpper && row <= col) ) -+ ) { -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * ScalarType(tensor_c.at(coord))); -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general Rank 2k update (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ FillMode FillModeC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_rank2k( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum) { -+ compute_rank2k( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, -+ initial_accum); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ FillMode FillModeC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd -+> -+struct Rank2K; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct Rank2K { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_rank2k>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_rank2k>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h -new file mode 100644 -index 0000000..519379c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h -@@ -0,0 +1,318 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued Rank 2K update in host-side code. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Rank2KComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ FillMode fill_mode_c, -+ BlasMode blas_mode, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ // Rank2K update operates on A=NxK, B=NxK, and C=NxN -+ assert(M==N); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { -+ -+ // Compute matrix product using blocks -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N && -+ ( (fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col) ) -+ ) { -+ -+ // A x B^T (Symmetric) or A x B^H (Hermitian) -+ // complex conjugation on operandB (b_t) is function of blas3 computation -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementB b_t = (blas_mode == BlasMode::kHermitian) ? -+ conj(tensor_b.at(MatrixCoord(col, k_block))) : -+ tensor_b.at(MatrixCoord(col, k_block)); -+ -+ ComputeType a_ik = ComputeType(a); -+ ComputeType b_jk = ComputeType(b_t); -+ -+ // complex conjugation is a function of operand layouts -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_ik = conj(a_ik); -+ } -+ // complex conjugation is a function of operand layouts -+ if (transform_b == ComplexTransform::kConjugate) { -+ b_jk = conj(b_jk); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ /* HER2K need two epilogues to handle complex alpha value */ -+ if ( blas_mode == BlasMode::kHermitian ) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N && -+ ((fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col)) -+ ) { -+ -+ ScalarType c = tensor_c.at(coord); -+ // The imaginary parts of the diagonal elements of -+ // a complex data type are assumed and set to zero -+ if (blas_mode == BlasMode::kHermitian) { -+ c = (row == col) ? real(c) : c; -+ } -+ -+ tensor_d.at(coord) = convert_op(alpha * -+ ScalarType(accum[i][j]) + -+ beta * c); -+ } -+ } -+ } -+ -+ /* Zeoring out accum for second HERK */ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N && -+ ( (fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col) ) -+ ) { -+ -+ // B x A^T (Symmetric) or B x A^H (Hermitian) -+ // complex conjugation on operandB (a_t) is function of blas3 computation -+ ElementB b = tensor_b.at(MatrixCoord(row, k_block)); -+ ElementA a_t = (blas_mode == BlasMode::kHermitian) ? -+ conj(tensor_a.at(MatrixCoord(col, k_block))): -+ tensor_a.at(MatrixCoord(col, k_block)); -+ -+ ComputeType b_ik = ComputeType(b); -+ ComputeType a_jk = ComputeType(a_t); -+ -+ // complex conjugation here is a function of operand layouts -+ if (transform_b == ComplexTransform::kConjugate) { -+ b_ik = conj(b_ik); -+ } -+ // complex conjugation here is a function of operand layouts -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_jk = conj(a_jk); -+ } -+ -+ accum[i][j] = inner_product_op(b_ik, a_jk, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ ScalarType alpha_hermitian = (blas_mode == BlasMode::kHermitian) ? -+ conj(alpha) : alpha; -+ ScalarType beta_hermitian = (blas_mode == BlasMode::kHermitian) ? -+ 1 : beta; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N && -+ ((fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col)) -+ ) { -+ -+ ScalarType d = (blas_mode == BlasMode::kHermitian) ? -+ tensor_d.at(coord) : tensor_c.at(coord); -+ -+ ScalarType tmp_d = convert_op( -+ alpha_hermitian * ScalarType(accum[i][j]) + -+ beta_hermitian * d); -+ -+ if (blas_mode == BlasMode::kHermitian && row == col ) { -+ tensor_d.at(coord) = real(tmp_d); -+ } else { -+ tensor_d.at(coord) = tmp_d; -+ } -+ } -+ } -+ } -+ -+ } // for (col_block) -+ } // for (row_block) -+ -+ tensor_a.add_pointer_offset(batch_stride_A); -+ tensor_b.add_pointer_offset(batch_stride_B); -+ tensor_c.add_pointer_offset(batch_stride_C); -+ tensor_d.add_pointer_offset(batch_stride_D); -+ -+ } // for (batch_idx) -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void Rank2KComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ FillMode fill_mode_c, -+ BlasMode blas_mode) { -+ -+ Rank2KComplex( -+ problem_size, alpha, -+ tensor_a, transform_a, -+ tensor_b, transform_b, -+ beta, tensor_c, tensor_d, -+ ScalarType(0), -+ fill_mode_c, -+ blas_mode); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h -new file mode 100644 -index 0000000..d5f3f2e ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h -@@ -0,0 +1,234 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued Rank 2K update in host-side code. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Rank2KComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ FillMode fill_mode_c, -+ BlasMode blas_mode, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ // Rank2K update operates on A=NxK, B=NxK, and C=NxN -+ assert(M==N); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { -+ -+ // Compute matrix product using blocks -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N && -+ ( (fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col) ) -+ ) { -+ -+ // A x A^T (Symmetric) or A x A^H (Hermitian) -+ // complex conjugation on operandB (a_t) (function of blas3 computation) -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementA a_t = (blas_mode == BlasMode::kHermitian) ? -+ conj(tensor_a.at(MatrixCoord(col, k_block))) : -+ tensor_a.at(MatrixCoord(col, k_block)); -+ -+ ComputeType a_ik = ComputeType(a); -+ ComputeType b_jk = ComputeType(a_t); -+ -+ // complex conjugation (function of input layouts) -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_ik = conj(a_ik); -+ } -+ // complex conjugation (function of input layouts) -+ if (transform_a == ComplexTransform::kConjugate) { -+ b_jk = conj(b_jk); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); -+ -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N && -+ ((fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col)) -+ ) { -+ -+ ScalarType c = tensor_c.at(coord); -+ // The imaginary parts of the diagonal elements of -+ // a complex data type are assumed and set to zero -+ if (blas_mode == BlasMode::kHermitian) { -+ c = (row == col) ? real(c) : c; -+ } -+ -+ ScalarType tmp_d = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * c); -+ -+ if (blas_mode == BlasMode::kHermitian && row == col ) { -+ tensor_d.at(coord) = real(tmp_d); -+ } else { -+ tensor_d.at(coord) = tmp_d; -+ } -+ } -+ } -+ } -+ -+ } // for (col_block) -+ } // for (row_block) -+ -+ tensor_a.add_pointer_offset(batch_stride_A); -+ tensor_c.add_pointer_offset(batch_stride_C); -+ tensor_d.add_pointer_offset(batch_stride_D); -+ -+ } // for (batch_idx) -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void RankKComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ FillMode fill_mode_c, -+ BlasMode blas_mode) { -+ -+ Rank2KComplex( -+ problem_size, alpha, -+ tensor_a, transform_a, -+ beta, tensor_c, tensor_d, -+ ScalarType(0), -+ fill_mode_c, -+ blas_mode); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/symm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/symm.h -new file mode 100644 -index 0000000..736107a ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/symm.h -@@ -0,0 +1,285 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for SYMM update in host-side code. -+ -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_symm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ static_assert(SideModeA != SideMode::kInvalid -+ , "Side Mode can either be Left or Right."); -+ -+ static_assert( -+ FillModeA == FillMode::kLower || -+ FillModeA == FillMode::kUpper, -+ "Fill Mode can either be Lower or Upper."); -+ -+ using CompareOp_w_diag = typename TrMatrixCompareOp::Type; -+ using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ // Assuming correct k-dimension value is passed -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ CompareOp_w_diag compare_op_1; -+ CompareOp_wo_diag compare_op_2; -+ -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ ElementA a_1 = ElementA(); -+ ElementB b_1 = ElementB(); -+ ElementA a_2 = ElementA(); -+ ElementB b_2 = ElementB(); -+ -+ // A x B or B x A (with diagonal) -+ if (SideModeA == SideMode::kLeft) { -+ a_1 = (compare_op_1(row, k_block)) ? -+ (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); -+ b_1 = tensor_b.at(MatrixCoord(k_block, col)); -+ } else if (SideModeA == SideMode::kRight) { -+ a_1 = tensor_b.at(MatrixCoord(row, k_block)); -+ b_1 = (compare_op_1(k_block, col)) ? -+ tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); -+ } -+ -+ ComputeType compute_a_1(cast_if_scalar(a_1)); -+ ComputeType compute_b_1(cast_if_scalar(b_1)); -+ -+ accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); -+ -+ // A^T x B or B x A^T (without diagonal) -+ if (SideModeA == SideMode::kLeft) { -+ a_2 = (compare_op_2(k_block, row)) ? -+ (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); -+ b_2 = tensor_b.at(MatrixCoord(k_block, col)); -+ } else if (SideModeA == SideMode::kRight) { -+ a_2 = tensor_b.at(MatrixCoord(row, k_block)); -+ b_2 = (compare_op_2(col, k_block)) ? -+ tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); -+ } -+ -+ ComputeType compute_a_2(cast_if_scalar(a_2)); -+ ComputeType compute_b_2(cast_if_scalar(b_2)); -+ -+ accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * ScalarType(tensor_c.at(coord))); -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general Symm update (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_symm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum) { -+ compute_symm( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, -+ initial_accum); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd -+> -+struct Symm; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct Symm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_symm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_symm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h -new file mode 100644 -index 0000000..aa46891 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h -@@ -0,0 +1,319 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued SYMM update in host-side code. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ BlasMode BlasMode_ = BlasMode::kSymmetric, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_symm_complex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static SideMode const kSideModeA = SideModeA; -+ static FillMode const kFillModeA = FillModeA; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ static_assert(kSideModeA != SideMode::kInvalid -+ , "Side Mode can either be Left or Right."); -+ -+ static_assert( -+ kFillModeA == FillMode::kLower || -+ kFillModeA == FillMode::kUpper, -+ "Fill Mode can either be Lower or Upper."); -+ -+ using CompareOp_w_diag = typename TrMatrixCompareOp::Type; -+ using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ // Assuming correct k-dimension value is passed -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ CompareOp_w_diag compare_op_1; -+ CompareOp_wo_diag compare_op_2; -+ -+ for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { -+ -+ // Compute matrix product using blocks -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) -+ { -+ ElementA a_1 = ElementA(); -+ ElementB b_1 = ElementB(); -+ ElementA a_2 = ElementA(); -+ ElementB b_2 = ElementB(); -+ -+ // A x B or B x A (with diagonal) -+ if (kSideModeA == SideMode::kLeft) { -+ a_1 = (compare_op_1(row, k_block)) ? -+ (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); -+ b_1 = tensor_b.at(MatrixCoord(k_block, col)); -+ } else if (kSideModeA == SideMode::kRight) { -+ a_1 = tensor_b.at(MatrixCoord(row, k_block)); -+ b_1 = (compare_op_1(k_block, col)) ? -+ tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); -+ } -+ ComputeType compute_a_1 = ComputeType(a_1); -+ ComputeType compute_b_1 = ComputeType(b_1); -+ -+ // The imaginary parts of the diagonal elements of -+ // a complex data type are assumed and set to zero -+ if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kLeft && row == k_block) { -+ compute_a_1 = real(compute_a_1); -+ } else if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kRight && k_block == col) { -+ compute_b_1 = real(compute_b_1); -+ } -+ -+ accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); -+ -+ // A^T x B or B x A^T (without diagonal) -+ if (kSideModeA == SideMode::kLeft) { -+ a_2 = (compare_op_2(k_block, row)) ? -+ (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); -+ b_2 = tensor_b.at(MatrixCoord(k_block, col)); -+ if (kBlasMode == BlasMode::kHermitian) -+ a_2 = conj(a_2); -+ } else if (kSideModeA == SideMode::kRight) { -+ a_2 = tensor_b.at(MatrixCoord(row, k_block)); -+ b_2 = (compare_op_2(col, k_block)) ? -+ tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); -+ if (kBlasMode == BlasMode::kHermitian) -+ b_2 = conj(b_2); -+ } -+ -+ ComputeType compute_a_2 = ComputeType(a_2); -+ ComputeType compute_b_2 = ComputeType(b_2); -+ -+ accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ -+ ScalarType c = tensor_c.at(coord); -+ -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * c); -+ } -+ } -+ } -+ -+ } // for (col_block) -+ } // for (row_block) -+ -+ tensor_a.add_pointer_offset(batch_stride_A); -+ tensor_b.add_pointer_offset(batch_stride_B); -+ tensor_c.add_pointer_offset(batch_stride_C); -+ tensor_d.add_pointer_offset(batch_stride_D); -+ -+ } // for (batch_idx) -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ BlasMode BlasMode_ = cutlass::BlasMode::kSymmetric, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex -+> -+struct SymmComplex; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct SymmComplex { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_symm_complex>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for gaussian multiply-add -+template -+struct SymmComplex { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_symm_complex>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h -new file mode 100644 -index 0000000..f9a362e ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h -@@ -0,0 +1,305 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines host-side elementwise operations on TensorView. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/tensor_view_planar_complex.h" -+ -+#include "cutlass/util/distribution.h" -+//#include "cutlass/util/type_traits.h" -+#include "tensor_foreach.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorEqualsFunc { -+ -+ // -+ // Data members -+ // -+ -+ TensorView lhs; -+ TensorView rhs; -+ bool result; -+ -+ /// Ctor -+ TensorEqualsFunc(): result(true) { } -+ -+ /// Ctor -+ TensorEqualsFunc( -+ TensorView const &lhs_, -+ TensorView const &rhs_ -+ ) : -+ lhs(lhs_), rhs(rhs_), result(true) { } -+ -+ /// Visits a coordinate -+ void operator()(Coord const &coord) { -+ -+ Element lhs_ = lhs.at(coord); -+ Element rhs_ = rhs.at(coord); -+ -+ if (lhs_ != rhs_) { -+ result = false; -+ } -+ } -+ -+ /// Returns true if equal -+ operator bool() const { -+ return result; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if two tensor views are equal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+bool TensorEquals( -+ TensorView const &lhs, -+ TensorView const &rhs) { -+ -+ // Extents must be identical -+ if (lhs.extent() != rhs.extent()) { -+ return false; -+ } -+ -+ detail::TensorEqualsFunc func(lhs, rhs); -+ TensorForEach( -+ lhs.extent(), -+ func -+ ); -+ -+ return bool(func); -+} -+ -+/// Returns true if two tensor views are equal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+bool TensorEquals( -+ TensorViewPlanarComplex const &lhs, -+ TensorViewPlanarComplex const &rhs) { -+ -+ // Extents must be identical -+ if (lhs.extent() != rhs.extent()) { -+ return false; -+ } -+ -+ detail::TensorEqualsFunc real_func( -+ {lhs.data(), lhs.layout(), lhs.extent()}, -+ {rhs.data(), rhs.layout(), rhs.extent()} -+ ); -+ -+ TensorForEach( -+ lhs.extent(), -+ real_func -+ ); -+ -+ if (!bool(real_func)) { -+ return false; -+ } -+ -+ detail::TensorEqualsFunc imag_func( -+ {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, -+ {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()} -+ ); -+ -+ TensorForEach( -+ lhs.extent(), -+ imag_func -+ ); -+ -+ return bool(imag_func); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if two tensor views are NOT equal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+bool TensorNotEquals( -+ TensorView const &lhs, -+ TensorView const &rhs) { -+ -+ // Extents must be identical -+ if (lhs.extent() != rhs.extent()) { -+ return true; -+ } -+ -+ detail::TensorEqualsFunc func(lhs, rhs); -+ TensorForEach( -+ lhs.extent(), -+ func -+ ); -+ -+ return !bool(func); -+} -+ -+/// Returns true if two tensor views are equal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+bool TensorNotEquals( -+ TensorViewPlanarComplex const &lhs, -+ TensorViewPlanarComplex const &rhs) { -+ -+ return !TensorEquals(lhs, rhs); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorContainsFunc { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element value; -+ bool contains; -+ Coord location; -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ TensorContainsFunc(): contains(false) { } -+ -+ /// Ctor -+ TensorContainsFunc( -+ TensorView const &view_, -+ Element value_ -+ ) : -+ view(view_), value(value_), contains(false) { } -+ -+ /// Visits a coordinate -+ void operator()(Coord const &coord) { -+ -+ if (view.at(coord) == value) { -+ if (!contains) { -+ location = coord; -+ } -+ contains = true; -+ } -+ } -+ -+ /// Returns true if equal -+ operator bool() const { -+ return contains; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if a value is present in a tensor -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+bool TensorContains( -+ TensorView const & view, -+ Element value) { -+ -+ detail::TensorContainsFunc func( -+ view, -+ value -+ ); -+ -+ TensorForEach( -+ view.extent(), -+ func -+ ); -+ -+ return bool(func); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns a pair containing a boolean of whether a value exists in a tensor and the location of -+/// of the first occurrence. If the value is not contained in the tensor, the second element of the -+/// pair is undefined. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+std::pair > TensorFind( -+ TensorView const & view, -+ Element value) { -+ -+ detail::TensorContainsFunc func( -+ view, -+ value -+ ); -+ -+ TensorForEach( -+ view.extent(), -+ func -+ ); -+ -+ return std::make_pair(bool(func), func.location); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp -new file mode 100644 -index 0000000..a4a5b4e ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp -@@ -0,0 +1,101 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Provides several functions for filling tensors with data. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Cute includes -+#include "cute/tensor.hpp" -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if two tensor views are equal. -+template < -+ typename TensorL, -+ typename TensorR -+> -+bool TensorEquals( -+ TensorL lhs, -+ TensorR rhs) { -+ -+ // Extents must be identical -+ if (cute::size(lhs) != cute::size(rhs)) { -+ return false; -+ } -+ -+ for (int64_t idx = 0; idx < cute::size(lhs); ++idx) { -+ if (lhs(idx) != rhs(idx)) { -+ return false; -+ } -+ } -+ -+ return true; -+} -+ -+/// Returns true if two tensor views are NOT equal. -+template < -+ typename TensorL, -+ typename TensorR -+> -+bool TensorNotEquals( -+ TensorL lhs, -+ TensorR rhs) { -+ -+ return TensorEquals(lhs, rhs); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h -new file mode 100644 -index 0000000..053511c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h -@@ -0,0 +1,256 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines host-side elementwise operations on TensorView. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "tensor_foreach.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Helper to convert between types -+template < -+ typename DstElement, -+ typename SrcElement -+> -+struct TrivialConvert { -+ -+ TrivialConvert() { } -+ -+ DstElement operator()(SrcElement src) const { -+ return DstElement(src); -+ } -+}; -+ -+/// Helper to conditionally copy between tensor views. -+template < -+ typename DstElement, -+ typename DstLayout, -+ typename SrcElement, -+ typename SrcLayout, -+ typename F -+> -+struct TensorCopyIf { -+ -+ using DstTensorView = TensorView; -+ using SrcTensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ DstTensorView dst; -+ SrcTensorView src; -+ F convert; -+ -+ // -+ // Methods -+ // -+ -+ TensorCopyIf() { } -+ -+ TensorCopyIf( -+ DstTensorView const &dst_, -+ SrcTensorView const &src_, -+ F const &convert_): dst(dst_), src(src_), convert(convert_) {} -+ -+ /// Copies based on destination and source bounds -+ void operator()(Coord const &coord) { -+ if (dst.contains(coord) && src.contains(coord)) { -+ dst.at(coord) = convert(src.at(coord)); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies elements from one tensor view into another, satisfying bounds of each tensor. -+template < -+ typename DstElement, /// Destination tensor's element type -+ typename DstLayout, /// Destination tensor's layout -+ typename SrcElement, /// Source tensor's element type -+ typename SrcLayout, /// Source tensor's layout -+ typename F /// Transformation functor -+> -+void TensorCopy( -+ TensorView dst, -+ TensorView src, -+ F const &transform) { -+ -+ using CopyIf = detail::TensorCopyIf< -+ DstElement, -+ DstLayout, -+ SrcElement, -+ SrcLayout, -+ F>; -+ -+ CopyIf copy_if(dst, src, transform); -+ -+ TensorForEach(dst.extent(), copy_if); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent -+/// to avoid out of bounds accesses. -+template < -+ typename DstElement, /// Destination tensor's element type -+ typename DstLayout, /// Destination tensor's layout -+ typename SrcElement, /// Source tensor's element type -+ typename SrcLayout, /// Source tensor's layout -+ typename F /// Transformation functor -+> -+void TensorCopy( -+ TensorView dst, -+ TensorRef src, -+ F const &transform) { -+ -+ using CopyIf = detail::TensorCopyIf< -+ DstElement, -+ DstLayout, -+ SrcElement, -+ SrcLayout, -+ F>; -+ -+ TensorView src_view(src, dst.extent()); -+ -+ CopyIf copy_if(dst, src_view, transform); -+ -+ TensorForEach(dst.extent(), copy_if); -+} -+ -+/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent -+/// to avoid out of bounds accesses. -+template < -+ typename DstElement, /// Destination tensor's element type -+ typename DstLayout, /// Destination tensor's layout -+ typename SrcElement, /// Source tensor's element type -+ typename SrcLayout, /// Source tensor's layout -+ typename F /// Transformation functor -+> -+void TensorCopy( -+ TensorRef dst, -+ TensorView src, -+ F const &transform) { -+ -+ using CopyIf = detail::TensorCopyIf< -+ DstElement, -+ DstLayout, -+ SrcElement, -+ SrcLayout, -+ F>; -+ -+ TensorView dst_view(dst, src.extent()); -+ -+ CopyIf copy_if(dst_view, src, transform); -+ -+ TensorForEach(src.extent(), copy_if); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds -+/// if SrcElement can be converted to DstElement. -+template < -+ typename DstElement, /// Destination tensor's element type -+ typename DstLayout, /// Destination tensor's layout -+ typename SrcElement, /// Source tensor's element type -+ typename SrcLayout /// Source tensor's layout -+> -+void TensorCopy( -+ TensorView dst, -+ TensorView src) { -+ -+ detail::TrivialConvert convert; -+ -+ TensorCopy(dst, src, convert); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds -+/// if SrcElement can be converted to DstElement. -+template < -+ typename DstElement, /// Destination tensor's element type -+ typename DstLayout, /// Destination tensor's layout -+ typename SrcElement, /// Source tensor's element type -+ typename SrcLayout, /// Source tensor's layout -+ typename F /// Transformation functor -+> -+void TensorCopy( -+ TensorView dst, -+ TensorRef src) { -+ -+ detail::TrivialConvert convert; -+ -+ TensorCopy(dst, src, convert); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds -+/// if SrcElement can be converted to DstElement. -+template < -+ typename DstElement, /// Destination tensor's element type -+ typename DstLayout, /// Destination tensor's layout -+ typename SrcElement, /// Source tensor's element type -+ typename SrcLayout /// Source tensor's layout -+> -+void TensorCopy( -+ TensorRef dst, -+ TensorView src) { -+ -+ detail::TrivialConvert convert; -+ -+ TensorCopy(dst, src, convert); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h -new file mode 100644 -index 0000000..72f5f24 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h -@@ -0,0 +1,341 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines host-side elementwise operations on TensorView. -+*/ -+ -+#pragma once -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+ -+#include "tensor_foreach.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to apply a binary operator in place -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementD, -+ typename LayoutD, -+ typename BinaryFunc> -+struct TensorFuncBinaryOp { -+ -+ // -+ // Data members -+ // -+ -+ /// View of left-hand-side tensor -+ TensorView view_d; -+ TensorRef view_a; -+ TensorRef view_b; -+ BinaryFunc func; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ TensorFuncBinaryOp() { } -+ -+ /// Constructor -+ TensorFuncBinaryOp( -+ TensorView const & view_d_, -+ TensorRef const & view_a_, -+ TensorRef const & view_b_, -+ BinaryFunc func = BinaryFunc() -+ ): -+ view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { } -+ -+ /// Equality check -+ void operator()(Coord const &coord) const { -+ view_d.at(coord) = func( -+ ElementD(view_a.at(coord)), -+ ElementD(view_b.at(coord)) -+ ); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Adds two tensors and stores in the destination tensor: d = a + b -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB -+> -+void TensorAdd( -+ TensorView d, ///< destination tensor view -+ TensorRef a, ///< A tensor reference -+ TensorRef b ///< B tensor reference -+) { -+ -+ detail::TensorFuncBinaryOp< -+ ElementD, -+ LayoutD, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ cutlass::plus -+ > func(d, a, b); -+ -+ TensorForEach( -+ d.extent(), -+ func); -+} -+ -+/// Adds a tensor in place: d = d .+ a -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA -+> -+void TensorAdd( -+ TensorView d, ///< destination tensor view -+ TensorRef a ///< A tensor reference -+) { -+ TensorAdd(d, d, a); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Subtracts two tensors and stores in the destination tensor: d = a - b -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB -+> -+void TensorSub( -+ TensorView d, ///< destination tensor view -+ TensorRef a, ///< A tensor reference -+ TensorRef b ///< B tensor reference -+ ) { -+ -+ detail::TensorFuncBinaryOp< -+ ElementD, -+ LayoutD, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ cutlass::minus -+ > func(d, a, b); -+ -+ TensorForEach( -+ d.extent(), -+ func); -+} -+ -+/// Subtracts two tensors in place: d = d .- a -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB -+> -+void TensorSub( -+ TensorView d, ///< destination tensor view -+ TensorRef a ///< A tensor reference -+ ) { -+ -+ TensorSub(d, d, a); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Multiplies two tensors and stores in the destination tensor: d = a .* b -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB -+> -+void TensorMul( -+ TensorView d, ///< destination tensor view -+ TensorRef a, ///< A tensor reference -+ TensorRef b ///< B tensor reference -+) { -+ -+ detail::TensorFuncBinaryOp< -+ ElementD, -+ LayoutD, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ cutlass::multiplies -+ > func(d, a, b); -+ -+ TensorForEach( -+ d.extent(), -+ func); -+} -+ -+/// Multiplies tensors in place: d = d .* a -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA -+> -+void TensorMul( -+ TensorView d, ///< destination tensor view -+ TensorRef a ///< A tensor reference -+) { -+ TensorMul(d, d, a); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Divides two tensors and stores in the destination tensor: d = a ./ b -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB -+> -+void TensorDiv( -+ TensorView d, ///< destination tensor view -+ TensorRef a, ///< A tensor reference -+ TensorRef b ///< B tensor reference -+) { -+ -+ detail::TensorFuncBinaryOp< -+ ElementD, -+ LayoutD, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ cutlass::divides -+ > func(d, a, b); -+ -+ TensorForEach( -+ d.extent(), -+ func); -+} -+ -+/// Divides tensors in place: d = d ./ a -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA -+> -+void TensorDiv( -+ TensorView d, ///< destination tensor view -+ TensorRef a ///< A tensor reference -+) { -+ TensorDiv(d, d, a); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Divides two tensors and stores in the destination tensor: d = a ./ b -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB -+> -+void TensorModulus( -+ TensorView d, ///< destination tensor view -+ TensorRef a, ///< A tensor reference -+ TensorRef b ///< B tensor reference -+) { -+ -+ detail::TensorFuncBinaryOp< -+ ElementD, -+ LayoutD, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ cutlass::divides -+ > func(d, a, b); -+ -+ TensorForEach( -+ d.extent(), -+ func); -+} -+ -+/// Divides tensors in place: d = d ./ a -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA -+> -+void TensorModulus( -+ TensorView d, ///< destination tensor view -+ TensorRef a ///< A tensor reference -+) { -+ TensorDiv(d, d, a); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h -new file mode 100644 -index 0000000..a8b938d ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h -@@ -0,0 +1,1468 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Provides several functions for filling tensors with data. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/subbyte_reference.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/tensor_view_planar_complex.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/util/distribution.h" -+#include "tensor_foreach.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element value; -+ -+ // -+ // Methods -+ // -+ -+ TensorFillFunc( -+ TensorView const &view_ = TensorView(), -+ Element value_ = Element(0) -+ ): view(view_), value(value_) { } -+ -+ void operator()(Coord const & coord) const { -+ view.at(coord) = value; -+ } -+}; -+ -+/// Returns a pair of values of the Gaussian distribution generated by the Box Muller method -+struct BoxMullerFunc { -+ -+ BoxMullerFunc() {} -+ -+ void operator()( -+ double* rnd, ///< Size-2 vector to be filled with random values -+ double mean = 0, ///< Mean of the Gaussian distribution -+ double stddev = 1, ///< Standard deviation of the Gaussian distribution -+ double pi = std::acos(-1)) const { -+ -+ double u1 = double(std::rand()) / double(RAND_MAX); -+ double u2 = double(std::rand()) / double(RAND_MAX); -+ rnd[0] = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); -+ rnd[1] = std::sqrt(-2 * std::log(u1)) * std::sin(2 * pi * u2); -+ rnd[0] = mean + stddev * rnd[0]; -+ rnd[1] = mean + stddev * rnd[1]; -+ } -+}; -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with a uniform value -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFill( -+ TensorView dst, ///< destination tensor -+ Element val = Element(0)) { ///< value to uniformly fill it with -+ -+ detail::TensorFillFunc func(dst, val); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/// Fills a tensor with a uniform value -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFill( -+ TensorViewPlanarComplex dst, ///< destination tensor -+ cutlass::complex val = cutlass::complex(0)) { ///< value to uniformly fill it with -+ -+ TensorFill(dst.view_real(), val.real()); -+ TensorFill(dst.view_imag(), val.imag()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct RandomGaussianFunc { -+ -+ uint64_t seed; -+ double mean; -+ double stddev; -+ int int_scale; -+ double pi; -+ -+ // -+ // Methods -+ // -+ RandomGaussianFunc( -+ uint64_t seed_ = 0, -+ double mean_ = 0, -+ double stddev_ = 1, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { -+ std::srand((unsigned)seed); -+ } -+ -+ /// Compute random value and update RNG state -+ Element operator()() const { -+ -+ // Box-Muller transform to generate random numbers with Normal distribution -+ double u1 = double(std::rand()) / double(RAND_MAX); -+ double u2 = double(std::rand()) / double(RAND_MAX); -+ -+ // Compute Gaussian random value -+ double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); -+ rnd = mean + stddev * rnd; -+ -+ // Scale and convert final result -+ Element result; -+ -+ if (int_scale >= 0) { -+ rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); -+ result = static_cast(rnd); -+ } -+ else { -+ result = static_cast(rnd); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Partial specialization for initializing a complex value. -+template -+struct RandomGaussianFunc > { -+ -+ uint64_t seed; -+ double mean; -+ double stddev; -+ int int_scale; -+ double pi; -+ -+ // -+ // Methods -+ // -+ RandomGaussianFunc( -+ uint64_t seed_ = 0, -+ double mean_ = 0, -+ double stddev_ = 1, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { -+ std::srand((unsigned)seed); -+ } -+ -+ /// Compute random value and update RNG state -+ complex operator()() const { -+ -+ Element reals[2]; -+ -+ double rnd[2]; -+ detail::BoxMullerFunc func; -+ func(rnd, mean, stddev, pi); -+ -+ if (int_scale >= 0) { -+ rnd[0] = double(int(rnd[0] * double(1 << int_scale))); -+ rnd[1] = double(int(rnd[1] * double(1 << int_scale))); -+ reals[0] = from_real(rnd[0] / double(1 << int_scale)); -+ reals[1] = from_real(rnd[1] / double(1 << int_scale)); -+ } else { -+ reals[0] = from_real(rnd[0]); -+ reals[1] = from_real(rnd[1]); -+ } -+ -+ return complex(reals[0], reals[1]); -+ } -+}; -+ -+/// Partial specialization for initializing a complex value. -+template -+struct RandomGaussianFunc > { -+ -+ uint64_t seed; -+ double mean; -+ double stddev; -+ int int_scale; -+ double pi; -+ -+ // -+ // Methods -+ // -+ RandomGaussianFunc( -+ uint64_t seed_ = 0, -+ double mean_ = 0, -+ double stddev_ = 1, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { -+ std::srand((unsigned)seed); -+ } -+ -+ /// Compute random value and update RNG state -+ Quaternion operator()() const { -+ -+ Element reals[4]; -+ -+ double rnd1[2]; -+ double rnd2[2]; -+ detail::BoxMullerFunc func; -+ func(rnd1, mean, stddev, pi); -+ func(rnd2, mean, stddev, pi); -+ -+ if (int_scale >= 0) { -+ rnd1[0] = double(int(rnd1[0] * double(1 << int_scale))); -+ rnd1[1] = double(int(rnd1[1] * double(1 << int_scale))); -+ rnd2[0] = double(int(rnd2[0] * double(1 << int_scale))); -+ rnd2[1] = double(int(rnd2[1] * double(1 << int_scale))); -+ -+ reals[0] = from_real(rnd1[0] / double(1 << int_scale)); -+ reals[1] = from_real(rnd1[1] / double(1 << int_scale)); -+ reals[2] = from_real(rnd2[0] / double(1 << int_scale)); -+ reals[3] = from_real(rnd2[1] / double(1 << int_scale)); -+ } else { -+ reals[0] = from_real(rnd1[0]); -+ reals[1] = from_real(rnd1[1]); -+ reals[2] = from_real(rnd2[0]); -+ reals[3] = from_real(rnd2[1]); -+ } -+ -+ return Quaternion(reals[0], reals[1], reals[2], reals[3]); -+ } -+}; -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillGaussianFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ RandomGaussianFunc func; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ TensorFillGaussianFunc( -+ TensorView view_ = TensorView(), -+ RandomGaussianFunc func_ = RandomGaussianFunc() -+ ): -+ view(view_), func(func_) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ void operator()(Coord const &coord) const { -+ view.at(coord) = func(); -+ } -+}; -+ -+/// Computes a random Gaussian distribution for a rank-2 tensor -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillSymmetricGaussianFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ RandomGaussianFunc func; -+ cutlass::FillMode fill_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ TensorFillSymmetricGaussianFunc( -+ TensorView view_ = TensorView(), -+ RandomGaussianFunc func_ = RandomGaussianFunc(), -+ cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid -+ ): -+ view(view_), func(func_), fill_mode(fill_mode_) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ void operator()(Coord const &coord) const { -+ // Fill half of matrix based on FillMode -+ if (Layout::kRank == 2 && -+ fill_mode == cutlass::FillMode::kLower && -+ coord[0] >= coord[1]) { -+ view.at(coord) = func(); -+ } else if (Layout::kRank == 2 && -+ fill_mode == cutlass::FillMode::kUpper && -+ coord[0] <= coord[1]) { -+ view.at(coord) = func(); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a Gaussian distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomGaussian( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double mean = 0, ///< Gaussian distribution's mean -+ double stddev = 1, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); -+ -+ detail::TensorFillGaussianFunc func( -+ dst, -+ random_func -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/// Fills a tensor with random values with a Gaussian distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomGaussian( -+ TensorViewPlanarComplex dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double mean = 0, ///< Gaussian distribution's mean -+ double stddev = 1, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits); -+ TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a Gaussian distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillSymmetricRandomGaussian( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices -+ double mean = 0, ///< Gaussian distribution's mean -+ double stddev = 1, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); -+ -+ detail::TensorFillSymmetricGaussianFunc func( -+ dst, -+ random_func, -+ fill_mode -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values of a Gaussian distribution. -+template < -+ typename Element ///< Element type -+> -+void BlockFillRandomGaussian( -+ Element *ptr, ///< destination buffer -+ size_t capacity, ///< number of elements -+ uint64_t seed, ///< seed for RNG -+ double mean = 0, ///< Gaussian distribution's mean -+ double stddev = 1, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ -+ detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); -+ -+ for (size_t i = 0; i < capacity; ++i) { -+ ReferenceFactory::get(ptr, i) = random_func(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct RandomUniformFunc { -+ -+ using Real = typename RealType::Type; -+ -+ uint64_t seed; -+ double range; -+ double min; -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ RandomUniformFunc( -+ uint64_t seed_ = 0, -+ double max = 1, -+ double min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { -+ std::srand((unsigned)seed); -+ } -+ -+ -+ /// Compute random value and update RNG state -+ Element operator()() const { -+ -+ double rnd = double(std::rand()) / double(RAND_MAX); -+ -+ rnd = min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ Element result; -+ -+ if (int_scale >= 0) { -+ rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); -+ result = static_cast(Real(rnd)); -+ } -+ else { -+ result = static_cast(Real(rnd)); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Partial specialization for initializing a complex value. -+template -+struct RandomUniformFunc > { -+ -+ using Real = typename RealType::Type; -+ -+ uint64_t seed; -+ double range; -+ double min; -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ RandomUniformFunc( -+ uint64_t seed_ = 0, -+ double max = 1, -+ double min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { -+ std::srand((unsigned)seed); -+ } -+ -+ -+ /// Compute random value and update RNG state -+ complex operator()() const { -+ -+ Element reals[2]; -+ -+ for (int i = 0; i < 2; ++i) { -+ double rnd = double(std::rand()) / double(RAND_MAX); -+ -+ rnd = min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ -+ if (int_scale >= 0) { -+ rnd = double(int(rnd * double(1 << int_scale))); -+ reals[i] = from_real(Real(rnd / double(1 << int_scale))); -+ } -+ else { -+ reals[i] = from_real(Real(rnd)); -+ } -+ } -+ -+ return complex(reals[0], reals[1]); -+ } -+}; -+ -+/// Partial specialization for initializing a Quaternion value. -+template -+struct RandomUniformFunc > { -+ -+ using Real = typename RealType::Type; -+ -+ uint64_t seed; -+ double range; -+ double min; -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ RandomUniformFunc( -+ uint64_t seed_ = 0, -+ double max = 1, -+ double min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { -+ std::srand((unsigned)seed); -+ } -+ -+ -+ /// Compute random value and update RNG state -+ Quaternion operator()() const { -+ -+ Element reals[4]; -+ -+ for (int i = 0; i < 4; ++i) { -+ double rnd = double(std::rand()) / double(RAND_MAX); -+ -+ rnd = min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ -+ if (int_scale >= 0) { -+ rnd = double(int(rnd * double(1 << int_scale))); -+ reals[i] = from_real(Real(rnd / double(1 << int_scale))); -+ } -+ else { -+ reals[i] = from_real(Real(rnd)); -+ } -+ } -+ -+ return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); -+ } -+}; -+ -+/// Computes a random uniform distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillRandomUniformFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ RandomUniformFunc func; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of uniform RNG functor. -+ TensorFillRandomUniformFunc( -+ TensorView view_ = TensorView(), -+ RandomUniformFunc func_ = RandomUniformFunc() -+ ): -+ view(view_), func(func_) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ void operator()(Coord const &coord) const { -+ -+ view.at(coord) = func(); -+ } -+}; -+ -+/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a uniform distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillSymmetricRandomUniformFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ RandomUniformFunc func; -+ cutlass::FillMode fill_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of uniform RNG functor. -+ TensorFillSymmetricRandomUniformFunc( -+ TensorView view_ = TensorView(), -+ RandomUniformFunc func_ = RandomUniformFunc(), -+ cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid -+ ): -+ view(view_), func(func_), fill_mode(fill_mode_) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ void operator()(Coord const &coord) const { -+ // Fill half of matrix based on FillMode -+ if (Layout::kRank == 2 && -+ fill_mode == cutlass::FillMode::kLower && -+ coord[0] >= coord[1]) { -+ view.at(coord) = func(); -+ } else if (Layout::kRank == 2 && -+ fill_mode == cutlass::FillMode::kUpper && -+ coord[0] <= coord[1]) { -+ view.at(coord) = func(); -+ } -+ } -+}; -+ -+/// Computes a random Uniform distribution and pads diagonal with zeros -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillPadDiagonalRandomUniformFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ RandomUniformFunc func; -+ cutlass::FillMode fill_mode; -+ int alignment; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of uniform RNG functor. -+ TensorFillPadDiagonalRandomUniformFunc( -+ TensorView view_ = TensorView(), -+ RandomUniformFunc func_ = RandomUniformFunc(), -+ cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid, -+ int alignment_ = 1 -+ ): -+ view(view_), func(func_), fill_mode(fill_mode_), alignment(alignment_) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ void operator()(Coord const &coord) const { -+ // Fill half of matrix based on FillMode -+ if (Layout::kRank == 2 && -+ (fill_mode == cutlass::FillMode::kLower) && -+ (coord[0] >= coord[1]) || -+ ((coord[1] - coord[0]) >= alignment)) { -+ view.at(coord) = func(); -+ } else if (Layout::kRank == 2 && -+ fill_mode == cutlass::FillMode::kUpper && -+ (coord[0] <= coord[1]) || -+ ((coord[0] - coord[1]) >= alignment)) { -+ view.at(coord) = func(); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values of a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomUniform( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ detail::RandomUniformFunc random_func(seed, max, min, bits); -+ -+ detail::TensorFillRandomUniformFunc func( -+ dst, -+ random_func -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/// Fills a tensor with random values of a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomUniform( -+ TensorViewPlanarComplex dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ TensorFillRandomUniform(dst.view_real(), seed, max, min, bits); -+ TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits); -+} -+ -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomUniform( -+ TensorView, Layout> dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ detail::RandomUniformFunc> random_func(seed, max, min, bits); -+ -+ detail::TensorFillRandomUniformFunc, Layout> func( -+ dst, -+ random_func -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillSymmetricRandomUniform( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ detail::RandomUniformFunc random_func(seed, max, min, bits); -+ -+ detail::TensorFillSymmetricRandomUniformFunc func( -+ dst, -+ random_func, -+ fill_mode -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/// Fills a tensor with random values with a uniform random distribution pads zeros along diagonal -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillPadDiagonalRandomUniform( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1, ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ int alignment = 1 -+) { -+ -+ detail::RandomUniformFunc random_func(seed, max, min, bits); -+ -+ detail::TensorFillPadDiagonalRandomUniformFunc func( -+ dst, -+ random_func, -+ fill_mode, -+ alignment -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element ///< Element type -+> -+void BlockFillRandomUniform( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, ///< seed for RNG -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ detail::RandomUniformFunc random_func(seed, max, min, bits); -+ -+ for (size_t i = 0; i < capacity; ++i) { -+ ReferenceFactory::get(ptr, i) = random_func(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillDiagonalFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element diag; -+ Element other; -+ -+ // -+ // Methods -+ // -+ -+ TensorFillDiagonalFunc( -+ TensorView const &view_ = TensorView(), -+ Element diag_ = Element(1), -+ Element other_ = Element(0) -+ ): -+ view(view_), diag(diag_), other(other_) { } -+ -+ void operator()(Coord const & coord) const { -+ bool is_diag = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[i - 1]) { -+ is_diag = false; -+ break; -+ } -+ } -+ -+ view.at(coord) = (is_diag ? diag : other); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor everywhere with a unique value for its diagonal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillDiagonal( -+ TensorView dst, ///< destination tensor -+ Element diag = Element(1), ///< value to write in the diagonal -+ Element other = Element(0)) { ///< value to write off the diagonal -+ -+ detail::TensorFillDiagonalFunc func( -+ dst, -+ diag, -+ other -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to fill a tensor's digonal with 1 and 0 everywhere else. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillIdentity( -+ TensorView dst) { ///< destination tensor -+ -+ TensorFillDiagonal(dst, Element(1), Element(0)); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorUpdateDiagonal( -+ TensorView dst, ///< destination tensor -+ Element val = Element(1)) { -+ -+ typename Layout::Index extent = dst.extent().min(); -+ -+ for (typename Layout::Index i = 0; i < extent; ++i) { -+ Coord coord(i); -+ dst.at(coord) = val; -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorUpdateOffDiagonalFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element other; -+ -+ // -+ // Methods -+ // -+ -+ TensorUpdateOffDiagonalFunc( -+ TensorView const &view_ = TensorView(), -+ Element other_ = Element(0) -+ ): -+ view(view_), other(other_) { } -+ -+ void operator()(Coord const & coord) const { -+ bool is_diag = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[i - 1]) { -+ is_diag = false; -+ break; -+ } -+ } -+ -+ if (!is_diag) { -+ view.at(coord) = other; -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorUpdateOffDiagonal( -+ TensorView dst, ///< destination tensor -+ Element other = Element(1)) { -+ -+ detail::TensorUpdateOffDiagonalFunc func( -+ dst, -+ other -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillLinearFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Array v; -+ Element s; -+ -+ // -+ // Methods -+ // -+ -+ TensorFillLinearFunc() { } -+ -+ /// Constructs functor -+ TensorFillLinearFunc( -+ TensorView const &view_, -+ Array const & v_, -+ Element s_ = Element(0) -+ ): -+ view(view_), v(v_), s(s_) { } -+ -+ /// Updates the tensor -+ void operator()(Coord const & coord) const { -+ -+ Element sum(s); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank; ++i) { -+ sum += Element(coord[i]) * v[i]; -+ } -+ -+ view.at(coord) = sum; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills tensor with a linear combination of its coordinate and another vector -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillLinear( -+ TensorView dst, ///< destination tensor -+ Array const & v, -+ Element s = Element(0)) { -+ -+ detail::TensorFillLinearFunc func( -+ dst, -+ v, -+ s -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills tensor with a linear combination of its coordinate and another vector -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillSequential( -+ TensorView dst, ///< destination tensor -+ Element s = Element(0)) { -+ -+ Array stride; -+ -+ stride[0] = Element(1); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ stride[i] = stride[i - 1] * Element(dst.extent()[i - 1]); -+ } -+ -+ TensorFillLinear(dst, stride, s); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillSequential( -+ Element *ptr, -+ int64_t capacity, -+ Element v = Element(1), -+ Element s = Element(0)) { -+ int i = 0; -+ -+ while (i < capacity) { -+ cutlass::ReferenceFactory::value < -+ 8)>::get(ptr, i) = s; -+ -+ s = Element(s + v); -+ ++i; -+ } -+} -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillSequentialModN( -+ Element *ptr, -+ int64_t capacity, -+ int64_t mod, -+ int64_t v = int64_t(1), -+ int64_t s = int64_t(0)) { -+ int i = 0; -+ -+ while (i < capacity) { -+ cutlass::ReferenceFactory::value < -+ 8)>::get(ptr, i) = Element(s); -+ -+ s = int64_t(s + v) % mod; -+ ++i; -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillRandom( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, -+ Distribution dist) { -+ -+ if (dist.kind == Distribution::Gaussian) { -+ BlockFillRandomGaussian( -+ ptr, -+ capacity, -+ seed, -+ dist.gaussian.mean, -+ dist.gaussian.stddev, -+ dist.int_scale); -+ } -+ else if (dist.kind == Distribution::Uniform) { -+ BlockFillRandomUniform( -+ ptr, -+ capacity, -+ seed, -+ dist.uniform.max, -+ dist.uniform.min, -+ dist.int_scale); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct RandomSparseMetaFunc { -+ -+ uint64_t seed; -+ int range; -+ int MetaSizeInBits; -+ -+ // -+ // Methods -+ // -+ -+ RandomSparseMetaFunc( -+ uint64_t seed_ = 0, -+ int MetaSizeInBits_ = 2 -+ ): -+ seed(seed_), MetaSizeInBits(MetaSizeInBits_) { -+ std::srand((unsigned)seed); -+ if (MetaSizeInBits_ == 2) { -+ range = 6; -+ } else if (MetaSizeInBits_ == 4) { -+ range = 2; -+ } -+ } -+ -+ /// Compute random value and update RNG state -+ Element operator()() const { -+ Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; -+ Element TwoToOneMeta[2] = {0x4, 0xe}; -+ -+ Element * MetaArray = (MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; -+ -+ Element result = 0x0; -+ -+ for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { -+ int rnd = std::rand() % range; -+ Element meta = MetaArray[rnd]; -+ -+ result = (Element)(result | ((Element)(meta << (i * 4)))); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Computes a random sparse meta -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillRandomSparseMetaFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ RandomSparseMetaFunc func; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ TensorFillRandomSparseMetaFunc( -+ TensorView view_ = TensorView(), -+ RandomSparseMetaFunc func_ = RandomSparseMetaFunc() -+ ): -+ view(view_), func(func_) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ void operator()(Coord const &coord) const { -+ -+ view.at(coord) = func(); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomSparseMeta( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ int MetaSizeInBits) { ///< 2 bit or 4 bit -+ -+ detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); -+ -+ detail::TensorFillRandomSparseMetaFunc func( -+ dst, -+ random_func -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element ///< Element type -+> -+void BlockFillRandomSparseMeta( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, ///< seed for RNG -+ int MetaSizeInBits) { ///< 2 bit or 4bit -+ -+ detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); -+ -+ for (size_t i = 0; i < capacity; ++i) { -+ ptr[i] = random_func(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a ell block index matrix with random values with a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomEllIdx( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ int rows, int ell_cols, int cols) { ///< dimension of the matrix -+ -+ std::srand((unsigned)seed); -+ -+ for (int i = 0; i < rows; ++i) { -+ int col_idx = std::rand() % cols; -+ -+ for (int j = 0; j < ell_cols; ++j) { -+ dst.at({i, j}) = col_idx; -+ -+ if (col_idx != -1) { -+ if (col_idx == (cols - 1)) { -+ col_idx = -1; -+ } else { -+ col_idx = std::rand() % (cols - col_idx - 1) + col_idx + 1; -+ } -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies a diagonal in from host memory without modifying off-diagonal elements. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorCopyDiagonalIn( -+ TensorView dst, ///< destination tensor -+ Element const *ptr) { ///< dense buffer of elements -+ -+ typename Layout::Index extent = dst.extent().min(); -+ -+ for (typename Layout::Index i = 0; i < extent; ++i) { -+ Coord coord(i); -+ dst.at(coord) = ReferenceFactory::get(ptr, i); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies the diagonal of a tensor into a dense buffer in host memory. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorCopyDiagonalOut( -+ Element *ptr, ///< dense buffer of elements -+ TensorView src) { ///< source tensor -+ -+ typename Layout::Index extent = src.extent().min(); -+ -+ for (typename Layout::Index i = 0; i < extent; ++i) { -+ Coord coord(i); -+ ReferenceFactory::get(ptr, i) = src.at(coord); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp -new file mode 100644 -index 0000000..3262c53 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp -@@ -0,0 +1,432 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Provides several functions for filling tensors with data. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Cute includes -+#include "cute/tensor.hpp" -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Uniform and procedural tensor fills -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with a scalar element -+template -+void TensorFill(Tensor dst, typename Tensor::value_type element) { -+ -+ for (int64_t idx = 0; idx < cute::size(dst); ++idx) { -+ dst(idx) = element; -+ } -+} -+ -+/// Fills a tensor with the contents of its layout -+template -+void TensorFillSequential(Tensor dst) { -+ -+ auto layout = dst.layout(); -+ -+ for (int64_t idx = 0; idx < cute::size(dst); ++idx) { -+ dst(idx) = layout(idx); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Random uniform values -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct RandomUniformFunc { -+ -+ using Real = typename RealType::Type; -+ -+ uint64_t seed; -+ double range; -+ double min; -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ RandomUniformFunc( -+ uint64_t seed_ = 0, -+ double max = 1, -+ double min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { -+ std::srand((unsigned)seed); -+ } -+ -+ -+ /// Compute random value and update RNG state -+ Element operator()() const { -+ -+ double rnd = double(std::rand()) / double(RAND_MAX); -+ -+ rnd = min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ Element result; -+ -+ if (int_scale >= 0) { -+ rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); -+ result = static_cast(Real(rnd)); -+ } -+ else { -+ result = static_cast(Real(rnd)); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Partial specialization for initializing a complex value. -+template -+struct RandomUniformFunc > { -+ -+ using Real = typename RealType::Type; -+ -+ uint64_t seed; -+ double range; -+ double min; -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ RandomUniformFunc( -+ uint64_t seed_ = 0, -+ double max = 1, -+ double min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { -+ std::srand((unsigned)seed); -+ } -+ -+ -+ /// Compute random value and update RNG state -+ complex operator()() const { -+ -+ Element reals[2]; -+ -+ for (int i = 0; i < 2; ++i) { -+ double rnd = double(std::rand()) / double(RAND_MAX); -+ -+ rnd = min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ -+ if (int_scale >= 0) { -+ rnd = double(int(rnd * double(1 << int_scale))); -+ reals[i] = from_real(Real(rnd / double(1 << int_scale))); -+ } -+ else { -+ reals[i] = from_real(Real(rnd)); -+ } -+ } -+ -+ return complex(reals[0], reals[1]); -+ } -+}; -+ -+/// Partial specialization for initializing a Quaternion value. -+template -+struct RandomUniformFunc > { -+ -+ using Real = typename RealType::Type; -+ -+ uint64_t seed; -+ double range; -+ double min; -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ RandomUniformFunc( -+ uint64_t seed_ = 0, -+ double max = 1, -+ double min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { -+ std::srand((unsigned)seed); -+ } -+ -+ -+ /// Compute random value and update RNG state -+ Quaternion operator()() const { -+ -+ Element reals[4]; -+ -+ for (int i = 0; i < 4; ++i) { -+ double rnd = double(std::rand()) / double(RAND_MAX); -+ -+ rnd = min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ -+ if (int_scale >= 0) { -+ rnd = double(int(rnd * double(1 << int_scale))); -+ reals[i] = from_real(Real(rnd / double(1 << int_scale))); -+ } -+ else { -+ reals[i] = from_real(Real(rnd)); -+ } -+ } -+ -+ return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template ///< Tensor object -+void TensorFillRandomUniform( -+ Tensor dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ detail::RandomUniformFunc random_func(seed, max, min, bits); -+ -+ for (int64_t idx = 0; idx < cute::size(dst); ++idx) { -+ dst(idx) = random_func(); -+ } -+} -+ -+/// Fills a block with random values with a uniform random distribution. -+template < -+ typename Element ///< Element type -+> -+void BlockFillRandomUniform( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, ///< seed for RNG -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ detail::RandomUniformFunc random_func(seed, max, min, bits); -+ -+ for (size_t i = 0; i < capacity; ++i) { -+ ptr[i] = random_func(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Random Gaussian -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct RandomGaussianFunc { -+ -+ uint64_t seed; -+ double mean; -+ double stddev; -+ int int_scale; -+ double pi; -+ -+ // -+ // Methods -+ // -+ RandomGaussianFunc( -+ uint64_t seed_ = 0, -+ double mean_ = 0, -+ double stddev_ = 1, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { -+ std::srand((unsigned)seed); -+ } -+ -+ /// Compute random value and update RNG state -+ Element operator()() const { -+ -+ // Box-Muller transform to generate random numbers with Normal distribution -+ double u1 = double(std::rand()) / double(RAND_MAX); -+ double u2 = double(std::rand()) / double(RAND_MAX); -+ -+ // Compute Gaussian random value -+ double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); -+ rnd = mean + stddev * rnd; -+ -+ // Scale and convert final result -+ Element result; -+ -+ if (int_scale >= 0) { -+ rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); -+ result = static_cast(rnd); -+ } -+ else { -+ result = static_cast(rnd); -+ } -+ -+ return result; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a Gaussian distribution. -+template < -+ typename Tensor -+> -+void TensorFillRandomGaussian( -+ Tensor dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double mean = 0, ///< Gaussian distribution's mean -+ double stddev = 1, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); -+ -+ for (int64_t idx = 0; idx < cute::size(dst); ++idx) { -+ dst(idx) = random_func(); -+ } -+} -+ -+/// Fills a block with random values with a Gaussian distribution. -+template < -+ typename Element ///< Element type -+> -+void BlockFillRandomGaussian( -+ Element *ptr, ///< destination buffer -+ size_t capacity, ///< number of elements -+ uint64_t seed, ///< seed for RNG -+ double mean = 0, ///< Gaussian distribution's mean -+ double stddev = 1, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); -+ -+ for (size_t i = 0; i < capacity; ++i) { -+ ptr[i] = random_func(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillSequential( -+ Element *ptr, -+ int64_t capacity, -+ Element v = Element(1), -+ Element s = Element(0)) { -+ int i = 0; -+ -+ while (i < capacity) { -+ -+ ptr[i] = Element(s + v); -+ ++i; -+ } -+} -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillSequentialModN( -+ Element *ptr, -+ int64_t capacity, -+ int64_t mod, -+ int64_t v = int64_t(1), -+ int64_t s = int64_t(0)) { -+ int i = 0; -+ -+ while (i < capacity) { -+ -+ ptr[i] = static_cast(int32_t(int64_t(s + v) % mod)); -+ ++i; -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h -new file mode 100644 -index 0000000..a195893 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h -@@ -0,0 +1,134 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines several helpers -+namespace detail { -+ -+/// Helper to perform for-each operation -+template -+struct TensorForEachHelper { -+ -+ /// Index of the active rank -+ static int const kActiveRank = Rank - RankRemaining - 1; -+ -+ /// Constructor for general rank -+ TensorForEachHelper( -+ Func &func, -+ Coord const &extent, -+ Coord &coord) { -+ -+ for (int i = 0; i < extent.at(kActiveRank); ++i) { -+ coord[kActiveRank] = i; -+ TensorForEachHelper(func, extent, coord); -+ } -+ } -+}; -+ -+/// Helper to perform for-each operation -+template -+struct TensorForEachHelper { -+ -+ /// Index of the active rank -+ static int const kActiveRank = Rank - 1; -+ -+ /// Constructor for fastest chaning rank -+ TensorForEachHelper( -+ Func &func, -+ Coord const &extent, -+ Coord &coord) { -+ -+ for (int i = 0; i < extent.at(kActiveRank); ++i) { -+ coord[kActiveRank] = i; -+ func(coord); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Iterates over the index space of a tensor -+template < -+ typename Func, ///< function applied to each point in a tensor's index space -+ int Rank> ///< rank of index space -+void TensorForEach(Coord extent, Func & func) { -+ Coord coord; -+ detail::TensorForEachHelper(func, extent, coord); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Iterates over the index space of a tensor and calls a C++ lambda -+template < -+ typename Func, ///< function applied to each point in a tensor's index space -+ int Rank> ///< rank of index space -+void TensorForEachLambda(Coord extent, Func func) { -+ Coord coord; -+ detail::TensorForEachHelper(func, extent, coord); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct BlockForEach { -+ -+ /// Constructor performs the operation. -+ BlockForEach( -+ Element *ptr, -+ size_t capacity, -+ typename Func::Params params = typename Func::Params()) { -+ -+ Func func(params); -+ -+ for (size_t index = 0; index < capacity; ++index) { -+ ptr[index] = func(); -+ } -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h -new file mode 100644 -index 0000000..9d52b08 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h -@@ -0,0 +1,42 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+ -+#include "cutlass/cutlass.h" -+ -+// The contents of this file have been moved to 'tensor_reduce' to cover other types of reductions. -+ -+#include "cutlass/util/reference/host/tensor_reduce.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h -new file mode 100644 -index 0000000..672e4d5 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h -@@ -0,0 +1,203 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/util/reference/detail/linear_to_coordinate.h" -+#include "cutlass/core_io.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -+/// workspace -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorView view, -+ ComputeType identity, -+ ReduceOp reduce, -+ TransformOp transform -+) { -+ -+ for (int64_t idx = 0; idx < view.size(); ++idx) { -+ typename Layout::TensorCoord coord; -+ cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); -+ -+ if (view.contains(coord)) { -+ Element x = view.at(coord); -+ identity = reduce(identity, transform(x)); -+ } -+ } -+ -+ return identity; -+} -+ -+/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -+/// workspace -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorView view_A, -+ TensorView view_B, -+ ComputeType identity, -+ ReduceOp reduce, -+ TransformOp transform) { -+ -+ if (view_A.extent() != view_B.extent()) { -+ throw std::runtime_error("Tensor extents must match."); -+ } -+ -+ for (int64_t idx = 0; idx < view_A.size(); ++idx) { -+ -+ typename Layout::TensorCoord coord; -+ cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); -+ -+ if (view_A.contains(coord)) { -+ Element a = view_A.at(coord); -+ Element b = view_B.at(coord); -+ identity = reduce(identity, transform(a, b)); -+ } -+ } -+ -+ return identity; -+} -+ -+/// Helper to compute the sum of the elements of a tensor -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = Element -+> -+ComputeType TensorSum( -+ TensorView view, -+ ComputeType identity = ComputeType() -+) { -+ -+ plus reduce; -+ NumericConverter transform; -+ -+ return TensorTransformReduce( -+ view, identity, reduce, transform); -+} -+ -+/// Helper to compute the sum of the squares of the elements of a tensor -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = Element -+> -+ComputeType TensorSumSq( -+ TensorView view, -+ ComputeType identity = ComputeType() -+) { -+ -+ plus reduce; -+ magnitude_squared transform; -+ -+ return TensorTransformReduce( -+ view, identity, reduce, transform); -+} -+ -+/// Helper to compute the norm of the elements of a tensor. -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorNorm( -+ TensorView view, -+ ComputeType identity = ComputeType() -+) { -+ -+ return std::sqrt(TensorSumSq(view, identity)); -+} -+ -+/// Helper to compute the sum of the squares of the differences of two tensors -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorSumSqDiff( -+ TensorView view_A, -+ TensorView view_B, -+ ComputeType identity = ComputeType() -+) { -+ -+ plus reduce; -+ magnitude_squared_difference transform; -+ -+ return TensorTransformReduce( -+ view_A, view_B, identity, reduce, transform); -+} -+ -+ -+/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorNormDiff( -+ TensorView view_A, -+ TensorView view_B, -+ ComputeType identity = ComputeType() -+) { -+ -+ return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp -new file mode 100644 -index 0000000..aadf60a ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp -@@ -0,0 +1,203 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Provides several functions for filling tensors with data. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Cute includes -+#include "cute/tensor.hpp" -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Tensor reductions -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -+/// workspace -+template < -+ typename Tensor, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ Tensor view, -+ ComputeType identity, -+ ReduceOp reduce, -+ TransformOp transform -+) { -+ -+ for (int64_t idx = 0; idx < cute::size(view); ++idx) { -+ identity = reduce(identity, transform(view(idx))); -+ } -+ -+ return identity; -+} -+ -+/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -+/// workspace -+template < -+ typename TensorA, -+ typename TensorB, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorA view_A, -+ TensorB view_B, -+ ComputeType identity, -+ ReduceOp reduce, -+ TransformOp transform) { -+ -+ if (cute::size(view_A) != cute::size(view_B)) { -+ throw std::runtime_error("Tensor sizes must match."); -+ } -+ -+ for (int64_t idx = 0; idx < cute::size(view_A); ++idx) { -+ identity = reduce(identity, transform(view_A(idx), view_B(idx))); -+ } -+ -+ return identity; -+} -+ -+/// Helper to compute the sum of the elements of a tensor -+template < -+ typename Tensor, -+ typename ComputeType = typename Tensor::value_type -+> -+ComputeType TensorSum( -+ Tensor view, -+ ComputeType identity = ComputeType() -+) { -+ -+ plus reduce; -+ NumericConverter transform; -+ -+ return TensorTransformReduce( -+ view, identity, reduce, transform); -+} -+ -+/// Helper to compute the sum of the squares of the elements of a tensor -+template < -+ typename Tensor, -+ typename ComputeType = typename Tensor::value_type -+> -+ComputeType TensorSumSq( -+ Tensor view, -+ ComputeType identity = ComputeType() -+) { -+ -+ plus reduce; -+ magnitude_squared transform; -+ -+ return TensorTransformReduce( -+ view, identity, reduce, transform); -+} -+ -+/// Helper to compute the norm of the elements of a tensor. -+template < -+ typename Tensor, -+ typename ComputeType = double -+> -+ComputeType TensorNorm( -+ Tensor view, -+ ComputeType identity = ComputeType() -+) { -+ -+ return std::sqrt(TensorSumSq(view, identity)); -+} -+ -+/// Helper to compute the sum of the squares of the differences of two tensors -+template < -+ typename TensorA, -+ typename TensorB, -+ typename ComputeType = double -+> -+ComputeType TensorSumSqDiff( -+ TensorA view_A, -+ TensorB view_B, -+ ComputeType identity = ComputeType() -+) { -+ -+ plus reduce; -+ magnitude_squared_difference transform; -+ -+ return TensorTransformReduce( -+ view_A, view_B, identity, reduce, transform); -+} -+ -+ -+/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory -+template < -+ typename TensorA, -+ typename TensorB, -+ typename ComputeType = double -+> -+ComputeType TensorNormDiff( -+ TensorA view_A, -+ TensorB view_B, -+ ComputeType identity = ComputeType() -+) { -+ -+ return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h -new file mode 100644 -index 0000000..0c931ee ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h -@@ -0,0 +1,215 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for TRMM in host-side code. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ DiagType DiagTypeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_trmm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ TensorRef tensor_d, -+ ComputeType initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ static_assert(SideModeA != SideMode::kInvalid -+ , "Side Mode can either be Left or Right."); -+ -+ static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper -+ , "Fill Mode can either be Lower or Upper."); -+ -+ using CompareOp = typename TrMatrixCompareOp::Type; -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ // Assuming correct k-dimension value is passed -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ CompareOp compare_op; -+ -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ ElementA a = ElementA(); -+ ElementB b = ElementB(); -+ -+ if (SideModeA == SideMode::kLeft) { -+ a = (compare_op(row, k_block)) ? -+ (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); -+ if (row == k_block && DiagTypeA == DiagType::kUnit) { -+ a = ElementA(1); -+ } -+ b = tensor_b.at(MatrixCoord(k_block, col)); -+ } else if (SideModeA == SideMode::kRight) { -+ a = tensor_b.at(MatrixCoord(row, k_block)); -+ b = (compare_op(k_block, col)) ? -+ tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); -+ if (k_block == col && DiagTypeA == DiagType::kUnit) { -+ b = ElementA(1); -+ } -+ } -+ -+ ComputeType compute_a(cast_if_scalar(a)); -+ ComputeType compute_b(cast_if_scalar(b)); -+ -+ accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j])); -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ DiagType DiagTypeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd -+> -+struct Trmm; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct Trmm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_trmm>( -+ problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h -new file mode 100644 -index 0000000..455c8a9 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h -@@ -0,0 +1,262 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued TRMM in host-side code. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ ComplexTransform TransformA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ DiagType DiagTypeA, -+ typename ElementB, -+ typename LayoutB, -+ ComplexTransform TransformB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_trmm_complex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ TensorRef tensor_d, -+ ComputeType initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ static_assert(SideModeA != SideMode::kInvalid -+ , "Side Mode can either be Left or Right."); -+ -+ static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper -+ , "Fill Mode can either be Lower or Upper."); -+ -+ using CompareOp = typename TrMatrixCompareOp::Type; -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ // Assuming correct k-dimension value is passed -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ CompareOp compare_op; -+ -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ ElementA a = ElementA(); -+ ElementB b = ElementB(); -+ -+ if (SideModeA == SideMode::kLeft) { -+ a = (compare_op(row, k_block)) ? -+ (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); -+ if (row == k_block && DiagTypeA == DiagType::kUnit) { -+ a = ElementA(1); -+ } -+ b = tensor_b.at(MatrixCoord(k_block, col)); -+ } else if (SideModeA == SideMode::kRight) { -+ a = tensor_b.at(MatrixCoord(row, k_block)); -+ b = (compare_op(k_block, col)) ? -+ tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); -+ if (k_block == col && DiagTypeA == DiagType::kUnit) { -+ b = ElementA(1); -+ } -+ } -+ -+ ComputeType a_ik = ComputeType(a); -+ ComputeType b_kj = ComputeType(b); -+ -+ // Conjugate, and hence hermitian, is only allowed for the triangular matrix -+ if (SideModeA == SideMode::kLeft && TransformA == ComplexTransform::kConjugate) { -+ a_ik = conj(a_ik); -+ } else if (SideModeA == SideMode::kRight && TransformA == ComplexTransform::kConjugate) { -+ b_kj = conj(b_kj); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j])); -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ ComplexTransform TransformA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ DiagType DiagTypeA, -+ typename ElementB, -+ typename LayoutB, -+ ComplexTransform TransformB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex -+> -+struct TrmmComplex; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct TrmmComplex { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_trmm_complex>( -+ problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for gaussian multiply-add -+template -+struct TrmmComplex { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_trmm_complex>( -+ problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/tensor_view_io.h b/3rdparty/cutlass/tools/util/include/cutlass/util/tensor_view_io.h -new file mode 100644 -index 0000000..6a352df ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/tensor_view_io.h -@@ -0,0 +1,262 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+* -+**************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/core_io.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/tensor_view_planar_complex.h" -+#include "cutlass/complex.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Helper to write the least significant rank of a TensorView -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream & TensorView_WriteLeastSignificantRank( -+ std::ostream& out, -+ TensorView const& view, -+ Coord const &start_coord, -+ int rank, -+ std::streamsize width) { -+ -+ for (int idx = 0; idx < view.extent(rank); ++idx) { -+ -+ Coord coord(start_coord); -+ coord[rank] = idx; -+ -+ if (idx) { -+ out.width(0); -+ out << ", "; -+ } -+ if (idx || coord) { -+ out.width(width); -+ } -+ out << ScalarIO(view.at(coord)); -+ } -+ -+ return out; -+} -+ -+/// Helper to write a rank of a TensorView -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream & TensorView_WriteRank( -+ std::ostream& out, -+ TensorView const& view, -+ Coord const &start_coord, -+ int rank, -+ std::streamsize width) { -+ -+ // If called on the least significant rank, write the result as a row -+ if (rank + 1 == Layout::kRank) { -+ return TensorView_WriteLeastSignificantRank(out, view, start_coord, rank, width); -+ } -+ -+ // Otherwise, write a sequence of rows and newlines -+ for (int idx = 0; idx < view.extent(rank); ++idx) { -+ -+ Coord coord(start_coord); -+ coord[rank] = idx; -+ -+ if (rank + 2 == Layout::kRank) { -+ // Write least significant ranks asa matrix with rows delimited by "\n" -+ out << (idx ? ",\n" : ""); -+ TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width); -+ } -+ else { -+ // Higher ranks are separated by newlines -+ out << (idx ? ",\n\n" : ""); -+ TensorView_WriteRank(out, view, coord, rank + 1, width); -+ } -+ } -+ -+ return out; -+} -+ -+/// Helper to write the least significant rank of a TensorView -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream & TensorViewPlanarComplex_WriteLeastSignificantRank( -+ std::ostream& out, -+ TensorViewPlanarComplex const& view, -+ Coord const &start_coord, -+ int rank, -+ std::streamsize width) { -+ -+ for (int idx = 0; idx < view.extent(rank); ++idx) { -+ -+ Coord coord(start_coord); -+ coord[rank] = idx; -+ -+ if (idx) { -+ out.width(0); -+ out << ", "; -+ } -+ if (idx || coord) { -+ out.width(width); -+ } -+ -+ complex x = view.at(coord); -+ out << x; -+ } -+ -+ return out; -+} -+ -+/// Helper to write a rank of a TensorView -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream & TensorViewPlanarComplex_WriteRank( -+ std::ostream& out, -+ TensorViewPlanarComplex const& view, -+ Coord const &start_coord, -+ int rank, -+ std::streamsize width) { -+ -+ // If called on the least significant rank, write the result as a row -+ if (rank + 1 == Layout::kRank) { -+ return TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, start_coord, rank, width); -+ } -+ -+ // Otherwise, write a sequence of rows and newlines -+ for (int idx = 0; idx < view.extent(rank); ++idx) { -+ -+ Coord coord(start_coord); -+ coord[rank] = idx; -+ -+ if (rank + 2 == Layout::kRank) { -+ // Write least significant ranks asa matrix with rows delimited by ";\n" -+ out << (idx ? ";\n" : ""); -+ TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width); -+ } -+ else { -+ // Higher ranks are separated by newlines -+ out << (idx ? "\n" : ""); -+ TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width); -+ } -+ } -+ -+ return out; -+} -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Prints human-readable representation of a TensorView to an ostream -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream& TensorViewWrite( -+ std::ostream& out, -+ TensorView const& view) { -+ -+ // Prints a TensorView according to the following conventions: -+ // - least significant rank is printed as rows separated by ";\n" -+ // - all greater ranks are delimited with newlines -+ // -+ // The result is effectively a whitespace-delimited series of 2D matrices. -+ -+ return detail::TensorView_WriteRank(out, view, Coord(), 0, out.width()); -+} -+ -+/// Prints human-readable representation of a TensorView to an ostream -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream& operator<<( -+ std::ostream& out, -+ TensorView const& view) { -+ -+ // Prints a TensorView according to the following conventions: -+ // - least significant rank is printed as rows separated by ";\n" -+ // - all greater ranks are delimited with newlines -+ // -+ // The result is effectively a whitespace-delimited series of 2D matrices. -+ -+ return TensorViewWrite(out, view); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Prints human-readable representation of a TensorView to an ostream -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream& TensorViewWrite( -+ std::ostream& out, -+ TensorViewPlanarComplex const& view) { -+ -+ // Prints a TensorView according to the following conventions: -+ // - least significant rank is printed as rows separated by ";\n" -+ // - all greater ranks are delimited with newlines -+ // -+ // The result is effectively a whitespace-delimited series of 2D matrices. -+ -+ return detail::TensorViewPlanarComplex_WriteRank(out, view, Coord(), 0, out.width()); -+} -+ -+/// Prints human-readable representation of a TensorView to an ostream -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream& operator<<( -+ std::ostream& out, -+ TensorViewPlanarComplex const& view) { -+ -+ // Prints a TensorView according to the following conventions: -+ // - least significant rank is printed as rows separated by ";\n" -+ // - all greater ranks are delimited with newlines -+ // -+ // The result is effectively a whitespace-delimited series of 2D matrices. -+ -+ return TensorViewWrite(out, view); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/type_traits.h b/3rdparty/cutlass/tools/util/include/cutlass/util/type_traits.h -new file mode 100644 -index 0000000..f187b97 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/type_traits.h -@@ -0,0 +1,238 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Type traits for common CUDA types -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/complex.h" -+ -+namespace cutlass { -+struct half_t; -+ -+template -+struct TypeTraits { -+ typedef T host_type; -+ typedef T device_type; -+ static inline T remove_negative_zero(T x) { return x; } -+ static inline T to_print(T x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_8I; -+ typedef int8_t host_type; -+ typedef int8_t device_type; -+ typedef int8_t integer_type; -+ typedef uint8_t unsigned_type; -+ static inline int8_t remove_negative_zero(int8_t x) { return x; } -+ static inline int to_print(int8_t x) { return (int)x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_8I; -+ typedef uint8_t host_type; -+ typedef uint8_t device_type; -+ typedef uint8_t integer_type; -+ typedef uint8_t unsigned_type; -+ static inline uint8_t remove_negative_zero(uint8_t x) { return x; } -+ static inline uint32_t to_print(uint8_t x) { return (uint32_t)x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_32I; -+ typedef int host_type; -+ typedef int device_type; -+ typedef int32_t integer_type; -+ typedef uint32_t unsigned_type; -+ static inline int32_t remove_negative_zero(int32_t x) { return x; } -+ static inline int to_print(int x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_32I; -+ typedef unsigned host_type; -+ typedef unsigned device_type; -+ typedef uint32_t integer_type; -+ typedef uint32_t unsigned_type; -+ static inline uint32_t remove_negative_zero(uint32_t x) { return x; } -+ static inline uint32_t to_print(uint32_t x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_8I; -+ typedef int64_t host_type; -+ typedef int64_t device_type; -+ typedef int64_t integer_type; -+ typedef uint64_t unsigned_type; -+ static inline int64_t remove_negative_zero(int64_t x) { return x; } -+ static inline int64_t to_print(int64_t x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_8I; -+ typedef uint64_t host_type; -+ typedef uint64_t device_type; -+ typedef uint64_t integer_type; -+ typedef uint64_t unsigned_type; -+ static inline uint64_t remove_negative_zero(uint64_t x) { return x; } -+ static inline uint64_t to_print(uint64_t x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_16F; -+ typedef half_t host_type; -+ typedef half_t device_type; -+ typedef int16_t integer_type; -+ typedef uint16_t unsigned_type; -+ static inline half_t remove_negative_zero(half_t x) { -+ return (x.raw() == 0x8000 ? half_t::bitcast(0) : x); -+ } -+ static inline half_t to_print(half_t x) { return x; } -+ static inline device_type to_device(half_t x) { return reinterpret_cast(x); } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_32F; -+ typedef float host_type; -+ typedef float device_type; -+ typedef int32_t integer_type; -+ typedef uint32_t unsigned_type; -+ static inline float remove_negative_zero(float x) { return x == -0.f ? 0.f : x; } -+ static inline float to_print(float x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_64F; -+ typedef double host_type; -+ typedef double device_type; -+ typedef int64_t integer_type; -+ typedef uint64_t unsigned_type; -+ static inline double remove_negative_zero(double x) { return x == -0.0 ? 0.0 : x; } -+ static inline double to_print(double x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Complex types -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct TypeTraits > { -+ static cudaDataType_t const cublas_type = CUDA_C_16F; -+ typedef complex host_type; -+ typedef complex device_type; -+ typedef int16_t integer_type; -+ typedef uint16_t unsigned_type; -+ static inline device_type to_device(complex x) { return reinterpret_cast(x); } -+}; -+ -+template <> -+struct TypeTraits > { -+ static cudaDataType_t const cublas_type = CUDA_C_16F; -+ typedef complex host_type; -+ typedef complex device_type; -+ typedef int16_t integer_type; -+ typedef uint16_t unsigned_type; -+ static inline complex remove_negative_zero(complex x) { -+ return complex( -+ real(x) == -0_hf ? 0_hf : real(x), -+ imag(x) == -0_hf ? 0_hf : imag(x) -+ ); -+ } -+ static inline complex to_print(complex x) { return x; } -+ static inline device_type to_device(complex x) { return reinterpret_cast(x); } -+}; -+ -+template <> -+struct TypeTraits > { -+ -+ static cudaDataType_t const cublas_type = CUDA_C_32F; -+ typedef complex host_type; -+ typedef complex device_type; -+ typedef int64_t integer_type; -+ typedef uint64_t unsigned_type; -+ -+ static inline complex remove_negative_zero(complex x) { -+ return complex( -+ real(x) == -0.f ? 0.f : real(x), -+ imag(x) == -0.f ? 0.f : imag(x) -+ ); -+ } -+ -+ static inline complex to_print(complex x) { return x; } -+ static inline device_type to_device(complex x) { return reinterpret_cast(x); } -+}; -+ -+template <> -+struct TypeTraits > { -+ static cudaDataType_t const cublas_type = CUDA_C_64F; -+ typedef complex host_type; -+ typedef complex device_type; -+ struct integer_type { int64_t real, imag; }; -+ struct unsigned_type { uint64_t real, imag; }; -+ static inline complex remove_negative_zero(complex x) { -+ return complex( -+ real(x) == -0.0 ? 0.0 : real(x), -+ imag(x) == -0.0 ? 0.0 : imag(x) -+ ); -+ } -+ static inline complex to_print(complex x) { return x; } -+ static inline device_type to_device(complex x) { return reinterpret_cast(x); } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/trt_fused_multihead_attention/CMakeLists.txt b/3rdparty/trt_fused_multihead_attention/CMakeLists.txt -index 8707220..c9369e0 100644 ---- a/3rdparty/trt_fused_multihead_attention/CMakeLists.txt -+++ b/3rdparty/trt_fused_multihead_attention/CMakeLists.txt -@@ -21,7 +21,10 @@ set(trt_fused_multi_head_attention_files - ) - - file(GLOB trt_fused_multi_head_attention_files ${trt_fused_multi_head_attention_files} *.sm*.cpp) -- -+if(${CUDA_VERSION_STRING} VERSION_LESS_EQUAL "10.1.105" ) -+#this cuda don't support sm80 -+ list(REMOVE_ITEM trt_fused_multi_head_attention_files fused_mha_with_relPosBias_fp16_64_32_kernel.sm80.cpp) -+endif() - add_library(trt_fused_multi_head_attention STATIC ${trt_fused_multi_head_attention_files}) - target_link_libraries(trt_fused_multi_head_attention PUBLIC -lcublas -lcudart) - set_property(TARGET trt_fused_multi_head_attention PROPERTY POSITION_INDEPENDENT_CODE ON) -diff --git a/CMakeLists.txt b/CMakeLists.txt -index ea21014..66cf2af 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -14,7 +14,9 @@ - cmake_minimum_required(VERSION 3.8 FATAL_ERROR) # for PyTorch extensions, version should be greater than 3.13 - project(FasterTransformer LANGUAGES CXX CUDA) - --find_package(CUDA 10.2 REQUIRED) -+find_package(CUDA 10.1 REQUIRED) -+ -+option(EXAMPLES "build examples" on) - - if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11") - add_definitions("-DENABLE_BF16") -@@ -61,7 +63,7 @@ if(USE_TRITONSERVER_DATATYPE) - add_definitions("-DUSE_TRITONSERVER_DATATYPE") - endif() - --set(CXX_STD "14" CACHE STRING "C++ standard") -+set(CXX_STD "17" CACHE STRING "C++ standard") - - set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) - -@@ -85,7 +85,7 @@ endif() - - # setting compiler flags - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") --set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") -+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector-strong -D_FORTIFY_SOURCE=2") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall -ldl") - - set(SM_SETS 52 60 61 70 75 80 86) -@@ -92,13 +94,15 @@ set(FIND_SM False) - - foreach(SM_NUM IN LISTS SM_SETS) - string(FIND "${SM}" "${SM_NUM}" SM_POS) -+ message("find ${SM} in ${SM_NUM}") - if(SM_POS GREATER -1) - if(FIND_SM STREQUAL False) - set(ENV{TORCH_CUDA_ARCH_LIST} "") - endif() - set(FIND_SM True) - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM_NUM},code=\\\"sm_${SM_NUM},compute_${SM_NUM}\\\"") -- -+ math(EXPR CUDA_ARCH "${SM_NUM}*10") -+ add_definitions("-D__CUDA_ARCH_HOST__=${CUDA_ARCH}") - if (SM_NUM STREQUAL 70 OR SM_NUM STREQUAL 75 OR SM_NUM STREQUAL 80 OR SM_NUM STREQUAL 86) - set(USING_WMMA True) - endif() -@@ -125,8 +129,6 @@ if(NOT (FIND_SM STREQUAL True)) - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \ - -gencode=arch=compute_70,code=\\\"sm_70,compute_70\\\" \ - -gencode=arch=compute_75,code=\\\"sm_75,compute_75\\\" \ -- -gencode=arch=compute_80,code=\\\"sm_80,compute_80\\\" \ -- -gencode=arch=compute_86,code=\\\"sm_86,compute_86\\\" \ - ") - # -rdc=true") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") -@@ -136,7 +138,13 @@ if(NOT (FIND_SM STREQUAL True)) - set(ENV{TORCH_CUDA_ARCH_LIST} "7.0;7.5;8.0;8.6") - endif() - set(CMAKE_CUDA_ARCHITECTURES 70 75 80 86) -- message("-- Assign GPU architecture (sm=70,75,80,86)") -+ add_definitions("-D__CUDA_ARCH_HOST__=800") -+ if(${CUDA_VERSION_STRING} VERSION_LESS_EQUAL "10.1" ) -+ message("${CUDA_VERSION_STRING} removing unsupported sm 80 & 86") -+ list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80 86) -+endif() -+ message("-- Assign GPU architectures (sm=${CMAKE_CUDA_ARCHITECTURES})") -+ set(SM 70) - endif() - - if(BUILD_PYT) -@@ -152,8 +160,9 @@ set(CMAKE_CXX_STANDARD "${CXX_STD}") - set(CMAKE_CXX_STANDARD_REQUIRED ON) - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") --set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}") -- -+if(${CUDA_VERSION_STRING} VERSION_GREATER "10.1.105" ) -+ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}") -+endif() - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") - # set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose") - set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3") -@@ -230,9 +239,10 @@ link_directories( - - add_subdirectory(3rdparty) - add_subdirectory(src) --add_subdirectory(examples) --add_subdirectory(tests) -- -+if(EXAMPLES) -+ add_subdirectory(examples) -+ add_subdirectory(tests) -+endif() - ######################################## - - if(BUILD_MULTI_GPU) -@@ -249,6 +259,7 @@ add_library(transformer-static STATIC - $ - $ - $ -+ $ - $ - $ - $ -@@ -313,8 +324,9 @@ add_library(transformer-static STATIC - set_property(TARGET transformer-static PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET transformer-static PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(transformer-static PUBLIC -lcudart -lnccl -lmpi -lcublas -lcublasLt -lcurand) -+endif() - --add_library(transformer-shared SHARED -+set(transformer_objects - $ - $ - $ -@@ -324,29 +336,10 @@ add_library(transformer-shared SHARED - $ - $ - $ -- $ -- $ -- $ -+ $ - $ -- $ - $ - $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ - $ - $ - $ -@@ -373,9 +366,7 @@ add_library(transformer-shared SHARED - $ - $ - $ -- $ - $ -- $ - $ - $ - $ -@@ -387,14 +378,23 @@ add_library(transformer-shared SHARED - $ - $ - $) -+ -+if(${SM} GREATER_EQUAL 70) -+ set(transformer_objects ${transformer_objects} $) -+endif() -+ -+add_library(transformer-shared SHARED ${transformer_objects}) - set_target_properties(transformer-shared PROPERTIES POSITION_INDEPENDENT_CODE ON) - set_target_properties(transformer-shared PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) - set_target_properties(transformer-shared PROPERTIES LINKER_LANGUAGE CXX) --target_link_libraries(transformer-shared PUBLIC -lcudart -lnccl -lmpi -lcublas -lcublasLt -lcurand) -+target_link_libraries(transformer-shared PUBLIC -lcudart -lcublas -lcublasLt -lcurand) -+target_link_options(transformer-shared PUBLIC -Wl,-z,now,-s,-fstack-protector-strong) - --include(GNUInstallDirs) -+#include(GNUInstallDirs) - set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/FasterTransformer) - -+ -+ - include(CMakePackageConfigHelpers) - configure_package_config_file( - ${CMAKE_CURRENT_LIST_DIR}/cmake/FasterTransformerConfig.cmake.in -@@ -402,52 +401,23 @@ configure_package_config_file( - INSTALL_DESTINATION ${INSTALL_CONFIGDIR} - ) - --install( -- FILES -- ${CMAKE_CURRENT_BINARY_DIR}/FasterTransformerConfig.cmake -- DESTINATION ${INSTALL_CONFIGDIR} --) - - install( - TARGETS - transformer-shared - EXPORT - transformer-shared-targets -- LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/fastertransformer -- ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/fastertransformer --) -- --install( -- EXPORT -- transformer-shared-targets -- FILE -- FasterTransformerTargets.cmake -- DESTINATION -- ${INSTALL_CONFIGDIR} -+ LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib -+ ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/lib - ) - - file(GLOB_RECURSE HEADER_FILES "*.h" "*.hpp" "*.cuh") - foreach ( file ${HEADER_FILES} ) - file( RELATIVE_PATH rfile ${CMAKE_CURRENT_SOURCE_DIR} ${file} ) - get_filename_component( dir ${rfile} DIRECTORY ) -- install( FILES ${file} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${dir} ) -+ install( FILES ${file} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${dir}) - endforeach() - - --################################################################################ --# add_executable(gpt sample/cpp/gpt_sample.cc ) --# target_link_libraries(gpt PUBLIC -lcublas -lcublasLt -lcudart -lcurand -lnccl -lmpi transformer-static) --# target_link_libraries(gpt PUBLIC -lcublas -lcublasLt -lcudart -lcurand -lnccl -lmpi decoder decoding) -- --export( -- EXPORT -- transformer-shared-targets -- FILE -- ${CMAKE_CURRENT_BINARY_DIR}/FasterTransformerTargets.cmake -- NAMESPACE -- TritonCore:: --) - --export(PACKAGE FasterTransformer) - --endif() # BUILD_MULTI_GPU -diff --git a/README.md b/README.md -index a60983c..45b5374 100644 ---- a/README.md -+++ b/README.md -@@ -52,7 +52,7 @@ FasterTransformer is built on top of CUDA, cuBLAS, cuBLASLt and C++. We provide - | Swin Transformer | PyTorch | Yes | Yes | - | - | - | - | Swin Transformer | TensorRT | Yes | Yes | - | - | - | - | ViT | PyTorch | Yes | Yes | - | - | - | --| ViT | TensorRT | Yes | Yes | - | - | - | -+| ViT | TensorRT | Yes | - | - | - | - | - - * Note that the FasterTransformer supports the models above on C++ because all source codes are built on C++. - -diff --git a/deploy.ch b/deploy.ch -new file mode 100644 -index 0000000..9df51e9 ---- /dev/null -+++ b/deploy.ch -@@ -0,0 +1,125 @@ -+ms_benchmark -+ -+sent 15,086 bytes received 22,535 bytes 25,080.67 bytes/sec -+total size is 14,006,280 speedup is 372.30 -+file=/home/shira/git-proj/FasterTransformer/build/bin/ms_benchmark -+T5_input1.fp32 -+T5_input2.fp32 -+T5_input3.fp32 -+T5_input4.fp32 -+T5_input5.fp32 -+T5_output1.fp32 -+mha_T5_cross_input1.fp32 -+mha_T5_cross_input2.fp32 -+mha_T5_cross_input3.fp32 -+mha_T5_cross_input4.fp32 -+mha_T5_cross_output1.fp32 -+mha_T5_cross_weight1.fp32 -+mha_T5_cross_weight2.fp32 -+mha_T5_cross_weight3.fp32 -+mha_T5_input1.fp32 -+mha_T5_input2.fp32 -+mha_T5_input3.fp32 -+mha_T5_output1.fp32 -+mha_T5_weight1.fp32 -+mha_T5_weight2.fp32 -+mha_cross_input1.fp32 -+mha_cross_input2.fp32 -+mha_cross_input3.fp32 -+mha_cross_output1.fp32 -+mha_cross_output2.fp32 -+mha_cross_output3.fp32 -+mha_cross_weight1.fp32 -+mha_cross_weight2.fp32 -+mha_cross_weight3.fp32 -+mha_cross_weight4.fp32 -+mha_cross_weight5.fp32 -+mha_x1_input1.fp32 -+mha_x1_input2.fp32 -+mha_x1_output1.fp32 -+mha_x1_output2.fp32 -+mha_x1_output3.fp32 -+mha_x1_weight1.fp32 -+mha_x1_weight2.fp32 -+mha_x1_weight3.fp32 -+mha_x1_weight4.fp32 -+test_input1.fp32 -+transformer_decoder_layer_input1.fp32 -+transformer_decoder_layer_input2.fp32 -+transformer_decoder_layer_input3.fp32 -+transformer_decoder_layer_input4.fp32 -+transformer_decoder_layer_output1.fp32 -+transformer_decoder_layer_t5_input1.fp32 -+transformer_decoder_layer_t5_input2.fp32 -+transformer_decoder_layer_t5_input3.fp32 -+transformer_decoder_layer_t5_input4.fp32 -+transformer_decoder_layer_t5_input5.fp32 -+transformer_decoder_layer_t5_input6.fp32 -+transformer_decoder_layer_t5_output1.fp32 -+transformer_decoder_layer_t5_weight1.fp32 -+transformer_decoder_layer_t5_weight10.fp16 -+transformer_decoder_layer_t5_weight2.fp32 -+transformer_decoder_layer_t5_weight3.fp32 -+transformer_decoder_layer_t5_weight4.fp32 -+transformer_decoder_layer_t5_weight5.fp32 -+transformer_decoder_layer_t5_weight6.fp32 -+transformer_decoder_layer_t5_weight7.fp32 -+transformer_decoder_layer_t5_weight8.fp32 -+transformer_decoder_layer_t5_weight9.fp16 -+transformer_decoder_layer_weight1.fp32 -+transformer_decoder_layer_weight10.fp32 -+transformer_decoder_layer_weight11.fp32 -+transformer_decoder_layer_weight12.fp32 -+transformer_decoder_layer_weight13.fp32 -+transformer_decoder_layer_weight14.fp32 -+transformer_decoder_layer_weight15.fp32 -+transformer_decoder_layer_weight16.fp16 -+transformer_decoder_layer_weight17.fp16 -+transformer_decoder_layer_weight18.fp16 -+transformer_decoder_layer_weight19.fp32 -+transformer_decoder_layer_weight2.fp32 -+transformer_decoder_layer_weight3.fp32 -+transformer_decoder_layer_weight4.fp32 -+transformer_decoder_layer_weight5.fp32 -+transformer_decoder_layer_weight6.fp32 -+transformer_decoder_layer_weight7.fp32 -+transformer_decoder_layer_weight8.fp32 -+transformer_decoder_layer_weight9.fp32 -+transformer_encoder_layer_input1.fp32 -+transformer_encoder_layer_input2.fp32 -+transformer_encoder_layer_output1.fp32 -+transformer_encoder_layer_t5_input1.fp32 -+transformer_encoder_layer_t5_input2.fp32 -+transformer_encoder_layer_t5_input3.fp32 -+transformer_encoder_layer_t5_output1.fp32 -+transformer_encoder_layer_t5_weight1.fp32 -+transformer_encoder_layer_t5_weight2.fp32 -+transformer_encoder_layer_t5_weight3.fp32 -+transformer_encoder_layer_t5_weight4.fp32 -+transformer_encoder_layer_t5_weight5.fp16 -+transformer_encoder_layer_t5_weight6.fp16 -+transformer_encoder_layer_weight1.fp32 -+transformer_encoder_layer_weight10.fp16 -+transformer_encoder_layer_weight11.fp16 -+transformer_encoder_layer_weight12.fp32 -+transformer_encoder_layer_weight2.fp32 -+transformer_encoder_layer_weight3.fp32 -+transformer_encoder_layer_weight4.fp32 -+transformer_encoder_layer_weight5.fp32 -+transformer_encoder_layer_weight6.fp32 -+transformer_encoder_layer_weight7.fp32 -+transformer_encoder_layer_weight8.fp32 -+transformer_encoder_layer_weight9.fp16 -+ -+sent 224,270 bytes received 328,739 bytes 368,672.67 bytes/sec -+total size is 100,832,432 speedup is 182.33 -+libtransformer-shared.so -+ -+sent 40,407 bytes received 70,578 bytes 73,990.00 bytes/sec -+total size is 101,408,056 speedup is 913.71 -+command= CUDA_VISIBLE_DEVICES=5 LD_LIBRARY_PATH=/home/shira/git-proj/FasterTransformer/../FasterTransformer:/usr/local/cuda-11.7/lib64 /home/shira/git-proj/FasterTransformer/build/bin/ms_benchmark -b 1 -l 12 -H 2 -S 8 -s 20 -f 32 -x 1 -P 0 -m transformer_edeode_layer_t5 -+[INFO] Device: NVIDIA A100-PCIE-40GB -+[WARNING] gemm_config.in is not found; using default GEMM algo -+model_nametransformer_edeode_layer_t5 -+model num=-1TDL_T59 -+batch_size 1 seq_len 20 layer 12 AVG FT-CPP-time 0.00 ms (1000 iterations) Total Time 0.01 ms -diff --git a/deploy.sh b/deploy.sh -new file mode 100755 -index 0000000..5b0ed1b ---- /dev/null -+++ b/deploy.sh -@@ -0,0 +1,36 @@ -+#copy cuda folder (once) -+base=/home/batya/git-proj/FasterTransformer -+#`git rev-parse --show-toplevel` -+#debug="gdb --args" -+ -+server=pick -+while getopts "d" opt -+do -+case "${opt}" in -+ "d" ) -+ debug="gdb --args" -+ shift -+ ;; -+esac -+done -+file=/home/batya/git-proj/FasterTransformer/build/bin/ms_benchmark -+#`realpath $1` -+shift -+rsync -v ${file} ${server}:${file} -+echo "file=${file}" -+# rsync -Iv ${base}/../mindspore/trc/pangu/*.fp* ${server}:${base}/build/bin -+rm -f ${server}:${base}/build/bin/*decoder* -+ -+# rsync -v ${base}/../mindspore/trc/pangu/transformer_decoder_layer_weight*.fp16 ${server}:${base}/build/bin -+rsync -v ${base}/../mindspore/trc/transformer/*transformer_decoder_layer* ${server}:${base}/build/bin -+rsync -v ${base}/build/lib/*.so ${server}:${base}/build/lib -+# echo "cd ${base}/build/bin/" -+ -+command=$(cat <<-ENDM -+ CUDA_VISIBLE_DEVICES=5 \ -+ LD_LIBRARY_PATH=${base}/../FasterTransformer:/usr/local/cuda-11.7/lib64 \ -+ ${debug} ${file} $@ -+ENDM -+) -+echo "command=${command}" -+ssh ${server} "cd ${base}/build/bin ;${command}" -diff --git a/deploy.trc b/deploy.trc -new file mode 100644 -index 0000000..023cccf ---- /dev/null -+++ b/deploy.trc -@@ -0,0 +1,303 @@ -+ms_benchmark -+ -+sent 848,395 bytes received 22,547 bytes 580,628.00 bytes/sec -+total size is 14,046,536 speedup is 16.13 -+file=/home/shira/git-proj/FasterTransformer/build/bin/ms_benchmark -+T5_input1.fp32 -+T5_input2.fp32 -+T5_input3.fp32 -+T5_input4.fp32 -+T5_input5.fp32 -+T5_output1.fp32 -+mha_T5_input1.fp32 -+mha_T5_input2.fp32 -+mha_T5_input3.fp32 -+mha_T5_output1.fp32 -+mha_T5_weight1.fp32 -+mha_T5_weight2.fp32 -+mha_cross_input1.fp32 -+mha_cross_input2.fp32 -+mha_cross_input3.fp32 -+mha_cross_output1.fp32 -+mha_cross_output2.fp32 -+mha_cross_output3.fp32 -+mha_cross_weight1.fp32 -+mha_cross_weight2.fp32 -+mha_cross_weight3.fp32 -+mha_cross_weight4.fp32 -+mha_cross_weight5.fp32 -+mha_x1_input1.fp32 -+mha_x1_input2.fp32 -+mha_x1_output1.fp32 -+mha_x1_output2.fp32 -+mha_x1_output3.fp32 -+mha_x1_weight1.fp32 -+mha_x1_weight2.fp32 -+mha_x1_weight3.fp32 -+mha_x1_weight4.fp32 -+test_input1.fp32 -+transformer_decoder_layer_input1.fp32 -+transformer_decoder_layer_input2.fp32 -+transformer_decoder_layer_input3.fp32 -+transformer_decoder_layer_input4.fp32 -+transformer_decoder_layer_output1.fp32 -+transformer_decoder_layer_t5_input1.fp32 -+transformer_decoder_layer_t5_input2.fp32 -+transformer_decoder_layer_t5_input3.fp32 -+transformer_decoder_layer_t5_input4.fp32 -+transformer_decoder_layer_t5_input5.fp32 -+transformer_decoder_layer_t5_input6.fp32 -+transformer_decoder_layer_t5_output1.fp32 -+transformer_decoder_layer_t5_weight1.fp32 -+transformer_decoder_layer_t5_weight10.fp16 -+transformer_decoder_layer_t5_weight2.fp32 -+transformer_decoder_layer_t5_weight3.fp32 -+transformer_decoder_layer_t5_weight4.fp32 -+transformer_decoder_layer_t5_weight5.fp32 -+transformer_decoder_layer_t5_weight6.fp32 -+transformer_decoder_layer_t5_weight7.fp32 -+transformer_decoder_layer_t5_weight8.fp32 -+transformer_decoder_layer_t5_weight9.fp16 -+transformer_decoder_layer_weight1.fp32 -+transformer_decoder_layer_weight10.fp32 -+transformer_decoder_layer_weight11.fp32 -+transformer_decoder_layer_weight12.fp32 -+transformer_decoder_layer_weight13.fp32 -+transformer_decoder_layer_weight14.fp32 -+transformer_decoder_layer_weight15.fp32 -+transformer_decoder_layer_weight16.fp32 -+transformer_decoder_layer_weight17.fp32 -+transformer_decoder_layer_weight18.fp32 -+transformer_decoder_layer_weight19.fp32 -+transformer_decoder_layer_weight2.fp32 -+transformer_decoder_layer_weight3.fp32 -+transformer_decoder_layer_weight4.fp32 -+transformer_decoder_layer_weight5.fp32 -+transformer_decoder_layer_weight6.fp32 -+transformer_decoder_layer_weight7.fp32 -+transformer_decoder_layer_weight8.fp32 -+transformer_decoder_layer_weight9.fp32 -+transformer_encoder_layer_input1.fp32 -+transformer_encoder_layer_input2.fp32 -+transformer_encoder_layer_output1.fp32 -+transformer_encoder_layer_t5_input1.fp32 -+transformer_encoder_layer_t5_input2.fp32 -+transformer_encoder_layer_t5_input3.fp32 -+transformer_encoder_layer_t5_output1.fp32 -+transformer_encoder_layer_t5_weight1.fp32 -+transformer_encoder_layer_t5_weight2.fp32 -+transformer_encoder_layer_t5_weight3.fp32 -+transformer_encoder_layer_t5_weight4.fp32 -+transformer_encoder_layer_t5_weight5.fp32 -+transformer_encoder_layer_t5_weight6.fp32 -+transformer_encoder_layer_weight1.fp32 -+transformer_encoder_layer_weight10.fp32 -+transformer_encoder_layer_weight11.fp32 -+transformer_encoder_layer_weight12.fp32 -+transformer_encoder_layer_weight2.fp32 -+transformer_encoder_layer_weight3.fp32 -+transformer_encoder_layer_weight4.fp32 -+transformer_encoder_layer_weight5.fp32 -+transformer_encoder_layer_weight6.fp32 -+transformer_encoder_layer_weight7.fp32 -+transformer_encoder_layer_weight8.fp32 -+transformer_encoder_layer_weight9.fp32 -+ -+sent 99,079 bytes received 141,615 bytes 481,388.00 bytes/sec -+total size is 41,715,056 speedup is 173.31 -+libtransformer-shared.so -+ -+sent 2,297,603 bytes received 70,578 bytes 947,272.40 bytes/sec -+total size is 101,408,056 speedup is 42.82 -+command= CUDA_VISIBLE_DEVICES=5 LD_LIBRARY_PATH=/home/shira/git-proj/FasterTransformer/../FasterTransformer:/usr/local/cuda-11.7/lib64 /home/shira/git-proj/FasterTransformer/build/bin/ms_benchmark -b 1 -l 12 -H 2 -S 8 -s 20 -f 32 -x 0 -P 1 -m transformer_decoder_layer_t5 -+[INFO] Device: NVIDIA A100-PCIE-40GB -+[WARNING] gemm_config.in is not found; using default GEMM algo -+model_nametransformer_decoder_layer_t5 -+model num=9TDL_T59 -+ffn hidden size= 32hidden_units= 8opt_a->hidden_size= 8InitWeight -+model_nametransformer_decoder_layer_t5 -+forward -+0 -+i: 0.458147, -+ -+-0.0708623,0.428798,0.185046,0.057154,-0.203372,0.412952,0.149334, -+1 -+i: 1, -+ -+1,1,1,1,1,1,1, -+2 -+i: -0.00567709, -+ -+0.00275152,-0.0183454,-0.00663061,0.00240306,-0.000401073,-0.0296616,-0.00454295, -+3 -+i: 0.042776, -+ -+-0.184324,0.159665,0.0819166,-0.551958,0.424301,-0.107998,0.906672, -+4 -+i: -0.509908, -+ -+-0.200033,0.462621,0.674143,0.119592,-0.104946,0.62515,-0.689601, -+5 -+i: -0.00660441, -+ -+0.000136423,0.00216104,-0.0194914,-0.0169052,-0.0184758,0.0166858,-0.00725936, -+6 -+i: 1, -+ -+1,1,1,1,1,1,1, -+7 -+i: -0.979686, -+ -+-0.655586,0.619562,-0.208414,-0.406135,-0.363263,0.452679,0.711803, -+8 -+i: 0.00186535, -+ -+-0.000453893,-0.00868625,0.0019144,0.00439803,0.00553658,-0.0059722,0.00984536, -+9 -+i: -0.0151223, -+ -+0.00973828,0.00486344,0.00208442,-0.00624479,0.0143989,0.00552422,0.00341345, -+10 -+i: -0.411828, -+ -+0.689115,0.670348,-0.657296,0.389433,-0.099893,0.639807,-0.0361588, -+11 -+i: 0.140165, -+ -+0.147417,0.453423,0.41165,0.410601,0.353489,0.748643,0.34843, -+12 -+i: 0.0161316, -+ -+0.00175404,-0.00734864,-0.0212876,-0.0123986,-0.00419559,0.00115276,0.00576639, -+13 -+i: 1, -+ -+1,1,1,1,1,1,1, -+14 -+i: 1.38904e-28, -+ -+7.80033e-34,2.25142e-20,-3.8139e-20,-8.21938e-19,-2.85581e-23,-8.682e-19,-1.04875e-19, -+15 -+i: 9.3225e-21, -+ -+3.67653e-27,8.89747e-23,5.8367e-25,4.46003e-20,-1.02178e-23,1.33927e-18,-1.14603e-18, -+0001001 -+tensor 1.22719, -+ -+-1.08314,1.09902,0.0344878,-0.524054,-1.66184,1.02982,-0.121479, -+tensor -0.00567709, -+ -+0.00275152,-0.0183454,-0.00663061,0.00240306,-0.000401073,-0.0296616,-0.00454295, -+tensor 0.042776, -+ -+-0.184324,0.159665,0.0819166,-0.551958,0.424301,-0.107998,0.906672, -+tensor -0.509908, -+ -+-0.200033,0.462621,0.674143,0.119592,-0.104946,0.62515,-0.689601, -+tensor -0.00660441, -+ -+0.000136423,0.00216104,-0.0194914,-0.0169052,-0.0184758,0.0166858,-0.00725936, -+tensor 1, -+ -+1,1,1,1,1,1,1, -+not cross -+weight_qkv -0.00567709, -+ -+0.00275152,-0.0183454,-0.00663061,0.00240306,-0.000401073,-0.0296616,-0.00454295, -+from_tensor 1.22719, -+ -+-1.08314,1.09902,0.0344878,-0.524054,-1.66184,1.02982,-0.121479, -+qkv_buf -0.034644, -+ -+-0.0221119,-0.0255647,-0.0296132,0.039853,-0.0505666,-0.0563698,-0.0394761, -+output1 -0.0189542, -+ -+-0.030236,-0.00238584,0.0115228,-0.00333089,-0.0369255,-0.0199997,0.0257213, -+q_buf_2 -0.034644, -+ -+-0.0221119,-0.0255647,-0.0296132,-0.0132158,0.0133059,-0.0213238,-0.0197066, -+output2 0.0174091, -+ -+-0.0174263,-0.00708923,-0.0546997,0.0247199,0.0332275,0.00747891,-0.0340745, -+output1 -0.0189542, -+ -+-0.030236,-0.00238584,0.0115228,-0.00333089,-0.0369255,-0.0199997,0.0257213, -+qk_buf 0.00104488, -+ -+0.000681124,-0.000605213,0.00103154,0.00169038,-0.00193581,-0.00120881,0.00293503, -+attention_mask 0.042776, -+ -+-0.184324,0.159665,0.0819166,-0.551958,0.424301,-0.107998,0.906672, -+position_bias -0.509908, -+ -+-0.200033,0.462621,0.674143,0.119592,-0.104946,0.62515,-0.689601, -+qk_buf 0, -+ -+0,0,0,0,0,0,0, -+qkv_buf_2 0.0560303, -+ -+-0.0145187,0.00715256,-0.0731201,0.0189056,-0.00694656,-0.00557709,0.0466919, -+qkv_buf_3 0.0560303, -+ -+-0.0145187,0.00715256,-0.0731201,0.00761032,-0.0180664,0.0290527,-0.0121078, -+param->in_idx5 -+output[0] 0.000853663, -+ -+-0.00158517,0.000399898,-0.00118085,-0.000752181,-0.000945039,0.000822433,0.00073974, -+tensor 1.22736, -+ -+-1.08373,1.09882,0.0334885,-0.52422,-1.66137,1.0301,-0.120446, -+tensor -0.979686, -+ -+-0.655586,0.619562,-0.208414,-0.406135,-0.363263,0.452679,0.711803, -+tensor 0.00186535, -+ -+-0.000453893,-0.00868625,0.0019144,0.00439803,0.00553658,-0.0059722,0.00984536, -+tensor -0.0151223, -+ -+0.00973828,0.00486344,0.00208442,-0.00624479,0.0143989,0.00552422,0.00341345, -+tensor -0.411828, -+ -+0.689115,0.670348,-0.657296,0.389433,-0.099893,0.639807,-0.0361588, -+tensor 0.140165, -+ -+0.147417,0.453423,0.41165,0.410601,0.353489,0.748643,0.34843, -+is cross -+output1 0.0290729, -+ -+-0.0301254,0.023922,0.0158201,-0.0092989,0.00192484,-0.00540288,-0.0149312, -+qk_buf 0.00079844, -+ -+-0.000482788,-0.000210089,0.000353139,0.000915573,0.000227468,0.000344165,-0.000868972, -+attention_mask -0.411828, -+ -+0.689115,0.670348,-0.657296,0.389433,-0.099893,0.639807,-0.0361588, -+position_bias 0.140165, -+ -+0.147417,0.453423,0.41165,0.410601,0.353489,0.748643,0.34843, -+qk_buf 0, -+ -+0,0,0,0,0,0,0, -+qkv_buf_2 -0.00668335, -+ -+0.0517578,0.00294304,-0.0211945,0.0103607,-0.0372925,0.012619,0.0177155, -+qkv_buf_3 -0.00668335, -+ -+0.0517578,0.00294304,-0.0211945,0.00642776,0.0122375,0.022171,-0.00510406, -+param->in_idx7 -+output[0] -0.000111274, -+ -+-0.00114325,-0.000122436,0.000859971,0.000143449,-0.000135836,-0.000207918,-0.000205797, -+gamma3 1, -+ -+1,1,1,1,1,1,1, -+normed_attn2_out 1.22722, -+ -+-1.08463,1.09869,0.034459,-0.523898,-1.6612,1.02988,-0.120521, -+attn2_out 1.22794, -+ -+-1.08587,1.09929,0.0341669,-0.524663,-1.66292,1.03043,-0.120945, -+attn_out 1.22805, -+ -+-1.08472,1.09942,0.0333069,-0.524806,-1.66279,1.03064,-0.120739, -+13 -diff --git a/docs/gpt_guide.md b/docs/gpt_guide.md -index afcba9a..71c4fab 100644 ---- a/docs/gpt_guide.md -+++ b/docs/gpt_guide.md -@@ -312,7 +312,7 @@ python tools/checkpoint_util.py --model-type GPT --loader megatron --saver faste - To convert the Megatron GPT model to binary, FasterTransformer provides a tool `examples/onnx/multi_gpu_gpt/onnx_ckpt_convert.py` to convert the checkpoint. - - ```bash --wget https://github.com/onnx/models/raw/master/text/machine_comprehension/gpt-2/model/gpt2-10.onnx -+wget https://github.com/onnx/models/raw/main/text/machine_comprehension/gpt-2/model/gpt2-10.onnx - python ../examples/onnx/multi_gpu_gpt/onnx_ckpt_convert.py -i gpt2-10.onnx -o ../models/onnx-models/c-model/124m/ -i_g 1 - python ../examples/onnx/multi_gpu_gpt/onnx_ckpt_convert.py -i gpt2-10.onnx -o ../models/onnx-models/c-model/124m/ -i_g 4 - ``` -diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt -index b67cd01..3cc4155 100644 ---- a/examples/cpp/CMakeLists.txt -+++ b/examples/cpp/CMakeLists.txt -@@ -13,6 +13,7 @@ - # limitations under the License. - - add_subdirectory(bert) -+add_subdirectory(ms) - add_subdirectory(bert_int8) - add_subdirectory(decoding) - add_subdirectory(gpt) -diff --git a/examples/cpp/gpt/gpt_example.cc b/examples/cpp/gpt/gpt_example.cc -index cacb09e..5fec0c9 100644 ---- a/examples/cpp/gpt/gpt_example.cc -+++ b/examples/cpp/gpt/gpt_example.cc -@@ -236,7 +236,7 @@ void gpt_example(const INIReader reader) - #endif - - if (std::is_same::value) { -- cublas_wrapper.setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); -+ cublas_wrapper.setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUBLAS_COMPUTE_32F_FAST_TF32); - } - #ifdef ENABLE_BF16 - else if (std::is_same::value) { -diff --git a/examples/cpp/ms/CMakeLists.txt b/examples/cpp/ms/CMakeLists.txt -new file mode 100644 -index 0000000..33e562b ---- /dev/null -+++ b/examples/cpp/ms/CMakeLists.txt -@@ -0,0 +1,22 @@ -+# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+add_executable(ms_benchmark ms.cc) -+if (SPARSITY_SUPPORT) -+# target_link_libraries(ms_benchmark PUBLIC -lcublas -lcublasLt -lcudart -lcusparse -lcusparseLt transformer-shared) -+target_link_libraries(ms_benchmark PUBLIC -lcublas -lcublasLt -lcudart -lcusparse -lcusparseLt GptContextAttentionLayer MSLayer) -+else() -+# target_link_libraries(ms_benchmark PUBLIC -lcublas -lcublasLt -lcudart transformer-shared) -+target_link_libraries(ms_benchmark PUBLIC -lcublas -lcublasLt -lcudart GptContextAttentionLayer MSLayer) -+endif() -diff --git a/examples/cpp/ms/initialize.h b/examples/cpp/ms/initialize.h -new file mode 100644 -index 0000000..9e72838 ---- /dev/null -+++ b/examples/cpp/ms/initialize.h -@@ -0,0 +1,746 @@ -+#pragma once -+ -+#include "src/fastertransformer/layers/ms_layers/MSAttentionLayer.h" -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/MSDecoderLayer.h" -+#include "src/fastertransformer/layers/ms_layers/MSEncoderLayer.h" -+#include "src/fastertransformer/layers/ms_layers/MSLayerWeight.h" -+using namespace fastertransformer; -+ -+template -+struct Decriptor { -+ std::vector input_tensors; // GPU -+ std::vector input_python_tensors; // CPU -+ std::vector output_tensors; // GPU -+ std::vector output_python_tensors; // CPU -+ std::vector w_tensors; -+ MSBaseLayer* MSLayer; -+}; -+ -+template -+void InitializeAttn(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ // TODO Nizzan - check if need to be -+ desc.MSLayer = new MSMHALayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false, // is_cross -+ false, // sparse -+ false); // is_position_bias -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->seq_len}, 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{3 * hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); -+} -+template -+void InitializeAttnX2(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ -+ desc.MSLayer = new MSMHALayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false, // is_cross -+ false, // sparse -+ false); // is_position_bias -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ -+ // GPU RESULTS -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ desc.output_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->size_per_head}, -+ 0}); -+ desc.output_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->size_per_head}, -+ 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ desc.output_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->size_per_head}, -+ 0}); -+ desc.output_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->size_per_head}, -+ 0}); -+ -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{3 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 2 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); -+} -+ -+template -+void InitializeAttnCross(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ -+ desc.MSLayer = new MSMHALayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ true, // is_cross -+ false, // sparse -+ false); // is_position_bias -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->tgt_seq_len, hidden_units}, 0}); -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->tgt_seq_len}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->tgt_seq_len, hidden_units}, 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->tgt_seq_len}, 0}); -+ -+ // GPU RESULTS -+ -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{3 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 2 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); -+} -+template -+void InitializeAttnT5(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ -+ desc.MSLayer = new MSMHALayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false, // is_cross -+ false, // sparse -+ true); // is_position_bias -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->seq_len}, 0}); -+ -+ desc.input_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->seq_len}, 0}); -+ -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ // GPU RESULTS -+ -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, -+ // opt_a->tgt_seq_len},0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, -+ // opt_a->tgt_seq_len}, 0}); -+ -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+} -+ -+template -+void InitializeAttnT5Cross(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ -+ desc.MSLayer = new MSMHALayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ true, // is_cross -+ false, // sparse -+ true); // is_position_bias -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->tgt_seq_len, hidden_units}, 0}); -+ -+ desc.input_tensors.push_back(Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ desc.input_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->tgt_seq_len, hidden_units}, 0}); -+ -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ // GPU RESULTS -+ -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, -+ // opt_a->tgt_seq_len},0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, -+ // opt_a->tgt_seq_len}, 0}); -+ -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 2 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+} -+ -+template -+void InitializeEncoder(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ // const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ const size_t hidden_units = opt_a->hidden_size; -+ // TODO Nizzan - check if need to be -+ desc.MSLayer = new MSELayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ opt_a->ffn_hidden_size, -+ opt_a->eps1, -+ opt_a->eps2, -+ opt_a->post_layernorm_residual, -+ false, -+ opt_a->is_ffn_fp16, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false); // sparse -+ bool compress=false; -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ if(compress) -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ if(compress) -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{3 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); -+ -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); -+ -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size, opt_a->ffn_hidden_size}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->ffn_hidden_size}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->ffn_hidden_size, opt_a->hidden_size}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); -+} -+template -+void InitializeEncoderT5(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ // const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ const size_t hidden_units = opt_a->hidden_size; -+ // TODO Nizzan - check if need to be -+ desc.MSLayer = new MSELayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ opt_a->ffn_hidden_size, -+ opt_a->eps1, -+ opt_a->eps2, -+ opt_a->post_layernorm_residual, -+ true, -+ opt_a->is_ffn_fp16, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false); // sparse -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.input_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ desc.input_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->seq_len}, -+ 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->seq_len}, -+ 0}); -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // g1 -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); // wt -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); // wp -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // g2 -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size, opt_a->ffn_hidden_size}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->ffn_hidden_size, opt_a->hidden_size}, 0}); -+} -+ -+template -+void InitializeDecoder(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ // TODO Nizzan - check if need to be -+ desc.MSLayer = new MSDLayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ opt_a->ffn_hidden_size, -+ opt_a->eps1, -+ opt_a->eps2, -+ opt_a->eps3, -+ opt_a->post_layernorm_residual, -+ false, -+ false, -+ opt_a->is_ffn_fp16, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false); // sparse -+ desc.input_tensors.push_back(Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->hidden_size}, -+ 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->seq_len}, 0}); -+ -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->hidden_size}, -+ 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->seq_len}, 0}); -+ -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // G1 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // B1 -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); // wt -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{3 * hidden_units}, 0}); // bt -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); // wp -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); // bp -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // g1 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // b2 -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units * 2}, 0}); // bt2 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units * 3}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); // wp2 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); // bp2 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // g3 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // b3 -+ desc.w_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size, opt_a->ffn_hidden_size}, 0}); // wm -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->ffn_hidden_size}, 0}); // bm -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size, opt_a->ffn_hidden_size}, 0}); -+ ; // wp -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // bp -+} -+template -+void InitializeDecoderT5(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ // TODO Nizzan - check if need to be -+ desc.MSLayer = new MSDLayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ opt_a->ffn_hidden_size, -+ opt_a->eps1, -+ opt_a->eps2, -+ opt_a->eps3, -+ opt_a->post_layernorm_residual, -+ true, -+ true, -+ opt_a->is_ffn_fp16, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false); // sparse -+ desc.input_tensors.push_back(Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->hidden_size}, -+ 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->seq_len}, 0}); -+ desc.input_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ desc.input_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->hidden_size}, -+ 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->seq_len}, 0}); -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // G1 -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); // wt -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); // wp -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // g1 -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 2 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); // wp2 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // g3 -+ desc.w_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size, opt_a->ffn_hidden_size}, 0}); // wm -+ desc.w_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size, opt_a->ffn_hidden_size}, 0}); // wp -+} -+ -+template -+void Init(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ int model_num = ModelNum(opt_a->model_name); -+ switch (model_num) { -+ case MHA_X1: -+ InitializeAttn(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case MHA_X2: -+ InitializeAttnX2(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case MHA_CROSS: -+ InitializeAttnCross(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case MHA_T5: -+ InitializeAttnT5(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case MHA_T5_CROSS: -+ InitializeAttnT5Cross(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case TEL: -+ InitializeEncoder(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case TEL_T5: -+ InitializeEncoderT5(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case TDL: -+ InitializeDecoder(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case TDL_T5: -+ InitializeDecoderT5(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ default: -+ break; -+ } -+} -diff --git a/examples/cpp/ms/ms.cc b/examples/cpp/ms/ms.cc -new file mode 100644 -index 0000000..085407f ---- /dev/null -+++ b/examples/cpp/ms/ms.cc -@@ -0,0 +1,494 @@ -+/* -+ * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+#include "examples/cpp/ms/initialize.h" -+#include "src/fastertransformer/utils/logger.h" -+#include -+#include -+#include -+using namespace fastertransformer; -+ -+template -+int MsExample(opt_arg* opt_a); -+void usage() -+{ -+ std::cout << "Usage: ms_benchmark -b -l " -+ << "-s -H -S -p " -+ << "-T -W -F " -+ << "-m \n"; -+} -+ -+bool read_args(int argc, char* argv[], opt_arg* opt_a) -+{ -+ int opt; -+ while ((opt = getopt(argc, argv, "b:l:s:t:H:S:p:m:T:W:F:i:w:f:P:x:1:2:3")) != -1) { -+ switch (opt) { -+ case 'b': -+ opt_a->batch_size = atoi(optarg); -+ break; -+ case 'l': -+ opt_a->num_layers = atoi(optarg); -+ break; -+ case 's': -+ opt_a->seq_len = atoi(optarg); -+ break; -+ case 't': -+ opt_a->tgt_seq_len = atoi(optarg); -+ break; -+ case 'H': -+ opt_a->head_num = atoi(optarg); -+ break; -+ case 'S': -+ opt_a->hidden_size = atoi(optarg); -+ break; -+ case 'm': -+ opt_a->model_name = std::string(optarg); -+ break; -+ case 'T': -+ opt_a->compute_type = std::string(optarg); -+ break; -+ case 'W': -+ opt_a->w_compute_type = std::string(optarg); -+ break; -+ case 'F': -+ opt_a->s_compute_type = std::string(optarg); -+ break; -+ case 'f': -+ opt_a->ffn_hidden_size = atoi(optarg); -+ break; -+ case '1': -+ opt_a->eps1 = atoi(optarg); -+ break; -+ case '2': -+ opt_a->eps2 = atoi(optarg); -+ break; -+ case '3': -+ opt_a->eps3 = atoi(optarg); -+ break; -+ case 'P': -+ if (atoi(optarg) == 1) -+ opt_a->post_layernorm_residual = true; -+ else if (atoi(optarg) == 0) -+ opt_a->post_layernorm_residual = false; -+ break; -+ case 'p': -+ opt_a->is_remove_padding = bool(optarg); -+ break; -+ case 'x': -+ if (atoi(optarg) == 1) -+ opt_a->is_ffn_fp16 = true; -+ else if (atoi(optarg) == 0) -+ opt_a->is_ffn_fp16 = false; -+ break; -+ case 'i': -+ case 'w': -+ break; -+ case 'h': -+ default: -+ usage(); -+ return false; -+ } -+ } -+ opt_a->size_per_head = opt_a->hidden_size / opt_a->head_num; -+ opt_a->tgt_seq_len = (opt_a->tgt_seq_len == -1) ? opt_a->seq_len : opt_a->tgt_seq_len; -+ if (opt_a->ffn_hidden_size == -1) { -+ opt_a->ffn_hidden_size = opt_a->hidden_size * opt_a->expand_ratio; -+ } -+ return true; -+} -+ -+int main(int argc, char** argv) -+{ -+ opt_arg opt_a; -+ opt_a.batch_size = 1; -+ opt_a.num_layers = 1; -+ opt_a.seq_len = 1; -+ opt_a.tgt_seq_len = -1; -+ opt_a.head_num = 1; -+ opt_a.hidden_size = 1; -+ opt_a.size_per_head = 1; -+ opt_a.expand_ratio = 4; -+ opt_a.ffn_hidden_size = -1; -+ opt_a.eps1 = 1e-6f; -+ opt_a.eps2 = 1e-6f; -+ opt_a.eps3 = 1e-6f; -+ opt_a.post_layernorm_residual = true; -+ opt_a.is_remove_padding = false; -+ opt_a.model_name = ""; -+ opt_a.compute_type = "fp32"; -+ opt_a.w_compute_type = "fp32"; -+ opt_a.s_compute_type = "fp32"; -+ opt_a.is_ffn_fp16 = false; -+ -+ if (read_args(argc, argv, &opt_a)) { -+ bool c_type_fp32 = (opt_a.compute_type.compare("fp32") == 0); -+ bool w_type_fp32 = (opt_a.w_compute_type.compare("fp32") == 0); -+ bool s_type_fp32 = (opt_a.s_compute_type.compare("fp32") == 0); -+ -+ s_type_fp32 = c_type_fp32; // Do softmax compute type as compute type -+ if (c_type_fp32 && w_type_fp32 && s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else if (c_type_fp32 && w_type_fp32 && !s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else if (c_type_fp32 && !w_type_fp32 && s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else if (c_type_fp32 && !w_type_fp32 && !s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else if (!c_type_fp32 && w_type_fp32 && s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else if (!c_type_fp32 && w_type_fp32 && !s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else if (!c_type_fp32 && !w_type_fp32 && s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else { // (!c_type_fp32 && !w_type_fp32 && !s_type_fp32) -+ return MsExample(&opt_a); -+ } -+ } -+} -+ -+template -+int ReadFileBuf(const std::string file, T* buf, size_t size_buff) -+{ -+ if (file.empty()) { -+ FT_LOG_ERROR("file is nullptr\n"); -+ return -1; -+ } -+ -+ std::ifstream ifs(file); -+ if (!ifs.good()) { -+ FT_LOG_ERROR("file: %s does not exist\n", file.c_str()); -+ return -1; -+ } -+ -+ if (!ifs.is_open()) { -+ FT_LOG_ERROR("file: open failed\n"); -+ return -1; -+ } -+ -+ ifs.seekg(0, std::ios::end); -+ size_t file_size = ifs.tellg(); -+ if (file_size != size_buff) { -+ ifs.close(); -+ FT_LOG_ERROR("file: %s size is %d desc size is %d\n", file.c_str(), file_size, size_buff); -+ return -1; -+ } -+ // return 0; -+ ifs.seekg(0, std::ios::beg); -+ ifs.read(reinterpret_cast(buf), size_buff); -+ ifs.close(); -+ return 0; -+} -+ -+template -+int CalcTensorsSize(std::vector& tensors) -+{ -+ int total = 0; -+ for (size_t i = 0; i < tensors.size(); i++) { -+ float size = 1; -+ for (size_t j = 0; j < tensors[i].shape.size(); j++) { -+ size *= tensors[i].shape[j]; -+ } -+ total += size; -+ } -+ -+ return total * sizeof(T); -+} -+ -+template -+int ReadTensors(std::vector& tensors, std::string post, opt_arg* opt_a, bool cpy = true) -+{ -+ for (size_t i = 0; i < tensors.size(); i++) { -+ // if (tensors[i].type != TYPE_FP32) { -+ // FT_LOG_ERROR("Type not supported, exiting "); -+ // return -1; -+ // } -+ float size = 1; -+ for (size_t j = 0; j < tensors[i].shape.size(); j++) { -+ size *= tensors[i].shape[j]; -+ } -+ std::string suffix = post.compare("weight") == 0 ? opt_a->w_compute_type : opt_a->compute_type; -+ std::string fn = opt_a->model_name + "_" + post + std::to_string(i + 1) + "." + suffix; -+ T* input; -+ T* input_host = (T*)malloc(size * sizeof(T)); -+ int res = ReadFileBuf(fn, input_host, size * sizeof(T)); -+ if (res) { -+ fn = opt_a->model_name + "_" + post + std::to_string(i + 1) + "." + "fp16"; -+ res = ReadFileBuf(fn, input_host, size * 2); -+ } -+ FT_CHECK(!res); -+ if (tensors[i].where == MEMORY_GPU) { -+ deviceMalloc(&input, size, false); -+ if (cpy) -+ cudaH2Dcpy(input, input_host, size); -+ else -+ deviceMemSetZero(input, size); -+ tensors[i].data = input; -+ free(input_host); -+ input_host = 0; -+ } -+ else if (tensors[i].where == MEMORY_CPU) { -+ tensors[i].data = input_host; -+ } -+ } -+ return 0; -+} -+ -+template -+static float CompareData(const T* refOutput, int size, const T* msTensorData) -+{ -+ constexpr float relativeTolerance = 1e-5; -+ constexpr float absoluteTolerance = 1e-8; -+ size_t errorCount = 0; -+ float meanError = 0; -+ std::cout << "Out tensor size is: " << size << std::endl; -+ std::cout << "Data of model output: "; -+ static int x = 0; -+ int s = std::min(10, size); -+ if (x == 0) { -+ for (int j = 0; j < std::min(50, size); j++) { -+ std::cout << static_cast(msTensorData[j]) << " "; -+ } -+ std::cout << std::endl; -+ std::cout << "Data of Ref output : "; -+ for (int j = 0; j < std::min(50, size); j++) { -+ std::cout << static_cast(refOutput[j]) << " "; -+ } -+ std::cout << std::endl; -+ } -+ x++; -+ int nan_cnt = 0; -+ for (int j = 0; j < size; j++) { -+ if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { -+ // std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; -+ // FT_LOG_ERROR("Output tensor has nan or inf data, compare fail\n"); -+ // return RET_ERROR; -+ // return -1; -+ nan_cnt++; -+ continue; -+ } -+ -+ auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]); -+ auto absoluteError = std::fabs(static_cast(msTensorData[j]) - static_cast(refOutput[j])); -+ if (absoluteError > tolerance) { -+ if (fabs(refOutput[j]) == 0) { -+ if (absoluteError > 1e-5) { -+ meanError += absoluteError; -+ errorCount++; -+ } -+ else { -+ continue; -+ } -+ } -+ else { -+ // if (absoluteError > 1e-2) std::cout << "idx=" < 0.0f) { -+ meanError /= errorCount; -+ } -+ if (meanError <= 0.0000001) { -+ std::cout << "Mean bias of tensor: 0%" << std::endl; -+ } -+ else { -+ std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl; -+ } -+ std::cout << std::endl; -+ return meanError; -+} -+ -+template -+int CompareOutput(std::vector output_python_tensors, std::vector output_tensors) -+{ -+ float total_bias = 0; -+ int total_size = 0; -+ float accuracy_threshold_ = 0.5f; -+ bool has_error = false; -+ for (size_t i = 0; i < output_tensors.size(); i++) { -+ float size = 1; -+ for (size_t j = 0; j < output_tensors[i].shape.size(); j++) { -+ size *= output_tensors[i].shape[j]; -+ } -+ T* output_device = (T*)output_tensors[i].data; -+ T* output_host = (T*)malloc(size * sizeof(T)); -+ cudaD2Hcpy(output_host, output_device, size); -+ float bias = CompareData((T*)output_python_tensors[i].data, size, output_host); -+ free(output_host); -+ if (bias >= 0) { -+ total_bias += bias; -+ total_size++; -+ } -+ else { -+ has_error = true; -+ break; -+ } -+ } -+ if (!has_error) { -+ float mean_bias; -+ if (total_size != 0) { -+ mean_bias = total_bias / total_size * 100; -+ } -+ else { -+ mean_bias = 0; -+ } -+ -+ std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%" -+ << " threshold is:" << accuracy_threshold_ << std::endl; -+ std::cout << "=======================================================" << std::endl << std::endl; -+ -+ if (mean_bias > accuracy_threshold_) { -+ FT_LOG_INFO("Mean bias of all nodes/tensors is too big: %f %", mean_bias); -+ std::cout << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl; -+ return -9; -+ } -+ else { -+ return 0; -+ } -+ } -+ else { -+ FT_LOG_ERROR("Error in CompareData"); -+ std::cerr << "Error in CompareData" << std::endl; -+ std::cout << "=======================================================" << std::endl << std::endl; -+ return -1; -+ } -+} -+ -+void FreeDesc(std::vector& desc) -+{ -+ for (size_t i = 0; i < desc.size(); i++) { -+ if (desc[i].where == MEMORY_GPU) { -+ cudaFree((float*)desc[i].data); -+ } -+ else if (desc[i].where == MEMORY_CPU) { -+ free((float*)desc[i].data); -+ } -+ } -+} -+ -+uint64_t GetTimeUs() -+{ -+ const int USEC = 1000000; -+ const int MSEC = 1000; -+ struct timespec ts = {0, 0}; -+ if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) { -+ return 0; -+ } -+ uint64_t retval = (uint64_t)((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC)); -+ return retval; -+} -+ -+template -+int MsExample(opt_arg* opt_a) -+{ -+ printf("[INFO] Device: %s \n", getDeviceName().c_str()); -+ -+ cudaStream_t stream; -+ cublasHandle_t cublas_handle; -+ cublasLtHandle_t cublaslt_handle; -+ cudaStreamCreate(&stream); -+ cublasCreate(&cublas_handle); -+ cublasLtCreate(&cublaslt_handle); -+#ifdef SPARSITY_ENABLED -+ cusparseLtHandle_t cusparselt_handle; -+ CHECK_CUSPARSE(cusparseLtInit(&cusparselt_handle)); -+#endif -+ cublasSetStream(cublas_handle, stream); -+ cublasAlgoMap* cublas_algo_map = new cublasAlgoMap("gemm_config.in", ""); -+ -+ Allocator allocator(getDevice()); -+ -+ std::mutex* cublas_wrapper_mutex = new std::mutex(); -+#ifdef SPARSITY_ENABLED -+ cublasMMWrapper cublas_wrapper = cublasMMWrapper( -+ cublas_handle, cublaslt_handle, cusparselt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, &allocator); -+#else -+ cublasMMWrapper cublas_wrapper = -+ cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, &allocator); -+#endif -+ if (std::is_same::value) { -+ if (std::is_same::value) { -+ cublas_wrapper.setFP16MixedGemmConfig(); -+ } -+ else { -+ cublas_wrapper.setFP16GemmConfig(); -+ } -+ } -+ else if (std::is_same::value) { -+ if (std::is_same::value) { -+ cublas_wrapper.setFP32MixedGemmConfig(); -+ } -+ else { -+ cublas_wrapper.setFP32GemmConfig(); -+ } -+ } -+ Decriptor desc; -+ Init(opt_a, desc, stream, &cublas_wrapper, cublas_handle, &allocator); -+ int res = ReadTensors(desc.input_tensors, std::string("input"), opt_a); -+ FT_CHECK(!res); -+ res = ReadTensors(desc.input_python_tensors, std::string("input"), opt_a); -+ FT_CHECK(!res); -+ res = ReadTensors(desc.output_tensors, std::string("output"), opt_a, false); -+ FT_CHECK(!res); -+ res = ReadTensors(desc.output_python_tensors, std::string("output"), opt_a); -+ FT_CHECK(!res); -+ res = ReadTensors(desc.w_tensors, std::string("weight"), opt_a); -+ FT_CHECK(!res); -+ desc.MSLayer->InitWeight(opt_a, desc.MSLayer->ms_weights, desc.w_tensors); -+ desc.MSLayer->forward(&desc.output_tensors, &desc.input_tensors, desc.MSLayer->ms_weights); -+ CompareOutput(desc.output_python_tensors, desc.output_tensors); -+#define DO_TIME1 -+#ifdef DO_TIME -+ // warmup -+ for (int i = 0; i < 10; i++) { -+ desc.MSLayer->forward(&desc.output_tensors, &desc.input_tensors, desc.MSLayer->ms_weights); -+ } -+ // profile time -+ const int ite = 1000; -+ CudaTimer cuda_timer(stream); -+ cuda_timer.start(); -+ float total_time = cuda_timer.stop(); -+ printf("batch_size %ld seq_len %ld layer %ld " -+ "AVG FT-CPP-time %.2f ms (%d iterations) " -+ "Total Time %.2f ms\n", -+ opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->num_layers, -+ total_time / ite, -+ ite, -+ total_time); -+#endif -+ -+#ifdef SPARSITY_ENABLED -+ cusparseLtDestroy(&cusparselt_handle); -+#endif -+ delete cublas_algo_map; -+ delete cublas_wrapper_mutex; -+ FreeDesc(desc.output_tensors); -+ FreeDesc(desc.input_tensors); -+ FreeDesc(desc.output_python_tensors); -+ FreeDesc(desc.w_tensors); -+ return 0; -+} -diff --git a/examples/pytorch/swin/Swin-Transformer-Quantization/SwinTransformer b/examples/pytorch/swin/Swin-Transformer-Quantization/SwinTransformer -new file mode 160000 -index 0000000..cbaa0d8 ---- /dev/null -+++ b/examples/pytorch/swin/Swin-Transformer-Quantization/SwinTransformer -@@ -0,0 +1 @@ -+Subproject commit cbaa0d8707db403d85ad0e13c59f2f71cd6db425 -diff --git a/examples/pytorch/vit/ViT-quantization/ViT-pytorch b/examples/pytorch/vit/ViT-quantization/ViT-pytorch -new file mode 160000 -index 0000000..460a162 ---- /dev/null -+++ b/examples/pytorch/vit/ViT-quantization/ViT-pytorch -@@ -0,0 +1 @@ -+Subproject commit 460a162767de1722a014ed2261463dbbc01196b6 -diff --git a/path.sh b/path.sh -new file mode 100755 -index 0000000..53f5ca6 ---- /dev/null -+++ b/path.sh -@@ -0,0 +1 @@ -+export PATH=/usr/local/cuda-11/bin:/home/yoni/.vscode-server/bin/4af164ea3a06f701fe3e89a2bcbb421d2026b68f/bin/remote-cli:/home/yoni/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin -diff --git a/src/fastertransformer/kernels/CMakeLists.txt b/src/fastertransformer/kernels/CMakeLists.txt -index 3db0830..3dd4210 100644 ---- a/src/fastertransformer/kernels/CMakeLists.txt -+++ b/src/fastertransformer/kernels/CMakeLists.txt -@@ -159,9 +159,12 @@ add_library(matrix_vector_multiplication STATIC matrix_vector_multiplication.cu) - set_property(TARGET matrix_vector_multiplication PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET matrix_vector_multiplication PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - --add_library(custom_ar_kernels STATIC custom_ar_kernels.cu) --set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) --set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -+if(${SM} GREATER_EQUAL 70) -+ message("-- Making custom kernels") -+ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu) -+ set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) -+ set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -+endif() - - add_library(vit_kernels STATIC vit_kernels.cu) - set_property(TARGET vit_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) -diff --git a/src/fastertransformer/kernels/activation_kernels.cu b/src/fastertransformer/kernels/activation_kernels.cu -index 7ff8e0f..abe7634 100644 ---- a/src/fastertransformer/kernels/activation_kernels.cu -+++ b/src/fastertransformer/kernels/activation_kernels.cu -@@ -19,6 +19,82 @@ - #include "src/fastertransformer/utils/cuda_utils.h" - namespace fastertransformer { - -+template -+__inline__ __device__ T fastGelu(T x) -+{ -+ T abs_x = fabsf((T)x); -+ float numerator = expf(0.851f * (x - abs_x)); -+ float denominator = 1 + expf(-1.702f * abs_x); -+ return (T)(x / denominator * numerator); -+} -+ -+template<> -+__inline__ __device__ half fastGelu(half x) -+{ -+ half abs_x = (half)(fabsf(__half2float(x))); -+ half numerator = hexp((half)(0.851f) * (x - abs_x)); -+ half denominator = (half)1 + hexp(half(-1.702f) * abs_x); -+ return (x / denominator * numerator); -+} -+ -+template<> -+__inline__ __device__ half2 fastGelu(half2 x) -+{ -+ half2 half2_x_abs = __habs2(x); -+ half2 numerator = h2exp(half2(0.851, 0.851) * (x - half2_x_abs)); -+ half2 denominator = half2(1, 1) + h2exp(half2(-1.702, -1.702) * half2_x_abs); -+ return (x / denominator * numerator); -+} -+ -+template -+__global__ void addBiasFastGelu(T* out, const T* __restrict bias, int m, int n) -+{ -+ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { -+ T val = out[id]; -+ if (bias != nullptr) { -+ T reg_bias = __ldg(&bias[id % n]); -+ val = val + reg_bias; -+ } -+ out[id] = (fastGelu(val)); -+ } -+} -+ -+template<> -+__global__ void addBiasFastGelu(half* out, const half* __restrict bias, int m, int n) -+{ -+ half2* out_ptr = (half2*)out; -+ const half2* bias_ptr = (half2*)bias; -+ -+ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { -+ half2 val = out_ptr[id]; -+ if (bias != nullptr) { -+ half2 reg_bias = __ldg(&bias_ptr[id % n]); -+ val = __hadd2(val, reg_bias); -+ } -+ out_ptr[id] = fastGelu(val); -+ } -+} -+ -+template -+void invokeAddBiasFastGelu(T* out, const T* bias, const int m, const int n, cudaStream_t stream) -+{ -+ const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 -+ dim3 block, grid; -+ if (n / 4 / data_type_factor <= 1024) { -+ block.x = n / 4 / data_type_factor; -+ grid.x = m; -+ } -+ else { -+ block.x = 1024; -+ grid.x = ceil(m * n / 1024.); -+ } -+ addBiasFastGelu<<>>(out, bias, m, n / data_type_factor); -+} -+ -+template void invokeAddBiasFastGelu(float* out, const float* bias, const int m, const int n, cudaStream_t stream); -+template void invokeAddBiasFastGelu(half* out, const half* bias, const int m, const int n, cudaStream_t stream); -+ -+ - template - __inline__ __device__ T gelu(T x) - { -@@ -201,12 +277,21 @@ __global__ void add_bias(H_T* out, const B_T* __restrict bias, int m, int n) - } - } - -+template -+__global__ void add_bias_basic(H_T* out, const B_T* __restrict bias, int m, int n) -+{ -+ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { -+ out[id] = out[id] + (H_T)ldg(&bias[id % n]); -+ } -+} -+ - template<> - __global__ void add_bias(half* out, const half* __restrict bias, int m, int n) - { - half2* out_ptr = (half2*)out; - const half2* bias_ptr = (half2*)bias; -- for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { -+ int id = blockIdx.x * blockDim.x + threadIdx.x; -+ for (; id < m * n; id += blockDim.x * gridDim.x) { - out_ptr[id] = out_ptr[id] + __ldg(&bias_ptr[id % n]); - } - } -@@ -228,15 +313,29 @@ void invokeAddBias(H_T* out, const B_T* bias, const int m, const int n, cudaStre - { - const int data_type_factor = 4 / sizeof(H_T); // 1 for fp32, 2 for fp16 and bf16 - dim3 block, grid; -- if (n / 4 / data_type_factor <= 1024) { -- block.x = n / 4 / data_type_factor; -- grid.x = m; -- } -- else { -- block.x = 1024; -- grid.x = ceil(m * n / 1024.); -+ -+ bool reminder = (data_type_factor != 1) ? (n % data_type_factor) : false; -+ if (reminder) { -+ if (n / 4 <= 1024) { -+ block.x = n / 4; -+ grid.x = m; -+ } -+ else { -+ block.x = 1024; -+ grid.x = ceil(m * n / 1024.); -+ } -+ add_bias_basic<<>>(out, bias, m, n); -+ } else { -+ if (n / 4 / data_type_factor <= 1024) { -+ block.x = n / 4 / data_type_factor; -+ grid.x = m; -+ } -+ else { -+ block.x = 1024; -+ grid.x = ceil(m * n / 1024.); -+ } -+ add_bias<<>>(out, bias, m, (n / data_type_factor)); - } -- add_bias<<>>(out, bias, m, n / data_type_factor); - } - - template void invokeAddBias(float* out, const float* bias, const int m, const int n, cudaStream_t stream); -diff --git a/src/fastertransformer/kernels/activation_kernels.h b/src/fastertransformer/kernels/activation_kernels.h -index 6600457..f8c379a 100644 ---- a/src/fastertransformer/kernels/activation_kernels.h -+++ b/src/fastertransformer/kernels/activation_kernels.h -@@ -25,6 +25,9 @@ namespace fastertransformer { - template - void invokeAddBiasGelu(T* out, const T* bias, const int m, const int n, cudaStream_t stream); - -+template -+void invokeAddBiasFastGelu(T* out, const T* bias, const int m, const int n, cudaStream_t stream); -+ - template - void invokeAddBiasRelu(T* out, const T* bias, const int m, const int n, cudaStream_t stream); - -diff --git a/src/fastertransformer/kernels/add_residual_kernels.cu b/src/fastertransformer/kernels/add_residual_kernels.cu -index 4cd9f0f..42c9216 100644 ---- a/src/fastertransformer/kernels/add_residual_kernels.cu -+++ b/src/fastertransformer/kernels/add_residual_kernels.cu -@@ -29,6 +29,30 @@ __global__ void addBiasResidual(T* output, const T* input, const T* bias, const - } - } - -+template -+__global__ void addBiasResidualCast(U* output, const T* input, T* out, const T* bias, const int m, const int n) -+{ -+ S *out_cast = (S*)out; -+ const int col_index = blockIdx.y * blockDim.x + threadIdx.x; -+ if (col_index < n) { -+ T bias_val = (bias == nullptr) ? (T)(0.0f) : bias[col_index]; -+ out_cast[blockIdx.x * n + col_index] = -+ (S)((T)output[blockIdx.x * n + col_index] + (T)input[blockIdx.x * n + col_index] + bias_val); -+ } -+} -+ -+template -+__global__ void addBiasResidualSameTypeCast(U* output, const U* input, T* out, const T* bias, const int m, const int n) -+{ -+ S *out_cast = (S*)out; -+ const int col_index = blockIdx.y * blockDim.x + threadIdx.x; -+ if (col_index < n) { -+ T bias_val = (bias == nullptr) ? (T)(0.0f) : bias[col_index]; -+ out_cast[blockIdx.x * n + col_index] = -+ (S)((T)output[blockIdx.x * n + col_index] + (T)input[blockIdx.x * n + col_index] + bias_val); -+ } -+} -+ - template - void invokeAddBiasResidual(T* output, const T* input, const T* bias, const int m, const int n, cudaStream_t stream) - { -@@ -38,6 +62,31 @@ void invokeAddBiasResidual(T* output, const T* input, const T* bias, const int m - addBiasResidual<<>>(output, input, bias, m, n); - } - -+template -+void invokeAddBiasResidualCast(U* output, const T* input, T* out, const T* bias, const int m, const int n, cudaStream_t stream) -+{ -+ int blocks_per_row = ceil(float(n) / 1024); -+ dim3 grid(m, blocks_per_row); -+ dim3 block(min(n, 1024)); -+ addBiasResidualCast<<>>(output, input, out, bias, m, n); -+} -+ -+template -+void invokeAddBiasResidualSameTypeCast(U* output, const U* input, T* out, const T* bias, const int m, const int n, cudaStream_t stream) -+{ -+ int blocks_per_row = ceil(float(n) / 1024); -+ dim3 grid(m, blocks_per_row); -+ dim3 block(min(n, 1024)); -+ addBiasResidualSameTypeCast<<>>(output, input, out, bias, m, n); -+} -+ -+template void invokeAddBiasResidualCast(half* output, const float* input, float* out, const float* bias, const int m, const int n, cudaStream_t stream); -+template void invokeAddBiasResidualCast(float* output, const float* input, float* out, const float* bias, const int m, const int n, cudaStream_t stream); -+template void invokeAddBiasResidualCast(float* output, const float* input, float* out, const float* bias, const int m, const int n, cudaStream_t stream); -+template void invokeAddBiasResidualCast(half* output, const float* input, float* out, const float* bias, const int m, const int n, cudaStream_t stream); -+ -+template void invokeAddBiasResidualSameTypeCast(half* output, const half* input, float* out, const float* bias, const int m, const int n, cudaStream_t stream); -+ - template - __global__ void addBiasAttentionFfnResidual(T* block_output, - const T* ffn_output, -@@ -88,11 +137,9 @@ void invokeAddBiasAttentionFfnResidual(T* block_output, - } - } - --template void invokeAddBiasResidual( -- float* output, const float* input, const float* bias, const int m, const int n, cudaStream_t stream); -+template void invokeAddBiasResidual(float *output, const float *input, const float *bias, int m, int n, cudaStream_t stream); -+template void invokeAddBiasResidual(half *output, const half *input, const half *bias, int m, int n, cudaStream_t stream); - --template void --invokeAddBiasResidual(half* output, const half* input, const half* bias, const int m, const int n, cudaStream_t stream); - - #ifdef ENABLE_BF16 - template void invokeAddBiasResidual(__nv_bfloat16* output, -diff --git a/src/fastertransformer/kernels/add_residual_kernels.h b/src/fastertransformer/kernels/add_residual_kernels.h -index edd8179..afa5a77 100644 ---- a/src/fastertransformer/kernels/add_residual_kernels.h -+++ b/src/fastertransformer/kernels/add_residual_kernels.h -@@ -27,6 +27,9 @@ namespace fastertransformer { - template - void invokeAddBiasResidual(T* output, const T* input, const T* bias, const int m, const int n, cudaStream_t stream); - -+template -+void invokeAddBiasResidual(T* output, const T* input, const T* bias, const int m, const int n, int max_seq, const int *sequent_len, cudaStream_t stream); -+ - template - void invokeT5AddResidual(T* output, const T* input, const int m, const int n, cudaStream_t stream); - -@@ -65,4 +68,11 @@ void invokeAddBiasResidualCol32(T* output, - const float* input1_amax_ptr, - const int scale_is_vector = 0); - -+template -+void invokeAddBiasResidualCast(U* output, const T* input, T* out, const T* bias, const int m, const int n, cudaStream_t stream); -+ -+template -+void invokeAddBiasResidualSameTypeCast(U* output, const U* input, T* out, const T* bias, const int m, const int n, cudaStream_t stream); -+ - } // namespace fastertransformer -+ -diff --git a/src/fastertransformer/kernels/bert_preprocess_kernels.cu b/src/fastertransformer/kernels/bert_preprocess_kernels.cu -index c855fa1..19e29bc 100644 ---- a/src/fastertransformer/kernels/bert_preprocess_kernels.cu -+++ b/src/fastertransformer/kernels/bert_preprocess_kernels.cu -@@ -14,10 +14,13 @@ - * limitations under the License. - */ - -+#include "reduce_kernel_utils.cuh" - #include "bert_preprocess_kernels.h" -+#include "src/fastertransformer/utils/cuda_utils.h" - - namespace fastertransformer { - -+ - __global__ void getPaddingOffsetKernel(size_t* valid_word_num, - int* tmp_mask_offset, - const int* sequence_length, -@@ -29,7 +32,7 @@ __global__ void getPaddingOffsetKernel(size_t* valid_word_num, - int cum_offset = 0; - int index = 0; - for (int i = 0; i < batch_size; i++) { -- const int seq_len = sequence_length[i]; -+ const int seq_len = (sequence_length[i] == -1) ? 0 : sequence_length[i]; - for (int j = 0; j < seq_len; j++) { - tmp_mask_offset[index] = cum_offset; - index++; -@@ -50,50 +53,315 @@ void invokeGetPaddingOffset(size_t* h_token_num, - { - getPaddingOffsetKernel<<<1, 1, 0, stream>>>( - d_token_num, tmp_mask_offset, sequence_lengths, batch_size, max_seq_len); -- sync_check_cuda_error(); -- check_cuda_error(cudaMemcpyAsync(h_token_num, d_token_num, sizeof(size_t), cudaMemcpyDeviceToHost, stream)); -- sync_check_cuda_error(); -+ if (h_token_num != nullptr) { -+ cudaMemcpyAsync(h_token_num, d_token_num, sizeof(size_t), cudaMemcpyDeviceToHost, stream); -+ } - } - - template --__global__ void buildEncoderAttentionMaskKernel(T* attention_mask, const int* sequence_lengths, const int max_seq_len) -+__global__ void buildSequnceLength(const T * input, int *sequence_length, const int max_seq_length, const int hidden_size) { -+ __shared__ int s_max_val; -+ int bid = blockIdx.x; -+ const T * seq_base = input + bid* max_seq_length * hidden_size; -+ const T zero = static_cast(0.f); -+ int last = -max_seq_length; -+ for (int i=max_seq_length - 1 - threadIdx.x; i >= 0; i -= blockDim.x) { -+ const T * seq_ptr = seq_base + i * hidden_size; -+ if ((seq_ptr[0] == zero) && (seq_ptr[1] == zero)) { -+ last = -i; -+ } -+ } -+ int max_val = blockReduceMax(last); -+ if (threadIdx.x == 0) { -+ s_max_val = max_val; -+ } -+ __syncthreads(); -+ sequence_length[bid] = -s_max_val; -+} -+ -+__global__ void buildSequnceLength(const int *input, int *sequence_length, const int max_seq_length) { -+ __shared__ int s_max_val; -+ int bid = blockIdx.x; -+ int last = 0; -+ const int *base = input + bid * max_seq_length; -+ for (int i=threadIdx.x ; i < max_seq_length; i += blockDim.x) { -+ const int *ptr = base + i; -+ if (*ptr != 0){ -+ last = i; -+ } -+ } -+ int max_val = blockReduceMax(last); -+ if (threadIdx.x == 0) { -+ s_max_val = max_val + 1; -+ } -+ __syncthreads(); -+ sequence_length[bid] = s_max_val; -+} -+ -+__global__ void buildSequnceOffset(int *sequence_length, int *sequence_offset, int batch_size) { -+ for (int i = 0; i < batch_size ; i += 1) { -+ if (i == 0) {sequence_offset[i] = 0;} -+ else { -+ sequence_offset[i] = sequence_offset[i-1] + sequence_length[i-1] + 1; -+ } -+ } -+} -+__global__ void findStartPoint(int *sequence_lengths, int *start_points, int batch_size,int hidden_size) { -+ int i = blockDim.x * blockIdx.x + threadIdx.x; -+ if (i <= batch_size) { -+ int sum = 0; -+ for (int j = 0; j < i; j++) { -+ sum += sequence_lengths[j]*hidden_size; -+ } -+ start_points[i] = sum; -+ } -+} -+ -+ -+ -+template -+__global__ void buildEncoderAttentionMaskKernel(T* attention_mask, const int* q_sequence_lengths, const int* kv_sequence_lengths, const int src_seq_len, const int tgt_seq_len, const bool incremental_mode) - { - // sequence_lengths: [batch_size] - // attention_mask: [batch_size, 1, max_seq_len, max_seq_len] -- attention_mask += blockIdx.x * max_seq_len * max_seq_len; -- const int length = sequence_lengths[blockIdx.x]; -- for (int i = threadIdx.x; i < max_seq_len * max_seq_len; i += blockDim.x) { -- // int row_id = i / max_seq_len; -- int col_id = i % max_seq_len; -- // if (row_id < length && col_id < length) { -+ attention_mask += blockIdx.x * src_seq_len * tgt_seq_len; -+ const int q_length = q_sequence_lengths[blockIdx.x]; -+ const int kv_length = kv_sequence_lengths[blockIdx.x]; -+ for (int i = threadIdx.x; i < src_seq_len * tgt_seq_len; i += blockDim.x) { -+ int row_id = i / tgt_seq_len; -+ int col_id = i % tgt_seq_len; - // TODO (bhsueh) check this modification is ok or not on other rmodel -- if (col_id < length) { -- attention_mask[i] = (T)(1.0f); -- } -- else { -+ if (col_id >= q_length || row_id >= kv_length) { - attention_mask[i] = (T)(0.0f); - } - } - } - -+ - template - void invokeBuildEncoderAttentionMask( -- T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream) -+ T* attention_mask, const int* q_sequence_lengths, const int* kv_sequence_lengths, const int batch_size, const int src_seq_len, const int tgt_seq_len, const bool incremental_mode, cudaStream_t stream) - { -- buildEncoderAttentionMaskKernel<<>>(attention_mask, sequence_lengths, max_seq_len); -+ buildEncoderAttentionMaskKernel<<>>(attention_mask, q_sequence_lengths, kv_sequence_lengths, src_seq_len, tgt_seq_len, incremental_mode); - } - -+ - template void invokeBuildEncoderAttentionMask(float* attention_mask, -- const int* sequence_lengths, -+ const int* q_sequence_lengths, -+ const int* kv_sequence_lengths, - const int batch_size, -- const int max_seq_len, -+ const int src_seq_len, -+ const int tgt_seq_len, -+ const bool incremental_mode, - cudaStream_t stream); - template void invokeBuildEncoderAttentionMask(half* attention_mask, -- const int* sequence_lengths, -+ const int* q_sequence_lengths, -+ const int* kv_sequence_lengths, - const int batch_size, -- const int max_seq_len, -+ const int src_seq_len, -+ const int tgt_seq_len, -+ const bool incremental_mode, - cudaStream_t stream); - -+__global__ void buildUsePastSeqLenKernel(int *sequence_length_src, int *sequence_offset_dst, int batch_size, bool incremental_mode)//inc fix -+{ -+ // sequence_lengths: [batch_size] -+ // sequence_lengths2: [batch_size] -+ if (sequence_length_src[blockIdx.x] == -1) -+ sequence_offset_dst[blockIdx.x] = -1; -+ else if (!incremental_mode) { -+ sequence_offset_dst[blockIdx.x] = sequence_length_src[blockIdx.x] + 1; -+ } else { -+ sequence_offset_dst[blockIdx.x] = 1; -+ } -+} -+ -+ -+void buildUsePastSeqLenKernel(int *sequence_length_src, int *sequence_offset_dst, int batch_size, bool incremental_mode, cudaStream_t stream) -+{ -+ buildUsePastSeqLenKernel<<>>(sequence_length_src, sequence_offset_dst, batch_size, incremental_mode); -+} -+ -+template -+__global__ void buildUsePastAttentionMaskKernel(T* attention_mask, const int tgt_seq_len) -+{ -+ // attention_mask: [1, 1, tgt_seq_len] -+ for (int i = threadIdx.x; i < tgt_seq_len; i += blockDim.x) { -+ attention_mask[i] = (T)(1.0f); -+ } -+} -+ -+ -+template -+void invokeBuildUsePastAttentionMask( -+ T* attention_mask, const int tgt_seq_len, cudaStream_t stream) -+{ -+ buildUsePastAttentionMaskKernel<<<1, 256, 0, stream>>>(attention_mask, tgt_seq_len); -+} -+ -+template void invokeBuildUsePastAttentionMask(float* attention_mask, -+ const int tgt_seq_len, -+ cudaStream_t stream); -+template void invokeBuildUsePastAttentionMask(half* attention_mask, -+ const int tgt_seq_len, -+ cudaStream_t stream); -+ -+ template -+__global__ void EmbeddingPanguSigma(const int *input, -+ const int *input_position, -+ const T *emmbeding_table, -+ const T *emmbeding_pos_table, -+ T *output, -+ int h_token_num, -+ int hidden_size) { -+ -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * hidden_size; -+ index += gridDim.x * blockDim.x) { -+ // Gather the value from the input array -+ // Store the value in the output array -+ -+ int h_token_idx = index / hidden_size; -+ T val = (T)(emmbeding_pos_table[input_position[h_token_idx] * hidden_size + index % hidden_size]); -+ output[index] = val + emmbeding_table[input[h_token_idx] * hidden_size + index % hidden_size]; -+ } -+} -+__global__ void EmbeddingPanguSigmaHalf(const int *input, -+ const int *input_position, -+ const half *emmbeding_table, -+ const half *emmbeding_pos_table, -+ half *output, -+ int h_token_num, -+ int hidden_size) { -+ half2* output_ptr = (half2*)output; -+ const half2* emmbeding_table_ptr = (half2*)emmbeding_table; -+ const half2* emmbeding_pos_table_ptr = (half2*)emmbeding_pos_table; -+ -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * (hidden_size / 2); -+ index += gridDim.x * blockDim.x) { -+ // Gather the value from the input array -+ // Store the value in the output array -+ int h_token_idx = (index * 2) / hidden_size; -+ half2 val = emmbeding_pos_table_ptr[(input_position[h_token_idx] * hidden_size + (index * 2) % hidden_size) / 2]; -+ half2 val2 = emmbeding_table_ptr[(input[h_token_idx] * hidden_size + (index * 2) % hidden_size) / 2]; -+ output_ptr[index] = __hadd2(val, val2); -+ } -+} -+template -+void invokeEmbeddingPanguSigma(const int *input, -+ const int *input_position, -+ const T *emmbeding_table, -+ const T *emmbeding_pos_table, -+ T *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream) { -+ const int m = h_token_num; -+ const int n = hidden_size; -+ const int data_type_factor = (hidden_size % 2 == 1) ? 1 : 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 -+ dim3 block, grid; -+ if (n / 4 / data_type_factor <= 1024) { -+ block.x = n / 4 / data_type_factor + n % data_type_factor; -+ grid.x = m; -+ } -+ else { -+ block.x = 1024; -+ grid.x = ceil(m * (n / data_type_factor) / 1024.); -+ } -+ if (data_type_factor == 1) { -+ EmbeddingPanguSigma<<>>( -+ input, input_position, emmbeding_table, emmbeding_pos_table, output, h_token_num, hidden_size); -+ } else { -+ EmbeddingPanguSigmaHalf<<>>( -+ input, input_position, (const half*)emmbeding_table, (const half*)emmbeding_pos_table, (half*)output, h_token_num, hidden_size); -+ } -+} -+template void invokeEmbeddingPanguSigma(const int *input, -+ const int *input_position, -+ const float *emmbeding_table, -+ const float *emmbeding_pos_table, -+ float *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream); -+template void invokeEmbeddingPanguSigma(const int *input, -+ const int *input_position, -+ const half *emmbeding_table, -+ const half *emmbeding_pos_table, -+ half *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream); -+ -+ -+template -+__global__ void VocabEmbedding(const int *input, -+ const T *emmbeding_table, -+ T *output, -+ int h_token_num, -+ int hidden_size) { -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * hidden_size; -+ index += gridDim.x * blockDim.x) { -+ int h_token_idx = index / hidden_size; -+ // Gather the value from the input array -+ // Store the value in the output array -+ output[index] = emmbeding_table[input[h_token_idx] * hidden_size + index % hidden_size]; -+ } -+} -+ -+__global__ void VocabEmbeddingHalf(const int *input, -+ const half *emmbeding_table, -+ half *output, -+ int h_token_num, -+ int hidden_size) { -+ half2* output_ptr = (half2*)output; -+ const half2* emmbeding_table_ptr = (half2*)emmbeding_table; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * (hidden_size / 2); -+ index += gridDim.x * blockDim.x) { -+ int h_token_idx = (index * 2) / hidden_size; -+ // Gather the value from the input array -+ // Store the value in the output array -+ output_ptr[index] = emmbeding_table_ptr[(input[h_token_idx] * hidden_size + (index * 2) % hidden_size) / 2]; -+ } -+} -+template -+void invokeVocabEmbedding(const int *input, -+ const T *emmbeding_table, -+ T *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream) { -+ const int m = h_token_num; -+ int n = hidden_size; -+ const int data_type_factor = (hidden_size % 2 == 1) ? 1 : 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 -+ dim3 block, grid; -+ if (n / 4 / data_type_factor <= 1024) { -+ block.x = n / 4 / data_type_factor; -+ grid.x = m; -+ } -+ else { -+ block.x = 1024; -+ grid.x = ceil(m * (n / data_type_factor) / 1024.); -+ } -+ if (data_type_factor == 1) { -+ VocabEmbedding<<>>( -+ input, emmbeding_table, output, h_token_num, hidden_size); -+ } else { -+ VocabEmbeddingHalf<<>>( -+ input, (const half*)emmbeding_table, (half*)output, h_token_num, hidden_size); -+ } -+} -+template void invokeVocabEmbedding(const int *input, -+ const float *emmbeding_table, -+ float *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream); -+template void invokeVocabEmbedding(const int *input, -+ const half *emmbeding_table, -+ half *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream); - __global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, const int* sequence_length, const int batch_size) - { - // use for get tensorrt fused mha padding offset -@@ -113,6 +381,26 @@ __global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, const int - } - } - -+ -+ -+template -+void invokeBuildSequenceLength(const T * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream) { -+ buildSequnceLength<<>>(input,sequnce_length, max_seq_length,hidden_size); -+} -+ -+void invokeBuildSequenceLength(const int * input, int batch_size, int *sequnce_length, int max_seq_length,cudaStream_t stream) { -+ buildSequnceLength<<>>(input,sequnce_length, max_seq_length); -+} -+void invokeBuildSequnceOffset(int batch_size, int *sequnce_length, int* sequnce_offset,int hidden_size,cudaStream_t stream) { -+ findStartPoint<<<1, batch_size+1, 0, stream>>>(sequnce_length, sequnce_offset, batch_size+1,hidden_size); -+ -+} -+ -+ -+ -+ -+ -+ - void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - const int* sequence_length, - const int batch_size, -@@ -176,7 +464,52 @@ void invokeRebuildPadding( - // dst: [batch_size*max_seq_len, hidden_dim] - rebuild_sequence_length_padding<<>>(src, dst, padding_offset, n); - } -- -+template -+__global__ void rebuild_query_padding(const T* src, T* dst, const int* d_seq_len, const int batch, const int n) -+{ -+ // const int tid = threadIdx.x; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch * n; -+ index += gridDim.x * blockDim.x) { -+ int dst_index = index; -+ int batch_id = index / n; -+ int token_idx = 0; -+ int i{0}; -+ while (i <= batch_id) -+ { -+ if (d_seq_len[i] == -1) { -+ dst_index += n; -+ batch_id++; -+ } -+ else { -+ token_idx += d_seq_len[i]; -+ } -+ i++; -+ } -+ dst[dst_index] = src[(token_idx - 1) * n + (index % n)]; -+ } -+} -+template -+void invokeRebuildQuery( -+ T* dst, const T* src, const int* d_seq_len, const int batch, const int n, cudaStream_t stream) -+{ -+ // src: [token_num, hidden_dim] -+ // dst: [batch_size, hidden_dim] -+ dim3 grid((int)(ceil(1.0 * batch * n / 512))); -+ dim3 block(512); -+ rebuild_query_padding<<>>(src, dst, d_seq_len, batch, n); -+} -+template void invokeRebuildQuery(float* dst, -+ const float* src, -+ const int* d_seq_len, -+ const int batch, -+ const int n, -+ cudaStream_t stream); -+template void invokeRebuildQuery(half* dst, -+ const half* src, -+ const int* d_seq_len, -+ const int batch, -+ const int n, -+ cudaStream_t stream); - template - void invokeRebuildPadding( - T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream); -@@ -226,6 +559,12 @@ template void invokeRemovePadding(half* dst, - const int token_num, - const int hidden_dim, - cudaStream_t stream); -+template void invokeRemovePadding(int* dst, -+ const int* src, -+ const int* padding_offset, -+ const int token_num, -+ const int hidden_dim, -+ cudaStream_t stream); - - template - __global__ void buildRelativeAttentionBias(T* relative_attention_bias, -@@ -300,6 +639,8 @@ void invokeBuildRelativeAttentionBias(T* relative_attention_bias, - is_bidirectional, - max_distance); - } -+template void invokeBuildSequenceLength(const float * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream); -+template void invokeBuildSequenceLength(const half * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream); - - template void invokeBuildRelativeAttentionBias(float* relative_attention_bias, - const float* relative_attention_bias_table, -diff --git a/src/fastertransformer/kernels/bert_preprocess_kernels.h b/src/fastertransformer/kernels/bert_preprocess_kernels.h -index dcb8f85..f444b53 100644 ---- a/src/fastertransformer/kernels/bert_preprocess_kernels.h -+++ b/src/fastertransformer/kernels/bert_preprocess_kernels.h -@@ -19,6 +19,8 @@ - #include "src/fastertransformer/utils/cuda_utils.h" - #include - #include -+#include -+#include - - namespace fastertransformer { - -@@ -32,7 +34,17 @@ void invokeGetPaddingOffset(size_t* h_token_num, - - template - void invokeBuildEncoderAttentionMask( -- T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream); -+ T* attention_mask, const int* q_sequence_lengths, const int* kv_sequence_lengths, const int batch_size, const int src_seq_len, const int tgt_seq_len, const bool incremental_mode, cudaStream_t stream); -+ -+template -+void invokeBuildUsePastAttentionMask( -+ T* attention_mask, const int tgt_seq_len, cudaStream_t stream); -+ -+template -+void invokeBuildSequenceLength(const T * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream); -+ -+void invokeBuildSequenceLength(const int* input, int batch_size, int *sequnce_length, int max_seq_length,cudaStream_t stream); -+void invokeBuildSequnceOffset(int batch_size, int *sequnce_length, int* sequnce_offset,int hidden_size,cudaStream_t stream); - - void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - const int* sequence_length, -@@ -46,6 +58,25 @@ void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - cudaStream_t stream); - - template -+void invokeEmbeddingPanguSigma(const int *input, -+ const int *input_position, -+ const T *emmbeding_table, -+ const T *emmbeding_pos_table, -+ T *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream); -+template -+void invokeVocabEmbedding(const int *input, -+ const T *emmbeding_table, -+ T *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream); -+template -+void invokeRebuildQuery( -+ T* dst, const T* src, const int* d_seq_len, const int batch, const int n, cudaStream_t stream); -+template - void invokeRebuildPadding( - T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream); - -@@ -63,5 +94,6 @@ void invokeBuildRelativeAttentionBias(T* relative_attention_bias, - const int max_distance, - const PositionEmbeddingType position_embedding_type, - cudaStream_t stream); -- -+void buildUsePastSeqLenKernel( -+ int *sequence_length_src, int *sequence_offset_dst, int batch_size, bool incremental_mode, cudaStream_t stream); - } // namespace fastertransformer -diff --git a/src/fastertransformer/kernels/layernorm_kernels.cu b/src/fastertransformer/kernels/layernorm_kernels.cu -index 96a090e..e7bfec4 100644 ---- a/src/fastertransformer/kernels/layernorm_kernels.cu -+++ b/src/fastertransformer/kernels/layernorm_kernels.cu -@@ -13,6 +13,8 @@ - * See the License for the specific language governing permissions and - * limitations under the License. - */ -+#include -+ - - #include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" - #include "src/fastertransformer/kernels/layernorm_kernels.h" -@@ -29,7 +31,8 @@ __global__ void generalAddBiasResidualLayerNormOpt(T* normed_output, - const T* __restrict gamma, - const T* __restrict beta, - int m, -- int n) -+ int n, -+ float eps) - { - __shared__ float s_mean; - __shared__ float s_variance; -@@ -74,7 +77,7 @@ __global__ void generalAddBiasResidualLayerNormOpt(T* normed_output, - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n / 2 + 1e-6f); -+ s_variance = rsqrtf(variance / n / 2 + eps); - } - __syncthreads(); - -@@ -93,14 +96,15 @@ __global__ void generalAddBiasResidualLayerNormOpt(T* normed_output, - - // * Note that typename T is half2 or bfloat2 type - template --__global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output, -- T* output, -+__global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output,//out -+ T* output,//out - const T* __restrict bias, -- const T* __restrict residual, -+ const T* __restrict residual,//input - const T* __restrict gamma, - const T* __restrict beta, - int m, -- int n) -+ int n, -+ float eps) - { - __shared__ float s_mean; - __shared__ float s_variance; -@@ -108,7 +112,6 @@ __global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output, - float x2_sum = 0.0f; - const int b_offset = blockIdx.x * n; - using T1 = typename TypeConverter::Type; -- - #pragma unroll UNROLL_FACTOR - for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int index = b_offset + i; -@@ -145,7 +148,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output, - - if (threadIdx.x == 0) { - s_mean = sums[0] / n / 2; -- s_variance = rsqrtf(sums[1] / n / 2 - s_mean * s_mean + 1e-6f); -+ s_variance = rsqrtf(sums[1] / n / 2 - s_mean * s_mean + eps); - } - __syncthreads(); - -@@ -166,7 +169,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output, - // TODO(bhsueh) add half2 implementation - template - __global__ void --addBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n) -+addBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n, float eps) - { - __shared__ float s_mean; - __shared__ float s_variance; -@@ -197,7 +200,7 @@ addBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gam - } - variance = blockReduceSum(variance); - if (threadIdx.x == 0) { -- s_variance = variance / n + 1e-6f; -+ s_variance = variance / n + eps; - } - __syncthreads(); - -@@ -209,10 +212,62 @@ addBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gam - idx += blockDim.x; - } - } -+template -+__global__ void addBiasResidualPostLayerNormCast(S* attn_output, -+ D* norm_attn_out, -+ const S* __restrict input, -+ const T* __restrict bias, -+ const T* __restrict gamma, -+ const T* __restrict beta, -+ int m, -+ int n, -+ float eps) -+{ -+ __shared__ float s_mean; -+ __shared__ float s_variance; -+ float mean = 0.0f; -+ float variance = 0.0f; -+ float local_out_cache[N]; -+ -+#pragma unroll N -+ for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { -+ float local_out = (float)((T)attn_output[blockIdx.x * n + idx] + (T)input[blockIdx.x * n + idx] + (T)__ldg(&bias[idx])); -+ mean += local_out; -+ // save local_out to local_out_cache to save some recompute -+ local_out_cache[i] = local_out; -+ idx += blockDim.x; -+ } -+ -+ mean = blockReduceSum(mean); -+ if (threadIdx.x == 0) { -+ s_mean = mean / n; -+ } -+ __syncthreads(); -+ -+#pragma unroll N -+ for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { -+ float local_out = local_out_cache[i]; -+ variance += (local_out - s_mean) * (local_out - s_mean); -+ idx += blockDim.x; -+ } -+ variance = blockReduceSum(variance); -+ if (threadIdx.x == 0) { -+ s_variance = variance / n + eps; -+ } -+ __syncthreads(); -+ -+#pragma unroll N -+ for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { -+ float local_out = local_out_cache[i]; -+ norm_attn_out[blockIdx.x * n + idx] = -+ (D)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[idx])) + (float)(__ldg(&beta[idx]))); -+ idx += blockDim.x; -+ } -+} - - template - __global__ void addBiasResidualPostLayerNormHalf( -- half* out, const half* input, const half* bias, const half* gamma, const half* beta, int m, int n) -+ half* out, const half* input, const half* bias, const half* gamma, const half* beta, int m, int n, float eps) - { - __shared__ float s_mean; - __shared__ float s_variance; -@@ -255,7 +310,7 @@ __global__ void addBiasResidualPostLayerNormHalf( - - variance = blockReduceSum(variance); - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -274,7 +329,7 @@ __global__ void addBiasResidualPostLayerNormHalf( - - template - __global__ void --generalAddBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n) -+generalAddBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n, float eps) - { - __shared__ float s_mean; - __shared__ float s_variance; -@@ -300,7 +355,7 @@ generalAddBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const - } - variance = blockReduceSum(variance); - if (threadIdx.x == 0) { -- s_variance = variance / n + 1e-6f; -+ s_variance = variance / n + eps; - } - __syncthreads(); - -@@ -311,9 +366,55 @@ generalAddBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const - } - } - -+template -+__global__ void generalAddBiasResidualPostLayerNormCast(S* attn_output, -+ D* norm_attn_out, -+ const S* __restrict input, -+ const T* __restrict bias, -+ const T* __restrict gamma, -+ const T* __restrict beta, -+ int m, -+ int n, -+ float eps) -+{ -+ __shared__ float s_mean; -+ __shared__ float s_variance; -+ float mean = 0.0f; -+ float variance = 0.0f; -+ -+ for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { -+ float local_out = (float)((T)attn_output[blockIdx.x * n + idx] + (T)input[blockIdx.x * n + idx] + (T)__ldg(&bias[idx])); -+ mean += local_out; -+ // save local_out to out to save some recompute -+ attn_output[blockIdx.x * n + idx] = (T)local_out; -+ } -+ -+ mean = blockReduceSum(mean); -+ if (threadIdx.x == 0) { -+ s_mean = mean / n; -+ } -+ __syncthreads(); -+ -+ for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { -+ float local_out = (T)attn_output[blockIdx.x * n + idx]; -+ variance += (local_out - s_mean) * (local_out - s_mean); -+ } -+ variance = blockReduceSum(variance); -+ if (threadIdx.x == 0) { -+ s_variance = variance / n + eps; -+ } -+ __syncthreads(); -+ -+ for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { -+ float local_out = attn_output[blockIdx.x * n + idx]; -+ norm_attn_out[blockIdx.x * n + idx] = -+ (D)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[idx])) + (float)(__ldg(&beta[idx]))); -+ } -+} -+ - template<> - __global__ void generalAddBiasResidualPostLayerNorm( -- half* out, const half* input, const half* bias, const half* gamma, const half* beta, int m, int n) -+ half* out, const half* input, const half* bias, const half* gamma, const half* beta, int m, int n, float eps) - { - __shared__ float s_mean; - __shared__ float s_variance; -@@ -352,7 +453,7 @@ __global__ void generalAddBiasResidualPostLayerNorm( - - variance = blockReduceSum(variance); - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -373,7 +474,8 @@ __global__ void addBiasResidualPostLayerNormV2(T* out, - const T* __restrict bias, - const T* __restrict gamma, - const T* __restrict beta, -- int n) -+ int n, -+ float eps) - { - const int ite = 4; - const int tid = threadIdx.x; -@@ -409,7 +511,7 @@ __global__ void addBiasResidualPostLayerNormV2(T* out, - - variance = blockReduceSum(var); - if (tid == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -428,7 +530,8 @@ __global__ void addBiasResidualPostLayerNormV2(half* out, - const half* __restrict bias, - const half* __restrict gamma, - const half* __restrict beta, -- int n) -+ int n, -+ float eps) - { - const int ite = 4; - const int tid = threadIdx.x; -@@ -473,7 +576,7 @@ __global__ void addBiasResidualPostLayerNormV2(half* out, - - variance = blockReduceSum(var); - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -486,26 +589,154 @@ __global__ void addBiasResidualPostLayerNormV2(half* out, - } - } - -+template -+__global__ void addBiasResidualPostLayerNormV2Cast(S* attn_output, -+ D* norm_attn_out, -+ const S* __restrict input, -+ const T* __restrict bias, -+ const T* __restrict gamma, -+ const T* __restrict beta, -+ int n, -+ float eps) -+{ -+ const int ite = 4; -+ const int tid = threadIdx.x; -+ const int bid = blockIdx.x; -+ -+ __shared__ float s_mean; -+ __shared__ float s_variance; -+ float mean = 0.0f; -+ float variance = 0.0f; -+ float local_out[ite]; -+ -+ float sum = 0.0f; -+#pragma unroll -+ for (int i = 0; i < ite; i++) { -+ int col_id = i * blockDim.x + tid; -+ int id = bid * n + col_id; -+ local_out[i] = (float)((T)(attn_output[id]) + (T)__ldg(&input[id]) + (T)__ldg(&bias[col_id])); -+ sum += local_out[i]; -+ } -+ -+ mean = blockReduceSum(sum); -+ if (tid == 0) { -+ s_mean = mean / n; -+ } -+ __syncthreads(); -+ -+ float var = 0.0f; -+#pragma unroll -+ for (int i = 0; i < ite; i++) { -+ float diff = local_out[i] - s_mean; -+ var += diff * diff; -+ } -+ -+ variance = blockReduceSum(var); -+ if (tid == 0) { -+ s_variance = rsqrtf(variance / n + eps); -+ } -+ __syncthreads(); -+ -+#pragma unroll -+ for (int i = 0; i < ite; i++) { -+ int col_id = i * blockDim.x + tid; -+ int id = bid * n + col_id; -+ norm_attn_out[id] = -+ (D)((local_out[i] - s_mean) * s_variance * (float)__ldg(&gamma[col_id]) + (float)__ldg(&beta[col_id])); -+ } -+} -+ -+template -+void invokeAddBiasResidualLayerNormCast( -+ S* attn_output, D* norm_attn_out, const S* input, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, float eps) -+{ -+ dim3 grid(m); -+ dim3 block(std::min(n, 1024)); -+ if (n == 768 || n == 1024) { -+ addBiasResidualPostLayerNormV2Cast<<>>(attn_output, norm_attn_out, input, bias, gamma, beta, n, eps); -+ } -+ else { -+ block.x = std::min(n, 1024); -+ int num_trips = (n + block.x - 1) / block.x; -+ if (num_trips == 1) { -+ addBiasResidualPostLayerNormCast<<>>(attn_output, norm_attn_out, input, bias, gamma, beta, m, n, eps); -+ } -+ else if (num_trips == 2) { -+ addBiasResidualPostLayerNormCast<<>>(attn_output, norm_attn_out, input, bias, gamma, beta, m, n, eps); -+ } -+ else { -+ generalAddBiasResidualPostLayerNormCast<<>>(attn_output, norm_attn_out, input, bias, gamma, beta, m, n, eps); -+ } -+ } -+} -+ -+ -+template void invokeAddBiasResidualLayerNormCast(float* out, half* norm_attn_out, -+ const float* input, -+ const float* bias, -+ const float* gamma, -+ const float* beta, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps); -+ -+template void invokeAddBiasResidualLayerNormCast(half* out, float* norm_attn_out, -+ const half* input, -+ const float* bias, -+ const float* gamma, -+ const float* beta, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps); -+ -+ -+template void invokeGeneralAddBiasResidualPreLayerNormCast( -+ float* attn_output, -+ half* norm_output, -+ const float* from_tensor, -+ const float* gamma, -+ const float* beta, -+ const float* bias, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps, -+ int opt_version); -+ -+template void invokeGeneralAddBiasResidualT5PreLayerNormCast( -+ float* attn_output, -+ half* norm_output, -+ const float* from_tensor, -+ const float* gamma, -+ const float* beta, -+ const float* bias, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps, -+ int opt_version); - template - void invokeAddBiasResidualLayerNorm( -- T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream) -+ T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, float eps) - { - dim3 grid(m); - dim3 block(std::min(n, 1024)); - if (n == 768 || n == 1024) { -- addBiasResidualPostLayerNormV2<<>>(out, input, bias, gamma, beta, n); -+ addBiasResidualPostLayerNormV2<<>>(out, input, bias, gamma, beta, n, eps); - } - else { - block.x = std::min(n, 1024); - int num_trips = (n + block.x - 1) / block.x; - if (num_trips == 1) { -- addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); -+ addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n, eps); - } - else if (num_trips == 2) { -- addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); -+ addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n, eps); - } - else { -- generalAddBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); -+ generalAddBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n, eps); - } - } - } -@@ -518,25 +749,26 @@ void invokeAddBiasResidualLayerNorm(half* out, - const half* beta, - int m, - int n, -- cudaStream_t stream) -+ cudaStream_t stream, -+ float eps) - { - dim3 grid(m); - dim3 block(std::min(n, 1024)); - - if (m >= 512 && (n == 768 || n == 1024)) { -- addBiasResidualPostLayerNormV2<<>>(out, input, bias, gamma, beta, n); -+ addBiasResidualPostLayerNormV2<<>>(out, input, bias, gamma, beta, n, eps); - } - else { - block.x = std::min(n, 1024); - int num_trips = (n + block.x - 1) / block.x; - if (num_trips == 1) { -- addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); -+ addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n, eps); - } - else if (num_trips == 2) { -- addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); -+ addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n, eps); - } - else { -- generalAddBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); -+ generalAddBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n, eps); - } - } - } -@@ -548,7 +780,8 @@ template void invokeAddBiasResidualLayerNorm(float* out, - const float* beta, - int m, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - template void invokeAddBiasResidualLayerNorm(half* out, - const half* input, - const half* bias, -@@ -556,7 +789,8 @@ template void invokeAddBiasResidualLayerNorm(half* out, - const half* beta, - int m, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - - template - __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, -@@ -566,7 +800,8 @@ __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, - T* output, - T* norm_output, - int m, -- int n) -+ int n, -+ float eps) - { - int tid = threadIdx.x; - -@@ -601,7 +836,7 @@ __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -612,6 +847,89 @@ __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, - } - } - -+template -+__global__ void generalAddBiasResidualLayerNormCast(const T* __restrict input, -+ const T* __restrict gamma, -+ const T* __restrict beta, -+ const T* __restrict bias, -+ T* output, -+ S* norm_output, -+ int m, -+ int n, -+ float eps) -+{ -+ int tid = threadIdx.x; -+ -+ __shared__ float s_mean; -+ __shared__ float s_variance; -+ float mean = 0.0f; -+ float variance = 0.0f; -+ -+ float local_sum = 0.0f; -+ for (int i = tid; i < n; i += blockDim.x) { -+ float local_out = (float)(ldg(&input[blockIdx.x * n + i])); -+ local_out += (float)((T)output[blockIdx.x * n + i]); -+ if (bias != nullptr) { -+ local_out += (float)(ldg(&bias[i])); -+ } -+ output[blockIdx.x * n + i] = (T)local_out; -+ local_sum += local_out; -+ } -+ -+ mean = blockReduceSum(local_sum); -+ -+ if (threadIdx.x == 0) { -+ s_mean = mean / n; -+ } -+ __syncthreads(); -+ -+ float local_var_sum = 0.0f; -+ for (int i = tid; i < n; i += blockDim.x) { -+ float diff = (float)(output[blockIdx.x * n + i]) - s_mean; -+ local_var_sum += diff * diff; -+ } -+ variance = blockReduceSum(local_var_sum); -+ -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(variance / n + eps); -+ } -+ __syncthreads(); -+ -+ for (int i = tid; i < n; i += blockDim.x) { -+ float beta_val = (beta == nullptr) ? 0.0f : (float)(ldg(&beta[i])); -+ norm_output[blockIdx.x * n + i] = -+ (S)((((float)output[blockIdx.x * n + i] - s_mean) * s_variance) * (float)(ldg(&gamma[i])) + beta_val); -+ } -+} -+template -+__global__ void generalAddBiasResidualT5LayerNormCast(const T* __restrict input, -+ const T* __restrict gamma, -+ T* output, -+ S* norm_output, -+ int m, -+ int n, -+ float eps) -+{ -+ __shared__ float s_variance; -+ float variance = 0.0f; -+ float local_var_sum = 0.0f; -+ for (int i = threadIdx.x; i < n; i += blockDim.x) { -+ output[blockIdx.x * n + i] = -+ clamp_inf_for_half((float)__ldg(&input[blockIdx.x * n + i]) + (float)output[blockIdx.x * n + i]); -+ float diff = (float)(output[blockIdx.x * n + i]); -+ local_var_sum += diff * diff; -+ } -+ variance = blockReduceSum(local_var_sum); -+ -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(variance / n + eps); -+ } -+ __syncthreads(); -+ for (int i = threadIdx.x; i < n; i += blockDim.x) { -+ norm_output[blockIdx.x * n + i] = -+ (S)(clamp_inf_for_half((((float)output[blockIdx.x * n + i]) * s_variance) * (float)(__ldg(&gamma[i])))); -+ } -+} - #define HALF_LAYERNORM_OPT(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt \ - <<>>((T2*)norm_output, \ -@@ -621,7 +939,8 @@ __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, - (const T2*)gamma, \ - (const T2*)beta, \ - m, \ -- half_n); -+ half_n, \ -+ eps); - - #define HALF_LAYERNORM_OPT2(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt2 \ -@@ -632,7 +951,8 @@ __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, - (const T2*)gamma, \ - (const T2*)beta, \ - m, \ -- half_n); -+ half_n, \ -+ eps); - - template - void invokeGeneralAddBiasResidualPreLayerNorm(T* output, -@@ -644,6 +964,7 @@ void invokeGeneralAddBiasResidualPreLayerNorm(T* output, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version) - { - if (opt_version > 0 && sizeof(T) == 2 && n % 2 == 0) { -@@ -709,8 +1030,65 @@ void invokeGeneralAddBiasResidualPreLayerNorm(T* output, - - /* should pay attention to the rsqrt precision*/ - generalAddBiasResidualLayerNorm -- <<>>(input, gamma, beta, bias, output, norm_output, m, n); // For gpt-3 -+ <<>>(input, gamma, beta, bias, output, norm_output, m, n, eps); // For gpt-3 -+ } -+} -+ -+template -+void invokeGeneralAddBiasResidualPreLayerNormCast(T* attn_output, -+ S* norm_output, -+ const T* from_tensor, -+ const T* gamma, -+ const T* beta, -+ const T* bias, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps, -+ int opt_version) -+{ -+ dim3 grid(m); -+ dim3 block(min(n, 1024)); -+ -+ /* For general cases, n is equal to hidden_units, e.g., 512/1024. -+ Since we have warp shuffle inside the code, block.x % 32 should be 0. -+ */ -+ -+ if (n % 32 != 0) { -+ block.x = 1024; -+ } -+ -+ /* should pay attention to the rsqrt precision*/ -+ generalAddBiasResidualLayerNormCast -+ <<>>(from_tensor, gamma, beta, bias, attn_output, norm_output, m, n, eps); // For gpt-3 -+} -+ -+template -+void invokeGeneralAddBiasResidualT5PreLayerNormCast(T* attn_output, -+ S* norm_output, -+ const T* from_tensor, -+ const T* gamma, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps, -+ int opt_version) -+{ -+ -+ dim3 grid(m); -+ dim3 block(min(n, 1024)); -+ -+ /* For general cases, n is equal to hidden_units, e.g., 512/1024. -+ Since we have warp shuffle inside the code, block.x % 32 should be 0. -+ */ -+ -+ if (n % 32 != 0) { -+ block.x = 1024; - } -+ -+ /* should pay attention to the rsqrt precision*/ -+ generalAddBiasResidualT5LayerNormCast -+ <<>>(from_tensor, gamma, attn_output, norm_output, m, n, eps); // For gpt-3 - } - - #undef HALF_LAYERNORM_OPT -@@ -725,6 +1103,7 @@ template void invokeGeneralAddBiasResidualPreLayerNorm(float* output, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version); - - template void invokeGeneralAddBiasResidualPreLayerNorm(half* output, -@@ -736,6 +1115,7 @@ template void invokeGeneralAddBiasResidualPreLayerNorm(half* output, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version); - - #ifdef ENABLE_BF16 -@@ -748,12 +1128,13 @@ template void invokeGeneralAddBiasResidualPreLayerNorm(__nv_bfloat16* output, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version); - #endif - - template - __global__ void generalAddResidualT5LayerNorm( -- const T* __restrict input, const T* __restrict gamma, T* output, T* norm_output, int m, int n) -+ const T* __restrict input, const T* __restrict gamma, T* output, T* norm_output, int m, int n, float eps) - { - // layernorm module in the T5 style No bias and no subtraction of mean. - __shared__ float s_variance; -@@ -770,7 +1151,7 @@ __global__ void generalAddResidualT5LayerNorm( - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -783,7 +1164,7 @@ __global__ void generalAddResidualT5LayerNorm( - - template - void invokeGeneralAddResidualT5PreLayerNorm( -- T* output, T* norm_output, const T* input, const T* gamma, int m, int n, cudaStream_t stream) -+ T* output, T* norm_output, const T* input, const T* gamma, int m, int n, cudaStream_t stream, float eps) - { - dim3 grid(m); - dim3 block(min(n, 1024)); -@@ -799,14 +1180,14 @@ void invokeGeneralAddResidualT5PreLayerNorm( - block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x - - /* should pay attention to the rsqrt precision*/ -- generalAddResidualT5LayerNorm<<>>(input, gamma, output, norm_output, m, n); -+ generalAddResidualT5LayerNorm<<>>(input, gamma, output, norm_output, m, n, eps); - } - - template void invokeGeneralAddResidualT5PreLayerNorm( -- float* output, float* norm_output, const float* input, const float* gamma, int m, int n, cudaStream_t stream); -+ float* output, float* norm_output, const float* input, const float* gamma, int m, int n, cudaStream_t stream, float eps); - - template void invokeGeneralAddResidualT5PreLayerNorm( -- half* output, half* norm_output, const half* input, const half* gamma, int m, int n, cudaStream_t stream); -+ half* output, half* norm_output, const half* input, const half* gamma, int m, int n, cudaStream_t stream, float eps); - - template - void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, -@@ -817,17 +1198,39 @@ void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, - const T* bias, - int m, - int n, -- cudaStream_t stream) -+ cudaStream_t stream, -+ float eps) - { -- if (beta != nullptr && bias != nullptr) { -- invokeGeneralAddBiasResidualPreLayerNorm(output, norm_output, input, gamma, beta, bias, m, n, stream); -+ if (beta != nullptr || bias != nullptr) { -+ invokeGeneralAddBiasResidualPreLayerNorm(output, norm_output, input, gamma, beta, bias, m, n, stream, eps); - } - else { -- invokeGeneralAddResidualT5PreLayerNorm(output, norm_output, input, gamma, m, n, stream); -+ invokeGeneralAddResidualT5PreLayerNorm(output, norm_output, input, gamma, m, n, stream, eps); - } - return; - } - -+template -+void invokeGeneralAddBiasResidualT5PreLayerNormCast(T* attn_output, -+ S* norm_output, -+ const T* from_tensor, -+ const T* gamma, -+ const T* beta, -+ const T* bias, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps, -+ int opt_version) -+{ -+ if (beta != nullptr || bias != nullptr) { -+ invokeGeneralAddBiasResidualPreLayerNormCast(attn_output, norm_output, from_tensor, gamma, beta, bias, m, n, stream, eps); -+ } -+ else { -+ invokeGeneralAddBiasResidualT5PreLayerNormCast(attn_output, norm_output, from_tensor, gamma, m, n, stream, eps, opt_version); -+ } -+ return; -+} - template void invokeGeneralAddBiasResidualT5PreLayerNorm(float* output, - float* norm_output, - const float* input, -@@ -836,7 +1239,8 @@ template void invokeGeneralAddBiasResidualT5PreLayerNorm(float* output, - const float* bias, - int m, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - - template void invokeGeneralAddBiasResidualT5PreLayerNorm(half* output, - half* norm_output, -@@ -846,11 +1250,12 @@ template void invokeGeneralAddBiasResidualT5PreLayerNorm(half* output, - const half* bias, - int m, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - - template - __global__ void generalLayerNorm( -- const T* __restrict input, const T* __restrict gamma, const T* __restrict beta, T* output, int m, int n) -+ const T* __restrict input, const T* __restrict gamma, const T* __restrict beta, T* output, int m, int n, float eps) - { - const int tid = threadIdx.x; - -@@ -879,10 +1284,9 @@ __global__ void generalLayerNorm( - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); -- - for (int i = tid; i < n; i += blockDim.x) { - float beta_val = (beta == nullptr) ? 0.0f : (float)ldg(&beta[i]); - output[blockIdx.x * n + i] = -@@ -892,11 +1296,11 @@ __global__ void generalLayerNorm( - - #define HALF_LAYERNORM_OPT(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt<<>>( \ -- (T2*)out, (T2*)out, nullptr, (const T2*)input, (const T2*)gamma, (const T2*)beta, m, half_n); -+ (T2*)out, (T2*)out, nullptr, (const T2*)input, (const T2*)gamma, (const T2*)beta, m, half_n, eps); - - #define HALF_LAYERNORM_OPT2(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt2<<>>( \ -- (T2*)out, (T2*)out, nullptr, (const T2*)input, (const T2*)gamma, (const T2*)beta, m, half_n); -+ (T2*)out, (T2*)out, nullptr, (const T2*)input, (const T2*)gamma, (const T2*)beta, m, half_n, eps); - - template - void invokeGeneralLayerNorm(T* out, -@@ -906,6 +1310,7 @@ void invokeGeneralLayerNorm(T* out, - const int m, - const int n, - cudaStream_t stream, -+ float eps, - int opt_version) - { - dim3 grid(m); -@@ -965,7 +1370,7 @@ void invokeGeneralLayerNorm(T* out, - } - - /* should pay attention to the rsqrt precision*/ -- generalLayerNorm<<>>(input, gamma, beta, out, m, n); // For gpt-3 -+ generalLayerNorm<<>>(input, gamma, beta, out, m, n, eps); // For gpt-3 - } - } - -@@ -979,6 +1384,7 @@ template void invokeGeneralLayerNorm(float* out, - const int m, - const int n, - cudaStream_t stream, -+ float eps, - int opt_version); - template void invokeGeneralLayerNorm(half* out, - const half* input, -@@ -987,6 +1393,7 @@ template void invokeGeneralLayerNorm(half* out, - const int m, - const int n, - cudaStream_t stream, -+ float eps, - int opt_version); - #ifdef ENABLE_BF16 - template void invokeGeneralLayerNorm(__nv_bfloat16* out, -@@ -996,11 +1403,12 @@ template void invokeGeneralLayerNorm(__nv_bfloat16* out, - const int m, - const int n, - cudaStream_t stream, -+ float eps, - int opt_version); - #endif - - template --__global__ void generalT5LayerNorm(const T* __restrict input, const T* __restrict gamma, T* output, int m, int n) -+__global__ void generalT5LayerNorm(const T* __restrict input, const T* __restrict gamma, T* output, int m, int n, float eps) - { - // layernorm module in the T5 style No bias and no subtraction of mean. - const int tid = threadIdx.x; -@@ -1016,7 +1424,7 @@ __global__ void generalT5LayerNorm(const T* __restrict input, const T* __restric - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -1028,10 +1436,10 @@ __global__ void generalT5LayerNorm(const T* __restrict input, const T* __restric - - template - void invokeGeneralT5LayerNorm( -- T* out, const T* input, const T* gamma, const T* beta, const int m, const int n, cudaStream_t stream) -+ T* out, const T* input, const T* gamma, const T* beta, const int m, const int n, cudaStream_t stream, float eps) - { - if (beta != nullptr) { -- invokeGeneralLayerNorm(out, input, gamma, beta, m, n, stream); -+ invokeGeneralLayerNorm(out, input, gamma, beta, m, n, stream, eps); - return; - } - -@@ -1048,7 +1456,7 @@ void invokeGeneralT5LayerNorm( - block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x - - /* should pay attention to the rsqrt precision*/ -- generalT5LayerNorm<<>>(input, gamma, out, m, n); // For gpt-3 -+ generalT5LayerNorm<<>>(input, gamma, out, m, n, eps); // For gpt-3 - } - - template void invokeGeneralT5LayerNorm(float* out, -@@ -1057,9 +1465,10 @@ template void invokeGeneralT5LayerNorm(float* out, - const float* beta, - const int m, - const int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - template void invokeGeneralT5LayerNorm( -- half* out, const half* input, const half* gamma, const half* beta, const int m, const int n, cudaStream_t stream); -+ half* out, const half* input, const half* gamma, const half* beta, const int m, const int n, cudaStream_t stream, float eps); - - /******************* invokeLayernormShiftPartition ***********************/ - -@@ -1073,7 +1482,8 @@ __global__ void layernorm_shift_partition(T* out, - int W, - int n, - int shift_size, -- int window_size) -+ int window_size, -+ float eps) - { - int tid = threadIdx.x; - const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; -@@ -1102,7 +1512,7 @@ __global__ void layernorm_shift_partition(T* out, - float diff = (tid < n) ? (local_out - s_mean) : 0.0f; - variance = blockReduceSum(diff * diff); - if (threadIdx.x == 0) { -- s_variance = variance / n + 1e-6f; -+ s_variance = variance / n + eps; - } - __syncthreads(); - -@@ -1122,7 +1532,8 @@ __global__ void layernorm_shift_partition(half2* out_ptr, - int W, - int n, - int shift_size, -- int window_size) -+ int window_size, -+ float eps) - { - const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; - const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; -@@ -1161,7 +1572,7 @@ __global__ void layernorm_shift_partition(half2* out_ptr, - } - variance = blockReduceSum(variance); - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / (n * 2) + 1e-6f); -+ s_variance = rsqrtf(variance / (n * 2) + eps); - } - __syncthreads(); - -@@ -1184,7 +1595,8 @@ __global__ void layernorm_shift_partition_v2(T* out, - int W, - int n, - int shift_size, -- int window_size) -+ int window_size, -+ float eps) - { - const int ite = 4; - const int tid = threadIdx.x; -@@ -1236,7 +1648,7 @@ __global__ void layernorm_shift_partition_v2(T* out, - - variance = blockReduceSum(var); - if (tid == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -1260,7 +1672,8 @@ __global__ void layernorm_shift_partition_v2(half2* out_ptr, - int W, - int n, - int shift_size, -- int window_size) -+ int window_size, -+ float eps) - { - const int ite = 4; - const int tid = threadIdx.x; -@@ -1315,7 +1728,7 @@ __global__ void layernorm_shift_partition_v2(half2* out_ptr, - - variance = blockReduceSum(var); - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / (n * 2) + 1e-6f); -+ s_variance = rsqrtf(variance / (n * 2) + eps); - } - __syncthreads(); - -@@ -1341,18 +1754,19 @@ void invokeLayernormShiftPartition(T* out, - int n, - int shift_size, - int window_size, -- cudaStream_t stream) -+ cudaStream_t stream, -+ float eps) - { - dim3 grid(W, H, batch); - int blockSize = (n + 31) / 32 * 32; - if (blockSize >= 768) { - blockSize = ((blockSize / 4) + 31) / 32 * 32; - layernorm_shift_partition_v2 -- <<>>(out, input, gamma, beta, batch, H, W, n, shift_size, window_size); -+ <<>>(out, input, gamma, beta, batch, H, W, n, shift_size, window_size, eps); - } - else { - layernorm_shift_partition -- <<>>(out, input, gamma, beta, batch, H, W, n, shift_size, window_size); -+ <<>>(out, input, gamma, beta, batch, H, W, n, shift_size, window_size, eps); - } - } - -@@ -1367,7 +1781,8 @@ void invokeLayernormShiftPartition(half* out, - int n, - int shift_size, - int window_size, -- cudaStream_t stream) -+ cudaStream_t stream, -+ float eps) - { - dim3 grid(W, H, batch); - int blockSize = n / 2; -@@ -1384,7 +1799,8 @@ void invokeLayernormShiftPartition(half* out, - W, - n / 2, - shift_size, -- window_size); -+ window_size, -+ eps); - } - else { - layernorm_shift_partition<<>>((half2*)out, -@@ -1396,7 +1812,8 @@ void invokeLayernormShiftPartition(half* out, - W, - n / 2, - shift_size, -- window_size); -+ window_size, -+ eps); - } - } - -@@ -1410,7 +1827,8 @@ template void invokeLayernormShiftPartition(float* out, - int n, - int shift_size, - int window_size, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - - template void invokeLayernormShiftPartition(half* out, - const half* input, -@@ -1422,12 +1840,13 @@ template void invokeLayernormShiftPartition(half* out, - int n, - int shift_size, - int window_size, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - - /******************* invokeAddBiasLayernorm ***********************/ - - template --__global__ void add_bias_layernorm(T* out, const T* bias, const T* gamma, const T* beta, int n) -+__global__ void add_bias_layernorm(T* out, const T* bias, const T* gamma, const T* beta, int n, float eps) - { - int tid = threadIdx.x; - const int bid = blockIdx.x; -@@ -1447,7 +1866,7 @@ __global__ void add_bias_layernorm(T* out, const T* bias, const T* gamma, const - float diff = (tid < n) ? (local_out - s_mean) : 0.0f; - variance = blockReduceSum(diff * diff); - if (threadIdx.x == 0) { -- s_variance = variance / n + 1e-6f; -+ s_variance = variance / n + eps; - } - __syncthreads(); - -@@ -1459,7 +1878,7 @@ __global__ void add_bias_layernorm(T* out, const T* bias, const T* gamma, const - - template - __global__ void --add_bias_layernorm_v2(T* out, const T* __restrict bias, const T* __restrict gamma, const T* __restrict beta, int n) -+add_bias_layernorm_v2(T* out, const T* __restrict bias, const T* __restrict gamma, const T* __restrict beta, int n, float eps) - { - const int ite = 4; - const int tid = threadIdx.x; -@@ -1496,7 +1915,7 @@ add_bias_layernorm_v2(T* out, const T* __restrict bias, const T* __restrict gamm - - variance = blockReduceSum(var); - if (tid == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -1512,15 +1931,15 @@ add_bias_layernorm_v2(T* out, const T* __restrict bias, const T* __restrict gamm - - #define HALF_LAYERNORM_OPT(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt<<>>( \ -- (T2*)out, (T2*)out, (const T2*)bias, (const T2*)out, (const T2*)gamma, (const T2*)beta, m, half_n); -+ (T2*)out, (T2*)out, (const T2*)bias, (const T2*)out, (const T2*)gamma, (const T2*)beta, m, half_n, eps); - - #define HALF_LAYERNORM_OPT2(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt2<<>>( \ -- (T2*)out, (T2*)out, (const T2*)bias, (const T2*)out, (const T2*)gamma, (const T2*)beta, m, half_n); -+ (T2*)out, (T2*)out, (const T2*)bias, (const T2*)out, (const T2*)gamma, (const T2*)beta, m, half_n, eps); - - template - void invokeAddBiasLayernorm( -- T* out, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, int opt_version) -+ T* out, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, float eps, int opt_version) - { - dim3 grid(m); - if (n % 2 == 0 && std::is_same::value && opt_version > 0) { -@@ -1572,10 +1991,10 @@ void invokeAddBiasLayernorm( - int blockSize = (n + 31) / 32 * 32; - if (blockSize >= 768) { - blockSize = ((blockSize / 4) + 31) / 32 * 32; -- add_bias_layernorm_v2<<>>(out, bias, gamma, beta, n); -+ add_bias_layernorm_v2<<>>(out, bias, gamma, beta, n, eps); - } - else { -- add_bias_layernorm<<>>(out, bias, gamma, beta, n); -+ add_bias_layernorm<<>>(out, bias, gamma, beta, n, eps); - } - } - } -@@ -1590,6 +2009,7 @@ template void invokeAddBiasLayernorm(float* out, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version); - - template void invokeAddBiasLayernorm(half* out, -@@ -1599,6 +2019,7 @@ template void invokeAddBiasLayernorm(half* out, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version); - #ifdef ENABLE_BF16 - template void invokeAddBiasLayernorm<__nv_bfloat16>(__nv_bfloat16* out, -@@ -1608,6 +2029,7 @@ template void invokeAddBiasLayernorm<__nv_bfloat16>(__nv_bfloat16* out, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version); - #endif - -@@ -1625,7 +2047,8 @@ __global__ void merge_layernorm_v2(T* out, - int batch, - int H, - int W, -- int n) -+ int n, -+ float eps) - { - const int ite = 4; - const int tid = threadIdx.x; -@@ -1675,7 +2098,7 @@ __global__ void merge_layernorm_v2(T* out, - - variance = blockReduceSum(var); - if (tid == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -1693,7 +2116,7 @@ __global__ void merge_layernorm_v2(T* out, - // TODO : accelerate with half2 - template - void invokeMergeLayernorm( -- T* output, const T* input, const T* gamma, const T* beta, int batch, int H, int W, int n, cudaStream_t stream) -+ T* output, const T* input, const T* gamma, const T* beta, int batch, int H, int W, int n, cudaStream_t stream, float eps) - { - if ((W % 2 != 0) || (H % 2 != 0)) { - printf("[ERROR][invokeMergeLayernorm] H(W) should be a multiple of 2.\n"); -@@ -1706,7 +2129,7 @@ void invokeMergeLayernorm( - // if (blockSize >= 768) - { - blockSize = ((blockSize / 4) + 31) / 32 * 32; -- merge_layernorm_v2<<>>(output, input, gamma, beta, batch, H / 2, W / 2, n * 4); -+ merge_layernorm_v2<<>>(output, input, gamma, beta, batch, H / 2, W / 2, n * 4, eps); - } - /* - else -@@ -1722,7 +2145,8 @@ template void invokeMergeLayernorm(float* output, - int H, - int W, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - - template void invokeMergeLayernorm(half* output, - const half* input, -@@ -1732,6 +2156,45 @@ template void invokeMergeLayernorm(half* output, - int H, - int W, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); -+ -+ -+ -+ -+ -+ -+__global__ void ToFloat(half* src, float* dst, int element_cnt) { -+ for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < element_cnt; pos += blockDim.x * gridDim.x) { -+ dst[pos] = (float)(src[pos]); -+ } -+} -+ -+__global__ void ToHalf(float* src, half* dst, int element_cnt) { -+ for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < element_cnt; pos += blockDim.x * gridDim.x) { -+ dst[pos] = (half)(src[pos]); -+ } -+} -+ -+__global__ void ToFlaotFromFloat(float* src, float* dst, int element_cnt) { -+ for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < element_cnt; pos += blockDim.x * gridDim.x) { -+ dst[pos] = (src[pos]); -+ -+ } -+} -+ -+void InvokeCast(void* src, void* dst, int element_cnt, int dir, cudaStream_t stream) { -+ dim3 block, grid; -+ -+ block.x = 1024; -+ grid.x = ceil(element_cnt / 1024.); -+ if (dir) { -+ ToFloat<<>>((half*)src, (float*)dst, element_cnt); -+ } else { -+ ToHalf<<>>((float*)src, (half*)dst, element_cnt); -+ } -+ return; -+} -+ - - } // namespace fastertransformer -\ No newline at end of file -diff --git a/src/fastertransformer/kernels/layernorm_kernels.h b/src/fastertransformer/kernels/layernorm_kernels.h -index e8319de..22e8b94 100644 ---- a/src/fastertransformer/kernels/layernorm_kernels.h -+++ b/src/fastertransformer/kernels/layernorm_kernels.h -@@ -42,7 +42,19 @@ void invokeAddBiasResidualLayerNorm(T* out, - const T* beta, - const int m, - const int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps = 1e-6f); -+ -+template -+void invokeAddBiasResidualT5LayerNorm(T* out, -+ const T* input, -+ const T* bias, -+ const T* gamma, -+ const T* beta, -+ const int m, -+ const int n, -+ cudaStream_t stream, -+ float eps = 1e-6f); - - template - void invokeGeneralAddBiasResidualPreLayerNorm(T* output, -@@ -54,6 +66,7 @@ void invokeGeneralAddBiasResidualPreLayerNorm(T* output, - int m, - int n, - cudaStream_t stream, -+ float eps = 1e-6f, - int opt_version = 2); - - template -@@ -64,15 +77,16 @@ void invokeGeneralLayerNorm(T* out, - const int m, - const int n, - cudaStream_t stream, -+ float eps = 1e-6f, - int opt_version = 2); - - template - void invokeGeneralT5LayerNorm( -- T* out, const T* input, const T* gamma, const T* beta, const int m, const int n, cudaStream_t stream); -+ T* out, const T* input, const T* gamma, const T* beta, const int m, const int n, cudaStream_t stream, float eps = 1e-6f); - - template - void invokeGeneralAddResidualT5PreLayerNorm( -- T* output, T* norm_output, const T* input, const T* gamma, int m, int n, cudaStream_t stream); -+ T* output, T* norm_output, const T* input, const T* gamma, int m, int n, cudaStream_t stream, float eps = 1e-6f); - - template - void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, -@@ -83,7 +97,8 @@ void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, - const T* bias, - int m, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps = 1e-6f); - - template - void invokeLayernormShiftPartition(T* out, -@@ -96,14 +111,49 @@ void invokeLayernormShiftPartition(T* out, - int n, - int shift_size, - int window_size, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps = 1e-6f); -+ -+template -+void invokeGeneralAddBiasResidualPreLayerNormCast(T* attn_output, -+ S* norm_output, -+ const T* from_tensor, -+ const T* gamma, -+ const T* beta, -+ const T* bias, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps = 1e-6f, -+ int opt_version = 2); -+ -+template -+void invokeGeneralAddBiasResidualT5PreLayerNormCast(T* attn_output, -+ S* norm_output, -+ const T* from_tensor, -+ const T* gamma, -+ const T* beta, -+ const T* bias, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps = 1e-6f, -+ int opt_version = 2); -+ -+template -+void invokeAddBiasResidualLayerNormCast( -+ S* attn_output, D* norm_output, const S* input, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, float eps = 1e-6f); -+template -+void invokeAddBiasResidualT5LayerNormCast( -+ S* attn_output, D* norm_output, const S* input, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, float eps = 1e-6f); - - template - void invokeAddBiasLayernorm( -- T* out, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, int opt_version = 2); -+ T* out, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, float eps = 1e-6f, int opt_version = 2); - - template - void invokeMergeLayernorm( -- T* output, const T* input, const T* gamma, const T* beta, int batch, int H, int W, int n, cudaStream_t stream); -+ T* output, const T* input, const T* gamma, const T* beta, int batch, int H, int W, int n, cudaStream_t stream, float eps = 1e-6f); - -+void InvokeCast(void* src, void* dst, int element_cnt, int dir, cudaStream_t stream); - } // namespace fastertransformer -diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu -index f951e71..f404f45 100644 ---- a/src/fastertransformer/kernels/unfused_attention_kernels.cu -+++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu -@@ -15,6 +15,14 @@ - * limitations under the License. - */ - -+#ifndef CUDART_VERSION -+#error CUDART_VERSION Undefined! -+#elif (CUDART_VERSION >= 11050) -+#include -+#else -+#include "3rdparty/cub/cub.cuh" -+#endif -+ - #include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" - #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" - #include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" -@@ -23,6 +31,24 @@ - - namespace fastertransformer { - -+const int WARP_SIZE = 32; -+const bool ATTENION_OPT = true; -+const int ATTENTION_BLOCK_SIZE = 256; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+using Copy_half_t = typename std::conditional< -+ HALF_ELEMENTS_PER_WARP_LOAD == 32, -+ half, -+ typename std::conditional::type>::type>:: -+ type; -+ -+template -+using Copy_t = Copy_half_t; -+ - __inline__ __device__ int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4) - { - return id1 * (dim_2 * dim_3 * dim_4) + id3 * (dim_2 * dim_4) + id2 * dim_4 + id4; -@@ -243,6 +269,77 @@ __global__ void softmax_kernel_v4(T* qk_buf_, - } - } - -+template -+__global__ void softmax_mix_kernel_bias_v4(T* qk_buf_, -+ const T_M* attr_mask, -+ const T* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int head_num, -+ const int seq_len, -+ const int seq_stride, -+ const int trgt_seq_len, -+ const int trgt_seq_stride, -+ const int position_bias_head_num, -+ const T scalar, -+ const bool causal) -+{ -+ T* qk_buf_src = qk_buf_; -+ for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { -+ float data[ITEMS_PER_THREAD]; -+ int qk_offset; -+ __shared__ float s_mean, s_max; -+ float local_max = -1e20f; -+ for (int i = 0; blockDim.x * i + threadIdx.x < trgt_seq_len; i++) { -+ qk_offset = -+ ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * trgt_seq_len + blockDim.x * i + threadIdx.x; -+ int pos_offset = -+ ((blockIdx.z) * seq_stride + seq_id) * trgt_seq_stride + blockDim.x * i + threadIdx.x; -+ int mask_offset = (blockIdx.y * seq_stride + seq_id) * trgt_seq_stride + blockDim.x * i + threadIdx.x; -+ -+ int pos_offset2 = (seq_id) * trgt_seq_stride + blockDim.x * i + threadIdx.x; -+ int bias_offset = (position_bias_head_num == 1) ? pos_offset2 : pos_offset; -+ float qk = static_cast(qk_buf_src[qk_offset]); -+ float mask_val = (attr_mask != nullptr) ? static_cast(ldg(&attr_mask[mask_offset])) : 1.0f; -+ if (causal) { -+ mask_val = (blockDim.x * i + threadIdx.x <= seq_id && seq_id < d_seq_len[blockIdx.y]) ? 1.0f : 0.0f; -+ } else if (d_seq_len != nullptr) { -+ mask_val = (seq_id < d_seq_len[blockIdx.y] && blockDim.x * i + threadIdx.x < d_seq_len2[blockIdx.y]) ? mask_val : 0.0f; -+ } -+ float bias_val = (position_bias == nullptr) ? 0.0f : static_cast(ldg(&position_bias[bias_offset])); -+ mask_val = (1.0f - mask_val) * -10000.0f; -+ -+ data[i] = qk * static_cast(scalar) + mask_val + bias_val; -+ local_max = fmax(local_max, data[i]); -+ } -+ -+ float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax(local_max); -+ if (threadIdx.x == 0) { -+ s_max = max_val; -+ } -+ __syncthreads(); -+ -+ float local_sum = 0; -+ for (int i = 0; blockDim.x * i + threadIdx.x < trgt_seq_len; i++) { -+ data[i] = __expf(data[i] - s_max); -+ local_sum += data[i]; -+ } -+ float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum(local_sum); -+ if (threadIdx.x == 0) { -+ s_mean = sum_val + 1e-6f; -+ s_mean = __fdividef(1.0f, s_mean); -+ } -+ __syncthreads(); -+ -+ for (int i = 0; blockDim.x * i + threadIdx.x < trgt_seq_len; i++) { -+ qk_offset = -+ ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * trgt_seq_len + blockDim.x * i + threadIdx.x; -+ qk_buf_[qk_offset] = (T)(data[i] * s_mean); -+ } -+ } -+} -+ - template - __global__ void softmax_kernel_v4_half2( - T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) -@@ -298,6 +395,89 @@ __global__ void softmax_kernel_v4_half2( - } - } - -+template -+__global__ void softmax_cross_kernel_bias_v4_half2(T* qk_buf_, -+ const T* attr_mask, -+ const T* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int head_num, -+ const int seq_len, -+ const int seq_stride, -+ const int trgt_seq_len, -+ const int trgt_seq_stride, -+ const int position_bias_head_num, -+ const T scalar, -+ const bool causal) -+{ -+ using T2 = typename TypeConverter::Type; -+ T2* qk_buf_half2 = (T2*)qk_buf_; -+ const T2* attr_mask_half2 = (const T2*)attr_mask; -+ const T2* position_bias_half2 = (position_bias == nullptr) ? nullptr : (const T2*)position_bias; -+ const T2 zero = {0, 0}; -+ const T2 one = {1, 1}; -+ for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { -+ T2 data[ITEMS_PER_THREAD]; -+ int qk_offset; -+ int pos_offset; -+ int pos_offset2; -+ __shared__ float s_mean, s_max; -+ float local_max = -1e20f; -+ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { -+ qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * (trgt_seq_len / 2) + blockDim.x * i -+ + threadIdx.x; -+ pos_offset = ((blockIdx.z) * seq_stride + seq_id) * (trgt_seq_stride / 2) + blockDim.x * i -+ + threadIdx.x; -+ pos_offset2 = (seq_id) * (trgt_seq_stride / 2) + blockDim.x * i + threadIdx.x; -+ int mask_offset = (blockIdx.y * seq_stride + seq_id) * (trgt_seq_stride / 2) + blockDim.x * i + threadIdx.x; -+ int bias_offset = (position_bias_head_num == 1) ? pos_offset2 : pos_offset; -+ -+ T2 qk = qk_buf_half2[qk_offset]; -+ T2 mask_val = (attr_mask_half2!= nullptr) ? ldg(&attr_mask_half2[mask_offset]) : one; -+ if (causal) { -+ mask_val.x = ((mask_offset % (trgt_seq_stride / 2)) * 2 <= seq_id && seq_id < d_seq_len[blockIdx.y]) ? 1 : 0; -+ mask_val.y = (((mask_offset % (trgt_seq_stride / 2)) * 2 + 1) <= seq_id && seq_id < d_seq_len[blockIdx.y]) ? 1 : 0; -+ } else if (d_seq_len != nullptr) { -+ mask_val.x = (seq_id < d_seq_len[blockIdx.y] && (mask_offset % (trgt_seq_stride / 2)) * 2 < d_seq_len2[blockIdx.y]) ? mask_val.x : (T)0; -+ mask_val.y = (seq_id < d_seq_len[blockIdx.y] && ((mask_offset % (trgt_seq_stride / 2)) * 2 + 1) < d_seq_len2[blockIdx.y]) ? mask_val.y : (T)0; -+ } -+ mask_val = hmul2(hsub2(float2type2(1.0f), mask_val), float2type2(-10000.0f)); -+ T2 bias_val = (position_bias_half2 == nullptr) ? zero : (ldg(&position_bias_half2[bias_offset])); -+ -+ data[i] = hadd2(hadd2(hmul2(qk, type2type2(scalar)), mask_val), bias_val); -+ -+ local_max = fmax(local_max, fmax((float)data[i].x, (float)data[i].y)); -+ } -+ -+ float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax(local_max); -+ if (threadIdx.x == 0) { -+ s_max = max_val; -+ } -+ __syncthreads(); -+ -+ float local_sum = 0; -+ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { -+ data[i] = hexp2(hsub2(data[i], float2type2(s_max))); -+ local_sum += (float)(data[i].x + data[i].y); -+ } -+ -+ float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum(local_sum); -+ -+ if (threadIdx.x == 0) { -+ s_mean = sum_val + 1e-6f; -+ s_mean = __fdividef(1.0f, s_mean); -+ } -+ __syncthreads(); -+ -+ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { -+ qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * (trgt_seq_len / 2) + blockDim.x * i -+ + threadIdx.x; -+ qk_buf_half2[qk_offset] = hmul2(data[i], float2type2(s_mean)); -+ } -+ } -+} -+ - template - __global__ void softmax_kernel_v5_half2( - T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) -@@ -415,6 +595,162 @@ __global__ void softmax_kernel_v5_half2( - } - } - -+template -+__global__ void softmax_cross_kernel_bias_v5_half2(T* qk_buf_, -+ const T* attr_mask, -+ const T* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int head_num, -+ const int seq_len, -+ const int seq_stride, -+ const int trgt_seq_len, -+ const int trgt_seq_stride, -+ const int position_bias_head_num, -+ const T scalar, -+ const bool causal) -+{ -+ using T2 = typename TypeConverter::Type; -+ T2* qk_buf_half2 = (T2*)qk_buf_; -+ const T2* attr_mask_half2 = (const T2*)attr_mask; -+ const T2* position_bias_half2 = (position_bias == nullptr) ? nullptr : (const T2*)position_bias; -+ const T2 zero = {0, 0}; -+ const T2 one = {1, 1}; -+ for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x * NUM) { -+ T2 data[NUM][ITEMS_PER_THREAD]; -+ -+ int qk_offset[NUM]; -+ int pos_offset[NUM]; -+ int pos_offset2[NUM]; -+ int pos_bias_offset[NUM]; -+ -+ __shared__ float s_sum[NUM], s_max[NUM]; -+ float local_max[NUM]; -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ local_max[j] = -1e20f; -+ } -+ -+ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { -+ int mask_offset[NUM]; -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ qk_offset[j] = -+ ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id + j * gridDim.x) * (trgt_seq_len / 2) -+ + blockDim.x * i + threadIdx.x; -+ pos_offset[j] = -+ ((blockIdx.z) * seq_stride + seq_id + j * gridDim.x) * (trgt_seq_stride / 2) -+ + blockDim.x * i + threadIdx.x; -+ pos_offset2[j] = (seq_id + j * gridDim.x) * (trgt_seq_stride / 2) + blockDim.x * i + threadIdx.x; -+ mask_offset[j] = (blockIdx.y * seq_stride + seq_id + j * gridDim.x) * (trgt_seq_stride / 2) + blockDim.x * i + threadIdx.x; -+ -+ pos_bias_offset[j] = (position_bias_head_num == 1) ? pos_offset2[j] : pos_offset[j]; -+ } -+ -+ T2 mask_val[NUM]; -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ mask_val[j] = (attr_mask_half2 != 0) ? ldg(&attr_mask_half2[mask_offset[j]]) : one; -+ if (causal) { -+ mask_val[j].x = ((mask_offset[j] % (trgt_seq_stride / 2)) * 2 <= seq_id && seq_id < d_seq_len[blockIdx.y]) ? 1 : 0; -+ mask_val[j].y = (((mask_offset[j] % (trgt_seq_stride / 2)) * 2 + 1) <= seq_id && seq_id < d_seq_len[blockIdx.y]) ? 1 : 0; -+ } else if (d_seq_len != nullptr) { -+ mask_val[j].x = (seq_id < d_seq_len[blockIdx.y] && (mask_offset[j] % (trgt_seq_stride / 2)) * 2 < d_seq_len2[blockIdx.y]) ? mask_val[j].x : (T)0; -+ mask_val[j].y = (seq_id < d_seq_len[blockIdx.y] && ((mask_offset[j] % (trgt_seq_stride / 2)) * 2 + 1) < d_seq_len2[blockIdx.y]) ? mask_val[j].y : (T)0; -+ } -+ } -+ -+ T2 qk[NUM]; -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ qk[j] = qk_buf_half2[qk_offset[j]]; -+ } -+ -+ T2 pos_bias_val[NUM]; -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ pos_bias_val[j] = -+ (position_bias_half2 == nullptr) ? zero : ldg(&position_bias_half2[pos_bias_offset[j]]); -+ } -+ -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ mask_val[j] = hmul2(hsub2(float2type2(1.0f), mask_val[j]), float2type2(-10000.0f)); -+ } -+ -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ data[j][i] = -+ hadd2(hadd2(hmul2(qk[j], type2type2(scalar)), mask_val[j]), pos_bias_val[j]); -+ local_max[j] = fmax(local_max[j], fmax((float)data[j][i].x, (float)data[j][i].y)); -+ } -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceMaxV2(local_max); -+ } -+ else { -+ blockReduceMaxV2(local_max); -+ } -+ -+ if (threadIdx.x == 0) { -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ s_max[j] = local_max[j]; -+ } -+ } -+ __syncthreads(); -+ -+ float local_sum[NUM]; -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ local_sum[j] = {0.f}; -+ } -+ -+ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ data[j][i] = hexp2(hsub2(data[j][i], float2type2(s_max[j]))); -+ } -+ -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ local_sum[j] += (float)(data[j][i].x + data[j][i].y); -+ } -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceSumV2(local_sum); -+ } -+ else { -+ blockReduceSumV2(local_sum); -+ } -+ -+ if (threadIdx.x == 0) { -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f); -+ } -+ } -+ __syncthreads(); -+ -+ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ qk_offset[j] = -+ ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id + j * gridDim.x) * (trgt_seq_len / 2) -+ + blockDim.x * i + threadIdx.x; -+ } -+ -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ qk_buf_half2[qk_offset[j]] = hmul2(data[j][i], float2type2(s_sum[j])); -+ } -+ } -+ } -+} -+ - #define SOFTMAX_KERNEL(ITEMS_PER_THREAD) \ - block.x /= ITEMS_PER_THREAD; \ - assert(block.x <= 1024); \ -@@ -434,6 +770,63 @@ __global__ void softmax_kernel_v5_half2( - <<>>(buffer, buffer_src, attr_mask, batch_size, head_num, seq_len, scalar); \ - } - -+#define SOFTMAX_MIX_KERNEL_BIAS(ITEMS_PER_THREAD) \ -+ block.x /= ITEMS_PER_THREAD; \ -+ assert(block.x <= 1024); \ -+ if (is_half2) { \ -+ if (grid.x % 4 == 0) { \ -+ grid.x /= 4; \ -+ softmax_cross_kernel_bias_v5_half2 \ -+ <<>>((half*)io_buffer, \ -+ (const half*)attr_mask, \ -+ (const half*)position_bias, \ -+ (const int*)d_seq_len, \ -+ (const int*)d_seq_len2, \ -+ batch_size, \ -+ head_num, \ -+ seq_len, \ -+ src_seq_stride, \ -+ trgt_seq_len, \ -+ tgt_seq_stride, \ -+ position_bias_head_num, \ -+ (const half)scalar, \ -+ causal); \ -+ } \ -+ else { \ -+ softmax_cross_kernel_bias_v4_half2 \ -+ <<>>((half*)io_buffer, \ -+ (const half*)attr_mask, \ -+ (const half*)position_bias, \ -+ (const int*)d_seq_len, \ -+ (const int*)d_seq_len2, \ -+ batch_size, \ -+ head_num, \ -+ seq_len, \ -+ src_seq_stride, \ -+ trgt_seq_len, \ -+ tgt_seq_stride, \ -+ position_bias_head_num, \ -+ (const half)scalar, \ -+ causal); \ -+ } \ -+ } \ -+ else { \ -+ softmax_mix_kernel_bias_v4<<>>(io_buffer, \ -+ attr_mask, \ -+ position_bias, \ -+ d_seq_len, \ -+ d_seq_len2, \ -+ batch_size, \ -+ head_num, \ -+ seq_len, \ -+ src_seq_stride, \ -+ trgt_seq_len, \ -+ tgt_seq_stride, \ -+ position_bias_head_num, \ -+ scalar, \ -+ causal); \ -+ } -+ - #ifdef ENABLE_BF16 - #define SOFTMAX_KERNEL_BF16(ITEMS_PER_THREAD) \ - block.x /= ITEMS_PER_THREAD; \ -@@ -501,6 +894,48 @@ void invokeMaskedSoftMax(T* buffer, - } - } - -+template -+void invokeMixMaskedSoftMax(T* io_buffer, -+ const T_M* attr_mask, -+ const T* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int seq_len, -+ const int src_seq_stride, -+ const int trgt_seq_len, -+ const int tgt_seq_stride, -+ const int head_num, -+ const int position_bias_head_num, -+ const T scalar, -+ const bool causal, -+ cudaStream_t stream) -+{ -+ dim3 grid(seq_len, batch_size, head_num); -+ if (batch_size * head_num > 360) { -+ grid.x = ceil(float(seq_len) / 32.0f); -+ } -+ -+ bool is_half2 = sizeof(T) == 2 && sizeof(T_M) == 2 && trgt_seq_len % 2 == 0; -+ dim3 block((trgt_seq_len / (is_half2 ? 2 : 1) + 31) / 32 * 32); -+ -+ if (block.x > 3072 && block.x <= 4096) { -+ SOFTMAX_MIX_KERNEL_BIAS(4) -+ } -+ if (block.x > 2048) { -+ SOFTMAX_MIX_KERNEL_BIAS(3) -+ } -+ else if (block.x > 1024) { -+ SOFTMAX_MIX_KERNEL_BIAS(2) -+ } -+ else if (block.x > 0) { -+ SOFTMAX_MIX_KERNEL_BIAS(1) -+ } -+ else { -+ FT_CHECK(trgt_seq_len <= 4096 || seq_len <= 4096); -+ } -+} -+ - #ifdef ENABLE_BF16 - template<> - void invokeMaskedSoftMax(__nv_bfloat16* buffer, -@@ -574,8 +1009,73 @@ void invokeMaskedSoftMax(__nv_bfloat16* buffer, - FT_CHECK(seq_len <= 4096); - } - } -+ - #endif // ENABLE_BF16 - -+template void invokeMixMaskedSoftMax(float* io_buffer, -+ const float* attr_mask, -+ const float* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int seq_len, -+ const int seq_len_stride, -+ const int tgt_seq_len, -+ const int tgt_seq_stride, -+ const int head_num, -+ const int position_bias_head_num, -+ const float scalar, -+ const bool causal, -+ cudaStream_t stream); -+ -+template void invokeMixMaskedSoftMax(half* io_buffer, -+ const half* attr_mask, -+ const half* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int seq_len, -+ const int seq_len_stride, -+ const int tgt_seq_len, -+ const int tgt_seq_stride, -+ const int head_num, -+ const int position_bias_head_num, -+ const half scalar, -+ const bool causal, -+ cudaStream_t stream); -+ -+template void invokeMixMaskedSoftMax(float* io_buffer, -+ const half* attr_mask, -+ const float* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int seq_len, -+ const int seq_len_stride, -+ const int tgt_seq_len, -+ const int tgt_seq_stride, -+ const int head_num, -+ const int position_bias_head_num, -+ const float scalar, -+ const bool causal, -+ cudaStream_t stream); -+ -+template void invokeMixMaskedSoftMax(half* io_buffer, -+ const float* attr_mask, -+ const half* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int seq_len, -+ const int seq_len_stride, -+ const int tgt_seq_len, -+ const int tgt_seq_stride, -+ const int head_num, -+ const int position_bias_head_num, -+ const half scalar, -+ const bool causal, -+ cudaStream_t stream); -+ - template void invokeMaskedSoftMax(float* buffer, - const float* buffer_src, - const float* attr_mask, -@@ -621,6 +1121,7 @@ template void invokeMaskedSoftMax(__nv_bfloat16* buffer, - const int head_num, - const __nv_bfloat16 scalar, - cudaStream_t stream); -+ - #endif // ENABLE_BF16 - - template -@@ -726,9 +1227,9 @@ void invokeTransposeQKV(T* dst, - seq_per_block *= 2; - } - -- FT_CHECK(grid.x * seq_per_block == batch_size * head_num * seq_len); -+ FT_CHECK((int)(grid.x * seq_per_block) == batch_size * head_num * seq_len); - -- if (seq_per_block * size_per_head % 2 == 0) { -+ if (size_per_head % 2 == 0) { - block.x = seq_per_block * size_per_head / 2; - if (std::is_same::value) { - transpose<<>>( -@@ -778,6 +1279,7 @@ template void invokeTransposeQKV(__nv_bfloat16* src, - const int head_num, - const int size_per_head, - cudaStream_t stream); -+ - #endif - - template -@@ -993,12 +1495,14 @@ __global__ void transpose_remove_padding(const T* src, - - const int dst_seq_id = bid; - -+ const int src_offset_base = src_batch_id * seq_len * head_num * size_per_head + src_seq_id * size_per_head; -+ const int dst_offset_base = dst_seq_id * head_num * size_per_head; -+ - for (int idx = threadIdx.x; idx < head_num * size_per_head; idx += blockDim.x) { - const int head_id = idx / size_per_head; - const int hidden_id = idx % size_per_head; -- dst[dst_seq_id * head_num * size_per_head + idx] = -- __ldg(&src[src_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -- + src_seq_id * size_per_head + hidden_id]); -+ const T src_elem = ldg(&src[src_offset_base + head_id * seq_len * size_per_head + hidden_id]); -+ dst[dst_offset_base + idx] = src_elem; - } - } - -@@ -1061,12 +1565,12 @@ template void invokeTransposeAttentionOutRemovePadding(half* src, - const int* mask_offset, - cudaStream_t stream); - --template -+template - __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, - T* k_buf, - T* v_buf, - const T* __restrict QKV, -- const T* __restrict qkv_bias, -+ const U* __restrict qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, -@@ -1081,8 +1585,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, - for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 3 * n; - index += gridDim.x * blockDim.x) { - int bias_id = index % (3 * n); -- T val = ldg(&QKV[index]) + ldg(&qkv_bias[bias_id]); -- -+ T val = ldg(&QKV[index]) + (T)ldg(&qkv_bias[bias_id]); - int tmp_index = index; - const int target_batch_id = tmp_index / (seq_len * 3 * n); - tmp_index -= target_batch_id * seq_len * 3 * n; -@@ -1097,15 +1600,217 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, - + seq_id * size_per_head + size_id] = val; - } - } -- --template --struct Vec_t {}; --template<> --struct Vec_t { -- using Type = float2; --}; --template<> --struct Vec_t { -+template -+__global__ void add_fusedQKV_bias_transpose_kernel_mb_vsl(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int h_token_num, -+ const int batch_size, -+ const int max_seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 3, n] -+ // qkv_bias: [3, n] -+ // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * 3 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (3 * n); -+ T bias_val = (qkv_bias == nullptr) ? (T)0.0f : (T)ldg(&qkv_bias[bias_id]); -+ T val = ldg(&QKV[index]) + bias_val; -+ int tmp_index = index; -+ int h_token_idx = (index) / (3 * n); -+ int batch_id = (padding_offset[h_token_idx] + h_token_idx) / max_seq_len; -+ int seq = (padding_offset[h_token_idx] + h_token_idx) % max_seq_len; -+ h_token_idx -= seq; -+ tmp_index -= h_token_idx * 3 * n; -+ const int seq_id = tmp_index / (3 * n); -+ tmp_index -= seq_id * 3 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ qkv_ptr[qkv_id][h_token_idx * head_num * size_per_head + head_id * d_sequence_length[batch_id] * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+template -+__global__ void add_fusedQKV_bias_transpose_kernel_use_past(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const T* __restrict qkv_bias, -+ const int seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 3, n] -+ // qkv_bias: [3, n] -+ // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < actual_seq_len * 3 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (3 * n); -+ T val = ldg(&QKV[index]) + (T)ldg(&qkv_bias[bias_id]); -+ int tmp_index = index; -+ const int seq_id = tmp_index / (3 * n); -+ tmp_index -= seq_id * 3 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ const int offset = head_id * seq_len * size_per_head + seq_id * size_per_head + size_id; -+ qkv_ptr[qkv_id][offset] = val; -+ } -+} -+template -+__global__ void add_fusedQKV_bias_transpose_kernel_use_past_mb(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const T* __restrict qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int batch_size, -+ const int h_token_num, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool incremental_mode, -+ const bool padding) -+{ -+ // QKV: [m, 3, n] -+ // qkv_bias: [3, n] -+ // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * 3 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (3 * n); -+ T bias_val = (qkv_bias == nullptr) ? (T)0.0f : (T)ldg(&qkv_bias[bias_id]); -+ T val = ldg(&QKV[index]) + bias_val; -+ int tmp_index = index; -+ int h_token_idx = (index) / (3 * n); -+ int batch_id = (padding_offset[h_token_idx] + h_token_idx) / actual_seq_len; -+ int seq_id = (padding_offset[h_token_idx] + h_token_idx) % actual_seq_len; -+ tmp_index -= h_token_idx * 3 * n; -+ h_token_idx -= seq_id; -+ -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ int offset = seq_id * size_per_head + size_id; -+ if ((!padding || incremental_mode) && (qkv_id == 0)) { -+ offset += h_token_idx * n + head_id * d_sequence_length[batch_id] * size_per_head; -+ } else if (qkv_id != 1 || !padding) { -+ offset += batch_id * actual_seq_len * n + head_id * actual_seq_len * size_per_head; -+ } else { -+ offset = batch_id * actual_seq_len * n + head_id * actual_seq_len * size_per_head + size_id * actual_seq_len + seq_id; -+ } -+ if (incremental_mode && !(qkv_id == 0)) { -+ if (qkv_id == 1 && padding) -+ offset += (d_sequence_length2[batch_id] - 1); -+ else { -+ offset += (d_sequence_length2[batch_id] - 1) * size_per_head; -+ } -+ -+ } -+ qkv_ptr[qkv_id][offset] = val; -+ } -+} -+template -+__global__ void transposeQKV_kernel(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 3, n] -+ // qkv_bias: [3, n] -+ // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 3 * n; -+ index += gridDim.x * blockDim.x) { -+ T val = ldg(&QKV[index]); -+ int tmp_index = index; -+ const int target_batch_id = tmp_index / (seq_len * 3 * n); -+ tmp_index -= target_batch_id * seq_len * 3 * n; -+ const int seq_id = tmp_index / (3 * n); -+ tmp_index -= seq_id * 3 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+ -+template -+__global__ void add_fusedQKV_ZP_bias_transpose_kernel(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int token_num, -+ int* mask_offset) -+{ -+ // QKV: [m, 3, n] -+ // qkv_bias: [3, n] -+ // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < token_num * 3 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (3 * n); // 0 - 160 -+ T val = ldg(&QKV[index]); -+ if (qkv_bias != nullptr) -+ val += (T)ldg(&qkv_bias[bias_id]); -+ int tmp_index = index; // 0 -160 * 3 * n -+ int token_id = tmp_index / (3 * n); -+ int batch_id = (token_id + ldg(&mask_offset[token_id])) / seq_len; -+ int seq_id = (token_id + ldg(&mask_offset[token_id])) % seq_len; -+ tmp_index -= token_id * 3 * n; -+ int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ int head_id = tmp_index / size_per_head; -+ int size_id = tmp_index - head_id * size_per_head; -+ int dst_id = batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id; -+ qkv_ptr[qkv_id][dst_id] = val; -+ } -+} -+ -+template -+struct Vec_t {}; -+template<> -+struct Vec_t { -+ using Type = float2; -+}; -+template<> -+struct Vec_t { - using Type = uint32_t; - }; - -@@ -1116,12 +1821,12 @@ struct Vec_t<__nv_bfloat16> { - }; - #endif - --template -+template - __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, - T* k_buf, - T* v_buf, - const T* __restrict QKV, -- const T* __restrict qkv_bias, -+ const U* __restrict qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, -@@ -1174,8 +1879,21 @@ template - void invokeAddFusedQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, -- T* QKV, -+ const T* QKV, - const T* qkv_bias, -+ const int max_seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool inc, -+ cudaStream_t stream); -+ -+template -+void invokeAddFusedQKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, -@@ -1183,23 +1901,714 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, - const int rotary_embedding_dim, - cudaStream_t stream) - { -- if (rotary_embedding_dim == 0) { -+ if (qkv_bias != nullptr) { -+ if (rotary_embedding_dim == 0) { -+ const int m = batch_size * seq_len; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ add_fusedQKV_bias_transpose_kernel<<>>( -+ q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head); -+ } -+ else { -+ // To implement rotary embeddings, each thread processes two QKV elems: -+ dim3 block((size_per_head / 2 + 31) / 32 * 32); -+ dim3 grid(seq_len, head_num, batch_size); -+ add_fusedQKV_bias_transpose_kernel<<>>( -+ q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, rotary_embedding_dim); -+ } -+ } -+ else { - const int m = batch_size * seq_len; - const int n = head_num * size_per_head; - dim3 block(384); - dim3 grid((int)(ceil(1.0 * m * n / 384))); -- add_fusedQKV_bias_transpose_kernel<<>>( -- q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head); -+ transposeQKV_kernel<<>>( -+ q_buf, k_buf, v_buf, QKV, batch_size, seq_len, head_num, size_per_head); - } -- else { -+} -+template -+void invokeAddFusedQKVBiasTransposeMBVSL(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream) -+{ -+ const int m = h_token_num; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ add_fusedQKV_bias_transpose_kernel_mb_vsl<<>>( -+ q_buf, k_buf, v_buf, QKV, qkv_bias, padding_offset, d_sequence_length, h_token_num, batch_size, max_seq_len, head_num, size_per_head); -+} -+template -+void invokeAddFusedQKVBiasTransposeUsePast(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* QKV, -+ const T* qkv_bias, -+ const int seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream) -+{ -+ if (qkv_bias != nullptr) { - // To implement rotary embeddings, each thread processes two QKV elems: - dim3 block((size_per_head / 2 + 31) / 32 * 32); -- dim3 grid(seq_len, head_num, batch_size); -- add_fusedQKV_bias_transpose_kernel<<>>( -- q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, rotary_embedding_dim); -+ dim3 grid(actual_seq_len, head_num); -+ add_fusedQKV_bias_transpose_kernel_use_past<<>>( -+ q_buf, k_buf, v_buf, QKV, qkv_bias, seq_len, actual_seq_len, head_num, size_per_head); -+ } -+ else { -+ std::cout << "null qkv bias not supported" << std::endl; -+ } -+} -+template -+void invokeAddFusedQKVBiasTransposeUsePastMB(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* QKV, -+ const T* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int batch_size, -+ const int h_token_num, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool incremental_mode, -+ const bool padding, -+ cudaStream_t stream) -+{ -+ const int m = h_token_num; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ add_fusedQKV_bias_transpose_kernel_use_past_mb<<>>( -+ q_buf, k_buf, v_buf, QKV, qkv_bias, padding_offset, d_sequence_length, -+ d_sequence_length2, batch_size, h_token_num, -+ actual_seq_len, head_num, size_per_head, -+ incremental_mode, padding); -+} -+template -+void invokeAddFusedZP_QKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ int* padding_mask, -+ cudaStream_t stream) -+{ -+ -+ const int m = h_token; -+ const int n = head_num * size_per_head; -+ cudaMemsetAsync(q_buf, 0, batch_size * seq_len * n * sizeof(T), stream); -+ cudaMemsetAsync(k_buf, 0, batch_size * seq_len * n * sizeof(T), stream); -+ cudaMemsetAsync(v_buf, 0, batch_size * seq_len * n * sizeof(T), stream); -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ add_fusedQKV_ZP_bias_transpose_kernel<<>>( -+ q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, h_token, padding_mask); -+} -+ -+template void invokeAddFusedZP_QKVBiasTranspose(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const float* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ int* padding_mask, -+ cudaStream_t stream); -+ -+template void invokeAddFusedZP_QKVBiasTranspose(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const half* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ int* padding_mask, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedZP_QKVBiasTranspose(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const float* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ const int h_token2, -+ int* padding_mask, -+ int* padding_mask2, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedZP_QKVBiasTranspose(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const half* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ const int h_token2, -+ int* padding_mask, -+ int* padding_mask2, -+ cudaStream_t stream); -+ -+template -+__global__ void invokeCrossAddFusedQKVBiasTransposeQ(T* q_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 1, n] -+ // qkv_bias: [1, n] -+ // q_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[1] = {q_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 1 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (1 * n); -+ T val = ldg(&QKV[index]) + (T)ldg(&qkv_bias[bias_id]); -+ -+ int tmp_index = index; -+ const int target_batch_id = tmp_index / (seq_len * 1 * n); -+ tmp_index -= target_batch_id * seq_len * 1 * n; -+ const int seq_id = tmp_index / (1 * n); -+ tmp_index -= seq_id * 1 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ -+ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+ -+template -+__global__ void invokeCrossAddFusedQKVBiasTransposeQMBVSL(T* q_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int h_token_num, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 1, n] -+ // qkv_bias: [1, n] -+ // q_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[1] = {q_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * 1 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (1 * n); -+ T bias_val = (qkv_bias == nullptr) ? (T)0.0f : (T)ldg(&qkv_bias[bias_id]); -+ T val = ldg(&QKV[index]) + bias_val; -+ -+ int tmp_index = index; -+ int h_token_idx = (index) / (1 * n); -+ int batch_id = (padding_offset[h_token_idx] + h_token_idx) / seq_len; -+ int seq = (padding_offset[h_token_idx] + h_token_idx) % seq_len; -+ h_token_idx -= seq; -+ tmp_index -= h_token_idx * 1 * n; -+ -+ const int seq_id = tmp_index / (1 * n); -+ tmp_index -= seq_id * 1 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ -+ qkv_ptr[qkv_id][h_token_idx * head_num * size_per_head + head_id * d_sequence_length[batch_id] * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+template -+__global__ void add_fusedQKV_ZP_bias_transpose_kernel_q(T* q_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int token_num, -+ int* mask_offset) -+{ -+ // QKV: [m, 1, n] -+ // qkv_bias: [1, n] -+ // q_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[1] = {q_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < token_num * 1 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (1 * n); -+ T val = ldg(&QKV[index]); -+ if (qkv_bias != nullptr) -+ val += (T)ldg(&qkv_bias[bias_id]); -+ int tmp_index = index; -+ // const int target_batch_id = tmp_index / (seq_len * 1 * n); -+ int token_id = tmp_index / (1 * n); -+ int batch_id = (token_id + ldg(&mask_offset[token_id])) / seq_len; -+ int seq_id = (token_id + ldg(&mask_offset[token_id])) % seq_len; -+ tmp_index -= token_id * 1 * n; -+ int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ int head_id = tmp_index / size_per_head; -+ int size_id = tmp_index - head_id * size_per_head; -+ int dst_id = batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id; -+ qkv_ptr[qkv_id][dst_id] = val; -+ } -+} -+ -+template -+__global__ void invokeCrossTransposeQ(T* q_buf, -+ const T* __restrict QKV, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 1, n] -+ // q_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[1] = {q_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 1 * n; -+ index += gridDim.x * blockDim.x) { -+ T val = ldg(&QKV[index]); -+ -+ int tmp_index = index; -+ const int target_batch_id = tmp_index / (seq_len * 1 * n); -+ tmp_index -= target_batch_id * seq_len * 1 * n; -+ const int seq_id = tmp_index / (1 * n); -+ tmp_index -= seq_id * 1 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ -+ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+ -+template -+__global__ void invokeCrossAddFusedQKVBiasTransposeKV(T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 2, n] -+ // qkv_bias: [2, n] -+ // k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[2] = {k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 2 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (2 * n); -+ T val = ldg(&QKV[index]) + (T)ldg(&qkv_bias[bias_id]); -+ -+ int tmp_index = index; -+ const int target_batch_id = tmp_index / (seq_len * 2 * n); -+ tmp_index -= target_batch_id * seq_len * 2 * n; -+ const int seq_id = tmp_index / (2 * n); -+ tmp_index -= seq_id * 2 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+ -+template -+__global__ void invokeCrossAddFusedQKVBiasTransposeKVMBVSL(T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int h_token_num, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 2, n] -+ // qkv_bias: [2, n] -+ // k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[2] = {k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * 2 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (2 * n); -+ T bias_val = (qkv_bias == nullptr) ? (T)0.0f : (T)ldg(&qkv_bias[bias_id]); -+ T val = ldg(&QKV[index]) + bias_val; -+ int tmp_index = index; -+ int h_token_idx = (index) / (2 * n); -+ int batch_id = (padding_offset[h_token_idx] + h_token_idx) / seq_len; -+ int seq = (padding_offset[h_token_idx] + h_token_idx) % seq_len; -+ h_token_idx -= seq; -+ tmp_index -= h_token_idx * 2 * n; -+ const int seq_id = tmp_index / (2 * n); -+ tmp_index -= seq_id * 2 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ qkv_ptr[qkv_id][h_token_idx * head_num * size_per_head + head_id * d_sequence_length[batch_id] * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+ -+template -+__global__ void add_fusedQKV_ZP_bias_transpose_kernel_kv(T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int token_num, -+ int* mask_offset) -+{ -+ // QKV: [m, 2, n] -+ // qkv_bias: [2, n] -+ // k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[2] = {k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < token_num * 2 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (2 * n); -+ T val = ldg(&QKV[index]); -+ if (qkv_bias != nullptr) -+ val += (T)ldg(&qkv_bias[bias_id]); -+ int tmp_index = index; -+ int token_id = tmp_index / (2 * n); -+ int batch_id = (token_id + ldg(&mask_offset[token_id])) / seq_len; -+ int seq_id = (token_id + ldg(&mask_offset[token_id])) % seq_len; -+ tmp_index -= token_id * 2 * n; -+ int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ int head_id = tmp_index / size_per_head; -+ int size_id = tmp_index - head_id * size_per_head; -+ int dst_id = batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id; -+ qkv_ptr[qkv_id][dst_id] = val; - } - } - -+template -+__global__ void invokeCrossTransposeKV(T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 2, n] -+ // k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[2] = {k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 2 * n; -+ index += gridDim.x * blockDim.x) { -+ T val = ldg(&QKV[index]); -+ -+ int tmp_index = index; -+ const int target_batch_id = tmp_index / (seq_len * 2 * n); -+ tmp_index -= target_batch_id * seq_len * 2 * n; -+ const int seq_id = tmp_index / (2 * n); -+ tmp_index -= seq_id * 2 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+ -+template -+void invokeCrossAddFusedQKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream) -+{ -+ if (qkv_bias != nullptr) { -+ const int m = batch_size * seq_len; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ invokeCrossAddFusedQKVBiasTransposeQ<<>>( -+ q_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head); -+ -+ const int m2 = batch_size * tgt_seq_len; -+ const int n2 = head_num * size_per_head; -+ dim3 block2(384); -+ dim3 grid2((int)(ceil(1.0 * m2 * n2 / 384))); -+ invokeCrossAddFusedQKVBiasTransposeKV<<>>( -+ k_buf, v_buf, QKV + m * n, qkv_bias + n2, batch_size, tgt_seq_len, head_num, size_per_head); -+ } -+ else { -+ const int m = batch_size * seq_len; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ invokeCrossTransposeQ<<>>(q_buf, QKV, batch_size, seq_len, head_num, size_per_head); -+ -+ const int m2 = batch_size * tgt_seq_len; -+ const int n2 = head_num * size_per_head; -+ dim3 block2(384); -+ dim3 grid2((int)(ceil(1.0 * m2 * n2 / 384))); -+ invokeCrossTransposeKV<<>>( -+ k_buf, v_buf, QKV + m * n, batch_size, tgt_seq_len, head_num, size_per_head); -+ } -+} -+template -+void invokeCrossAddFusedQKVBiasTransposeMBVSL(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream) -+{ -+ const int m = h_token_num; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ invokeCrossAddFusedQKVBiasTransposeQMBVSL<<>>( -+ q_buf, QKV, qkv_bias, padding_offset, d_sequence_length, h_token_num, batch_size, seq_len, head_num, size_per_head); -+ const int m2 = h_token_num2; -+ const int n2 = head_num * size_per_head; -+ const U* kv_bias = (qkv_bias == nullptr) ? qkv_bias : qkv_bias + n2; -+ dim3 block2(384); -+ dim3 grid2((int)(ceil(1.0 * m2 * n2 / 384))); -+ invokeCrossAddFusedQKVBiasTransposeKVMBVSL<<>>( -+ k_buf, v_buf, QKV + h_token_num * n, kv_bias, padding_offset2, d_sequence_length2, h_token_num2, batch_size, tgt_seq_len, head_num, size_per_head); -+} -+template -+void invokeCrossAddFusedZP_QKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ const int h_token2, -+ int* padding_mask, -+ int* padding_mask2, -+ cudaStream_t stream) -+{ -+ const int m = h_token; -+ const int n = head_num * size_per_head; -+ cudaMemsetAsync(q_buf, 0, batch_size * seq_len * n * sizeof(T), stream); -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ add_fusedQKV_ZP_bias_transpose_kernel_q<<>>( -+ q_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, h_token, padding_mask); -+ -+ const int m2 = h_token2; -+ const int n2 = head_num * size_per_head; -+ cudaMemsetAsync(k_buf, 0, batch_size * tgt_seq_len * n2 * sizeof(T), stream); -+ cudaMemsetAsync(v_buf, 0, batch_size * tgt_seq_len * n2 * sizeof(T), stream); -+ dim3 block2(384); -+ dim3 grid2((int)(ceil(1.0 * m2 * n2 / 384))); -+ qkv_bias = (qkv_bias == nullptr) ? nullptr : qkv_bias + n2; -+ add_fusedQKV_ZP_bias_transpose_kernel_kv<<>>( -+ k_buf, v_buf, QKV + m * n, qkv_bias, batch_size, tgt_seq_len, head_num, size_per_head, h_token2, padding_mask2); -+} -+template void invokeCrossAddFusedQKVBiasTranspose(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const float* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTranspose(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const half* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTranspose(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const half* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTranspose(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const float* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTransposeMBVSL(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const float* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTransposeMBVSL(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const half* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTransposeMBVSL(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const half* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTransposeMBVSL(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const float* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); - template void invokeAddFusedQKVBiasTranspose(float* q_buf, - float* k_buf, - float* v_buf, -@@ -1224,6 +2633,87 @@ template void invokeAddFusedQKVBiasTranspose(half* q_buf, - const int rotary_embedding_dim, - cudaStream_t stream); - -+template void invokeAddFusedQKVBiasTranspose(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const half* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int rotary_embedding_dim, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTranspose(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const float* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int rotary_embedding_dim, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeMBVSL(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const float* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeMBVSL(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const half* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeMBVSL(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const half* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeMBVSL(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const float* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+ - #ifdef ENABLE_BF16 - template void invokeAddFusedQKVBiasTranspose(__nv_bfloat16* q_buf, - __nv_bfloat16* k_buf, -@@ -1236,6 +2726,49 @@ template void invokeAddFusedQKVBiasTranspose(__nv_bfloat16* q_buf, - const int size_per_head, - const int rotary_embedding_dim, - cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeMBVSL(__nv_bfloat16* q_buf, -+ __nv_bfloat16* k_buf, -+ __nv_bfloat16* v_buf, -+ __nv_bfloat16* QKV, -+ const __nv_bfloat16* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+template void invokeCrossAddFusedQKVBiasTranspose(__nv_bfloat16* q_buf, -+ __nv_bfloat16* k_buf, -+ __nv_bfloat16* v_buf, -+ __nv_bfloat16* QKV, -+ const __nv_bfloat16* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+template void invokeCrossAddFusedQKVBiasTransposeMBVSL(__nv_bfloat16* q_buf, -+ __nv_bfloat16* k_buf, -+ __nv_bfloat16* v_buf, -+ __nv_bfloat16* QKV, -+ const __nv_bfloat16* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ - #endif - - template -@@ -1860,4 +3393,419 @@ template void invokeMaskedSoftMaxWithRelPosBias(half* qk_buf, - const float qk_scale, - cudaStream_t stream); - -+template -+__global__ void attention_kernel(T* query_buf, -+ const T* Q_bias, -+ T* key_cache, -+ const T* K_bias, -+ T* value_cache, -+ const T* V_bias, -+ const int* length_per_sample, -+ T* context_buf, -+ const bool* finished, -+ int batch_size, -+ int head_num, -+ int size_per_head, -+ int step, -+ const int seq_len, -+ const T scalar) -+{ -+ if (finished != nullptr && finished[blockIdx.x / head_num] == true) { -+ return; -+ } -+ int tid = threadIdx.x; -+ int bid = blockIdx.x / head_num; -+ int head_id = blockIdx.x % head_num; -+ -+ extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; -+ T* sq = reinterpret_cast(s_buf); -+ T* logits = reinterpret_cast(&sq[size_per_head]); -+ -+ int length = __ldg(&length_per_sample[bid]); -+ -+ int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; -+ int qkv_bias_id = head_id * size_per_head + tid; -+ -+ if (tid < size_per_head) { -+ sq[tid] = query_buf[qkv_id] + Q_bias[qkv_bias_id]; -+ } -+ __syncthreads(); -+ -+ for (int ite = 0; ite < length; ++ite) { -+ int key_id = bid * (seq_len * head_num * size_per_head) + ite * (head_num * size_per_head) -+ + head_id * size_per_head + tid; -+ -+ T key = tid < size_per_head ? key_cache[key_id] : (T)(0.0f); -+ -+ // For the first step, we should add bias to key memory cache. -+ // The KV memory cache only need to be updated at the first step. -+ if (step == 1 && tid < size_per_head) { -+ key += K_bias[head_id * size_per_head + tid]; -+ key_cache[key_id] = key; -+ } -+ -+ T val = (tid < size_per_head) ? key * sq[tid] * scalar : (T)(0.0f); -+ T qk = blockReduceSum(val); -+ if (threadIdx.x == 0) { -+ logits[ite] = qk; -+ } -+ __syncthreads(); // try to remove -+ } -+ __syncthreads(); -+ -+ __shared__ float s_max_val, s_sum; -+ -+ float local_i = tid < length ? (float)logits[tid] : -1e20f; -+ float max_val = blockReduceMax(local_i); -+ if (tid == 0) { -+ s_max_val = max_val; -+ } -+ __syncthreads(); -+ -+ local_i -= s_max_val; -+ float local_o = tid < length ? __expf(local_i) : 0.0f; -+ float val = blockReduceSum(local_o); -+ -+ if (tid == 0) { -+ s_sum = val + 1e-6; -+ } -+ __syncthreads(); -+ if (tid < length) { -+ logits[tid] = local_o / s_sum; -+ } -+ __syncthreads(); -+ -+ if (tid < size_per_head) { -+ T sum = (T)0.0f; -+ for (int ite = 0; ite < length; ++ite) { -+ int value_id = bid * seq_len * head_num * size_per_head + ite * head_num * size_per_head -+ + head_id * size_per_head + tid; -+ -+ T value = value_cache[value_id]; -+ -+ // for the first step, we should add bias to key memory cache -+ if (step == 1) { -+ value += V_bias[head_id * size_per_head + tid]; -+ value_cache[value_id] = value; -+ } -+ sum += value * logits[ite]; -+ } -+ context_buf[bid * head_num * size_per_head + head_id * size_per_head + tid] = sum; -+ } -+} -+ -+template -+__global__ void attention_kernel_opt(const T* __restrict qkv_buf, -+ const T* __restrict qkv_bias, -+ const T* __restrict attr_mask, -+ T* __restrict out_buf, -+ T* __restrict key_cache_output, -+ T* __restrict value_cache_output, -+ int batch_size, -+ int head_num, -+ const int seq_len, -+ const float scalar) -+{ -+ typedef Copy_t copy_t; -+ const int elems_per_thread = size_per_head / WARP_SIZE; -+ union Access_t { -+ copy_t v; -+ T x[elems_per_thread]; // supported size 1,2,4 -+ }; -+ typedef struct Float_n_t { -+ float x[elems_per_thread]; // supported size 1,2,4 -+ } float_n_t; -+ -+ __shared__ float_n_t sq[block_sz]; -+ extern __shared__ float logits[]; // use to store the logits from [0~step] -+ -+ const int warp_id = threadIdx.x / WARP_SIZE; -+ const int warp_num = block_sz / WARP_SIZE; -+ -+ typedef cub::BlockReduce MaxValBlockReduce; -+ typedef cub::BlockReduce BlockReduce; -+ __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; -+ __shared__ typename BlockReduce::TempStorage block_temp_storage; -+ -+ __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; -+ -+ const int tid = threadIdx.x; -+ const int bid = blockIdx.x / head_num; -+ const int head_id = blockIdx.x % head_num; -+ int seq_id = blockIdx.y; -+ -+ int length = seq_len; -+ const int lane_id = tid % WARP_SIZE; -+ -+ // QKV [m 3 n] shape -+ int qkv_id = bid * (3 * seq_len * head_num * size_per_head) + seq_id * (3 * head_num * size_per_head) -+ + head_id * size_per_head; -+ int q_id = -+ bid * (seq_len * head_num * size_per_head) + seq_id * (head_num * size_per_head) + head_id * size_per_head; -+ int qkv_bias_id = head_id * size_per_head; -+ int key_id = bid * (3 * seq_len * head_num * size_per_head) + head_num * size_per_head + head_id * size_per_head; -+ int value_id = -+ bid * (3 * seq_len * head_num * size_per_head) + 2 * head_num * size_per_head + head_id * size_per_head; -+ -+ int key_trn_id = bid * (seq_len * head_num * size_per_head) + head_id * (size_per_head * seq_len); -+ int value_trn_id = bid * (seq_len * head_num * size_per_head) + head_id * (size_per_head * seq_len); -+ int mask_offset = bid * (seq_len * seq_len) + seq_id * seq_len; -+ -+ // get pointers -+ const T* query_buf = qkv_buf + qkv_id; -+ const T* Q_bias = qkv_bias + qkv_bias_id; -+ T* context_buf = out_buf + q_id; -+ -+ const T* key_cache = qkv_buf + key_id; -+ const T* K_bias = qkv_bias + head_num * size_per_head + qkv_bias_id; -+ T* key_cache_out = key_cache_output + key_trn_id; -+ -+ const T* value_cache = qkv_buf + value_id; -+ const T* V_bias = qkv_bias + 2 * head_num * size_per_head + qkv_bias_id; -+ T* value_cache_out = value_cache_output + value_trn_id; -+ -+ Access_t bias_r, key_val_r, query_buf_r; -+ // offset inside head -+ int minor_offset = lane_id; // offset in copy_t elements -+ // each warp will have its own copy of sq -+ query_buf_r.v = *((copy_t*)query_buf + minor_offset); -+ -+ bias_r.v = *((copy_t*)Q_bias + minor_offset); -+ float qb_r[elems_per_thread]; -+#pragma unroll -+ for (int i = 0; i < elems_per_thread; ++i) { -+ qb_r[i] = (float)query_buf_r.x[i] + (float)bias_r.x[i]; -+ } -+ -+ // offset for each step -+ int offset = 3 * head_num * size_per_head; -+ bias_r.v = *((copy_t*)K_bias + minor_offset); -+ for (int ite = warp_id; ite < length; ite += warp_num) { -+ key_val_r.v = *((copy_t*)&key_cache[ite * offset] + minor_offset); -+ -+ if (seq_id == 0) { -+ for (int i = 0; i < elems_per_thread; i++) { -+ key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; -+ key_cache_out[ite + seq_len * (minor_offset * elems_per_thread + i)] = key_val_r.x[i]; -+ } -+ } -+ else { -+ for (int i = 0; i < elems_per_thread; i++) { -+ key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; -+ } -+ } -+ float val = 0; -+ for (int i = 0; i < elems_per_thread; i++) { -+ val = val + (float)key_val_r.x[i] * qb_r[i]; -+ } -+ float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val); -+ -+ if (lane_id == 0) { -+ T mask_val = attr_mask[mask_offset + ite]; -+ mask_val = (1.0f - mask_val) * -10000.0f; -+ logits[ite] = qk * scalar + mask_val; -+ } -+ } -+ -+ __syncthreads(); -+ -+ __shared__ float s_max_val, s_sum; -+ float local_i = -1e20f; -+ for (int i = tid; i < length; i += blockDim.x) { -+ local_i = max(local_i, logits[i]); -+ } -+ -+ float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max()); -+ if (tid == 0) { -+ s_max_val = max_val; -+ } -+ __syncthreads(); -+ -+ float local_o = 0.0f; -+ for (int i = tid; i < length; i += blockDim.x) { -+ logits[i] = __expf(logits[i] - s_max_val); -+ local_o += logits[i]; -+ } -+ float val = BlockReduce(block_temp_storage).Sum(local_o); -+ -+ if (tid == 0) { -+ s_sum = val + 1e-6; -+ } -+ __syncthreads(); -+ -+ float s_sum_inverse = __fdividef(1.0f, s_sum); -+ for (int i = tid; i < length; i += blockDim.x) { -+ logits[i] = logits[i] * s_sum_inverse; -+ } -+ __syncthreads(); -+ -+ // This optimization introduces discrepancy because of different order in FP32 summation -+ float sum_r[elems_per_thread] = {0.f}; -+ bias_r.v = *((copy_t*)V_bias + minor_offset); -+ for (int ite = warp_id; ite < length; ite += warp_num) { -+ key_val_r.v = *((copy_t*)&value_cache[ite * offset] + minor_offset); -+#pragma unroll -+ for (int i = 0; i < elems_per_thread; i++) { -+ key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; -+ } -+ if (seq_id == 0) -+ *((copy_t*)&value_cache_out[ite * size_per_head] + minor_offset) = key_val_r.v; -+#pragma unroll -+ for (int i = 0; i < elems_per_thread; ++i) { -+ sum_r[i] += (float)key_val_r.x[i] * logits[ite]; -+ } -+ } -+ for (int i = 0; i < elems_per_thread; i++) { -+ sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i]; -+ } -+ __syncthreads(); -+ if (threadIdx.x < WARP_SIZE) { -+#pragma unroll -+ for (int j = 1; j < warp_num; j++) { -+ for (int i = 0; i < elems_per_thread; ++i) { -+ sum_r[i] = sum_r[i] + (float)sq[j * WARP_SIZE + threadIdx.x].x[i]; -+ } -+ } -+ } -+ __syncthreads(); -+#pragma unroll -+ for (int i = 0; i < elems_per_thread; i++) { -+ key_val_r.x[i] = sum_r[i]; -+ } -+ if (threadIdx.x < WARP_SIZE) { -+ *((copy_t*)context_buf + minor_offset) = key_val_r.v; -+ } -+} -+ -+template -+void myAttnention(const T* qkv_buf, -+ const T* qkv_bias, -+ const T* attr_mask, -+ T* context_buf, -+ T* key_cache_out, -+ T* value_cache_out, -+ const int inference_batch_size, -+ const int head_num, -+ const int size_per_head, -+ const int seq_len, -+ const float q_scaling, -+ cudaStream_t stream) -+{ -+ const int block_sz = ATTENTION_BLOCK_SIZE; // blockDim.x -+ float scalar = 1.f / (sqrtf(size_per_head * 1.0f) * q_scaling); -+ -+ dim3 grid(inference_batch_size * head_num, seq_len); // gridDim.x gridDim.y -+ int cond = size_per_head * ((ATTENION_OPT) ? 1 : 0); -+ switch (cond) { -+ case 32: -+ attention_kernel_opt -+ <<>>(qkv_buf, -+ qkv_bias, -+ attr_mask, -+ context_buf, -+ key_cache_out, -+ value_cache_out, -+ inference_batch_size, -+ head_num, -+ seq_len, -+ scalar); -+ break; -+ case 64: -+ attention_kernel_opt -+ <<>>(qkv_buf, -+ qkv_bias, -+ attr_mask, -+ context_buf, -+ key_cache_out, -+ value_cache_out, -+ inference_batch_size, -+ head_num, -+ seq_len, -+ scalar); -+ break; -+ case 128: -+ attention_kernel_opt -+ <<>>(qkv_buf, -+ qkv_bias, -+ attr_mask, -+ context_buf, -+ key_cache_out, -+ value_cache_out, -+ inference_batch_size, -+ head_num, -+ seq_len, -+ scalar); -+ break; -+ default:; -+ } -+} -+ -+template void myAttnention(const float* qkv_buf, -+ const float* qkv_bias, -+ const float* attr_mask, -+ float* context_buf, -+ float* key_cache_out, -+ float* value_cache_out, -+ const int inference_batch_size, -+ const int head_num, -+ const int size_per_head, -+ const int seq_len, -+ const float q_scaling, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeUsePast(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ const float* QKV, -+ const float* qkv_bias, -+ const int seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeUsePast(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ const half* QKV, -+ const half* qkv_bias, -+ const int seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeUsePastMB(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ const float* QKV, -+ const float* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int batch_size, -+ const int h_token_num, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool incremental_mode, -+ const bool padding, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeUsePastMB(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ const half* QKV, -+ const half* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int batch_size, -+ const int h_token_num, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool incremental_mode, -+ const bool padding, -+ cudaStream_t stream); - } // namespace fastertransformer -diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.h b/src/fastertransformer/kernels/unfused_attention_kernels.h -index be8b178..f418eab 100644 ---- a/src/fastertransformer/kernels/unfused_attention_kernels.h -+++ b/src/fastertransformer/kernels/unfused_attention_kernels.h -@@ -43,6 +43,23 @@ void invokeMaskedSoftMax(T* buffer, - const T scalar, - cudaStream_t stream); - -+template -+void invokeMixMaskedSoftMax(T* io_buffer, -+ const T_M* attr_mask, -+ const T* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int seq_len, -+ const int seq_len_stride, -+ const int tgt_seq_len, -+ const int tgt_seq_stride, -+ const int head_num, -+ const int position_bias_head_num, -+ const T scalar, -+ const bool causal, -+ cudaStream_t stream); -+ - template - void invokeTransposeQKV(T* dst, - T* src, -@@ -81,12 +98,12 @@ void invokeTransposeAttentionOutRemovePadding(T* src, - const int* mask_offset, - cudaStream_t stream); - --template -+template - void invokeAddFusedQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, - T* QKV, -- const T* qkv_bias, -+ const U* qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, -@@ -97,12 +114,132 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, - q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, 0, stream); - } - -+ - template -+void invokeAddFusedQKVBiasTransposeUsePast(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* QKV, -+ const T* qkv_bias, -+ const int seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+template -+void invokeAddFusedQKVBiasTransposeUsePastMB(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* QKV, -+ const T* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int batch_size, -+ const int h_token_num, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool incremental_mode, -+ const bool padding, -+ cudaStream_t stream); -+ -+template -+ void invokeAddFusedQKVBiasTranspose (T *q_buf, -+ T *k_buf, -+ T *v_buf, -+ const T *QKV, -+ const T *qkv_bias, -+ const int max_seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool inc, -+ cudaStream_t stream) { -+ -+ } -+template -+void invokeAddFusedQKVBiasTransposeMBVSL(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template -+void invokeAddFusedZP_QKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ int *padding_mask, -+ cudaStream_t stream); -+template -+void invokeCrossAddFusedZP_QKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ const int h_token2, -+ int *padding_mask, -+ int *padding_mask2, -+ cudaStream_t stream); -+ -+template -+void invokeCrossAddFusedQKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template -+void invokeCrossAddFusedQKVBiasTransposeMBVSL(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+template - void invokeAddFusedQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, - T* QKV, -- const T* qkv_bias, -+ const U* qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, -@@ -166,4 +303,21 @@ void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf, - const float qk_scale, - cudaStream_t stream); - -+ -+template -+void myAttnention(const T* qkv_buf, -+ const T* qkv_bias, -+ const T *attr_mask, -+ T* context_buf, -+ T* key_cache_out, -+ T* value_cache_out, -+ const int inference_batch_size, -+ const int head_num, -+ const int size_per_head, -+ const int seq_len, -+ const float q_scaling, -+ cudaStream_t stream); -+ -+ -+ - } // namespace fastertransformer -diff --git a/src/fastertransformer/layers/CMakeLists.txt b/src/fastertransformer/layers/CMakeLists.txt -index cbaf4fa..49779bf 100644 ---- a/src/fastertransformer/layers/CMakeLists.txt -+++ b/src/fastertransformer/layers/CMakeLists.txt -@@ -14,6 +14,7 @@ - - cmake_minimum_required(VERSION 3.8) - -+add_subdirectory(ms_layers) - add_subdirectory(attention_layers) - add_subdirectory(attention_layers_int8) - add_subdirectory(xlnet_attention_layers) -@@ -30,15 +31,18 @@ set_property(TARGET FfnLayerINT8 PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET FfnLayerINT8 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(FfnLayerINT8 PUBLIC -lcublasLt -lcublas -lcudart cublasMMWrapper cublasINT8MMWrapper activation_int8_kernels memory_utils) - -+if(EXAMPLES) - add_library(TensorParallelGeluFfnLayer STATIC TensorParallelGeluFfnLayer.cc) - set_property(TARGET TensorParallelGeluFfnLayer PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET TensorParallelGeluFfnLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(TensorParallelGeluFfnLayer PUBLIC -lcudart FfnLayer nccl_utils) - -+ - add_library(TensorParallelReluFfnLayer STATIC TensorParallelReluFfnLayer.cc) - set_property(TARGET TensorParallelReluFfnLayer PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET TensorParallelReluFfnLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(TensorParallelReluFfnLayer PUBLIC -lcudart FfnLayer nccl_utils) -+endif() - - add_library(DynamicDecodeLayer STATIC DynamicDecodeLayer.cc) - set_property(TARGET DynamicDecodeLayer PROPERTY POSITION_INDEPENDENT_CODE ON) -diff --git a/src/fastertransformer/layers/DenseWeight.h b/src/fastertransformer/layers/DenseWeight.h -index 5a5eb6a..c95b97c 100644 ---- a/src/fastertransformer/layers/DenseWeight.h -+++ b/src/fastertransformer/layers/DenseWeight.h -@@ -28,4 +28,5 @@ struct DenseWeight { - const float* scale = nullptr; - }; - -+ - } // namespace fastertransformer -\ No newline at end of file -diff --git a/src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h b/src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h -index b21e3a7..746cb71 100644 ---- a/src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h -+++ b/src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h -@@ -62,13 +62,13 @@ AttentionType getAttentionTypeINT8( - } - } - --template -+template - class BaseAttentionLayer: public BaseLayer { - - public: - virtual void forward(std::vector* output_tensors, - const std::vector* input_tensors, -- const AttentionWeight* attention_weights) = 0; -+ const AttentionWeight* attention_weights) = 0; - BaseAttentionLayer(cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, -diff --git a/src/fastertransformer/layers/attention_layers/CMakeLists.txt b/src/fastertransformer/layers/attention_layers/CMakeLists.txt -index 9cef315..7170af4 100644 ---- a/src/fastertransformer/layers/attention_layers/CMakeLists.txt -+++ b/src/fastertransformer/layers/attention_layers/CMakeLists.txt -@@ -42,8 +42,8 @@ target_link_libraries(DecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasM - add_library(GptContextAttentionLayer STATIC GptContextAttentionLayer.cc) - set_property(TARGET GptContextAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET GptContextAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) --target_link_libraries(GptContextAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils unfused_attention_kernels) -- -+target_link_libraries(GptContextAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils unfused_attention_kernels activation_kernels) -+if(EXAMPLES) - add_library(TensorParallelDecoderSelfAttentionLayer STATIC TensorParallelDecoderSelfAttentionLayer.cc) - set_property(TARGET TensorParallelDecoderSelfAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET TensorParallelDecoderSelfAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -@@ -63,6 +63,7 @@ add_library(TensorParallelUnfusedAttentionLayer STATIC TensorParallelUnfusedAtte - set_property(TARGET TensorParallelUnfusedAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET TensorParallelUnfusedAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(TensorParallelUnfusedAttentionLayer PUBLIC -lcudart UnfusedAttentionLayer nccl_utils) -+endif() - - add_library(WindowAttention STATIC WindowAttention.cc) - set_property(TARGET WindowAttention PROPERTY POSITION_INDEPENDENT_CODE ON) -diff --git a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc -old mode 100644 -new mode 100755 -index bada640..4a48c48 ---- a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc -+++ b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc -@@ -16,10 +16,39 @@ - */ - - #include "src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h" -+#include "src/fastertransformer/kernels/activation_kernels.h" - #include "src/fastertransformer/kernels/unfused_attention_kernels.h" - - namespace fastertransformer { - -+template -+cublasComputeType_t getCublasComputeType() -+{ -+ if (std::is_same::value) -+ return CUBLAS_COMPUTE_16F; -+ -+ else -+ return CUBLAS_COMPUTE_32F_FAST_TF32; -+} -+template -+static void printTensor(char* str, T* input, int size) -+{ -+ printf("%s ", str); -+ T* input_device = input; -+ T* input_host = (T*)malloc(size * sizeof(T)); -+ -+ fastertransformer::cudaD2Hcpy(input_host, input_device, size); -+ -+ for (int k = 0; k < (int)size; k++) { -+ std::cout << input_host[k] << ","; -+ if (k % 10 == 0) -+ std::cout << std::endl; -+ } -+ -+ std::cout << std::endl; -+ -+ free(input_host); -+} - template - void GptContextAttentionLayer::forward(std::vector* output_tensors, - const std::vector* input_tensors, -@@ -34,7 +63,6 @@ void GptContextAttentionLayer::forward(std::vector - // attention_out [batch_size * seq_len, hidden_dimension] - // key_cache [batch, local_head_num, size_per_head // x, max_seq_len, x] - // value_cache [batch, local_head_num, max_seq_len, size_per_head] -- - FT_CHECK(input_tensors->size() == 3); - FT_CHECK(output_tensors->size() == 3); - FT_CHECK(output_tensors->at(1).shape.size() == 5); -@@ -49,7 +77,7 @@ void GptContextAttentionLayer::forward(std::vector - T* attention_out = (T*)output_tensors->at(0).data; - const T* attention_input = (const T*)input_tensors->at(0).data; - const T* attention_mask = (const T*)input_tensors->at(1).data; -- const bool is_final = *((bool*)(input_tensors->at(2).data)); -+ const bool is_final = false; // *((bool*)(input_tensors->at(2).data)); - - const int m = input_tensors->at(0).shape[0]; - -@@ -134,7 +162,7 @@ void GptContextAttentionLayer::forward(std::vector - request_seq_len, - request_seq_len * request_seq_len, - request_batch_size * local_head_num_, -- CUDA_R_32F); -+ getCublasComputeType()); - sync_check_cuda_error(); - T scalar = 1 / sqrtf(size_per_head_ * 1.0f); - invokeMaskedSoftMax(qk_buf_, -diff --git a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h -old mode 100644 -new mode 100755 -index 92e2175..9e90e09 ---- a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h -+++ b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h -@@ -18,7 +18,6 @@ - #pragma once - - #include "src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h" -- - namespace fastertransformer { - - template -diff --git a/src/fastertransformer/layers/ms_layers/BaseLayer.h b/src/fastertransformer/layers/ms_layers/BaseLayer.h -new file mode 100644 -index 0000000..a4078d1 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/BaseLayer.h -@@ -0,0 +1,117 @@ -+/* -+ * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+#include -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include -+ -+namespace fastertransformer { -+ -+class BaseLayerMS{ -+public: -+ typedef int (*allGatherFunc)(const void *input_addr, void *output_addr, size_t count, -+ nvinfer1::DataType data_type, cudaStream_t stream); -+ typedef int (*allReduceSumFunc)(const void *input_addr, void *output_addr, size_t count, -+ nvinfer1::DataType data_type, cudaStream_t stream); -+protected: -+ cublasGemmAlgo_t algo_; -+ size_t ws_offset_{0}; -+ int in_idx_; -+ size_t batch_size_; -+ size_t src_seq_len_; -+ size_t tgt_seq_len_; -+ size_t head_num_; -+ size_t head_size_; -+ size_t hidden_size_; -+ -+ int rank_num_{0}; -+ int rank_id_{0}; -+ BaseLayerMS::allGatherFunc all_gather_func_{nullptr}; -+ BaseLayerMS::allReduceSumFunc all_reduce_sum_func_{nullptr}; -+public: -+ template -+ T* GetBuf(void* ws, size_t buf) -+ { -+ return reinterpret_cast(static_cast(ws) + buf); -+ } -+ virtual void SetWSOffset(size_t ws_offset) -+ { -+ ws_offset_ = ws_offset; -+ } -+ size_t GetWSOffset() -+ { -+ return ws_offset_; -+ } -+ void SetIdx(int idx) -+ { -+ in_idx_ = idx; -+ } -+ int GetIdx() -+ { -+ return in_idx_; -+ } -+ virtual void SetAlgo(cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) -+ { -+ algo_ = algo; -+ } -+ virtual void SetParallelFunc(BaseLayerMS::allGatherFunc all_gather_func, BaseLayerMS::allReduceSumFunc all_reduce_sum_func) -+ { -+ all_gather_func_ = all_gather_func; -+ all_reduce_sum_func_ = all_reduce_sum_func; -+ } -+ BaseLayerMS(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ int rank_num, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP): -+ batch_size_(batch_size), -+ src_seq_len_(src_seq_len), -+ tgt_seq_len_(tgt_seq_len), -+ head_num_(head_num), -+ head_size_(head_size), -+ hidden_size_(hidden_size), -+ rank_num_(rank_num), -+ algo_(algo), -+ in_idx_(0){} -+ virtual ~BaseLayerMS() = default; -+ virtual int GetRankNum() -+ { -+ return rank_num_; -+ } -+ virtual void SetRankNum(int rank_num) -+ { -+ rank_num_ = rank_num; -+ } -+ virtual int GetRankId() -+ { -+ return rank_id_; -+ } -+ virtual void SetRankId(int rank_id) -+ { -+ rank_id_ = rank_id; -+ } -+ virtual size_t GetWorkspaceSize() {return 0;} -+ virtual void forward(std::vector &inputs, const std::vector&outputs, void *ws, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0) = 0; -+}; -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/CMakeLists.txt b/src/fastertransformer/layers/ms_layers/CMakeLists.txt -new file mode 100644 -index 0000000..bd486d1 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/CMakeLists.txt -@@ -0,0 +1,42 @@ -+# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+cmake_minimum_required(VERSION 3.8) -+ -+set(CUTLASS_INSTALL_DIR "${CMAKE_SOURCE_DIR}/3rdparty/cutlass") -+ -+add_library(MSLayer STATIC -+ MSDecoderLayer.cc -+ MSEncoderLayer.cc -+ MSAttentionLayer.cc -+ decoder.cc -+ encoder.cc -+ ffn.cc -+ gemm.cc -+ attention.cc -+ layer_norm.cc -+ opt_allocator.cc -+ debug_utils.cc -+ fmha_cutlass.cu -+ MoeFfnLayer.cu) -+set_property(TARGET MSLayer PROPERTY POSITION_INDEPENDENT_CODE ON) -+set_property(TARGET MSLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -+target_include_directories(MSLayer PUBLIC ${CUTLASS_INSTALL_DIR}/include/ -+ PUBLIC ${CUTLASS_INSTALL_DIR}/tools/util/include/ -+ PUBLIC ${CUTLASS_INSTALL_DIR}/examples/41_fused_multi_head_attention/ -+ PUBLIC ${CUTLASS_INSTALL_DIR}/examples/13_two_tensor_op_fusion/ -+ PUBLIC ${CUTLASS_INSTALL_DIR}/examples/39_gemm_permute/ -+ PUBLIC ${CUTLASS_INSTALL_DIR}/examples/common) -+target_link_libraries(MSLayer PUBLIC -lcublas -lcudart -lnvinfer unfused_attention_kernels activation_kernels -+ layernorm_kernels add_residual_kernels bert_preprocess_kernels) -diff --git a/src/fastertransformer/layers/ms_layers/MSAttentionLayer.cc b/src/fastertransformer/layers/ms_layers/MSAttentionLayer.cc -new file mode 100755 -index 0000000..51ba2e9 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSAttentionLayer.cc -@@ -0,0 +1,218 @@ -+/* -+ * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. -+ * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#include "src/fastertransformer/layers/ms_layers/MSAttentionLayer.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+ -+namespace fastertransformer { -+ -+template -+MSMHALayer::MSMHALayer(size_t max_batch_size, -+ size_t max_src_seq_len, -+ size_t max_tgt_seq_len, -+ size_t head_num, -+ size_t size_per_head, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool is_qk_buf_float, -+ bool is_cross, -+ bool sparse, -+ bool is_position_bias): -+ MSBaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse) -+{ -+ common_param_.stream = stream_; -+ common_param_.cublas_handle = cublas_handle; -+ common_param_.batch_size = max_batch_size; -+ common_param_.src_seq_len = max_src_seq_len; -+ common_param_.tgt_seq_len = max_tgt_seq_len; -+ common_param_.head_num = head_num; -+ common_param_.head_size = size_per_head; -+ common_param_.hidden_size = head_num * size_per_head; -+ common_param_.in_idx = 0; -+ common_param_.algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ common_param_.use_past = false; -+ common_param_.query_layer = false; -+ common_param_.h_token_num = max_batch_size * max_src_seq_len; -+ attn_param_.common_param = &common_param_; -+ attn_param_.attn.qkv_bias = !is_position_bias; -+ attn_param_.attn.projection_bias = !is_position_bias; -+ attn_param_.attn.is_cross = is_cross; -+ attn_param_.attn.position_bias = is_position_bias; -+ attn_param_.attn.mask = true; -+ attn_param_.attn.scale = is_position_bias ? 1.0f : 1.0f / sqrtf(size_per_head * 1.0f); -+ attn_param_.attn.padding_offset = nullptr; -+ this->ms_weights = new AttentionLayerWeight(); -+ -+ attn_layer_ = std::make_shared>( -+ common_param_.batch_size, -+ common_param_.src_seq_len, -+ common_param_.tgt_seq_len, -+ common_param_.head_num, -+ common_param_.head_size, -+ common_param_.hidden_size, -+ attn_param_.attn.qkv_bias, -+ attn_param_.attn.projection_bias, -+ attn_param_.attn.is_cross, -+ attn_param_.attn.position_bias, -+ attn_param_.attn.scale, -+ attn_param_.attn.mask, -+ common_param_.use_past, -+ common_param_.algo); -+ attn_layer_->SetAlgo(common_param_.algo); -+ attn_layer_->SetHTokenNum(common_param_.h_token_num); -+ attn_layer_->SetScale(attn_param_.attn.scale); -+ attn_layer_->SetCross(is_cross); -+ attn_layer_->SetOption(attn_param_.attn.qkv_bias, attn_param_.attn.projection_bias, attn_param_.attn.position_bias, attn_param_.attn.mask); -+} -+template -+void MSMHALayer::allocateBuffer() -+{ -+ if (buf_ == nullptr) { -+ size_t buff_size_allocator = attn_layer_->GetWorkspaceSize(); -+ buf_ = reinterpret_cast(allocator_->reMalloc(buf_, buff_size_allocator, true)); -+ attn_layer_->SetWSOffset(0); -+ } -+} -+template -+int MSMHALayer::forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* weights) -+{ -+ const AttentionLayerWeight* attention_weights = dynamic_cast*>(this->ms_weights); -+ if (attention_weights == NULL) { -+ FT_LOG_ERROR("cast AttentionLayerWeight not sucsses\n"); -+ return -1; -+ } -+ allocateBuffer(); // only once -+ if (attn_param_.attn.position_bias) { -+ if (attn_param_.attn.is_cross) { -+ std::vector outputs = {(void*)output_tensors->at(0).data}; -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)input_tensors->at(1).data, -+ (void*)attention_weights->query_weight.kernel, -+ (void*)attention_weights->key_weight.kernel, -+ (void*)input_tensors->at(2).data, -+ (void*)input_tensors->at(3).data, -+ (void*)attention_weights->attention_output_weight.kernel}; -+ attn_layer_->forward(inputs, outputs, buf_, common_param_.cublas_handle, common_param_.stream); -+ } -+ else { -+ std::vector outputs = {(void*)output_tensors->at(0).data}; -+ std::vectorinputs = {(void*)input_tensors->at(0).data, -+ (void*)attention_weights->query_weight.kernel, -+ (void*)input_tensors->at(1).data, -+ (void*)input_tensors->at(2).data, -+ (void*)attention_weights->attention_output_weight.kernel}; -+ attn_layer_->forward(inputs, outputs, buf_, common_param_.cublas_handle, common_param_.stream); -+ } -+ } -+ else { -+ if (attn_param_.attn.is_cross) { -+ std::vector outputs= {(void*)output_tensors->at(0).data}; -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)input_tensors->at(1).data, -+ (void*)attention_weights->query_weight.kernel, -+ (void*)attention_weights->key_weight.kernel, -+ (void*)attention_weights->query_weight.bias, -+ (void*)input_tensors->at(2).data, -+ (void*)attention_weights->attention_output_weight.kernel, -+ (void*)attention_weights->attention_output_weight.bias}; -+ attn_layer_->forward(inputs, outputs,buf_); -+ } -+ else { -+ std::vector outputs = {(void*)output_tensors->at(0).data}; -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)attention_weights->query_weight.kernel, -+ (void*)attention_weights->query_weight.bias, -+ (void*)input_tensors->at(1).data, -+ (void*)attention_weights->attention_output_weight.kernel, -+ (void*)attention_weights->attention_output_weight.bias}; -+ attn_layer_->forward(inputs, outputs, buf_, common_param_.cublas_handle, common_param_.stream); -+ } -+ } -+ return 0; -+} -+template -+MSMHALayer::~MSMHALayer() -+{ -+ cublas_wrapper_ = nullptr; -+ freeBuffer(); -+} -+ -+template -+void MSMHALayer::freeBuffer() -+{ -+ if (buf_ != nullptr) { -+ allocator_->free(buf_); -+ buf_ = nullptr; -+ } -+} -+template -+int MSMHALayer::InitWeight(opt_arg* opt_a, MSLayerWeight* ms_weights, std::vector w_tensors) -+{ -+ AttentionLayerWeight* attn_weights = dynamic_cast*>(this->ms_weights); -+ if (attn_weights == NULL) { -+ FT_LOG_ERROR("cast AttentionLayerWeight not sucsses\n"); -+ return -1; -+ } -+ int modelId = ModelNum(opt_a->model_name); -+ if (modelId == MHA_X1) { -+ attn_weights->query_weight.kernel = reinterpret_cast(w_tensors[0].data); -+ attn_weights->attention_output_weight.kernel = reinterpret_cast(w_tensors[1].data); -+ attn_weights->query_weight.bias = reinterpret_cast(w_tensors[2].data); -+ attn_weights->attention_output_weight.bias = reinterpret_cast(w_tensors[3].data); -+ } -+ else if (modelId == MHA_X2 || modelId == MHA_CROSS) { -+ attn_weights->query_weight.kernel = reinterpret_cast(w_tensors[0].data); -+ attn_weights->query_weight.bias = reinterpret_cast(w_tensors[1].data); -+ attn_weights->key_weight.kernel = reinterpret_cast(w_tensors[2].data); -+ attn_weights->attention_output_weight.kernel = reinterpret_cast(w_tensors[3].data); -+ attn_weights->attention_output_weight.bias = reinterpret_cast(w_tensors[4].data); -+ } -+ else if (modelId == MHA_T5) { -+ attn_weights->query_weight.kernel = reinterpret_cast(w_tensors[0].data); -+ attn_weights->query_weight.bias = nullptr; -+ attn_weights->attention_output_weight.kernel = reinterpret_cast(w_tensors[1].data); -+ attn_weights->attention_output_weight.bias = nullptr; -+ } -+ else if (modelId == MHA_T5_CROSS) { -+ attn_weights->query_weight.kernel = reinterpret_cast(w_tensors[0].data); -+ attn_weights->query_weight.bias = nullptr; -+ attn_weights->key_weight.kernel = reinterpret_cast(w_tensors[1].data); -+ attn_weights->attention_output_weight.kernel = reinterpret_cast(w_tensors[2].data); -+ attn_weights->attention_output_weight.bias = nullptr; -+ } -+ else { -+ FT_LOG_ERROR("illegal model !\n"); -+ return -1; -+ } -+ return 0; -+} -+ -+template class MSMHALayer; -+template class MSMHALayer; -+template class MSMHALayer; -+template class MSMHALayer; -+template class MSMHALayer; -+template class MSMHALayer; -+template class MSMHALayer; -+template class MSMHALayer; -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSAttentionLayer.h b/src/fastertransformer/layers/ms_layers/MSAttentionLayer.h -new file mode 100755 -index 0000000..b8f1ce0 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSAttentionLayer.h -@@ -0,0 +1,68 @@ -+/* -+ * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -+ * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/attention.h" -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+#include "src/fastertransformer/layers/ms_layers/param.h" -+namespace fastertransformer { -+ -+// TODO(haim): Add template according to "mix" compute type (fp32, fp16) -+template -+class MSMHALayer: public MSBaseLayer{ -+private: -+ void allocateBuffer() override; -+ void freeBuffer() override; -+ -+ attentionParamRun attn_param_; -+ CommonParam common_param_; -+ std::shared_ptr> attn_layer_; -+ using MSBaseLayer::is_free_buffer_after_forward_; -+ using MSBaseLayer::is_allocate_buffer_; -+ using MSBaseLayer::cublas_wrapper_; -+ using MSBaseLayer::allocator_; -+ -+protected: -+ using MSBaseLayer::stream_; -+ using MSBaseLayer::sparse_; -+ T* buf_ = nullptr; -+ -+public: -+ MSMHALayer(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t size_per_head, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool is_qk_buf_float, -+ bool is_cross, -+ bool sparse = false, -+ bool is_position_bias=false); -+ MSMHALayer(MSMHALayer const& attention_layer); -+ virtual ~MSMHALayer(); -+ int forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* weights) override; -+ int InitWeight(opt_arg* opt_a, MSLayerWeight* ms_weights, std::vector w_tensors) override; -+}; -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSBaseLayer.h b/src/fastertransformer/layers/ms_layers/MSBaseLayer.h -new file mode 100644 -index 0000000..b1f4d56 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSBaseLayer.h -@@ -0,0 +1,147 @@ -+/* -+ * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "3rdparty/trt_fused_multihead_attention/fused_multihead_attention_common.h" -+#include "src/fastertransformer/layers/BaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/MSLayerWeight.h" -+#include "src/fastertransformer/utils/Tensor.h" -+#include "src/fastertransformer/utils/allocator.h" -+#include "src/fastertransformer/utils/cublasMMWrapper.h" -+#include "src/fastertransformer/utils/memory_utils.h" -+namespace fastertransformer { -+ -+enum class MSLayerType { -+ UNFUSED_MS_LAYER, -+ FUSED_MS_LAYER -+}; -+struct opt_arg { -+ size_t batch_size; -+ size_t num_layers; -+ size_t seq_len; // source seq len -+ size_t tgt_seq_len; -+ size_t head_num; -+ size_t hidden_size; -+ size_t size_per_head; -+ float eps1; -+ float eps2; -+ float eps3; -+ bool position_bias1; -+ bool position_bias2; -+ bool post_layernorm_residual; -+ bool is_ffn_fp16; -+ bool is_remove_padding; -+ std::string model_name; -+ std::string compute_type; -+ std::string w_compute_type; -+ std::string s_compute_type; -+ size_t ffn_hidden_size; -+ size_t expand_ratio; -+}; -+typedef enum { -+ MHA_X1 = 1, // AttnIn + AttnMask -+ MHA_X2, // AttnIn + EncOut -- same seq size + AttnMask -+ MHA_CROSS, // AttnIn + EncOut + AttnMAsk -+ MHA_T5, // AttnIn + EncOut + AttnMAsk + position_bias -+ MHA_T5_CROSS, // AttnIn + EncOut + AttnMAsk + position_bias -+ TEL, // transformer encoder layer -+ TEL_T5, // transformer encoder layer -+ TDL, -+ TDL_T5, -+} MODEL_TEST_ID_E; -+template -+MSLayerType getMSLayerType( -+ size_t size_per_head, const int sm, const bool remove_padding, const int max_seq_len, const bool is_fuse = true) -+{ -+ if (std::is_same::value && (sm == kSM_70 || sm == kSM_86 || sm == kSM_80 || sm == kSM_75 || sm == kSM_72) -+ && size_per_head == 64 && max_seq_len <= 384 && is_fuse == true) { -+ return remove_padding ? MSLayerType::FUSED_MS_LAYER : MSLayerType::FUSED_MS_LAYER; -+ } -+ else { -+ return remove_padding ? MSLayerType::FUSED_MS_LAYER : MSLayerType::FUSED_MS_LAYER; -+ } -+} -+ -+template -+MSLayerType getMSLayerTypeINT8( -+ size_t size_per_head, const int sm, const bool remove_padding, const int max_seq_len, const int int8_mode) -+{ -+ if ((int8_mode == 1 || int8_mode == 2) && (sm == kSM_86 || sm == kSM_80 || sm == kSM_75) && size_per_head == 64 -+ && max_seq_len <= 384) { -+ return remove_padding ? MSLayerType::FUSED_MS_LAYER : MSLayerType::FUSED_MS_LAYER; -+ } -+ else { -+ return remove_padding ? MSLayerType::FUSED_MS_LAYER : MSLayerType::FUSED_MS_LAYER; -+ } -+} -+ -+template -+class MSBaseLayer: public BaseLayer { -+protected: -+public: -+ MSLayerWeight* ms_weights; -+ virtual int forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* layer_weights) = 0; -+ MSBaseLayer(cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool sparse = false): -+ BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr, sparse) -+ { -+ } -+ virtual int InitWeight(opt_arg* opt_a, MSLayerWeight* weights, std::vector w_tensors) = 0; -+ virtual ~MSBaseLayer() = default; -+}; -+static int ModelNum(std::string model_name) -+{ -+ if (model_name == "mha_x1") { -+ return MHA_X1; -+ } -+ else if (model_name == "mha_x2") { -+ return MHA_X2; -+ } -+ else if (model_name == "mha_cross") { -+ return MHA_CROSS; -+ } -+ else if (model_name == "mha_T5") { -+ return MHA_T5; -+ } -+ else if (model_name == "mha_T5_cross") { -+ return MHA_T5_CROSS; -+ } -+ else if (model_name == "transformer_encoder_layer") { -+ return TEL; -+ } -+ else if (model_name == "transformer_encoder_layer_t5") { -+ return TEL_T5; -+ } -+ else if (model_name == "transformer_decoder_layer") { -+ return TDL; -+ } -+ else if (model_name == "transformer_decoder_layer_t5") { -+ return TDL_T5; -+ } -+ else { -+ return -1; -+ } -+} -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSDecoderLayer.cc b/src/fastertransformer/layers/ms_layers/MSDecoderLayer.cc -new file mode 100644 -index 0000000..7250635 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSDecoderLayer.cc -@@ -0,0 +1,248 @@ -+/* -+ * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. -+ * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#include "src/fastertransformer/layers/ms_layers/MSDecoderLayer.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+ -+namespace fastertransformer { -+ -+template -+MSDLayer::MSDLayer(size_t max_batch_size, -+ size_t max_src_seq_len, -+ size_t max_tgt_seq_len, -+ size_t head_num, -+ size_t size_per_head, -+ size_t ffn_hidden_size, -+ float eps1, -+ float eps2, -+ float eps3, -+ bool post_layernorm, -+ bool position_bias1, -+ bool position_bias2, -+ bool is_ffn_fp16, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool is_qk_buf_float, -+ bool sparse): -+ -+ MSBaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), buf_(nullptr) -+{ //update commonparam -+ decoder_param_.common_param.stream = stream_; -+ decoder_param_.common_param.cublas_handle = cublas_handle; -+ decoder_param_.common_param.batch_size = max_batch_size; -+ decoder_param_.common_param.src_seq_len = max_src_seq_len; -+ decoder_param_.common_param.tgt_seq_len = max_tgt_seq_len; -+ decoder_param_.common_param.head_num = head_num; -+ decoder_param_.common_param.head_size = size_per_head; -+ decoder_param_.common_param.hidden_size = head_num * size_per_head; -+ decoder_param_.common_param.in_idx = 0; -+ decoder_param_.common_param.algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ decoder_param_.common_param.h_token_num = max_src_seq_len * max_batch_size; -+ decoder_param_.common_param.h_token_num2 = max_tgt_seq_len * max_batch_size; -+ //connect commonparam to attention and ffn -+ decoder_param_.attn1.common_param = &decoder_param_.common_param; -+ decoder_param_.attn2.common_param = &decoder_param_.common_param; -+ decoder_param_.ffn_param.common_param = &decoder_param_.common_param; -+ -+ decoder_param_.ffn_param.ffn_param.ffn_hidden_size = ffn_hidden_size; -+ decoder_param_.ffn_param.ffn_param.ffn_bias = !position_bias1; -+ decoder_param_.ffn_param.ffn_param.ffn_fp16 = is_ffn_fp16; -+ decoder_param_.ffn_param.ffn_param.act_type = !position_bias1 ? ActType::ActType_Gelu : ActType::ActType_Relu; // true; -+ decoder_param_.decoder.eps1 = eps1; -+ decoder_param_.decoder.eps2 = eps2; -+ decoder_param_.decoder.eps3 = eps3; -+ decoder_param_.decoder.layernorm_post = post_layernorm; -+ decoder_param_.decoder.has_beta = !position_bias1; -+ decoder_param_.attn1.attn.qkv_bias = !position_bias1; -+ decoder_param_.attn1.attn.projection_bias = !position_bias1; -+ decoder_param_.attn1.attn.is_cross = false; -+ decoder_param_.attn1.attn.position_bias = position_bias1; -+ decoder_param_.attn1.attn.scale = position_bias1 ? 1.0f : 1.0f / sqrtf(size_per_head * 1.0f); -+ decoder_param_.attn1.attn.mask = true; -+ decoder_param_.attn2.attn.position_bias = position_bias2; -+ decoder_param_.attn2.attn.qkv_bias = !position_bias2; -+ decoder_param_.attn2.attn.projection_bias = !position_bias2; -+ decoder_param_.attn2.attn.is_cross = true; -+ decoder_param_.attn2.attn.scale = position_bias2 ? 1.0f : 1.0f / sqrtf(size_per_head * 1.0f); -+ decoder_param_.attn2.attn.mask = true; -+ this->ms_weights = new DecoderLayerWeight(); -+ decoder_layer_ = std::make_shared>( -+ decoder_param_.common_param.batch_size, -+ decoder_param_.common_param.src_seq_len, -+ decoder_param_.common_param.tgt_seq_len, -+ decoder_param_.common_param.head_num, -+ decoder_param_.common_param.head_size, -+ decoder_param_.common_param.hidden_size -+ ); -+ decoder_layer_->SetEps(decoder_param_.decoder.eps1, decoder_param_.decoder.eps2, decoder_param_.decoder.eps3, decoder_param_.decoder.eps3); -+ decoder_layer_->SetIsCross(decoder_param_.attn2.attn.is_cross); -+ decoder_layer_->SetScaleAttn(decoder_param_.attn1.attn.scale); -+ decoder_layer_->SetIsLayerNorm(false, 1e-6f); -+ decoder_layer_->SetFfnParam(decoder_param_.ffn_param.ffn_param.ffn_fp16, decoder_param_.ffn_param.ffn_param.ffn_hidden_size, (FfnBase::ActType)decoder_param_.ffn_param.ffn_param.act_type, decoder_param_.ffn_param.ffn_param.ffn_bias); -+ decoder_layer_->SetT5(position_bias1); -+ decoder_layer_->SetHTokenNum(decoder_param_.common_param.h_token_num, decoder_param_.common_param.h_token_num2); -+ decoder_layer_->SetAlgo(CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+} -+ -+template -+void MSDLayer::allocateBuffer() -+{ -+ if (buf_ == nullptr) { -+ size_t buff_size_allocator = decoder_layer_->GetWorkspaceSize(); -+ buf_ = reinterpret_cast(allocator_->reMalloc(buf_, buff_size_allocator, true)); -+ } -+} -+ -+template -+void MSDLayer::freeBuffer() -+{ -+ if (buf_ != nullptr) { -+ allocator_->free(buf_); -+ buf_ = nullptr; -+ } -+} -+ -+template -+MSDLayer::~MSDLayer() -+{ -+ cublas_wrapper_ = nullptr; -+ freeBuffer(); -+} -+ -+template -+int MSDLayer::forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* weights) -+{ -+ const DecoderLayerWeight* decoder_weights = dynamic_cast*>(this->ms_weights); -+ if (decoder_weights == NULL) { -+ FT_LOG_ERROR("cast DecoderLayerWeight not sucsses\n"); -+ return -1; -+ } -+ allocateBuffer(); // only once -+ std::vector outputs= {(void*)output_tensors->at(0).data}; -+ if (decoder_param_.attn1.attn.qkv_bias && decoder_param_.attn2.attn.qkv_bias && !decoder_param_.attn1.attn.position_bias -+ && !decoder_param_.attn2.attn.position_bias) { -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)decoder_weights->layernorm1.gamma, -+ (void*)decoder_weights->layernorm1.beta, -+ (void*)decoder_weights->attention.query_weight.kernel, -+ (void*)decoder_weights->attention.query_weight.bias, -+ (void*)input_tensors->at(1).data, -+ (void*)decoder_weights->attention.attention_output_weight.kernel, -+ (void*)decoder_weights->attention.attention_output_weight.bias, -+ (void*)decoder_weights->layernorm2.gamma, -+ (void*)decoder_weights->layernorm2.beta, -+ (void*)input_tensors->at(2).data, -+ (void*)decoder_weights->cross_attention.query_weight.kernel, -+ (void*)decoder_weights->cross_attention.key_weight.kernel, -+ (void*)decoder_weights->cross_attention.query_weight.bias, -+ (void*)input_tensors->at(3).data, -+ (void*)decoder_weights->cross_attention.attention_output_weight.kernel, -+ (void*)decoder_weights->cross_attention.attention_output_weight.bias, -+ (void*)decoder_weights->layernorm3.gamma, -+ (void*)decoder_weights->layernorm3.beta, -+ (void*)decoder_weights->decoder_output_mapping.kernel, -+ (void*)decoder_weights->decoder_output_mapping.bias, -+ (void*)decoder_weights->decoder_output_projection.kernel, -+ (void*)decoder_weights->decoder_output_projection.bias}; -+ decoder_layer_->forward(inputs, outputs, buf_, decoder_param_.common_param.cublas_handle, decoder_param_.common_param.stream); -+ } -+ else if (decoder_param_.attn1.attn.position_bias && decoder_param_.attn2.attn.position_bias) { -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)decoder_weights->layernorm1.gamma, -+ (void*)decoder_weights->attention.query_weight.kernel, -+ (void*)input_tensors->at(1).data, -+ (void*)input_tensors->at(4).data, -+ (void*)decoder_weights->attention.attention_output_weight.kernel, -+ (void*)decoder_weights->layernorm2.gamma, -+ (void*)input_tensors->at(2).data, -+ (void*)decoder_weights->cross_attention.query_weight.kernel, -+ (void*)decoder_weights->cross_attention.key_weight.kernel, -+ (void*)input_tensors->at(3).data, -+ (void*)input_tensors->at(5).data, -+ (void*)decoder_weights->cross_attention.attention_output_weight.kernel, -+ (void*)decoder_weights->layernorm3.gamma, -+ (void*)decoder_weights->decoder_output_mapping.kernel, -+ (void*)decoder_weights->decoder_output_projection.kernel}; -+ decoder_layer_->forward(inputs, outputs, buf_, decoder_param_.common_param.cublas_handle, decoder_param_.common_param.stream); -+ } -+ return 0; -+} -+ -+template -+int MSDLayer::InitWeight(opt_arg* opt_a, MSLayerWeight* ms_weights, std::vector w_tensors) -+{ -+ DecoderLayerWeight* decoder_weights = dynamic_cast*>(this->ms_weights); -+ if (decoder_weights == NULL) { -+ FT_LOG_ERROR("cast DecoderLayerWeight not sucsses\n"); -+ return -1; -+ } -+ int modelId = ModelNum(opt_a->model_name); -+ if (modelId == TDL) { -+ decoder_weights->layernorm1.gamma = (const U*)w_tensors[0].data; -+ decoder_weights->layernorm1.beta = (const U*)w_tensors[1].data; -+ decoder_weights->attention.query_weight.kernel = (const U*)w_tensors[2].data; -+ decoder_weights->attention.query_weight.bias = (const U*)w_tensors[3].data; -+ decoder_weights->attention.attention_output_weight.kernel = (const U*)w_tensors[4].data; -+ decoder_weights->attention.attention_output_weight.bias = (const U*)w_tensors[5].data; -+ decoder_weights->layernorm2.gamma = (const U*)w_tensors[6].data; -+ decoder_weights->layernorm2.beta = (const U*)w_tensors[7].data; -+ decoder_weights->cross_attention.query_weight.kernel = (const U*)w_tensors[8].data; -+ decoder_weights->cross_attention.key_weight.kernel = (const U*)w_tensors[9].data; -+ decoder_weights->cross_attention.query_weight.bias = (const U*)w_tensors[10].data; -+ decoder_weights->cross_attention.key_weight.bias = (const U*)w_tensors[10].data; -+ decoder_weights->cross_attention.attention_output_weight.kernel = (const U*)w_tensors[11].data; -+ decoder_weights->cross_attention.attention_output_weight.bias = (const U*)w_tensors[12].data; -+ decoder_weights->layernorm3.gamma = (const U*)w_tensors[13].data; -+ decoder_weights->layernorm3.beta = (const U*)w_tensors[14].data; -+ decoder_weights->decoder_output_mapping.kernel = (const U*)w_tensors[15].data; -+ decoder_weights->decoder_output_mapping.bias = (const U*)w_tensors[16].data; -+ decoder_weights->decoder_output_projection.kernel = (const U*)w_tensors[17].data; -+ decoder_weights->decoder_output_projection.bias = (const U*)w_tensors[18].data; -+ } -+ else if (modelId == TDL_T5) { -+ decoder_weights->layernorm1.gamma = (const U*)w_tensors[0].data; -+ decoder_weights->attention.query_weight.kernel = (const U*)w_tensors[1].data; -+ decoder_weights->attention.attention_output_weight.kernel = (const U*)w_tensors[2].data; -+ decoder_weights->layernorm2.gamma = (const U*)w_tensors[3].data; -+ decoder_weights->cross_attention.query_weight.kernel = (const U*)w_tensors[4].data; -+ decoder_weights->cross_attention.key_weight.kernel = (const U*)w_tensors[5].data; -+ decoder_weights->cross_attention.attention_output_weight.kernel = (const U*)w_tensors[6].data; -+ decoder_weights->layernorm3.gamma = (const U*)w_tensors[7].data; -+ decoder_weights->decoder_output_mapping.kernel = (const U*)w_tensors[8].data; -+ decoder_weights->decoder_output_projection.kernel = (const U*)w_tensors[9].data; -+ } -+ else { -+ FT_LOG_ERROR("illegal model !\n"); -+ return -1; -+ } -+ return 0; -+} -+template class MSDLayer; -+template class MSDLayer; -+template class MSDLayer; -+template class MSDLayer; -+template class MSDLayer; -+template class MSDLayer; -+template class MSDLayer; -+template class MSDLayer; -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSDecoderLayer.h b/src/fastertransformer/layers/ms_layers/MSDecoderLayer.h -new file mode 100644 -index 0000000..b8f870c ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSDecoderLayer.h -@@ -0,0 +1,74 @@ -+/* -+ * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -+ * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+ -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/decoder.h" -+#include "src/fastertransformer/layers/ms_layers/param.h" -+ -+namespace fastertransformer { -+ -+// TODO(haim): Add template according to "mix" compute type (fp32, fp16) -+template -+class MSDLayer: public MSBaseLayer { -+private: -+ mutable decoderParamRun decoder_param_; -+ void allocateBuffer() override; -+ void freeBuffer() override; -+ void* buf_; -+ using MSBaseLayer::is_free_buffer_after_forward_; -+ using MSBaseLayer::is_allocate_buffer_; -+ using MSBaseLayer::cublas_wrapper_; -+ using MSBaseLayer::allocator_; -+ std::shared_ptr> decoder_layer_; -+protected: -+ using MSBaseLayer::stream_; -+ using MSBaseLayer::sparse_; -+ -+public: -+ MSDLayer(size_t max_batch_size, -+ size_t max_src_seq_len, -+ size_t max_tgt_seq_len, -+ size_t head_num, -+ size_t size_per_head, -+ size_t ffn_hidden_size, -+ float eps1, -+ float eps2, -+ float eps3, -+ bool post_layernorm, -+ bool position_bias1, -+ bool position_bias2, -+ bool is_ffn_fp16, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool is_qk_buf_float, -+ bool sparse); -+ -+ MSDLayer(MSDLayer const& decoder_layer); -+ -+ virtual ~MSDLayer(); -+ -+ int forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* weights) override; -+ int InitWeight(opt_arg* opt_a, MSLayerWeight* ms_weights, std::vector w_tensors) override; -+}; -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSEncoderLayer.cc b/src/fastertransformer/layers/ms_layers/MSEncoderLayer.cc -new file mode 100644 -index 0000000..9b18049 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSEncoderLayer.cc -@@ -0,0 +1,250 @@ -+/* -+ * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. -+ * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#include "src/fastertransformer/layers/ms_layers/MSEncoderLayer.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+ -+namespace fastertransformer { -+ -+template -+MSELayer::MSELayer(size_t max_batch_size, -+ size_t max_src_seq_len, -+ size_t max_tgt_seq_len, -+ size_t head_num, -+ size_t size_per_head, -+ size_t ffn_hidden_size, -+ float eps1, -+ float eps2, -+ bool post_layernorm, -+ bool position_bias, -+ bool is_ffn_fp16, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool is_qk_buf_float, -+ bool sparse): -+ -+ MSBaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), buf_(nullptr) -+{ -+ // update commonparam -+ encoder_param_.common_param.stream = stream_; -+ encoder_param_.common_param.cublas_handle = cublas_handle; -+ encoder_param_.common_param.batch_size = max_batch_size; -+ encoder_param_.common_param.src_seq_len = max_src_seq_len; -+ encoder_param_.common_param.tgt_seq_len = max_tgt_seq_len; -+ encoder_param_.common_param.head_num = head_num; -+ encoder_param_.common_param.head_size = size_per_head; -+ encoder_param_.common_param.hidden_size = head_num * size_per_head; -+ encoder_param_.common_param.algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ encoder_param_.common_param.in_idx = 0; -+ // connect commonparam to attention and ffn -+ encoder_param_.attn.common_param = &encoder_param_.common_param; -+ encoder_param_.ffn_param.common_param = &encoder_param_.common_param; -+ encoder_param_.common_param.h_token_num = max_src_seq_len * max_batch_size; -+ // update encoder_param_ -+ encoder_param_.encoder.layernorm_post = post_layernorm; -+ encoder_param_.encoder.eps1 = eps1; -+ encoder_param_.encoder.eps2 = eps2; -+ encoder_param_.ffn_param.ffn_param.ffn_hidden_size = ffn_hidden_size; -+ encoder_param_.ffn_param.ffn_param.ffn_fp16 = is_ffn_fp16; -+ encoder_param_.attn.attn.projection_bias = !position_bias; -+ encoder_param_.attn.attn.is_cross = false; -+ encoder_param_.attn.attn.position_bias = position_bias; -+ encoder_param_.attn.attn.qkv_bias = !position_bias; -+ encoder_param_.encoder.has_beta = !position_bias; -+ encoder_param_.ffn_param.ffn_param.ffn_bias = !position_bias; -+ encoder_param_.ffn_param.ffn_param.act_type = -+ !position_bias ? ActType::ActType_Gelu : ActType::ActType_Relu; // true; -+ encoder_param_.attn.attn.scale = position_bias ? 1.0f : 1.0f / sqrtf(size_per_head * 1.0f); -+ encoder_param_.attn.attn.mask = true; -+ this->ms_weights = new EncoderLayerWeight(); -+ encoder_layer_ = std::make_shared>( -+ encoder_param_.common_param.batch_size, -+ encoder_param_.common_param.src_seq_len, -+ encoder_param_.common_param.head_num, -+ encoder_param_.common_param.head_size, -+ encoder_param_.common_param.hidden_size -+ ); -+ -+ encoder_layer_->SetT5(encoder_param_.attn.attn.position_bias); -+ encoder_layer_->SetScaleAttn(encoder_param_.attn.attn.scale); -+ encoder_layer_->SetUsePast(false); -+ encoder_layer_->SetIsLayerNorm(false, 1e-6f); -+ encoder_layer_->SetHTokenNum(encoder_param_.common_param.h_token_num, encoder_param_.common_param.h_token_num); -+ encoder_layer_->SetFfnParam(encoder_param_.ffn_param.ffn_param.ffn_fp16, encoder_param_.ffn_param.ffn_param.ffn_hidden_size, (FfnBase::ActType)encoder_param_.ffn_param.ffn_param.act_type, encoder_param_.ffn_param.ffn_param.ffn_bias); -+ encoder_layer_->SetQueryLayer(false); -+ encoder_layer_->SetAlgo(CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+} -+ -+template -+void MSELayer::allocateBuffer() -+{ -+ if (buf_ == nullptr) { -+ size_t buff_size_allocator = encoder_layer_->GetWorkspaceSize(); -+ buf_ = reinterpret_cast(allocator_->reMalloc(buf_, sizeof(T) * buff_size_allocator, true)); -+ encoder_layer_->SetWSOffset(0); -+ } -+} -+ -+template -+void MSELayer::freeBuffer() -+{ -+ if (buf_ != nullptr) { -+ allocator_->free(buf_); -+ buf_ = nullptr; -+ } -+} -+ -+template -+MSELayer::~MSELayer() -+{ -+ cublas_wrapper_ = nullptr; -+ freeBuffer(); -+} -+ -+template -+int MSELayer::forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* weights) -+{ -+ const EncoderLayerWeight* encoder_weights = dynamic_cast*>(this->ms_weights); -+ if (encoder_weights == NULL) { -+ FT_LOG_ERROR("cast EncoderLayerWeight not sucsses\n"); -+ return -1; -+ } -+ allocateBuffer(); // only once -+ std::vector outputs= {(void*)output_tensors->at(0).data}; -+ if (!encoder_param_.encoder.layernorm_post) { -+ if (encoder_param_.attn.attn.position_bias) { -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)encoder_weights->layernorm1.gamma, -+ (void*)encoder_weights->attention.query_weight.kernel, -+ (void*)input_tensors->at(1).data, -+ (void*)input_tensors->at(2).data, -+ (void*)encoder_weights->attention.attention_output_weight.kernel, -+ (void*)encoder_weights->layernorm2.gamma, -+ (void*)encoder_weights->encoder_output_mapping.kernel, -+ (void*)encoder_weights->encoder_output_projection.kernel -+ -+ }; -+ encoder_layer_->forward(inputs,outputs,buf_, encoder_param_.common_param.cublas_handle, encoder_param_.common_param.stream); -+ } -+ else { -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)encoder_weights->layernorm1.gamma, -+ (void*)encoder_weights->layernorm1.beta, -+ (void*)encoder_weights->attention.query_weight.kernel, -+ (void*)encoder_weights->attention.query_weight.bias, -+ (void*)input_tensors->at(1).data, -+ (void*)encoder_weights->attention.attention_output_weight.kernel, -+ (void*)encoder_weights->attention.attention_output_weight.bias, -+ (void*)encoder_weights->layernorm2.gamma, -+ (void*)encoder_weights->layernorm2.beta, -+ (void*)encoder_weights->encoder_output_mapping.kernel, -+ (void*)encoder_weights->encoder_output_mapping.bias, -+ (void*)encoder_weights->encoder_output_projection.kernel, -+ (void*)encoder_weights->encoder_output_projection.bias}; -+ encoder_layer_->forward(inputs, outputs,buf_, encoder_param_.common_param.cublas_handle, encoder_param_.common_param.stream); -+ } -+ } -+ else { -+ if (encoder_param_.attn.attn.position_bias) { -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)encoder_weights->layernorm1.gamma, -+ (void*)encoder_weights->attention.query_weight.kernel, -+ (void*)input_tensors->at(1).data, -+ (void*)input_tensors->at(2).data, -+ (void*)encoder_weights->attention.attention_output_weight.kernel, -+ (void*)encoder_weights->layernorm2.gamma, -+ (void*)encoder_weights->encoder_output_mapping.kernel, -+ (void*)encoder_weights->encoder_output_projection.kernel -+ -+ }; -+ encoder_layer_->forward(inputs, outputs,buf_, encoder_param_.common_param.cublas_handle, encoder_param_.common_param.stream); -+ } -+ else { -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)encoder_weights->attention.query_weight.kernel, -+ (void*)encoder_weights->attention.query_weight.bias, -+ (void*)input_tensors->at(1).data, -+ (void*)encoder_weights->attention.attention_output_weight.kernel, -+ (void*)encoder_weights->attention.attention_output_weight.bias, -+ (void*)encoder_weights->layernorm1.gamma, -+ (void*)encoder_weights->layernorm1.beta, -+ (void*)encoder_weights->encoder_output_mapping.kernel, -+ (void*)encoder_weights->encoder_output_mapping.bias, -+ (void*)encoder_weights->encoder_output_projection.kernel, -+ (void*)encoder_weights->encoder_output_projection.bias, -+ (void*)encoder_weights->layernorm2.gamma, -+ (void*)encoder_weights->layernorm2.beta}; -+ encoder_layer_->forward(inputs, outputs, buf_, encoder_param_.common_param.cublas_handle, encoder_param_.common_param.stream); -+ } -+ } -+ -+ return 0; -+} -+ -+template -+int MSELayer::InitWeight(opt_arg* opt_a, MSLayerWeight* ms_weights, std::vector w_tensors) -+{ -+ EncoderLayerWeight* encoder_weights = dynamic_cast*>(this->ms_weights); -+ if (encoder_weights == NULL) { -+ FT_LOG_ERROR("cast EncoderLayerWeight not sucsses\n"); -+ return -1; -+ } -+ int modelId = ModelNum(opt_a->model_name); -+ if (modelId == TEL) { -+ encoder_weights->attention.query_weight.kernel = reinterpret_cast(w_tensors[2].data); -+ encoder_weights->attention.query_weight.bias = reinterpret_cast(w_tensors[3].data); -+ encoder_weights->attention.attention_output_weight.kernel = reinterpret_cast(w_tensors[4].data); -+ encoder_weights->attention.attention_output_weight.bias = reinterpret_cast(w_tensors[5].data); -+ encoder_weights->layernorm1.gamma = reinterpret_cast(w_tensors[0].data); -+ encoder_weights->layernorm1.beta = reinterpret_cast(w_tensors[1].data); -+ encoder_weights->layernorm2.gamma = reinterpret_cast(w_tensors[6].data); -+ encoder_weights->layernorm2.beta = reinterpret_cast(w_tensors[7].data); -+ encoder_weights->encoder_output_mapping.kernel = reinterpret_cast(w_tensors[8].data); -+ encoder_weights->encoder_output_projection.kernel = reinterpret_cast(w_tensors[10].data); -+ encoder_weights->encoder_output_mapping.bias = reinterpret_cast(w_tensors[9].data); -+ encoder_weights->encoder_output_projection.bias = reinterpret_cast(w_tensors[11].data); -+ } -+ else if (modelId == TEL_T5) { -+ encoder_weights->attention.query_weight.kernel = reinterpret_cast(w_tensors[1].data); -+ encoder_weights->attention.attention_output_weight.kernel = reinterpret_cast(w_tensors[2].data); -+ encoder_weights->layernorm1.gamma = reinterpret_cast(w_tensors[0].data); -+ encoder_weights->layernorm2.gamma = reinterpret_cast(w_tensors[3].data); -+ encoder_weights->encoder_output_mapping.kernel = reinterpret_cast(w_tensors[4].data); -+ encoder_weights->encoder_output_projection.kernel = reinterpret_cast(w_tensors[5].data); -+ } -+ else { -+ FT_LOG_ERROR("illegal model !\n"); -+ return -1; -+ } -+ return 0; -+} -+template class MSELayer; -+template class MSELayer; -+template class MSELayer; -+template class MSELayer; -+template class MSELayer; -+template class MSELayer; -+template class MSELayer; -+template class MSELayer; -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSEncoderLayer.h b/src/fastertransformer/layers/ms_layers/MSEncoderLayer.h -new file mode 100644 -index 0000000..358a3ca ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSEncoderLayer.h -@@ -0,0 +1,72 @@ -+/* -+ * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -+ * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+ -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/encoder.h" -+#include "src/fastertransformer/layers/ms_layers/param.h" -+namespace fastertransformer { -+ -+// TODO(haim): Add template according to "mix" compute type (fp32, fp16) -+template -+class MSELayer: public MSBaseLayer { -+private: -+ void allocateBuffer() override; -+ void freeBuffer() override; -+ void* buf_; -+ using MSBaseLayer::is_free_buffer_after_forward_; -+ using MSBaseLayer::is_allocate_buffer_; -+ using MSBaseLayer::cublas_wrapper_; -+ using MSBaseLayer::allocator_; -+ std::shared_ptr> encoder_layer_; -+protected: -+ using MSBaseLayer::stream_; -+ using MSBaseLayer::sparse_; -+ -+public: -+ encoderParamRun encoder_param_; -+ MSELayer(size_t max_batch_size, -+ size_t max_src_seq_len, -+ size_t max_tgt_seq_len, -+ size_t head_num, -+ size_t size_per_head, -+ size_t ffn_hidden_size, -+ float eps1, -+ float eps2, -+ bool post_layernorm, -+ bool position_bias, -+ bool is_ffn_fp16, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool is_qk_buf_float, -+ bool sparse); -+ -+ MSELayer(MSELayer const& encoder_layer); -+ -+ virtual ~MSELayer(); -+ -+ int forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* weights) override; -+int InitWeight(opt_arg* opt_a, MSLayerWeight* ms_weights, std::vector w_tensors) override; -+}; -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSLayerWeight.h b/src/fastertransformer/layers/ms_layers/MSLayerWeight.h -new file mode 100644 -index 0000000..d4db37d ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSLayerWeight.h -@@ -0,0 +1,55 @@ -+/* -+ * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+ -+#include "src/fastertransformer/kernels/layernorm_kernels.h" -+#include "src/fastertransformer/layers/DenseWeight.h" -+namespace fastertransformer { -+ -+template -+struct MSLayerWeight { -+ virtual ~MSLayerWeight() {} -+}; -+ -+template -+struct AttentionLayerWeight: MSLayerWeight { -+ DenseWeight query_weight; -+ DenseWeight key_weight; -+ DenseWeight value_weight; -+ DenseWeight attention_output_weight; -+}; -+ -+template -+struct DecoderLayerWeight: MSLayerWeight { -+ AttentionLayerWeight attention; -+ AttentionLayerWeight cross_attention; -+ DenseWeight decoder_output_mapping; -+ DenseWeight decoder_output_projection; -+ LayerNormWeight layernorm1; -+ LayerNormWeight layernorm2; -+ LayerNormWeight layernorm3; -+}; -+ -+template -+struct EncoderLayerWeight: MSLayerWeight { -+ AttentionLayerWeight attention; -+ DenseWeight encoder_output_mapping; -+ DenseWeight encoder_output_projection; -+ LayerNormWeight layernorm1; -+ LayerNormWeight layernorm2; -+}; -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MoeFfnLayer.cu b/src/fastertransformer/layers/ms_layers/MoeFfnLayer.cu -new file mode 100644 -index 0000000..ccf9189 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MoeFfnLayer.cu -@@ -0,0 +1,629 @@ -+ -+#include "MoeFfnLayer.h" -+#include "cublas_api.h" -+#include "cuda_kernels.h" -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include -+#include -+ -+ -+namespace fastertransformer { -+ -+cublasStatus_t cublasGemmExRowMajor(cublasHandle_t handle, -+ cublasOperation_t transa, -+ cublasOperation_t transb, -+ int m, -+ int n, -+ int k, -+ const void* alpha, -+ const void* A, -+ cudaDataType_t Atype, -+ const void* B, -+ cudaDataType_t Btype, -+ const void* beta, -+ void* C, -+ cudaDataType_t Ctype, -+ cublasComputeType_t computeType, -+ cublasGemmAlgo_t algo) -+{ -+ const void* B_ = A; -+ const void* A_ = B; -+ int lda = (transa == CUBLAS_OP_T) ? m : k; -+ int ldb = (transb == CUBLAS_OP_T) ? k : n; -+ int lda_ = ldb; -+ int ldb_ = lda; -+ int ldc = n; -+ int m_ = n; -+ int n_ = m; -+ cudaDataType_t Atype_ = Btype; -+ cudaDataType_t Btype_ = Atype; -+ cublasOperation_t transa_ = transb; -+ cublasOperation_t transb_ = transa; -+ -+ return cublasGemmEx(handle, -+ transa_, -+ transb_, -+ m_, -+ n_, -+ k, -+ alpha, -+ A_, -+ Atype_, -+ lda_, -+ B_, -+ Btype_, -+ ldb_, -+ beta, -+ C, -+ Ctype, -+ ldc, -+ computeType, -+ algo); -+} -+ -+cublasStatus_t cublasGemmStridedBatchedExRowMajor(cublasHandle_t handle, -+ cublasOperation_t transa, -+ cublasOperation_t transb, -+ int m, -+ int n, -+ int k, -+ const void* alpha, -+ const void* A, -+ cudaDataType_t Atype, -+ const void* B, -+ cudaDataType_t Btype, -+ const void* beta, -+ void* C, -+ cudaDataType_t Ctype, -+ int batchCount, -+ cublasComputeType_t computeType, -+ cublasGemmAlgo_t algo) -+{ -+ const void* B_ = A; -+ const void* A_ = B; -+ int lda = (transa == CUBLAS_OP_T) ? m : k; -+ int ldb = (transb == CUBLAS_OP_T) ? k : n; -+ int lda_ = ldb; -+ int ldb_ = lda; -+ int ldc = n; -+ int m_ = n; -+ int n_ = m; -+ cudaDataType_t Atype_ = Btype; -+ cudaDataType_t Btype_ = Atype; -+ cublasOperation_t transa_ = transb; -+ cublasOperation_t transb_ = transa; -+ long long int stride_a = m * k; -+ long long int stride_b = n * k; -+ long long int stride_c = n * m; -+ long long int stride_a_ = stride_b; -+ long long int stride_b_ = stride_a; -+ -+ return cublasGemmStridedBatchedEx(handle, -+ transa_, -+ transb_, -+ m_, -+ n_, -+ k, -+ alpha, -+ A_, -+ Atype_, -+ lda_, -+ stride_a_, -+ B_, -+ Btype_, -+ ldb_, -+ stride_b_, -+ beta, -+ C, -+ Ctype, -+ ldc, -+ stride_c, -+ batchCount, -+ computeType, -+ algo); -+} -+ -+cublasStatus_t cublasGemmArrBatchedExRowMajor(cublasHandle_t handle, -+ cublasOperation_t transa, -+ cublasOperation_t transb, -+ int m, -+ int n, -+ int k, -+ const void* alpha, -+ const void* A[], -+ cudaDataType_t Atype, -+ const void* B[], -+ cudaDataType_t Btype, -+ const void* beta, -+ void* C[], -+ cudaDataType_t Ctype, -+ int batchCount, -+ cublasComputeType_t computeType, -+ cublasGemmAlgo_t algo) -+{ -+ const void** B_ = A; -+ const void** A_ = B; -+ int lda = (transa == CUBLAS_OP_T) ? m : k; -+ int ldb = (transb == CUBLAS_OP_T) ? k : n; -+ int lda_ = ldb; -+ int ldb_ = lda; -+ int ldc = n; -+ int m_ = n; -+ int n_ = m; -+ cudaDataType_t Atype_ = Btype; -+ cudaDataType_t Btype_ = Atype; -+ cublasOperation_t transa_ = transb; -+ cublasOperation_t transb_ = transa; -+ -+ return cublasGemmBatchedEx(handle, -+ transa_, -+ transb_, -+ m_, -+ n_, -+ k, -+ alpha, -+ A_, -+ Atype_, -+ lda_, -+ B_, -+ Btype_, -+ ldb_, -+ beta, -+ C, -+ Ctype, -+ ldc, -+ batchCount, -+ computeType, -+ algo); -+} -+ -+PanguMoeFfnLayer::PanguMoeFfnLayer(int hidden_size, -+ int expert_num, -+ int ffn_hidden_size, -+ int rank_num, -+ int seq_len, -+ float expert_capability, -+ int batch_size): -+ expert_num_(expert_num), -+ ffn_hidden_size_(ffn_hidden_size), -+ expert_capability_(expert_capability), -+ BaseLayerMS(batch_size, seq_len, seq_len, 0, 0, hidden_size, rank_num) -+{ -+ max_capacity_ = static_cast(std::ceil(expert_capability_ * src_seq_len_ / expert_num_) + 0.01f); -+} -+ -+size_t PanguMoeFfnLayer::GetWorkspaceSize() -+{ -+ size_t size = ALIGN(sizeof(half) * expert_num_ * gather_stride(), ALIGN_SIZE) + // gather tokens (sort by expert) -+ ALIGN(sizeof(int) * expert_num_ * (router_stride() + 1), ALIGN_SIZE) + // router -+ 4 * ALIGN(expert_num_ * sizeof(half*), ALIGN_SIZE); // Group/Batch matmul arrays -+ -+ size_t s1 = ALIGN(sizeof(int) * expert_num_ * batch_size_, ALIGN_SIZE); // capcity per batch and expert -+ size_t s2 = ALIGN(sizeof(half) * ffn_hidden_size_ * max_capacity_, ALIGN_SIZE); // size of mapping mm -+ size_t s3 = ALIGN(sizeof(half) * hidden_size_, ALIGN_SIZE); // size of projection bias (AllGather) -+ size_t s4 = ALIGN(sizeof(half) * expert_num_ * ffn_hidden_size_ * experimental_threshold() * batch_size_, ALIGN_SIZE); // every expert (in incremental can have up to batch token allocated) -+ size_t mx = max(s1, s2); -+ mx = max(mx, s3); -+ mx = max(mx, s4); -+ return size + mx; -+} -+ -+__global__ void HashRouter(const int* expert_id, -+ int capacity, -+ const int* padding_offset, -+ const int* seq_length, -+ int* ws, -+ int* router, -+ int token_num, -+ int seq_len, -+ int batch, -+ int router_stride) -+{ -+ int e_id = blockIdx.x; -+ int* r = router + gridDim.x + router_stride * e_id; -+ int w_id = 0; -+ -+ // zero actual capacity per batch -+ ws = ws + e_id * batch; -+ for (int i = 0; i < batch; i++) { -+ ws[i] = (capacity < seq_length[i]) ? capacity : seq_length[i]; -+ } -+ // route tokens to expert (priority for earlier tokens) -+ for (int i = 0; i < token_num; i++) { -+ int element_offset = padding_offset[i] + i; -+ int b_id = element_offset / seq_len; -+ int cur_expert_id = (expert_id[i] > 0) ? expert_id[i] : 0; // do Relu -+ if ((cur_expert_id == e_id) && (ws[b_id] >= 0)) { -+ r[w_id++] = i; -+ ws[b_id]--; -+ } -+ } -+ w_id = (w_id == 1) ? (r[0] | (1 << 31)) : w_id; -+ // Total tokens per expert -+ router[e_id] = w_id; -+} -+ -+__global__ void -+HashRouterGather(int* router, const half* in, half* gather, int router_stride, int gather_stride, int hidden_size) -+{ -+ int e_id = blockIdx.x; -+ int* r = router + gridDim.x + router_stride * e_id; -+ half* g = gather + gather_stride * e_id; -+ int tokens_per_expert = router[e_id]; -+ if (tokens_per_expert & (1 << 31)) return; -+ for (int index = threadIdx.x; index < hidden_size * tokens_per_expert; index += blockDim.x) { -+ int token_idx = index / hidden_size; -+ int hid_idx = index % hidden_size; -+ int token_id = r[token_idx]; -+ int src_offset = token_id * hidden_size + hid_idx; -+ int dst_offset = token_idx * hidden_size + hid_idx; -+ -+ g[dst_offset] = in[src_offset]; -+ } -+} -+ -+int GatherByExpert(const int* expert_id, -+ int expert_num, -+ int max_capacity, -+ const int* padding_offset, -+ const int* seq_length, -+ void* ws, -+ const half* in, -+ int* router, -+ half* gather, -+ int token_num, -+ int seq_len, -+ int batch, -+ int router_stride, -+ int gather_stride, -+ int hidden_size, -+ int* expert_per_token_h, -+ cudaStream_t stream) -+{ -+ // Step I - Build router -+ // ARR[EXPERT#] # of tokens per expert -+ // 0 token # [list of tokens] -+ // 1 token # [list of tokens] -+ // . -+ // . -+ // . -+ // 15 token # [list of tokens] -+ dim3 grid(expert_num); -+ dim3 block(1); -+ HashRouter<<>>( -+ expert_id, max_capacity, padding_offset, seq_length, (int*)ws, router, token_num, seq_len, batch, router_stride); -+ // Step II - Gather tokens -+ // 0 [List of token data] -+ // 1 [List of token data] -+ // . -+ // . -+ // . -+ // 15 [list of tokens] -+ -+ dim3 grid1(expert_num); -+ dim3 block1(1024); -+ HashRouterGather<<>>(router, in, gather, router_stride, gather_stride, hidden_size); -+ // step III - Copy token# per expert to host device -+ cudaMemcpyAsync(expert_per_token_h, router, sizeof(int) * expert_num, cudaMemcpyDeviceToHost, stream); -+ return 0; -+} -+ -+__global__ void -+HashRouterScatter(int* router, const half* gather, half* scater, int router_stride, int gather_stride, int hidden_size) -+{ -+ int e_id = blockIdx.x; -+ int* r = router + gridDim.x + router_stride * e_id; -+ const half* g = gather + gather_stride * e_id; -+ int tokens_per_expert = router[e_id]; -+ if (tokens_per_expert & (1 << 31)) return; -+ for (int index = threadIdx.x; index < hidden_size * tokens_per_expert; index += blockDim.x) { -+ int token_idx = index / hidden_size; -+ int hid_idx = index % hidden_size; -+ int token_id = r[token_idx]; -+ int src_offset = token_idx * hidden_size + hid_idx; -+ int dst_offset = token_id * hidden_size + hid_idx; -+ scater[dst_offset] = g[src_offset]; -+ } -+} -+ -+int ScatterByExpert(int* router, -+ half* gather, -+ half* scatter, -+ int expert_num, -+ int router_stride, -+ int gather_stride, -+ int hidden_size, -+ cudaStream_t stream) -+{ -+ dim3 grid1(expert_num); -+ dim3 block1(1024); -+ HashRouterScatter<<>>(router, gather, scatter, router_stride, gather_stride, hidden_size); -+ return 0; -+} -+ -+void PanguMoeFfnLayer::forward(std::vector& inputs, -+ const std::vector& outputs, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ in_idx_ = 0; -+ int moe_id = expert_offset_; -+ half* layernorm = reinterpret_cast(inputs[in_idx_++]); -+ const int* expert_ids_ptr = reinterpret_cast(inputs[in_idx_++]); -+ half* weight1_ptr = reinterpret_cast(inputs[in_idx_++]); -+ half* bias1_ptr = reinterpret_cast(inputs[in_idx_++]); -+ half* weight2_ptr = reinterpret_cast(inputs[in_idx_++]); -+ half* bias2_ptr = reinterpret_cast(inputs[in_idx_++]); -+ const int* padding_offset_d = padding_offset_d_; -+ int* seq_length_d = seq_len_d_; -+ half* output = reinterpret_cast(outputs[0]); -+ // Get WS pointer -+ ws = GetBuf(ws, ws_offset_); -+ bool incremental = batch_size_ >= h_token_num_; -+ int token_number = h_token_num_; -+ expert_ids_ptr += moe_id * token_number; -+ if (token_number == 1) { // special handle when token# is 1 -+ forward_single_token(layernorm, expert_ids_ptr, weight1_ptr, bias1_ptr, weight2_ptr, bias2_ptr, output, ws, cublas_handle, stream); -+ return; -+ } -+ // router & gather are used till end of process - malloc first -+ int* router = reinterpret_cast(ws); -+ half* gather = reinterpret_cast(router + ALIGN(expert_num_ * (router_stride() + 1), ALIGN_SIZE)); -+ void* workspace = reinterpret_cast(gather + ALIGN(expert_num_ * gather_stride(), ALIGN_SIZE)); -+ -+ // Step I - zero output (in incremental mode all tokens are set) -+ if (!incremental) cudaMemsetAsync(output, 0, sizeof(half) * token_number * hidden_size_, stream); -+ -+ // Step II - gather tokens according to expert id -+ int expert_per_token_h[expert_num_]; -+ GatherByExpert(expert_ids_ptr, -+ expert_num_, -+ max_capacity_, -+ padding_offset_d, -+ seq_length_d, -+ workspace, -+ layernorm, -+ router, -+ gather, -+ token_number, -+ src_seq_len_, -+ batch_size_, -+ router_stride(), -+ gather_stride(), -+ hidden_size_, -+ expert_per_token_h, -+ stream); -+ -+ cudaStreamSynchronize(stream); // make sure expert_per_token_h is update to host -+ -+ // Step III - Run FFN per expert -+ for (int ei = 0; ei < expert_num_; ei++) { -+ int expert_token_num = expert_per_token_h[ei]; -+ if (!(expert_token_num & (1 << 31)) && expert_token_num > experimental_threshold()) { -+ int g_offset = ei * gather_stride(); -+ half* g = gather + g_offset; -+ forward_expert(g, weight1_ptr, bias1_ptr, weight2_ptr, bias2_ptr, g, workspace, ei, expert_token_num, cublas_handle, stream); -+ expert_per_token_h[ei] = 0; -+ } -+ } -+ forward_expert_experimental(layernorm, expert_per_token_h, weight1_ptr, bias1_ptr, weight2_ptr, bias2_ptr, output, workspace, cublas_handle, stream); -+ -+ // step IV - Scatter tokens into output according to router -+ ScatterByExpert(router, gather, output, expert_num_, router_stride(), gather_stride(), hidden_size_, stream); -+} -+ -+void PanguMoeFfnLayer::forward_single_token(half *in, const int *expert_in_ids, half *weight1_ptr, half *bias1_ptr, half *weight2_ptr, half *bias2_ptr, half* output, void *workspace, cublasHandle_t cublas_handle, cudaStream_t stream) { -+ int expert_token_num = 1; -+ int expert_id; -+ cudaMemcpyAsync(&expert_id, expert_in_ids, sizeof(int), cudaMemcpyDeviceToHost, stream); -+ forward_expert(in, weight1_ptr, bias1_ptr, weight2_ptr, bias2_ptr, output, workspace, expert_id, expert_token_num, cublas_handle, stream); -+} -+ -+void PanguMoeFfnLayer::forward_expert(half* in, -+ half* weight1_ptr, -+ half* bias1_ptr, -+ half* weight2_ptr, -+ half* bias2_ptr, -+ half* output, -+ void *workspace, -+ int expert_id, -+ int expert_token_num, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ float alpha = (float)(1.0f); -+ float beta = (float)(0.0f); -+ cublasGemmAlgo_t algo = (cublasGemmAlgo_t)CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; -+ -+ int w_offset = expert_id * hidden_size_ * ffn_hidden_size_; -+ int b1_offset = expert_id * ffn_hidden_size_; -+ int b2_size = hidden_size_;// / rank_num_; -+ int b2_offset = expert_id * b2_size; -+ half* w1 = weight1_ptr + w_offset; -+ half* w2 = weight2_ptr + w_offset; -+ half* bias1 = bias1_ptr + b1_offset; -+ half* bias2 = bias2_ptr + b2_offset; -+ half* mm1 = reinterpret_cast(workspace); -+ cublasGemmExRowMajor(cublas_handle, -+ CUBLAS_OP_N, -+ CUBLAS_OP_N, -+ expert_token_num, -+ ffn_hidden_size_, -+ hidden_size_, -+ &alpha, -+ in, -+ CUDA_R_16F, -+ w1, -+ CUDA_R_16F, -+ &beta, -+ mm1, -+ CUDA_R_16F, -+ compute_type, -+ algo); -+ if (act_type_ == FfnBase::ActType::Gelu) { -+ invokeAddBiasGelu(mm1, bias1, expert_token_num, ffn_hidden_size_, stream); -+ } -+ else if (act_type_ == FfnBase::ActType::FastGelu) { -+ invokeAddBiasFastGelu(mm1, bias1, expert_token_num, ffn_hidden_size_, stream); -+ } -+ cublasGemmExRowMajor(cublas_handle, -+ CUBLAS_OP_N, -+ CUBLAS_OP_N, -+ expert_token_num, -+ hidden_size_, -+ ffn_hidden_size_, -+ &alpha, -+ mm1, -+ CUDA_R_16F, -+ w2, -+ CUDA_R_16F, -+ &beta, -+ output, -+ CUDA_R_16F, -+ compute_type, -+ algo); -+ if (all_reduce_sum_func_ != nullptr) { -+ (all_reduce_sum_func_)(output, output, hidden_size_ * expert_token_num, nvinfer1::DataType::kHALF, stream); -+ } -+ -+ invokeAddBias(output, bias2, expert_token_num, hidden_size_, stream); -+} -+ -+ -+void PanguMoeFfnLayer::forward_expert_experimental(half* input, int *expert_per_token_h, -+ half* weight1_ptr, -+ half* bias1_ptr, -+ half* weight2_ptr, -+ half* bias2_ptr, -+ half* output, -+ void *workspace, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ float alpha = (float)(1.0f); -+ float beta = (float)(0.0f); -+ cublasGemmAlgo_t algo = (cublasGemmAlgo_t)CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; -+ -+ int cnt = 0; -+ half* a1_arr_h[expert_num_]; -+ half* b1_arr_h[expert_num_]; -+ half* b2_arr_h[expert_num_]; -+ half* c1_arr_h[expert_num_]; -+ -+ uint8_t* align = reinterpret_cast(workspace); -+ a1_arr_ = reinterpret_cast(align); -+ b1_arr_ = a1_arr_ + expert_num_; -+ b2_arr_ = b1_arr_ + expert_num_; -+ c1_arr_ = b2_arr_ + expert_num_; -+ -+ half* mm1 = reinterpret_cast(c1_arr_ + ALIGN(expert_num_, ALIGN_SIZE)); -+ size_t mapping_stride = ffn_hidden_size_ * 1 * batch_size_; -+ // prepare GEMM arrays -+ int mx_token = 0; -+ for (int ei = 0; ei < expert_num_; ei++) { -+ bool out_flag = expert_per_token_h[ei] & (1 << 31); -+ int token_num = out_flag ? 1 : expert_per_token_h[ei]; -+ mx_token = max(mx_token, token_num); -+ if (token_num) { -+ a1_arr_h[cnt] = out_flag ? input + (expert_per_token_h[ei] & ~ (1 << 31)) * hidden_size_ : input + ei * gather_stride(); -+ b1_arr_h[cnt] = weight1_ptr + ei * hidden_size_ * ffn_hidden_size_; -+ b2_arr_h[cnt] = weight2_ptr + ei * hidden_size_ * ffn_hidden_size_; -+ c1_arr_h[cnt] = mm1 + ei * mapping_stride; -+ cnt++; -+ } -+ } -+ if (cnt == 0) return; -+ // copy array to device -+ cudaMemcpyAsync(a1_arr_, a1_arr_h, cnt * sizeof(half*), cudaMemcpyHostToDevice, stream); -+ cudaMemcpyAsync(b1_arr_, b1_arr_h, cnt * sizeof(half*), cudaMemcpyHostToDevice, stream); -+ cudaMemcpyAsync(b2_arr_, b2_arr_h, cnt * sizeof(half*), cudaMemcpyHostToDevice, stream); -+ cudaMemcpyAsync(c1_arr_, c1_arr_h, cnt * sizeof(half*), cudaMemcpyHostToDevice, stream); -+ -+ cublasGemmArrBatchedExRowMajor(cublas_handle, -+ CUBLAS_OP_N, -+ CUBLAS_OP_N, -+ mx_token, -+ ffn_hidden_size_, -+ hidden_size_, -+ &alpha, -+ (const void**)a1_arr_, -+ CUDA_R_16F, -+ (const void**)b1_arr_, -+ CUDA_R_16F, -+ &beta, -+ (void**)c1_arr_, -+ CUDA_R_16F, -+ cnt, -+ compute_type, -+ algo); -+ int arr_idx = 0; -+ for (int ei = 0; ei < expert_num_; ei++) { -+ bool out_flag = expert_per_token_h[ei] & (1 << 31); -+ int token_num = out_flag ? 1 : expert_per_token_h[ei]; -+ if (token_num) { -+ int b1_offset = ei * ffn_hidden_size_; -+ half* bias1 = bias1_ptr + b1_offset; -+ if (act_type_ == FfnBase::ActType::Gelu) { -+ invokeAddBiasGelu(c1_arr_h[arr_idx++], bias1, token_num, ffn_hidden_size_, stream); -+ } -+ else if (act_type_ == FfnBase::ActType::FastGelu) { -+ invokeAddBiasFastGelu(c1_arr_h[arr_idx++], bias1, token_num, ffn_hidden_size_, stream); -+ } -+ } -+ } -+ -+ cnt = 0; -+ if (input != output) { -+ for (int ei = 0; ei < expert_num_; ei++) { -+ bool out_flag = expert_per_token_h[ei] & (1 << 31); -+ int token_num = out_flag ? 1 : expert_per_token_h[ei]; -+ mx_token = max(mx_token, token_num); -+ if (token_num) { -+ a1_arr_h[cnt++] = out_flag ? output + (expert_per_token_h[ei] & ~ (1 << 31)) * hidden_size_ : output + ei * gather_stride(); -+ } -+ cudaMemcpyAsync(a1_arr_, a1_arr_h, cnt * sizeof(half*), cudaMemcpyHostToDevice, stream); -+ } -+ } -+ -+ cublasGemmArrBatchedExRowMajor(cublas_handle, -+ CUBLAS_OP_N, -+ CUBLAS_OP_N, -+ mx_token, -+ hidden_size_, -+ ffn_hidden_size_, -+ &alpha, -+ (const void**)c1_arr_, -+ CUDA_R_16F, -+ (const void**)b2_arr_, -+ CUDA_R_16F, -+ &beta, -+ (void**)a1_arr_, -+ CUDA_R_16F, -+ cnt, -+ compute_type, -+ algo); -+ arr_idx = 0; -+ for (int ei = 0; ei < expert_num_; ei++) { -+ bool out_flag = expert_per_token_h[ei] & (1 << 31); -+ int token_num = out_flag ? 1 : expert_per_token_h[ei]; -+ if (token_num) { -+ if (all_reduce_sum_func_ != nullptr) { -+ (all_reduce_sum_func_)(a1_arr_h[arr_idx], -+ a1_arr_h[arr_idx], -+ hidden_size_ * token_num, -+ nvinfer1::DataType::kHALF, -+ stream); -+ } -+ -+ int b2_offset = ei * hidden_size_; -+ half* bias2 = bias2_ptr + b2_offset; -+ invokeAddBias(a1_arr_h[arr_idx], bias2, token_num, hidden_size_, stream); -+ arr_idx++; -+ } -+ } -+} -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MoeFfnLayer.h b/src/fastertransformer/layers/ms_layers/MoeFfnLayer.h -new file mode 100644 -index 0000000..e3afe9c ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MoeFfnLayer.h -@@ -0,0 +1,82 @@ -+#ifndef MOE_FFN_LAYER_H_ -+#define MOE_FFN_LAYER_H_ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+namespace fastertransformer { -+ -+ -+class PanguMoeFfnLayer : public BaseLayerMS { -+ public: -+ PanguMoeFfnLayer(int hidden_size, -+ int expert_num, -+ int ffn_hidden_size, -+ int rank_num, -+ int seq_len, -+ float expert_capability, -+ int batch_size); -+ size_t GetWorkspaceSize() override; -+ void SetExpertNum(int expert_num) { -+ expert_num_ = expert_num; -+ max_capacity_ = static_cast(std::ceil(expert_capability_ * src_seq_len_ / expert_num_) + 0.01f); -+ } -+ void SetFfnHiddenSize(int ffn_hidden_size) {ffn_hidden_size_ = ffn_hidden_size;} -+ void SetExpertOffset(size_t expert_offset) {expert_offset_ = expert_offset;} -+ void SetHTokenNum(size_t h_token_num) -+ { -+ h_token_num_ = h_token_num; -+ } -+ void SetPaddingOffsetDevice(int* padding_offset) {padding_offset_d_ = padding_offset;} -+ void SetSeqLenDevice(int* seq_len) {seq_len_d_ = seq_len;} -+ void SetSeqLenHost(int* seq_len) {seq_len_h_ = seq_len;} -+ void SetActType(FfnBase::ActType act_type) { act_type_ = act_type; } -+ void forward(std::vector &inputs, const std::vector &outputs, void *ws, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0) override; -+ private: -+ size_t gather_stride() {return max_capacity_ * batch_size_ * hidden_size_;}; -+ size_t router_stride() {return max_capacity_ * batch_size_;}; -+ -+ void forward_single_token(half *in, const int *expert_in_ids, half *weight1_ptr, half *bias1_ptr, half *weight2_ptr, half *bias2_ptr, half* output, void *workspace, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0); -+ void forward_expert(half *in, half *weight1_ptr, half *bias1_ptr, half *weight2_ptr, half *bias2_ptr, half* output, void *workspace, int expert_id, int expert_token_num, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0); -+ void forward_expert_experimental(half *gather, int *token_per_expert, half *weight1_ptr, half *bias1_ptr, half *weight2_ptr, half *bias2_ptr, half* output, void *workspace, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0); -+ int experimental_threshold() {return 1;}; -+ void PanguMoeLayer(const int* d_in, -+ const half* layernorm, -+ half* out, -+ const half* weight1, -+ const half* weight2, -+ const half* bias1, -+ const half* bias2, -+ int batch_size, -+ int expert_num, -+ int length, -+ int hidden_size, -+ int hidden_size2, -+ int onehot_size, -+ void* workspace, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream); -+ int expert_num_; -+ int ffn_hidden_size_; -+ float expert_capability_; -+ int max_capacity_; -+ size_t expert_offset_; -+ int *capacity_ = nullptr; -+ int *padding_offset_d_ = nullptr; -+ int *seq_len_d_ = nullptr; -+ int *seq_len_h_ = nullptr; -+ FfnBase::ActType act_type_; -+ size_t h_token_num_; -+ half** a1_arr_ = nullptr; -+ half** b1_arr_ = nullptr; -+ half** b2_arr_ = nullptr; -+ half** c1_arr_ = nullptr; -+}; -+ -+} // pangumoe -+ -+#endif // MOE_FFN_LAYER_H_ -diff --git a/src/fastertransformer/layers/ms_layers/attention.cc b/src/fastertransformer/layers/ms_layers/attention.cc -new file mode 100644 -index 0000000..0659eab ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/attention.cc -@@ -0,0 +1,773 @@ -+ -+#include "src/fastertransformer/layers/ms_layers/attention.h" -+#include "fmha_cutlass.h" -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/kernels/add_residual_kernels.h" -+#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" -+#include "src/fastertransformer/kernels/unfused_attention_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include "src/fastertransformer/layers/ms_layers/opt_allocator.h" -+#include -+namespace fastertransformer { -+ -+template -+size_t UnfusedMhaDispatch::GetWorkspaceSize() -+{ -+ size_t attn_out_size = batch_size_ * head_num_ * head_size_ * tgt_seq_len_; -+ size_t size_q = batch_size_ * src_seq_len_ * head_num_ * head_size_; -+ size_t size_k = batch_size_ * tgt_seq_len_ * head_num_ * head_size_; -+ size_t size_v = size_k; -+ size_t qkv_len = size_q + size_k + size_v; -+ size_t q_buf_2_len = size_q; -+ size_t qk_buf_len = batch_size_ * head_num_ * src_seq_len_ * tgt_seq_len_; -+ size_t qkv_buf_2_len = batch_size_ * src_seq_len_ * head_num_ * head_size_; -+ size_t qkv_buf_3_len = qkv_buf_2_len; -+ OptAllocator allocator(ALIGN_SIZE); -+ qkv_buf_ = allocator.Malloc(qkv_len * sizeof(T)); -+ q_buf_2_ = allocator.Malloc(q_buf_2_len * sizeof(T)); -+ if (use_past_) { -+ output1_ = 0; -+ output2_ = 0; -+ } else { -+ output1_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ output2_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ } -+ allocator.Free(qkv_buf_); -+ qk_buf_ = allocator.Malloc(qk_buf_len * sizeof(T)); -+ allocator.Free(q_buf_2_); -+ if (!use_past_) -+ allocator.Free(output1_); -+ qkv_buf_2_ = allocator.Malloc(qkv_buf_2_len * sizeof(T)); -+ allocator.Free(output2_); -+ allocator.Free(qk_buf_); -+ qkv_buf_3_ = allocator.Malloc(qkv_buf_3_len * sizeof(T)); -+ allocator.Free(qkv_buf_2_); -+ allocator.Free(qkv_buf_3_); -+ return allocator.total_size(); -+} -+ -+template -+void UnfusedMhaDispatch::forward(std::vector& inputs, -+ const std::vector& output, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ // setup inputs -+ T* q_buf_2 = reinterpret_cast(inputs[0]); -+ T* output1 = reinterpret_cast(inputs[1]); -+ T* output2 = reinterpret_cast(inputs[2]); -+ T* attention_mask = reinterpret_cast(inputs[3]); -+ T* position_bias = reinterpret_cast(inputs[4]); -+ // setup inner buffers -+ T* qk_buf = GetBuf(ws, qk_buf_); -+ T* qkv_buf_2 = GetBuf(ws, qkv_buf_2_); -+ int src_seq_len = src_seq_len_; // len of q tensor -+ int tgt_seq_len = tgt_seq_len_; // len of K, V tensors -+ int max = d_sequence_length_host_[0], min = d_sequence_length_host_[0]; -+ for (int i = 0; i < batch_size_; i++) -+ { -+ if (d_sequence_length2_host_[i] < min) -+ min = d_sequence_length2_host_[i]; -+ if (d_sequence_length2_host_[i] > max) -+ max = d_sequence_length2_host_[i]; -+ } -+ tgt_seq_len = max; -+ if (use_past_ && incremental_mode_) { -+ src_seq_len = 1; -+ } -+ // run unfused attention -+ int gemm_dims[] = {tgt_seq_len, src_seq_len, (int)head_size_}; -+ int gemm_lds[] = {(int)tgt_seq_len_, -+ (int)head_size_, -+ tgt_seq_len}; -+ cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N}; -+ cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; -+ if (std::is_same::value) { -+ gemm_data_types[0] = CUDA_R_16F; -+ gemm_data_types[1] = CUDA_R_16F; -+ gemm_data_types[2] = CUDA_R_16F; -+ } -+ float alpha = 1.0f; -+ float beta = 0.0f; -+ int gemm_strides[] = {(int)(tgt_seq_len_ * head_size_), -+ (int)(src_seq_len * head_size_), -+ (src_seq_len * tgt_seq_len)}; -+ //If batch valid != batch_size - run loop to skip not valid batches -+ if (padding_offset_ != nullptr && h_token_num_ < batch_size_) { -+ int offset1 = 0; -+ int offset2 = 0; -+ int offset3 = 0; -+ int head = (is_cross_ && position_bias_) ? int(1) : int(head_num_); -+ for (int i = 0; i < batch_size_; i++) { -+ src_seq_len = (int)(d_sequence_length_host_[i]); -+ tgt_seq_len = (int)(d_sequence_length2_host_[i]); -+ -+ if (src_seq_len == -1) { -+ offset1 += head_num_ * head_size_ * tgt_seq_len_; -+ continue; -+ } -+ gemm_dims[0] = tgt_seq_len; -+ gemm_dims[1] = src_seq_len; -+ gemm_dims[2] = head_size_; -+ gemm_lds[0] = tgt_seq_len_; -+ gemm_lds[1] = head_size_; -+ gemm_lds[2] = tgt_seq_len; -+ -+ gemm_strides[0] = tgt_seq_len_ * head_size_; -+ gemm_strides[1] = src_seq_len * head_size_; -+ gemm_strides[2] = src_seq_len * tgt_seq_len; -+ gemm_ops[0] = CUBLAS_OP_N; -+ gemm_ops[1] = CUBLAS_OP_N; -+ CublasGemmStridedBatchedWrapper(output1 + offset1, -+ q_buf_2 + offset2, -+ qk_buf + offset3, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_strides, -+ const_cast(gemm_data_types), -+ &alpha, -+ &beta, -+ head_num_, -+ cublas_handle, -+ algo_); -+ invokeMixMaskedSoftMax(static_cast(qk_buf) + offset3, -+ (use_past_ ) ? nullptr : (attention_mask == nullptr) ? nullptr : attention_mask + i * src_seq_len_ * tgt_seq_len_, -+ position_bias, -+ d_sequence_length_ + i, -+ d_sequence_length2_ + i, -+ 1, -+ src_seq_len, -+ src_seq_len_, -+ tgt_seq_len, -+ tgt_seq_len_, -+ head_num_, -+ head, -+ (T)(scale_), -+ (use_past_ && !incremental_mode_), -+ stream); -+ gemm_ops[0] = CUBLAS_OP_N; -+ gemm_ops[1] = CUBLAS_OP_N; -+ -+ gemm_dims[0] = head_size_; -+ gemm_dims[1] = src_seq_len; -+ gemm_dims[2] = tgt_seq_len; -+ -+ gemm_lds[0] = head_size_; -+ gemm_lds[1] = tgt_seq_len; -+ gemm_lds[2] = head_size_; -+ -+ gemm_strides[0] = tgt_seq_len_ * head_size_; -+ gemm_strides[1] = src_seq_len * tgt_seq_len; -+ gemm_strides[2] = src_seq_len * head_size_; -+ CublasGemmStridedBatchedWrapper(output2 + offset1, -+ qk_buf + offset3, -+ qkv_buf_2 + offset2, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_strides, -+ const_cast(gemm_data_types), -+ &alpha, -+ &beta, -+ head_num_, -+ cublas_handle, -+ algo_); -+ offset1 += head_num_ * head_size_ * tgt_seq_len_; -+ offset2 += head_num_ * head_size_ * src_seq_len; -+ offset3 += head_num_ * src_seq_len * tgt_seq_len_; -+ } -+ offset1 = 0; -+ for (int i = 0; i < batch_size_; i++) { -+ src_seq_len = (int)(d_sequence_length_host_[i]); -+ if (src_seq_len == -1) continue; -+ invokeTransposeQKV(static_cast(output[0]) + offset1, -+ static_cast(qkv_buf_2) + offset1, -+ 1, -+ src_seq_len, -+ head_num_, -+ head_size_, -+ stream); -+ offset1 += head_num_ * head_size_ * src_seq_len; -+ } -+ } -+ -+ else { -+ CublasGemmStridedBatchedWrapper(output1, -+ q_buf_2, -+ qk_buf, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_strides, -+ const_cast(gemm_data_types), -+ &alpha, -+ &beta, -+ batch_size_ * head_num_, -+ cublas_handle, -+ algo_); -+ invokeMixMaskedSoftMax(static_cast(qk_buf), -+ (use_past_ ) ? nullptr : attention_mask, -+ position_bias, -+ d_sequence_length_, -+ d_sequence_length2_, -+ batch_size_, -+ src_seq_len, -+ src_seq_len_, -+ tgt_seq_len, -+ tgt_seq_len_, -+ head_num_, -+ (is_cross_ && position_bias_) ? int(1) : int(head_num_), -+ (T)(scale_), -+ (use_past_ && !incremental_mode_), -+ stream); -+ gemm_ops[0] = CUBLAS_OP_N; -+ gemm_ops[1] = CUBLAS_OP_N; -+ -+ gemm_dims[0] = head_size_; -+ gemm_dims[1] = src_seq_len; -+ gemm_dims[2] = tgt_seq_len; -+ -+ gemm_lds[0] = head_size_; -+ gemm_lds[1] = tgt_seq_len; -+ gemm_lds[2] = head_size_; -+ -+ gemm_strides[0] = tgt_seq_len_ * head_size_; -+ gemm_strides[1] = src_seq_len * tgt_seq_len; -+ gemm_strides[2] = src_seq_len * head_size_; -+ CublasGemmStridedBatchedWrapper(output2, -+ qk_buf, -+ qkv_buf_2, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_strides, -+ const_cast(gemm_data_types), -+ &alpha, -+ &beta, -+ batch_size_ * head_num_, -+ cublas_handle, -+ algo_); -+ if (padding_offset_ == nullptr || incremental_mode_) { -+ invokeTransposeQKV(static_cast(output[0]), -+ static_cast(qkv_buf_2), -+ batch_size_, -+ src_seq_len, -+ head_num_, -+ head_size_, -+ stream); -+ } else { -+ invokeTransposeAttentionOutRemovePadding(qkv_buf_2, -+ reinterpret_cast(output[0]), -+ h_token_num_, -+ batch_size_, -+ src_seq_len_, -+ head_num_, -+ head_size_, -+ padding_offset_, -+ stream); -+ } -+ } -+ return; -+} -+ -+template -+bool FusedCutlassMhaDispatch::isSupport() -+{ -+ return fuse_mha_->isSupport(); -+} -+template -+size_t FusedCutlassMhaDispatch::GetWorkspaceSize() -+{ -+ size_t attn_out_size = batch_size_ * head_num_ * head_size_ * tgt_seq_len_; -+ size_t size_q = batch_size_ * src_seq_len_ * head_num_ * head_size_; -+ size_t size_k = batch_size_ * tgt_seq_len_ * head_num_ * head_size_; -+ size_t size_v = size_k; -+ size_t qkv_len = size_q + size_k + size_v; -+ size_t q_buf_2_len = size_q; -+ size_t qk_buf_len = batch_size_ * head_num_ * src_seq_len_ * tgt_seq_len_; -+ size_t qkv_buf_2_len = batch_size_ * src_seq_len_ * head_num_ * head_size_; -+ size_t qkv_buf_3_len = qkv_buf_2_len; -+ OptAllocator allocator(ALIGN_SIZE); -+ qkv_buf_ = allocator.Malloc(qkv_len * sizeof(T)); -+ q_buf_2_ = allocator.Malloc(q_buf_2_len * sizeof(T)); -+ if (use_past_) { -+ output1_ = 0; -+ output2_ = 0; -+ } else { -+ output1_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ output2_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ } -+ allocator.Free(qkv_buf_); -+ qk_buf_ = 0; // not in use -+ qkv_buf_2_ = allocator.Malloc(qkv_buf_2_len * sizeof(T)); -+ qkv_buf_3_ = allocator.Malloc(qkv_buf_3_len * sizeof(T)); -+ size_t size = 0; -+ size = fuse_mha_->GetWorkspaceSize(); -+ if (size > 0) { -+ mha_ = allocator.Malloc(size); -+ } -+ else { -+ mha_ = 0; // not in use -+ } -+ fuse_mha_->SetWSOffset(mha_); -+ allocator.Free(qkv_buf_3_); -+ size_t total = allocator.total_size(); -+ return total; -+} -+template -+void FusedCutlassMhaDispatch::forward(std::vector& inputs, -+ const std::vector& outputs, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ fuse_mha_->forward(inputs, outputs, ws, cublas_handle, stream); -+} -+ -+template -+size_t Attention::GetWorkspaceSize() -+{ -+ size_t size = dispatch_->GetWorkspaceSize(); -+ qkv_buf_ = dispatch_->qkv_buf_; -+ q_buf_2_ = dispatch_->q_buf_2_; -+ output1_ = dispatch_->output1_; -+ output2_ = dispatch_->output2_; -+ qk_buf_ = dispatch_->qk_buf_; -+ qkv_buf_2_ = dispatch_->qkv_buf_2_; -+ qkv_buf_3_ = dispatch_->qkv_buf_3_; -+ mha_ = dispatch_->mha_; -+ return size; -+} -+ -+template -+Attention::Attention(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool qkv_bias, -+ bool projection_bias, -+ bool is_cross, -+ bool position_bias, -+ float scale, -+ bool mask, -+ bool use_past, -+ int rank_num, -+ cublasGemmAlgo_t algo): -+ qkv_bias_(qkv_bias), -+ projection_bias_(projection_bias), -+ is_cross_(is_cross), -+ position_bias_(position_bias), -+ mask_(mask), -+ use_past_(use_past), -+ BaseLayerMS(batch_size, src_seq_len, tgt_seq_len, head_num, head_size, hidden_size, rank_num, algo) -+{ -+ std::shared_ptr> fuse = std::make_shared>(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ algo); -+ if (fuse->isSupport()) { -+ fmha_type_ = MhaDispatch::Type::CutlassFix; -+ dispatch_ = fuse; -+ } -+ else { -+ std::shared_ptr> unfuse = std::make_shared>(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ algo); -+ fmha_type_ = MhaDispatch::Type::UnFused; -+ dispatch_ = unfuse; -+ } -+} -+ -+MhaDispatch::MhaDispatch(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool is_cross, -+ bool position_bias, -+ float scale, -+ bool mask, -+ bool use_past, -+ int rank_num, -+ cublasGemmAlgo_t algo): -+ is_cross_(is_cross), -+ position_bias_(position_bias), -+ scale_(scale), -+ mask_(mask), -+ use_past_(use_past), -+ BaseLayerMS(batch_size, src_seq_len, tgt_seq_len, head_num, head_size, hidden_size, rank_num, algo) { -+ } -+template -+FusedCutlassMhaDispatch::FusedCutlassMhaDispatch(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool is_cross, -+ bool position_bias, -+ float scale, -+ bool mask, -+ bool use_past, -+ int rank_num, -+ cublasGemmAlgo_t algo): -+ MhaDispatch(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ rank_num, -+ algo) -+{ -+ typedef typename std::conditional::value, cutlass::half_t, float>::type Type; -+ fuse_mha_ = std::make_shared>(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ rank_num, -+ algo); -+} -+ -+template -+void Attention::forward(std::vector& inputs, -+ const std::vector& outputs, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{dispatch_->SetRankId(rank_id_); -+ ws = GetBuf(ws, ws_offset_); -+ in_idx_ = 0; -+ T* qkv_buf = GetBuf(ws, qkv_buf_); -+ T* q_buf_2 = GetBuf(ws, q_buf_2_); -+ T* qkv_buf_3 = GetBuf(ws, qkv_buf_3_); -+ -+ T* output1 = nullptr; -+ T* output2 = nullptr; -+ if (use_past_) { -+ output1 = reinterpret_cast(k_cache_); -+ output2 = reinterpret_cast(v_cache_); -+ } else { -+ output1 = GetBuf(ws, output1_); -+ output2 = GetBuf(ws, output2_); -+ } -+ int actual_hidden_size = head_size_ * head_num_; -+ int gemm_dims[] = { -+ 3 * (int)actual_hidden_size, (int)h_token_num_, (int)hidden_size_}; -+ int gemm_lds[] = {3 * (int)actual_hidden_size, (int)hidden_size_, 3 * (int)actual_hidden_size}; -+ T* from_tensor = reinterpret_cast(inputs[in_idx_++]); -+ cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N}; -+ cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; -+ if (std::is_same::value) { -+ gemm_data_types[0] = CUDA_R_16F; -+ gemm_data_types[1] = CUDA_R_16F; -+ gemm_data_types[2] = CUDA_R_16F; -+ } -+ float alpha = 1.0f; -+ float beta = 0.0f; -+ if (is_cross_) { -+ gemm_dims[0] = actual_hidden_size; -+ gemm_dims[1] = h_token_num_; -+ gemm_dims[2] = hidden_size_; -+ gemm_lds[0] = actual_hidden_size; -+ gemm_lds[1] = hidden_size_; -+ gemm_lds[2] = actual_hidden_size; -+ T* encoder_output = reinterpret_cast(inputs[in_idx_++]); -+ T* weight_q = reinterpret_cast(inputs[in_idx_++]); -+ if (use_past_) { -+ gemm_lds[2] = 3 * actual_hidden_size; -+ } -+ CublasGemmWrapper(weight_q, -+ from_tensor, -+ qkv_buf, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_data_types, -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+ -+ T *kv = qkv_buf + h_token_num_ * hidden_size_; -+ gemm_dims[0] = 2 * actual_hidden_size; -+ gemm_dims[1] = h_token_num2_; -+ gemm_dims[2] = hidden_size_; -+ -+ gemm_lds[0] = 2 * actual_hidden_size; -+ gemm_lds[1] = hidden_size_; -+ gemm_lds[2] = 2 * actual_hidden_size; -+ if (use_past_) { -+ gemm_lds[2] = 3 * actual_hidden_size; -+ kv = qkv_buf + actual_hidden_size; -+ } -+ T* weight_kv = reinterpret_cast(inputs[in_idx_++]); -+ CublasGemmWrapper(weight_kv, -+ encoder_output, -+ kv, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_data_types, -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+ T* bias_qkv = (qkv_bias_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ if (padding_offset_ == nullptr) { -+ if (use_past_) { -+ if (incremental_mode_) { -+ output1 += (cur_token_id_) * head_size_; -+ output2 += (cur_token_id_) * head_size_; -+ } -+ invokeAddFusedQKVBiasTransposeUsePast(static_cast(q_buf_2), -+ static_cast(output1), -+ static_cast(output2), -+ static_cast(qkv_buf), -+ bias_qkv, -+ src_seq_len_, -+ h_token_num_, -+ head_num_, -+ head_size_, -+ stream); -+ // restore cache to pointer start after concat -+ output1 = reinterpret_cast(k_cache_); -+ output2 = reinterpret_cast(v_cache_); -+ } -+ else { -+ invokeCrossAddFusedQKVBiasTranspose(q_buf_2, -+ output1, -+ output2, -+ qkv_buf, -+ bias_qkv, -+ batch_size_, -+ src_seq_len_, -+ tgt_seq_len_, -+ head_num_, -+ head_size_, -+ stream); -+ } -+ } else { -+ if (use_past_) { -+ invokeAddFusedQKVBiasTransposeUsePastMB(q_buf_2, -+ output1, -+ output2, -+ qkv_buf, -+ bias_qkv, -+ padding_offset_, -+ d_sequence_length_, -+ d_sequence_length2_, -+ batch_size_, -+ h_token_num_, -+ src_seq_len_, -+ head_num_, -+ head_size_, -+ incremental_mode_, -+ !(typeid(FusedCutlassMhaDispatch) == typeid(*dispatch_)), -+ stream); -+ -+ } -+ else if (typeid(FusedCutlassMhaDispatch) == typeid(*dispatch_)) { -+ invokeCrossAddFusedQKVBiasTransposeMBVSL(q_buf_2, -+ output1, -+ output2, -+ qkv_buf, -+ bias_qkv, -+ padding_offset_, -+ padding_offset2_, -+ d_sequence_length_, -+ d_sequence_length2_, -+ h_token_num_, -+ h_token_num2_, -+ batch_size_, -+ src_seq_len_, -+ tgt_seq_len_, -+ head_num_, -+ head_size_, -+ stream); -+ } else { -+ invokeCrossAddFusedZP_QKVBiasTranspose(q_buf_2, -+ output1, -+ output2, -+ qkv_buf, -+ bias_qkv, -+ batch_size_, -+ src_seq_len_, -+ tgt_seq_len_, -+ head_num_, -+ head_size_, -+ h_token_num_, -+ h_token_num2_, -+ padding_offset_, -+ padding_offset2_, -+ stream); -+ } -+ } -+ } else { // end of is_cross -+ T* weight_qkv = reinterpret_cast(inputs[in_idx_++]); -+ CublasGemmWrapper(weight_qkv, -+ from_tensor, -+ qkv_buf, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ const_cast(gemm_data_types), -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+ T* bias_qkv = (qkv_bias_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ if (padding_offset_ == nullptr) { -+ if (use_past_) { -+ if (incremental_mode_) { -+ output1 += (cur_token_id_) * head_size_; -+ output2 += (cur_token_id_) * head_size_; -+ } -+ invokeAddFusedQKVBiasTransposeUsePast(static_cast(q_buf_2), -+ static_cast(output1), -+ static_cast(output2), -+ static_cast(qkv_buf), -+ bias_qkv, -+ src_seq_len_, -+ h_token_num_, -+ head_num_, -+ head_size_, -+ stream); -+ // restore cache to pointer start after concat -+ output1 = reinterpret_cast(k_cache_); -+ output2 = reinterpret_cast(v_cache_); -+ } else { -+ invokeAddFusedQKVBiasTranspose(static_cast(q_buf_2), -+ static_cast(output1), -+ static_cast(output2), -+ static_cast(qkv_buf), -+ bias_qkv, -+ batch_size_, -+ src_seq_len_, -+ head_num_, -+ head_size_, -+ 0, -+ stream); -+ } -+ } else { -+ if (use_past_) { -+ invokeAddFusedQKVBiasTransposeUsePastMB(q_buf_2, -+ output1, -+ output2, -+ qkv_buf, -+ bias_qkv, -+ padding_offset_, -+ d_sequence_length_, -+ d_sequence_length2_, -+ batch_size_, -+ h_token_num_, -+ src_seq_len_, -+ head_num_, -+ head_size_, -+ incremental_mode_, -+ !(typeid(FusedCutlassMhaDispatch) == typeid(*dispatch_)), -+ stream); -+ } else if (typeid(FusedCutlassMhaDispatch) == typeid(*dispatch_)) { -+ invokeAddFusedQKVBiasTransposeMBVSL(static_cast(q_buf_2), -+ static_cast(output1), -+ static_cast(output2), -+ static_cast(qkv_buf), -+ bias_qkv, -+ padding_offset_, -+ d_sequence_length_, -+ batch_size_, -+ src_seq_len_, -+ h_token_num_, -+ head_num_, -+ head_size_, -+ stream); -+ } else { -+ invokeAddFusedZP_QKVBiasTranspose(static_cast(q_buf_2), -+ static_cast(output1), -+ static_cast(output2), -+ static_cast(qkv_buf), -+ bias_qkv, -+ batch_size_, -+ src_seq_len_, -+ head_num_, -+ head_size_, -+ h_token_num_, -+ padding_offset_, -+ stream); -+ } -+ } -+ } -+ // Do Softmax(Q*Kt + Bias + Mask) -+ T* attention_mask = (mask_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ T* position_bias = (position_bias_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ if (attention_mask && padding_offset_ != nullptr && typeid(FusedCutlassMhaDispatch) != typeid(dispatch_) && !(use_past_)) { -+ invokeBuildEncoderAttentionMask( -+ attention_mask, d_sequence_length2_, d_sequence_length_, batch_size_, src_seq_len_, tgt_seq_len_, incremental_mode_, stream); -+ } -+ std::vector dispatch_in = {q_buf_2, output1, output2, attention_mask, position_bias}; -+ std::vector dispatch_out = {qkv_buf_3}; -+ dispatch_->forward(dispatch_in, dispatch_out, ws, cublas_handle, stream); -+ gemm_ops[0] = CUBLAS_OP_N; -+ gemm_ops[1] = CUBLAS_OP_N; -+ gemm_dims[0] = hidden_size_; -+ gemm_dims[1] = h_token_num_; -+ gemm_dims[2] = actual_hidden_size; -+ -+ gemm_lds[0] = hidden_size_; -+ gemm_lds[1] = actual_hidden_size; -+ gemm_lds[2] = hidden_size_; -+ -+ CublasGemmWrapper(reinterpret_cast(inputs[in_idx_++]), -+ qkv_buf_3, -+ static_cast(outputs[0]), -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ const_cast(gemm_data_types), -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+ if (projection_bias_) { -+ int len = h_token_num_; -+ invokeAddBias( -+ static_cast(outputs[0]), (const T*)(inputs[in_idx_++]), len, hidden_size_, stream); -+ } -+ return; -+} -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/attention.h b/src/fastertransformer/layers/ms_layers/attention.h -new file mode 100644 -index 0000000..567947c ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/attention.h -@@ -0,0 +1,411 @@ -+#pragma once -+ -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/gemm.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+#include -+#include -+ -+namespace fastertransformer { -+class MhaDispatch : public BaseLayerMS{ -+protected: -+ bool position_bias_; -+ float scale_; -+ bool mask_; -+ bool is_cross_; -+ int* padding_offset_{nullptr}; -+ int* d_sequence_length_{nullptr}; -+ int* padding_offset2_{nullptr}; -+ int* d_sequence_length2_{nullptr}; -+ size_t data_parallel_{false}; -+ int cur_token_id_{0}; // current token id id -+ size_t h_token_num_; -+ size_t h_token_num2_; -+ bool incremental_mode_{false}; -+ bool use_past_{false}; // use past mode -+ int* d_sequence_length_host_; -+ int* d_sequence_length2_host_; -+ -+public: -+ typedef enum Type { -+ UnFused, -+ CutlassFix -+ } Type; -+ size_t qkv_buf_{0}; -+ size_t q_buf_2_{0}; -+ size_t output1_{0}; -+ size_t output2_{0}; -+ size_t qk_buf_{0}; -+ size_t qkv_buf_2_{0}; -+ size_t qkv_buf_3_{0}; -+ size_t mha_{0}; -+ size_t GetBatchSize() {return batch_size_;} -+ size_t GetSrcSeqLen() {return src_seq_len_;} -+ size_t GetTgtSeqLen() {return tgt_seq_len_;} -+ size_t GetHeadNum() {return head_num_;} -+ size_t GetHeadSize() {return head_size_;} -+ size_t GetHiddenSize() {return hidden_size_;} -+ bool GetIsCross() {return is_cross_;} -+ float GetScale() {return scale_;} -+ size_t GetHTokenNum() {return h_token_num_;} -+ bool GetPositionBias() {return position_bias_;} -+ bool GetIncrementalMode() {return incremental_mode_;} -+ int GetCurTokenId() {return cur_token_id_;} -+ bool GetUsePast() {return use_past_;} -+ virtual void SetVslParam(int* padding_offset = nullptr, int* padding_offset2 = nullptr, int* d_sequence_length = nullptr, int* d_sequence_length2 = nullptr) -+ { -+ padding_offset_ = padding_offset; -+ padding_offset2_ = padding_offset2; -+ d_sequence_length_ = d_sequence_length; -+ d_sequence_length2_ = d_sequence_length2; -+ } -+ virtual void SetCurTokenId(int cur_token_id) -+ { -+ cur_token_id_ = cur_token_id; -+ } -+ virtual void SetHTokenNum(size_t h_token_num, size_t h_token_num2 = -1) -+ { -+ h_token_num_ = h_token_num; -+ h_token_num2_ = h_token_num2; -+ } -+ virtual void SetCross(bool cross) {is_cross_ = cross;} -+ virtual void SetIncrementalMode(bool incremental_mode) -+ { -+ incremental_mode_ = incremental_mode; -+ } -+ virtual void SetScale(float scale) {scale_ = scale;} -+ virtual void SetUsePast(bool use_past) -+ { -+ use_past_ = use_past; -+ } -+ void SetBuffers(size_t qkv_buf = 0, -+ size_t q_buf_2 = 0, -+ size_t output1 = 0, -+ size_t output2 = 0, -+ size_t qk_buf = 0, -+ size_t qkv_buf_2 = 0, -+ size_t qkv_buf_3 = 0, -+ size_t mha = 0) -+ { -+ qkv_buf_ = qkv_buf; -+ q_buf_2_ = q_buf_2; -+ output1_ = output1; -+ output2_ = output2; -+ qk_buf_ = qk_buf; -+ qkv_buf_2_ = qkv_buf_2; -+ qkv_buf_3_ = qkv_buf_3; -+ mha_ = mha; -+ } -+ virtual void SetOption(bool position_bias = false, bool mask = true) -+ { -+ position_bias_ = position_bias; -+ mask_ = mask; -+ } -+ MhaDispatch(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool is_cross_, -+ bool position_bias = false, -+ float scale = 1.0f, -+ bool mask = true, -+ bool use_past = false, -+ int rank_num = 1, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+ virtual bool isSupport() -+ { -+ return true; -+ } -+ virtual void SetFuseWS(void* ws){} -+ virtual size_t GetWorkspaceSize() override {return 0;} -+}; -+template -+class FusedCutlassMhaDispatch: public MhaDispatch { -+private: -+ std::shared_ptr fuse_mha_; -+public: -+ FusedCutlassMhaDispatch(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool is_cross_, -+ bool position_bias = false, -+ float scale = 1.0f, -+ bool mask = true, -+ bool use_past = false, -+ int rank_num = 1, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+ void SetVslParam(int* padding_offset = nullptr, int* padding_offset2 = nullptr, int* d_sequence_length = nullptr, int* d_sequence_length2 = nullptr) override -+ { -+ padding_offset_ = padding_offset; -+ padding_offset2_ = padding_offset2; -+ d_sequence_length_ = d_sequence_length; -+ d_sequence_length2_ = d_sequence_length2; -+ fuse_mha_->SetVslParam(padding_offset, padding_offset2, d_sequence_length, d_sequence_length2); -+ } -+ void SetOption(bool position_bias = false, bool mask = true) override -+ { -+ mask_ = mask; -+ position_bias_ = position_bias; -+ fuse_mha_->SetOption(position_bias, mask); -+ } -+ void SetCross(bool cross) override -+ { -+ is_cross_ = cross; -+ fuse_mha_->SetCross(cross); -+ } -+ void SetHTokenNum(size_t h_token_num, size_t h_token_num2 = -1) override -+ { -+ h_token_num_ = h_token_num; -+ h_token_num2_ = h_token_num2; -+ fuse_mha_->SetHTokenNum(h_token_num, h_token_num2); -+ } -+ void SetIncrementalMode(bool incremental_mode) override -+ { -+ incremental_mode_ = incremental_mode; -+ fuse_mha_->SetIncrementalMode(incremental_mode); -+ } -+ void SetScale(float scale) override -+ { -+ scale_ = scale; -+ fuse_mha_->SetScale(scale); -+ } -+ void SetUsePast(bool use_past) override -+ { -+ use_past_ = use_past; -+ fuse_mha_->SetUsePast(use_past); -+ } -+ void SetCurTokenId(int cur_token_id) override -+ { -+ cur_token_id_ = cur_token_id; -+ fuse_mha_->SetCurTokenId(cur_token_id); -+ } -+ bool isSupport() override; -+ size_t GetWorkspaceSize() override; -+ void forward(std::vector &inputs, const std::vector &outputs, void *ws, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0) override; -+}; -+template -+class UnfusedMhaDispatch: public MhaDispatch { -+public: -+ UnfusedMhaDispatch(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool is_cross = false, -+ bool position_bias = false, -+ float scale = 1.0f, -+ bool mask = true, -+ bool use_past = false, -+ int rank_num = 1, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) : -+ MhaDispatch(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ rank_num, -+ algo) { -+ d_sequence_length_host_ = (int*)malloc(batch_size_ * sizeof(int)); -+ d_sequence_length2_host_ = (int*)malloc(batch_size_ * sizeof(int)); -+ -+ } -+ void SetVslParam(int* padding_offset = nullptr, int* padding_offset2 = nullptr, int* d_sequence_length = nullptr, int* d_sequence_length2 = nullptr) override -+ { -+ padding_offset_ = padding_offset; -+ padding_offset2_ = padding_offset2; -+ d_sequence_length_ = d_sequence_length; -+ d_sequence_length2_ = d_sequence_length2; -+ if (d_sequence_length_ != nullptr) { -+ cudaD2Hcpy(d_sequence_length_host_, d_sequence_length_, batch_size_); -+ } -+ if (d_sequence_length2_ != nullptr) { -+ cudaD2Hcpy(d_sequence_length2_host_, d_sequence_length2_, batch_size_); -+ } -+ } -+ size_t GetWorkspaceSize() override; -+ void forward(std::vector &inputs, const std::vector &outputs, void *ws, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0) override; -+}; -+template -+class Attention : public BaseLayerMS { -+private: -+ std::shared_ptr dispatch_; -+ -+ bool qkv_bias_; // ture -+ bool projection_bias_; // ture -+ bool is_cross_; // false -+ bool position_bias_; -+ bool mask_; -+ bool use_past_; // use past mode -+ MhaDispatch::Type fmha_type_; -+ size_t data_parallel_{false}; -+ int cur_token_id_{0}; // current token id id -+ size_t h_token_num_; -+ size_t h_token_num2_; -+ void* k_cache_{nullptr}; -+ void* v_cache_{nullptr}; -+ bool incremental_mode_{false}; -+ size_t qkv_buf_{0}; -+ size_t q_buf_2_{0}; -+ size_t output1_{0}; -+ size_t output2_{0}; -+ size_t qk_buf_{0}; -+ size_t qkv_buf_2_{0}; -+ size_t qkv_buf_3_{0}; -+ size_t mha_{0}; -+ int* padding_offset_{nullptr}; -+ int* d_sequence_length_{nullptr}; -+ int* padding_offset2_{nullptr}; -+ int* d_sequence_length2_{nullptr}; -+public: -+ void printParam() -+ { -+ std::cout<<"attn param\n"; -+ std::cout<<"batch_size = "<"; -+} -+ -+template -+void check(T result, char const* const func, const char* const file, int const line) -+{ -+ if (result) { -+ throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaKernalGetErrorEnum(result)) + " " -+ + file + ":" + std::to_string(line) + " \n"); -+ } -+} -+ -+#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) -+ -+template -+__inline__ __device__ -+T gelu(T x) -+{ -+ float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); -+ return x * cdf; -+} -+ -+template <> -+__inline__ __device__ -+half2 gelu(half2 val) -+{ -+ half2 val_pow3 = __hmul2(val, __hmul2(val, val)); -+ float2 tmp_pow = __half22float2(val_pow3); -+ float2 tmp = __half22float2(val); -+ -+ tmp.x = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); -+ tmp.y = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); -+ return __hmul2(val, __float22half2_rn(tmp)); -+ -+} -+ -+template -+__global__ -+void add_bias_act(T* out, const T* bias, int m, int k, int n) -+{ -+ T val, reg_bias; -+ -+ int ite = n / blockDim.x; -+ int tid = threadIdx.x; -+ -+ for(int i = 0; i < ite; ++i) -+ { -+ int row_id = blockIdx.x; -+ while(row_id < m){ -+ reg_bias = __ldg(&bias[row_id / k * n + i * blockDim.x + tid]); -+ -+ val = out[tid + i * blockDim.x + row_id * n]+ reg_bias; -+ out[tid + i * blockDim.x + row_id * n] = gelu(val); -+ row_id += gridDim.x; -+ } -+ } -+} -+ -+template <> -+__global__ -+void add_bias_act(half* out, const half* bias, int m, int k, int n) -+{ -+ half2 val, reg_bias; -+ int ite = n / blockDim.x / 2; -+ int tid = threadIdx.x; -+ -+ half2* out_ptr = (half2*) out; -+ const half2* bias_ptr = (half2*) bias; -+ for(int i = 0; i < ite; ++i) -+ { -+ int row_id = blockIdx.x; -+ while(row_id < m){ -+ reg_bias = __ldg(&bias_ptr[row_id / k * n + i * blockDim.x + tid]); -+ val = out_ptr[tid + i * blockDim.x + row_id * n / 2]; -+ val = __hadd2(val, reg_bias); -+ out_ptr[tid + i * blockDim.x + row_id * n / 2] = gelu(val); -+ row_id += gridDim.x; -+ } -+ } -+} -+ -+template -+void add_bias_act_kernelLauncher(T* out, const T* bias, int m, int k, int n, cudaStream_t stream) -+{ -+ dim3 grid(m * k / 4.); -+ dim3 block(n / 16); -+ assert(block.x <= 1024); -+ add_bias_act<<>>(out, bias, m * k, k, n); -+} -+ -+__device__ void ScanWarp(int32_t* shm_data) { -+ volatile int32_t* vshm_data = shm_data; -+ vshm_data[0] += vshm_data[-1]; -+ vshm_data[0] += vshm_data[-2]; -+ vshm_data[0] += vshm_data[-4]; -+ vshm_data[0] += vshm_data[-8]; -+ vshm_data[0] += vshm_data[-16]; -+} -+ -+__device__ void ScanBlock(int32_t* shm_data) { -+ int32_t warp_id = threadIdx.x >> 5; -+ int32_t lane = threadIdx.x & 31; -+ extern __shared__ int32_t warp_sum[]; // 16 zero padding -+ // scan each warp -+ ScanWarp(shm_data); -+ __syncthreads(); -+ // write sum of each warp to warp_sum -+ if (lane == 31) { -+ warp_sum[16 + warp_id] = *shm_data; -+ } -+ __syncthreads(); -+ // use a single warp to scan warp_sum -+ if (warp_id == 0) { -+ ScanWarp(warp_sum + 16 + lane); -+ } -+ __syncthreads(); -+ // add base -+ if (warp_id > 0) { -+ *shm_data += warp_sum[16 + warp_id - 1]; -+ } -+ __syncthreads(); -+} -+ -+__global__ void ScanAndWritePartSumKernel(const int32_t* input, -+ int32_t* output, size_t n, -+ size_t part_num, size_t shared_num) { -+ // the first 16 + 32 is used to save warp sum -+ extern __shared__ int32_t shm[]; -+ int32_t warp_id = threadIdx.x >> 5; -+ int32_t lane = threadIdx.x & 31; -+ for (int tid = threadIdx.x; tid < shared_num; tid += blockDim.x) { -+ shm[tid] = 0; -+ } -+ __syncthreads(); -+ // process each part -+ for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) { -+ // store this part input to shm -+ size_t index = part_i * blockDim.x + threadIdx.x; -+ int32_t* myshm = shm + (16 + 32) + warp_id * (16 + 32) + 16 + lane; -+ *myshm = index < n ? input[index] : 0; -+ __syncthreads(); -+ // scan on shared memory -+ ScanBlock(myshm); -+ __syncthreads(); -+ // write result -+ if (index < n) { -+ output[index] = *myshm; -+ } -+ } -+} -+ -+__global__ void ScanAndWritePartSumKernel2(const int32_t* input, int32_t* output, size_t n, -+ size_t part_size) { -+ size_t part_begin = part_size * blockIdx.x; -+ size_t part_end = min(part_size * (blockIdx.x + 1), n); -+ int32_t acc = 0; -+ for (size_t i = part_begin; i < part_end; ++i) { -+ acc += input[i]; -+ output[i] = acc; -+ } -+} -+ -+void ScanThenFan(int32_t* input, int32_t* buffer, int32_t* output, -+ size_t n, size_t length, cudaStream_t stream) { -+ size_t part_size = length; -+ size_t part_num = (n + part_size - 1) / part_size; -+ size_t block_num = std::min(part_num, 128); -+ size_t warp_num = (part_size + 31) / 32; -+ size_t shm_num = 16 + 32 + warp_num * (16 + 32); -+ size_t shm_size = shm_num * sizeof(int32_t); -+ ScanAndWritePartSumKernel<<>>(input, output, n, part_num, shm_num); -+} -+ -+template -+__global__ void OneHotTransposeFusionKernel(const T* in, T* out1, T* out2, int batch_size, int length, int expert_num) { -+ //__shared__ T s_mem[2048]; -+ for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < batch_size * length * expert_num; tid += blockDim.x * gridDim.x) { -+ int batchlength_idx = tid / expert_num; -+ int expert_idx = tid % expert_num; -+ int batch_idx = batchlength_idx / length; -+ int length_idx = batchlength_idx % length; -+ bool is_on = (in[batch_idx * length + length_idx] == expert_idx); -+ out1[tid] = static_cast(is_on); -+ out2[length_idx + expert_idx * length + batch_idx * expert_num * length] = static_cast(is_on); -+ } -+} -+ -+template -+void OneHotTransposeFusionKernelLaunch(const T* in, T* out1, T* out2, int batch_size, int length, int expert_num, cudaStream_t stream) { -+ if (length != 1) { -+ dim3 grid(batch_size * expert_num); -+ dim3 block(1024); -+ OneHotTransposeFusionKernel<<>>(in, out1, out2, batch_size, length, expert_num); -+ } else { -+ dim3 grid(batch_size); -+ dim3 block(batch_size * expert_num); -+ OneHotTransposeFusionKernel<<>>(in, out1, out2, batch_size, length, expert_num); -+ } -+} -+ -+template -+__global__ void MulLessCastMulReduceMulOnehotMulFusionKernel(T* in1, T* in2, S* out, int batch_size, int length, int expert_num, int max_expert_num, S threshold) { -+ __shared__ T s_mem[512][16 + 1]; -+ for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < batch_size * length * expert_num; tid += blockDim.x * gridDim.x) { -+ int batch_idx = tid / length / expert_num; -+ int length_idx = tid / expert_num % length; -+ int expert_idx = tid % expert_num; -+ T val_in1 = in1[tid]; -+ // fuse transpose -+ T val_in2 = in2[batch_idx * length * expert_num + expert_idx * length + length_idx]; -+ T mul1 = val_in1 * val_in2; -+ T mul2 = static_cast(mul1 < threshold) * val_in1; -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] = mul2; -+ __syncthreads(); -+ if (expert_idx < 8) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 8]; -+ } -+ __syncthreads(); -+ if (expert_idx < 4) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 4]; -+ } -+ __syncthreads(); -+ if (expert_idx < 2) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 2]; -+ } -+ __syncthreads(); -+ if (expert_idx < 1) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 1]; -+ } -+ __syncthreads(); -+ -+ for (int i = 0; i != max_expert_num; ++i) { -+ out[tid * max_expert_num + i] = static_cast(0); -+ } -+ out[tid * max_expert_num + mul1] = static_cast(s_mem[length_idx % 512][0]); -+ __syncthreads(); -+ } -+} -+ -+template <> -+__global__ void MulLessCastMulReduceMulOnehotMulFusionKernel(int* in1, int* in2, half* out, int batch_size, int length, int expert_num, int max_expert_num, half threshold) { -+ __shared__ int s_mem[512][16 + 1]; -+ for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < batch_size * length * expert_num; tid += blockDim.x * gridDim.x) { -+ int batch_idx = tid / length / expert_num; -+ int length_idx = tid / expert_num % length; -+ int expert_idx = tid % expert_num; -+ int val_in1 = in1[tid]; -+ // fuse transpose -+ int val_in2 = in2[batch_idx * length * expert_num + expert_idx * length + length_idx]; -+ int mul1 = val_in1 * val_in2; -+ int mul2 = static_cast(mul1 < __half2int_rn(threshold)) * val_in1; -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] = mul2; -+ __syncthreads(); -+ if (expert_idx < 8) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 8]; -+ } -+ __syncthreads(); -+ if (expert_idx < 4) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 4]; -+ } -+ __syncthreads(); -+ if (expert_idx < 2) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 2]; -+ } -+ __syncthreads(); -+ if (expert_idx < 1) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 1]; -+ } -+ __syncthreads(); -+ -+ for (int i = 0; i != max_expert_num; ++i) { -+ out[tid * max_expert_num + i] = (half)0.f; -+ } -+ out[tid * max_expert_num + mul1] = __int2half_rd(s_mem[length_idx % 512][0]); -+ } -+} -+ -+template -+void MulLessCastMulReduceMulOnehotMulFusionKernelLaunch(T* in1, T* in2, S* out, int batch_size, int length, int expert_num, int onehot_size, S threshold, cudaStream_t stream) { -+ if (length != 1) { -+ dim3 grid(expert_num); -+ dim3 block(512); -+ MulLessCastMulReduceMulOnehotMulFusionKernel<<>>(in1, in2, out, batch_size, length, expert_num, onehot_size, threshold); -+ } else { -+ dim3 grid(expert_num); -+ dim3 block(512); -+ MulLessCastMulReduceMulOnehotMulFusionKernel<<>>(in1, in2, out, batch_size, length, expert_num, onehot_size, threshold); -+ } -+} -+ -+template -+__global__ void AddBiasTransposeTransposeFusionKernel(const T* in, const T* bias, T* out, int ori_s1, int ori_s2, int ori_s3, int trans_s1, int trans_s2, int trans2_s1, int trans2_s2) { -+ for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < trans_s1 * trans_s2; tid += blockDim.x * gridDim.x) { -+ int expert_idx = tid / (ori_s2 * ori_s3); -+ int hidden_idx = tid % ori_s3; -+ T in_val = in[tid] + bias[expert_idx * ori_s3 + hidden_idx]; -+ -+ int ori_idx1 = tid / trans_s1; -+ int ori_idx2 = tid % trans_s1; -+ -+ int trans1_idx = ori_idx2 * trans_s2 + ori_idx1; -+ -+ int trans2_idx1 = trans1_idx / trans2_s1; -+ int trans2_idx2 = trans1_idx % trans2_s1; -+ int trans2_idx = trans2_idx2 * trans2_s2 + trans2_idx1; -+ out[trans2_idx] = in_val; -+ } -+} -+ -+template <> -+__global__ void AddBiasTransposeTransposeFusionKernel(const half* in, const half* bias, half* out, int ori_s1, int ori_s2, int ori_s3, int trans_s1, int trans_s2, int trans2_s1, int trans2_s2) { -+ for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < trans_s1 * trans_s2; tid += blockDim.x * gridDim.x) { -+ int expert_idx = tid / (ori_s2 * ori_s3); -+ int hidden_idx = tid % ori_s3; -+ half in_val = __hadd(in[tid], bias[expert_idx * ori_s3 + hidden_idx]); -+ -+ int ori_idx1 = tid / trans_s1; -+ int ori_idx2 = tid % trans_s1; -+ -+ int trans1_idx = ori_idx2 * trans_s2 + ori_idx1; -+ -+ int trans2_idx1 = trans1_idx / trans2_s1; -+ int trans2_idx2 = trans1_idx % trans2_s1; -+ int trans2_idx = trans2_idx2 * trans2_s2 + trans2_idx1; -+ out[trans2_idx] = in_val; -+ } -+} -+ -+template -+void AddBiasTransposeTransposeFusionKernelLaunch(const T* in, const T* bias, T* out, int ori_s1, int ori_s2, int ori_s3, int trans_s1, int trans_s2, int trans2_s1, int trans2_s2, cudaStream_t stream) { -+ dim3 grid(trans_s1 * trans_s2 / 4); -+ dim3 block(1024); -+ AddBiasTransposeTransposeFusionKernel<<>>(in, bias, out, ori_s1, ori_s2, ori_s3, trans_s1, trans_s2, trans2_s1, trans2_s2); -+} -+ -+template <> -+void AddBiasTransposeTransposeFusionKernelLaunch(const half* in, const half* bias, half* out, int ori_s1, int ori_s2, int ori_s3, int trans_s1, int trans_s2, int trans2_s1, int trans2_s2, cudaStream_t stream) { -+ dim3 grid(trans_s1 * trans_s2 / 4); -+ dim3 block(1024); -+ AddBiasTransposeTransposeFusionKernel<<>>(in, bias, out, ori_s1, ori_s2, ori_s3, trans_s1, trans_s2, trans2_s1, trans2_s2); -+} -+ -+#endif -diff --git a/src/fastertransformer/layers/ms_layers/debug_utils.cc b/src/fastertransformer/layers/ms_layers/debug_utils.cc -new file mode 100644 -index 0000000..6fc1330 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/debug_utils.cc -@@ -0,0 +1,140 @@ -+#include -+#include -+#include -+#include "src/fastertransformer/utils/memory_utils.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include -+namespace fastertransformer { -+ -+ -+template -+void printTensor(char* str, T* input, int size) -+{ -+ printf("%s ", str); -+ T* input_device = input; -+ T* input_host = (T*)malloc(size * sizeof(T)); -+ -+ cudaD2Hcpy(input_host, input_device, size); -+ -+ for (int k = 0; k < (int)size; k++) { -+ std::cout << input_host[k] << ","; -+ if (k % 16 == 0 && k != 0) -+ std::cout << std::endl; -+ } -+ -+ std::cout << std::endl; -+ -+ free(input_host); -+} -+int GetSeq(int* d_seq_len, int idx, int batch) -+{ -+ int* input_device = d_seq_len; -+ int* input_host = (int*)malloc(batch * sizeof(int)); -+ -+ cudaD2Hcpy(input_host, input_device, batch); -+ int num = input_host[idx]; -+ free(input_host); -+ return num; -+} -+template -+void isNan(char* str, T* input, int size) -+{ -+ std::cout << str << " " -+ << " size is " << size; -+ T* input_device = input; -+ T* input_host = (T*)malloc(size * sizeof(T)); -+ -+ cudaD2Hcpy(input_host, input_device, size); -+ -+ for (int k = 0; k < (int)size; k++) { -+ if (std::isnan((T)input_host[k]) || std ::isinf((T)input_host[k])) { -+ std::cout << "found NAN or INF " << k; -+ break; -+ } -+ } -+ -+ std::cout << std::endl; -+ free(input_host); -+} -+template -+T checksum(const T* tensor, int size) -+{ -+ auto tensor_host =(T*)malloc(size * sizeof(T)); -+ double sum = 0.; -+ cudaD2Hcpy(tensor_host, tensor, size); -+ for (int i = 0; i < size; i++) { -+ sum += (double)tensor_host[i]; -+ } -+ return static_cast(sum); -+} -+template -+double checksum2(char* str, const T* tensor, int size) -+ -+{ -+ double sum = 0.; -+ T* ptr = (T*)malloc(size * sizeof(T)); -+ -+ cudaD2Hcpy(ptr, tensor, size); -+ -+ for (int i = 0; i < size; i++) { -+ -+ sum += ptr[i]; -+ -+ } -+ std::cout << "checksum of "<< str << "is " << sum << std::endl; -+ free(ptr); -+ return sum; -+ -+} -+template -+void saveTensor(const std::string& name, T* tensor, int size) -+{ -+ auto tensor_host = std::make_unique(size); -+ T* ptr = tensor_host.get(); -+ cudaD2Hcpy(ptr, tensor, size); -+ std::ofstream wf(name + ".bin", std::ofstream::out | std::ofstream::binary); -+ wf.write(reinterpret_cast(ptr), size * sizeof(T)); -+ wf.close(); -+} -+ -+template -+void saveTensorFile(const std::string& name, T* tensor, int size) { -+ T* input_host = (T*)malloc(size * sizeof(T)); -+ cudaD2Hcpy(input_host, tensor, size); -+ std::ofstream wf(name+ ".bin", std::ios::out | std::ios::binary); -+ if(!wf) { -+ std::cout << "Cannot open file!" << std::endl; -+ return; -+ } -+ wf.write((char *)input_host,sizeof(T)*size); -+ wf.close(); -+} -+uint64_t GetTimeUs() -+{ -+ const int USEC = 1000000; -+ const int MSEC = 1000; -+ struct timespec ts = {0, 0}; -+ if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) { -+ return 0; -+ } -+ uint64_t retval = (uint64_t)((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC)); -+ return retval; -+} -+template void saveTensorFile(const std::string& name, half* tensor, int size); -+template void saveTensorFile(const std::string& name, float* tensor, int size); -+template void saveTensorFile(const std::string& name, int* tensor, int size); -+template void printTensor(char* str, float* input, int size); -+template void isNan(char* str, float* input, int size); -+template float checksum(const float* tensor, int size); -+template void saveTensor(const std::string& name, float* tensor, int size); -+ -+template void printTensor(char* str, half* input, int size); -+template void isNan(char* str, half* input, int size); -+template half checksum(const half* tensor, int size); -+template void saveTensor(const std::string& name, half* tensor, int size); -+ -+template void printTensor(char* str, int* input, int size); -+template void saveTensor(const std::string& name, int* tensor, int size); -+template double checksum2(char* str, const float* tensor, int size); -+template double checksum2(char* str, const half* tensor, int size); -+} -\ No newline at end of file -diff --git a/src/fastertransformer/layers/ms_layers/debug_utils.h b/src/fastertransformer/layers/ms_layers/debug_utils.h -new file mode 100644 -index 0000000..68ee331 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/debug_utils.h -@@ -0,0 +1,41 @@ -+#pragma once -+#include -+#include -+#include -+#include -+#include -+#if __has_include("NvInferRuntimeCommon.h") -+#include "NvInferRuntimeCommon.h" -+#else -+namespace nvinfer1 { -+enum class DataType : int32_t -+{ -+ kFLOAT = 0, -+ kHALF = 1, -+ kINT8 = 2, -+ kINT32 = 3, -+ kBOOL = 4 -+}; -+} -+#endif -+namespace fastertransformer { -+ -+#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) -+#define ALIGN(x, y) (UP_DIV(x, y) * (y)) -+#define ALIGN_SIZE 16 -+ -+template -+void printTensor(char* str, T* input, int size); -+template -+void isNan(char* str, T* input, int size); -+template -+T checksum(const T* tensor, int size); -+template -+void saveTensor(const std::string& name, T* tensor, int size); -+int GetSeq(int* d_seq_len, int idx, int batch); -+template -+double checksum2(char* str, const T* tensor, int size); -+template -+void saveTensorFile(const std::string& name, T* tensor, int size); -+uint64_t GetTimeUs(); -+} -\ No newline at end of file -diff --git a/src/fastertransformer/layers/ms_layers/decoder.cc b/src/fastertransformer/layers/ms_layers/decoder.cc -new file mode 100644 -index 0000000..16cd1e9 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/decoder.cc -@@ -0,0 +1,375 @@ -+ -+#include "src/fastertransformer/layers/ms_layers/decoder.h" -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/kernels/add_residual_kernels.h" -+#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" -+#include "src/fastertransformer/kernels/layernorm_kernels.h" -+#include "src/fastertransformer/kernels/unfused_attention_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/attention.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+#include "src/fastertransformer/layers/ms_layers/opt_allocator.h" -+#include -+namespace fastertransformer { -+ -+template -+Decoder::Decoder(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ float eps1, -+ float eps2, -+ float eps3, -+ float eps4, -+ bool layernorm_post, -+ bool has_beta, -+ bool is_layernorm, -+ bool ffn_fp16, -+ bool qkv_bias, -+ bool projection_bias, -+ bool is_cross, -+ bool position_bias, -+ float scale, -+ bool mask, -+ bool ffn_bias, -+ size_t ffn_hidden_size, -+ FfnBase::ActType act_type, -+ cublasGemmAlgo_t algo): -+ layernorm_post_(layernorm_post), -+ has_beta_(has_beta), -+ is_layernorm_(is_layernorm), -+ ffn_fp16_(ffn_fp16), -+ DecoderBase(batch_size, src_seq_len, tgt_seq_len, head_num, head_size, hidden_size, algo) -+{ -+ attention_layer1_ = std::make_shared>(batch_size, -+ src_seq_len, -+ src_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ qkv_bias, -+ projection_bias, -+ false, -+ position_bias, -+ scale, -+ mask, -+ false, -+ algo); -+ attention_layer2_ = std::make_shared>(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ qkv_bias, -+ projection_bias, -+ true, -+ position_bias, -+ scale, -+ mask, -+ false, -+ algo); -+ is_ffn_fp16_ = (std::is_same::value && ffn_fp16_ == true); -+ if (is_ffn_fp16_) { -+ ffn_layer_ = std::make_shared>(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ ffn_bias, -+ ffn_hidden_size, -+ act_type, -+ algo); -+ } -+ else { -+ ffn_layer_ = std::make_shared>(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ ffn_bias, -+ ffn_hidden_size, -+ act_type, -+ algo); -+ } -+ layer_norm1_ = std::make_shared>(has_beta_, false, LayerNorm::Type::T5, eps1, algo); -+ layer_norm2_ = -+ std::make_shared>(has_beta_, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_T5_PRE, eps2, algo); -+ if (std::is_same::value || !is_ffn_fp16_) -+ layer_norm3_ = -+ std::make_shared>(has_beta_, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_T5_PRE, eps3, algo); -+ else -+ layer_norm3_ = std::make_shared>( -+ has_beta_, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_T5_PRE_CAST, eps3, algo); -+ if (is_layernorm) -+ layer_norm4_ = std::make_shared>(has_beta_, false, LayerNorm::Type::T5, eps4, algo); -+} -+template -+size_t Decoder::GetWorkspaceSize() -+{ -+ size_t attn_out_size = batch_size_ * src_seq_len_ * hidden_size_; -+ size_t tmp_out_size = attn_out_size; -+ size_t compress_buffer_len = batch_size_ * src_seq_len_ * hidden_size_; -+ size_t compress_buffer_len2 = (src_seq_len_ > tgt_seq_len_) ? batch_size_ * src_seq_len_ * hidden_size_ : -+ batch_size_ * tgt_seq_len_ * hidden_size_; -+ size_t padding_len = batch_size_ * src_seq_len_; -+ size_t padding_len2 = batch_size_ * tgt_seq_len_; -+ -+ OptAllocator allocator(ALIGN_SIZE); -+ d_sequence_lengths_offset_buf_ = allocator.Malloc(batch_size_ * sizeof(int)); -+ padding_offset_buf_ = allocator.Malloc(padding_len * sizeof(int)); -+ d_token_num_buf_ = allocator.Malloc(1 * sizeof(size_t)); -+ d_sequence_lengths_offset_buf2_ = allocator.Malloc(batch_size_ * sizeof(int)); -+ padding_offset_buf2_ = allocator.Malloc(padding_len2 * sizeof(int)); -+ d_token_num_buf2_ = allocator.Malloc(1 * sizeof(size_t)); -+ compress_buf_ = allocator.Malloc(compress_buffer_len * sizeof(T)); -+ compress_buf2_ = allocator.Malloc(compress_buffer_len2 * sizeof(T)); -+ normed_from_tensor_buf_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ attn_ws_buf_ = allocator.Malloc(attention_layer1_->GetWorkspaceSize()); -+ attention_layer1_->SetWSOffset(attn_ws_buf_); -+ attn_out_buf_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ allocator.Free(attn_ws_buf_); -+ if (!layernorm_post_) -+ allocator.Free(normed_from_tensor_buf_); -+ normed_attn_out_buf_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ if (layernorm_post_) -+ allocator.Free(normed_from_tensor_buf_); -+ attn2_ws_buf_ = allocator.Malloc(attention_layer2_->GetWorkspaceSize()); -+ attention_layer2_->SetWSOffset(attn2_ws_buf_); -+ attn2_out_buf_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ allocator.Free(attn2_ws_buf_); -+ normed_attn2_out_buf_ = -+ is_ffn_fp16_ ? allocator.Malloc(attn_out_size * sizeof(half)) : allocator.Malloc(attn_out_size * sizeof(T)); -+ allocator.Free(attn_out_buf_); -+ tmp_out_buf_ = -+ is_ffn_fp16_ ? allocator.Malloc(tmp_out_size * sizeof(half)) : allocator.Malloc(tmp_out_size * sizeof(T)); -+ -+ ffn_ws_buf_ = allocator.Malloc(ffn_layer_->GetWorkspaceSize()); -+ ffn_layer_->SetWSOffset(ffn_ws_buf_); -+ return allocator.total_size(); -+} -+ -+template -+void Decoder::GetCompressBuffer(T* compress_buffer, -+ T* from_tensor, -+ int* d_sequence_lengths, -+ int* padding_offset, -+ size_t& h_token_num, -+ size_t* d_token_num, -+ size_t seq_len, -+ cudaStream_t stream) -+{ -+ invokeGetPaddingOffset(&h_token_num, d_token_num, padding_offset, d_sequence_lengths, batch_size_, seq_len, stream); -+ if (h_token_num * 2 <= batch_size_ * seq_len) { -+ invokeRemovePadding( -+ compress_buffer, (const T*)from_tensor, padding_offset, h_token_num, head_num_ * head_size_, stream); -+ } -+} -+ -+template -+void Decoder::ForwardAttention(std::shared_ptr> attention_layer_, -+ std::vector& inputs, -+ std::vector& from_tensor, -+ std::vector& output, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ inputs[--in_idx_] = from_tensor[0]; -+ bool is_projection_bias = attention_layer_->GetProjectionBias(); -+ attention_layer_->SetProjectionBias(false); -+ std::vector attn_in_vector(inputs.begin() + in_idx_, inputs.end()); -+ attention_layer_->forward(attn_in_vector, output, ws, cublas_handle, stream); -+ attention_layer_->SetProjectionBias(is_projection_bias); -+ in_idx_ = attention_layer_->GetIdx() + in_idx_; -+} -+template -+void Decoder::AddBiasResidual(std::vector& inputs, const std::vector& output, cudaStream_t stream) -+{ -+ if (std::is_same::value || !is_ffn_fp16_) { -+ invokeAddBiasResidual(reinterpret_cast(output[0]), -+ reinterpret_cast(inputs[0]), -+ reinterpret_cast(inputs[1]), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ } -+ else { -+ if (layernorm_post_) { -+ invokeAddBiasResidualSameTypeCast(reinterpret_cast(output[0]), -+ reinterpret_cast(inputs[0]), -+ reinterpret_cast(output[1]), -+ reinterpret_cast(inputs[1]), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ } -+ else { -+ invokeAddBiasResidualCast(reinterpret_cast(output[0]), -+ reinterpret_cast(inputs[0]), -+ reinterpret_cast(output[1]), -+ reinterpret_cast(inputs[1]), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ } -+ } -+} -+template -+void Decoder::forward(std::vector& inputs, -+ const std::vector& output, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ -+ int in_len = inputs.size(); -+ ws = GetBuf(ws, ws_offset_); -+ in_idx_ = 0; -+ std::vector decoder_in = inputs; -+ size_t h_token_num = h_token_num_ = batch_size_ * src_seq_len_; -+ size_t h_token_num2 = h_token_num2_ = batch_size_ * tgt_seq_len_; -+ SetHTokenNum(h_token_num, h_token_num2); -+ int* padding_offset = nullptr; -+ int* padding_offset2 = nullptr; -+ T* input_tensor = reinterpret_cast(inputs[in_idx_++]); -+ T* from_tensor = input_tensor; -+ -+ int idx_encoder_out = attention_layer1_->GetPositionBias() ? 7 : 10; -+ T* encoder_output = reinterpret_cast(inputs[idx_encoder_out]); -+ int* d_sequence_lengths2 = reinterpret_cast(inputs[in_len - 2]); -+ int* d_sequence_lengths1 = reinterpret_cast(inputs[in_len - 1]); -+ T* compress_buffer = GetBuf(ws, compress_buf_); -+ T* compress_buffer2 = GetBuf(ws, +compress_buf2_); -+ size_t* d_token_num = GetBuf(ws, +d_token_num_buf_); -+ size_t* d_token_num2 = GetBuf(ws, +d_token_num_buf2_); -+ attention_layer1_->SetVslParam(nullptr, nullptr, nullptr, nullptr); -+ attention_layer2_->SetVslParam(nullptr, nullptr, nullptr, nullptr); -+ if (eft_) { -+ padding_offset = GetBuf(ws, padding_offset_buf_); -+ GetCompressBuffer(compress_buffer, -+ from_tensor, -+ d_sequence_lengths1, -+ padding_offset, -+ h_token_num, -+ d_token_num, -+ src_seq_len_, -+ stream); -+ if (batch_size_ > 1) { -+ if (h_token_num * 2 <= batch_size_ * src_seq_len_) { -+ h_token_num_ = h_token_num; -+ from_tensor = compress_buffer; -+ attention_layer1_->SetVslParam( -+ padding_offset, padding_offset, d_sequence_lengths1, d_sequence_lengths1); -+ attention_layer2_->SetVslParam( -+ padding_offset, padding_offset, d_sequence_lengths1, d_sequence_lengths1); -+ } -+ } -+ else { -+ SetHTokenNum(h_token_num, h_token_num2); -+ } -+ padding_offset2 = GetBuf(ws, padding_offset_buf2_); -+ GetCompressBuffer(compress_buffer2, -+ encoder_output, -+ d_sequence_lengths2, -+ padding_offset2, -+ h_token_num2, -+ d_token_num2, -+ tgt_seq_len_, -+ stream); -+ if (h_token_num2 * 2 <= batch_size_ * tgt_seq_len_) { -+ h_token_num2_ = h_token_num2; -+ decoder_in[idx_encoder_out] = compress_buffer2; -+ attention_layer2_->SetVslParam(padding_offset, padding_offset2, d_sequence_lengths1, d_sequence_lengths2); -+ } -+ } -+ SetHTokenNum(h_token_num_, h_token_num2_); -+ h_token_num = h_token_num_; -+ h_token_num2 = h_token_num2_; -+ T* attn_out = GetBuf(ws, attn_out_buf_); -+ T* normed_from_tensor = GetBuf(ws, normed_from_tensor_buf_); -+ -+ T* normed_attn_out = GetBuf(ws, normed_attn_out_buf_); -+ T* attn2_out = GetBuf(ws, attn2_out_buf_); -+ T* normed_attn2_out = GetBuf(ws, normed_attn2_out_buf_); -+ T* tmp_out = reinterpret_cast(output[0]); -+ if (attention_layer1_->GetPaddingOffset() != nullptr || is_ffn_fp16_ == true || is_layernorm_) { -+ tmp_out = GetBuf(ws, tmp_out_buf_); -+ } -+ T* tmp_out1 = reinterpret_cast(output[0]); -+ T* tmp_out2 = reinterpret_cast(output[0]); -+ -+ if (is_layernorm_ && (attention_layer1_->GetPaddingOffset() != nullptr || is_ffn_fp16_ == true)) { -+ tmp_out1 = compress_buffer2; -+ if (attention_layer1_->GetPaddingOffset() != nullptr) { -+ tmp_out2 = compress_buffer; -+ } -+ } -+ else if (attention_layer1_->GetPaddingOffset() != nullptr && is_ffn_fp16_ == true) { -+ tmp_out1 = compress_buffer; -+ } -+ T* out_buf = is_ffn_fp16_ ? tmp_out1 : tmp_out; -+ layer_norm1_->SetParams(h_token_num_, hidden_size_); -+ layer_norm2_->SetParams(h_token_num_, hidden_size_); -+ layer_norm3_->SetParams(h_token_num_, hidden_size_); -+ -+ T* gamma1 = reinterpret_cast(inputs[in_idx_++]); -+ T* beta1 = (has_beta_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ std::vector in = {from_tensor, gamma1, beta1}; -+ std::vector out = {normed_from_tensor}; -+ layer_norm1_->forward(in, out, ws, cublas_handle, stream); -+ std::vector attn_from_vector{normed_from_tensor}; -+ std::vector attn_out_vector{attn_out}; -+ ForwardAttention(attention_layer1_, decoder_in, attn_from_vector, attn_out_vector, ws, cublas_handle, stream); -+ -+ T* projection_bias = attention_layer1_->GetProjectionBias() ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ T* gamma2 = reinterpret_cast(inputs[in_idx_++]); -+ T* beta2 = (has_beta_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ from_tensor = (layernorm_post_) ? normed_from_tensor : from_tensor; -+ std::vector in2 = {from_tensor, gamma2, beta2, projection_bias}; -+ std::vector out2 = {attn_out, normed_attn_out}; -+ layer_norm2_->forward(in2, out2, ws, cublas_handle, stream); -+ std::vector attn2_from_vector{normed_attn_out}; -+ std::vector attn2_out_vector{attn2_out}; -+ ForwardAttention(attention_layer2_, decoder_in, attn2_from_vector, attn2_out_vector, ws, cublas_handle, stream); -+ -+ T* projection_bias2 = (attention_layer2_->GetProjectionBias()) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ T* gamma3 = reinterpret_cast(inputs[in_idx_++]); -+ T* beta3 = (has_beta_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ attn_out = (layernorm_post_) ? normed_attn_out : attn_out; -+ std::vector in3 = {attn_out, gamma3, beta3, projection_bias2}; -+ std::vector out3 = {attn2_out, normed_attn2_out}; -+ layer_norm3_->forward(in3, out3, ws, cublas_handle, stream); -+ -+ inputs[--in_idx_] = normed_attn2_out; -+ std::vector ffn_in_vector(inputs.begin() + in_idx_, inputs.end()); -+ std::vector ffn_out_vector{tmp_out}; -+ ffn_layer_->forward(ffn_in_vector, ffn_out_vector, ws, cublas_handle, stream); -+ in_idx_ = ffn_layer_->GetIdx() + in_idx_; -+ -+ attn2_out = (layernorm_post_) ? normed_attn2_out : attn2_out; -+ T* ffn_bias = (ffn_layer_->GetffnBias()) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ std::vector add_residual_in_vector{attn2_out, ffn_bias}; -+ std::vector add_residual_out_vector{tmp_out, tmp_out1}; -+ AddBiasResidual(add_residual_in_vector, add_residual_out_vector, stream); -+ if (is_layernorm_) { -+ T* gamma4 = reinterpret_cast(inputs[in_idx_++]); -+ T* beta4 = (has_beta_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ std::vector in4 = {out_buf, gamma4, beta4}; -+ std::vector out4 = {tmp_out2}; -+ layer_norm4_->SetParams(h_token_num_, hidden_size_); -+ layer_norm4_->forward(in4, out4, ws, cublas_handle, stream); -+ out_buf = tmp_out2; -+ } -+ if (attention_layer1_->GetPaddingOffset() != nullptr) { -+ cudaMemsetAsync(output[0], 0, batch_size_ * src_seq_len_ * head_size_ * head_num_ * sizeof(T), stream); -+ invokeRebuildPadding( -+ (T*)output[0], out_buf, attention_layer1_->GetPaddingOffset(), h_token_num, hidden_size_, stream); -+ } -+ return; -+} -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/decoder.h b/src/fastertransformer/layers/ms_layers/decoder.h -new file mode 100644 -index 0000000..b33b42b ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/decoder.h -@@ -0,0 +1,243 @@ -+#pragma once -+ -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/attention.h" -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+#include "src/fastertransformer/layers/ms_layers/layer_norm.h" -+ -+#include -+#include -+ -+namespace fastertransformer { -+ -+class DecoderBase: public BaseLayerMS { -+public: -+ DecoderBase(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP):BaseLayerMS(batch_size, src_seq_len, tgt_seq_len, head_num, head_size, hidden_size, 1, algo){} -+ bool GetEft(); -+ virtual void SetFfnParam(bool ffn_fp16, size_t ffn_hidden_size, FfnBase::ActType act_type = FfnBase::ActType::Gelu, bool ffn_bias = true) = 0; -+ virtual void SetScaleAttn(float scale = 1.0f) = 0; -+ virtual void SetIsLayerNorm(bool is_layernorm, float eps = 1e-6f) = 0; -+ virtual void SetLayerNormPost(bool layernorm_post) = 0; -+ virtual void SetVSL(bool eft) = 0; -+ virtual void SetT5(bool t5) = 0; -+ virtual void SetEps(float eps1, float eps2, float eps3, float eps4 = 1e-6f) = 0; -+ virtual void SetAlgo(cublasGemmAlgo_t algo) override = 0; -+ virtual void SetHTokenNum(size_t h_token_num, size_t h_token_num2) = 0; -+}; -+template -+class Decoder : public DecoderBase{ -+private: -+ std::shared_ptr> attention_layer1_; -+ std::shared_ptr> attention_layer2_; -+ std::shared_ptr ffn_layer_; -+ std::shared_ptr> layer_norm1_; -+ std::shared_ptr> layer_norm2_; -+ std::shared_ptr> layer_norm3_; -+ std::shared_ptr> layer_norm4_; -+ void ForwardAttention(std::shared_ptr> attention_layer_, std::vector &inputs, std::vector &from_tensor, std::vector &output, void *ws, cublasHandle_t cublas_handle, cudaStream_t stream); -+ void AddBiasResidual(std::vector &inputs, const std::vector &output, cudaStream_t stream); -+ -+ void GetCompressBuffer(T* compress_buffer, -+ T* from_tensor, -+ int* d_sequence_lengths, -+ int* padding_offset, -+ size_t& h_token_num, -+ size_t* d_token_num, -+ size_t seq_len, -+ cudaStream_t stream); -+ -+ bool eft_{false}; -+ size_t h_token_num_; -+ size_t h_token_num2_; -+ bool layernorm_post_; -+ bool has_beta_; -+ bool is_layernorm_; -+ bool ffn_fp16_; -+ bool is_ffn_fp16_{false}; -+ -+ size_t normed_from_tensor_buf_; -+ size_t attn_out_buf_; -+ size_t attn_ws_buf_; -+ size_t attn2_out_buf_; -+ size_t attn2_ws_buf_; -+ size_t tmp_out_buf_; -+ size_t ffn_ws_buf_; -+ size_t normed_attn_out_buf_; -+ size_t normed_attn2_out_buf_; -+ size_t compress_buf_; -+ size_t d_token_num_buf_; -+ size_t padding_offset_buf_; -+ size_t d_sequence_lengths_offset_buf_; -+ size_t compress_buf2_; -+ size_t d_token_num_buf2_; -+ size_t padding_offset_buf2_; -+ size_t d_sequence_lengths_offset_buf2_; -+public: -+ void printParam() -+ { -+ std::cout<<"batch_size = "< -+ -+namespace fastertransformer { -+ -+template -+Encoder::Encoder(size_t batch_size, -+ size_t seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ float eps1, -+ float eps2, -+ float eps3, -+ bool layernorm_post, -+ bool has_beta, -+ bool is_layernorm, -+ int embedding_size, -+ bool ffn_fp16, -+ bool qkv_bias, -+ bool projection_bias, -+ bool is_cross, -+ bool position_bias, -+ float scale, -+ bool mask, -+ bool ffn_bias, -+ size_t ffn_hidden_size, -+ FfnBase::ActType act_type, -+ bool use_past, -+ bool query_layer, -+ size_t expert_num, -+ int rank_num, -+ cublasGemmAlgo_t algo): -+ layernorm_post_(layernorm_post), -+ has_beta_(has_beta), -+ is_layernorm_(is_layernorm), -+ embedding_size_(embedding_size), -+ ffn_fp16_(ffn_fp16), -+ use_past_(use_past), -+ query_layer_(query_layer), -+ expert_num_(expert_num), -+ EncoderBase(batch_size, seq_len, head_num, head_size, hidden_size, rank_num, algo) -+{ -+ attention_layer_ = std::make_shared>(batch_size, -+ seq_len, -+ seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ qkv_bias, -+ projection_bias, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ rank_num, -+ algo); -+ is_ffn_fp16_ = (std::is_same::value && ffn_fp16_ == true); -+ if (is_ffn_fp16_) { -+ ffn_layer_ = std::make_shared>(batch_size, -+ seq_len, -+ seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ ffn_bias, -+ ffn_hidden_size, -+ act_type, -+ rank_num, -+ algo); -+ } -+ else { -+ ffn_layer_ = std::make_shared>(batch_size, -+ seq_len, -+ seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ ffn_bias, -+ ffn_hidden_size, -+ act_type, -+ rank_num, -+ algo); -+ } -+ if (layernorm_post == false || position_bias) { -+ layer_norm1_ = std::make_shared>(has_beta, false, LayerNorm::Type::T5, eps1, algo); -+ if (std::is_same::value || !is_ffn_fp16_) -+ layer_norm2_ = std::make_shared>( -+ has_beta, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_T5_PRE, eps2, algo); -+ else -+ layer_norm2_ = std::make_shared>( -+ has_beta, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_T5_PRE_CAST, eps2, algo); -+ } -+ else { -+ if (std::is_same::value || !is_ffn_fp16_) -+ layer_norm1_ = -+ std::make_shared>(has_beta, true, LayerNorm::Type::ADD_BIAS_RESIDUAL, eps1, algo); -+ else -+ layer_norm1_ = -+ std::make_shared>(has_beta, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_CAST, eps1, algo); -+ if (std::is_same::value || !is_ffn_fp16_) -+ layer_norm2_ = -+ std::make_shared>(has_beta, true, LayerNorm::Type::ADD_BIAS_RESIDUAL, eps2, algo); -+ else -+ layer_norm2_ = std::make_shared>( -+ has_beta, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_CAST_FFN, eps2, algo); -+ } -+ if (is_layernorm) -+ layer_norm3_ = std::make_shared>(has_beta, false, LayerNorm::Type::T5, eps3, algo); -+ moe_ = std::make_shared( -+ hidden_size, expert_num, ffn_hidden_size, rank_num, seq_len, 1.1, batch_size); -+ seq_len_host_ = (int*)malloc(batch_size * sizeof(int)); -+ first_layer_ = false; -+} -+template -+size_t Encoder::GetWorkspaceSize() -+{ -+ size_t attn_out_len = batch_size_ * src_seq_len_ * hidden_size_; -+ size_t normed_from_tensor_len = batch_size_ * src_seq_len_ * hidden_size_; -+ size_t normed_attn_out_len = batch_size_ * src_seq_len_ * hidden_size_; -+ size_t tmp_out_size = attn_out_len; -+ size_t compress_buffer_len = batch_size_ * src_seq_len_ * hidden_size_; -+ -+ size_t padding_len = batch_size_ * src_seq_len_; -+ OptAllocator allocator(ALIGN_SIZE); -+ if (use_past_) -+ d_sequence_lengths_offset_buf_ = allocator.Malloc(batch_size_ * sizeof(int)); -+ if (use_past_) -+ d_sequence_lengths_offset_buf2_ = allocator.Malloc(batch_size_ * sizeof(int)); -+ -+ padding_offset_buf_ = allocator.Malloc(padding_len * sizeof(int)); -+ padding_offset_buf2_ = allocator.Malloc(padding_len * sizeof(int)); -+ -+ d_token_num_buf_ = allocator.Malloc(1 * sizeof(size_t)); -+ compress_buf_ = allocator.Malloc(compress_buffer_len * sizeof(T)); -+ -+ normed_from_tensor_buf_ = (!layernorm_post_ || attention_layer_->GetPositionBias()) ? -+ allocator.Malloc(normed_from_tensor_len * sizeof(T)) : -+ 0; -+ attn_ws_buf_ = allocator.Malloc(attention_layer_->GetWorkspaceSize()); -+ attention_layer_->SetWSOffset(attn_ws_buf_); -+ attn_out_buf_ = allocator.Malloc(attn_out_len * sizeof(T)); -+ -+ allocator.Free(d_token_num_buf_); -+ if (use_past_) -+ allocator.Free(d_sequence_lengths_offset_buf_); -+ if (use_past_) -+ allocator.Free(d_sequence_lengths_offset_buf2_); -+ -+ allocator.Free(attn_ws_buf_); -+ if (!layernorm_post_ || attention_layer_->GetPositionBias()) -+ allocator.Free(normed_from_tensor_buf_); -+ normed_attn_out_buf_ = ((!layernorm_post_ || attention_layer_->GetPositionBias()) || is_ffn_fp16_) ? -+ is_ffn_fp16_ ? allocator.Malloc(normed_attn_out_len * sizeof(half)) : -+ allocator.Malloc(normed_attn_out_len * sizeof(T)) : -+ 0; -+ if (is_moe_) { -+ size_t moe_ws_size = moe_->GetWorkspaceSize(); -+ size_t moe_offset = allocator.Malloc(moe_ws_size); -+ moe_->SetWSOffset(moe_offset); -+ } -+ else { -+ moe_ = nullptr; -+ ffn_ws_buf_ = allocator.Malloc(ffn_layer_->GetWorkspaceSize()); -+ ffn_layer_->SetWSOffset(ffn_ws_buf_); -+ } -+ tmp_out_buf_ = -+ is_ffn_fp16_ ? allocator.Malloc(tmp_out_size * sizeof(half)) : allocator.Malloc(tmp_out_size * sizeof(T)); -+ if (is_ffn_fp16_) -+ tmp_out1_buf_ = allocator.Malloc(tmp_out_size * sizeof(T)); -+ size_t size = allocator.total_size(); -+ return size; -+} -+template -+void Encoder::GetCompressBuffer(T* compress_buffer, -+ T* from_tensor, -+ int* d_sequence_lengths, -+ int* padding_offset, -+ size_t& h_token_num, -+ size_t* d_token_num, -+ cudaStream_t stream) -+{ -+ int old_h_token_num = h_token_num; -+ invokeGetPaddingOffset( -+ &h_token_num, d_token_num, padding_offset, d_sequence_lengths, batch_size_, src_seq_len_, stream); -+ if (h_token_num * 2 <= batch_size_ * src_seq_len_ || use_past_) { -+ invokeRemovePadding(compress_buffer, (const T*)from_tensor, padding_offset, h_token_num, hidden_size_, stream); -+ } -+ else -+ h_token_num = old_h_token_num; -+} -+template -+void Encoder::InitUsePast(std::vector& inputs, T*& from_tensor, void* ws, cudaStream_t stream) -+{ -+ int in_len = inputs.size(); -+ int k_cache_idx = in_idx_++; -+ int v_cache_idx = in_idx_++; -+ int vsl_inputs_idx = in_len - 4; -+ int position_idx = in_len - 5; -+ int emmbeding_pos_idx = in_len - 6; -+ int emmbeding_idx = in_len - 7; -+ -+ k_cache_ = reinterpret_cast(inputs[k_cache_idx]); -+ v_cache_ = reinterpret_cast(inputs[v_cache_idx]); -+ -+ size_t* d_token_num = GetBuf(ws, d_token_num_buf_); -+ T* compress_buffer = GetBuf(ws, compress_buf_); -+ int* d_sequence_lengths = reinterpret_cast(inputs[vsl_inputs_idx]); -+ int* d_sequence_lengths2 = reinterpret_cast(inputs[vsl_inputs_idx + 1]); -+ int* padding_offset = reinterpret_cast(inputs[vsl_inputs_idx + 2]); -+ size_t* input_h_token_num = reinterpret_cast(inputs[vsl_inputs_idx + 3]); -+ -+ int h_input_position = 0; -+ int* input_position = reinterpret_cast(inputs[position_idx]); -+ cudaMemcpyAsync(&h_input_position, input_position, sizeof(h_input_position), cudaMemcpyDeviceToHost, stream); -+ cudaStreamSynchronize(stream); -+ -+ cudaMemcpyAsync(&h_token_num_, input_h_token_num, sizeof(size_t), cudaMemcpyDeviceToHost, stream); -+ cudaStreamSynchronize(stream); -+ if (h_input_position == 0) { -+ const size_t size = batch_size_ * src_seq_len_ * head_num_ * head_size_; -+ T* k_cache = reinterpret_cast(k_cache_); -+ T* v_cache = reinterpret_cast(v_cache_); -+ cudaMemsetAsync(k_cache, 0, size * sizeof(T), stream); -+ cudaMemsetAsync(v_cache, 0, size * sizeof(T), stream); -+ attention_layer_->SetCache(k_cache_, v_cache_); -+ attention_layer_->SetIncrementalMode(false); -+ } -+ else { -+ attention_layer_->SetIncrementalMode(true); -+ } -+ -+ attention_layer_->SetVslParam(padding_offset, padding_offset, d_sequence_lengths, d_sequence_lengths2); -+ -+ if (first_layer_) { -+ T* input_after_emmbeding = compress_buffer; -+ T* emmbeding_table = reinterpret_cast(inputs[emmbeding_idx]); -+ T* emmbeding_pos_table = reinterpret_cast(inputs[emmbeding_pos_idx]); -+ int* input_position = reinterpret_cast(inputs[position_idx]); -+ invokeEmbeddingPanguSigma(const_cast(reinterpret_cast(from_tensor)), -+ const_cast(input_position), -+ const_cast(emmbeding_table), -+ const_cast(emmbeding_pos_table), -+ input_after_emmbeding, -+ h_token_num_, -+ hidden_size_, -+ stream); -+ from_tensor = input_after_emmbeding; -+ } -+ if (query_layer_) { -+ T* emmbeding_table = reinterpret_cast(inputs[in_len - 4]); -+ invokeVocabEmbedding(const_cast(reinterpret_cast(inputs[5])), -+ const_cast(emmbeding_table), -+ reinterpret_cast(compress_buffer), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ inputs[5] = compress_buffer; -+ } -+ -+ SetHTokenNum(h_token_num_, h_token_num_); -+} -+ -+template -+void Encoder::ForwardAttention(std::vector& inputs, -+ std::vector& from_tensor, -+ std::vector& output, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ if (query_layer_) { -+ inputs[in_idx_ - 1] = inputs[in_idx_]; -+ inputs[in_idx_] = from_tensor[0]; -+ in_idx_--; -+ } -+ else { -+ inputs[--in_idx_] = from_tensor[0]; -+ } -+ bool is_projection_bias = attention_layer_->GetProjectionBias(); -+ attention_layer_->SetProjectionBias(false); -+ std::vector attn_in_vector(inputs.begin() + in_idx_, inputs.end()); -+ -+ attention_layer_->forward(attn_in_vector, output, ws, cublas_handle, stream); -+ in_idx_ = attention_layer_->GetIdx() + in_idx_; -+ attention_layer_->SetProjectionBias(is_projection_bias); -+ -+ nvinfer1::DataType type = (std::is_same::value) ? nvinfer1::DataType::kFLOAT : nvinfer1::DataType::kHALF; -+ if (all_reduce_sum_func_ != nullptr) { -+ all_reduce_sum_func_(output[0], output[0], h_token_num_ * hidden_size_, type, stream); -+ } -+} -+ -+template -+void Encoder::ForwardFfn(std::vector& inputs, -+ const std::vector& output, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ nvinfer1::DataType type = (std::is_same::value) ? nvinfer1::DataType::kFLOAT : nvinfer1::DataType::kHALF; -+ std::vector ffn_in_vector{inputs.begin() + in_idx_, inputs.end()}; -+ std::vector ffn_output = output; -+ if (is_moe_) { -+ if constexpr (std::is_same::value) { -+ moe_->SetPaddingOffsetDevice(attention_layer_->GetPaddingOffset()); -+ moe_->SetSeqLenHost(seq_len_host_); -+ moe_->SetSeqLenDevice(attention_layer_->GetSequenceLength()); -+ moe_->SetParallelFunc(all_gather_func_, all_reduce_sum_func_); -+ moe_->forward(ffn_in_vector, output, ws, cublas_handle, stream); -+ in_idx_ = in_idx_ + moe_->GetIdx(); -+ } -+ else { -+ std::cout << "moe support only half" << std::endl; -+ } -+ } -+ else { -+ std::vector ffn_in_vector(inputs.begin() + in_idx_, inputs.end()); -+ ffn_layer_->forward(ffn_in_vector, output, ws, cublas_handle, stream); -+ in_idx_ = in_idx_ + ffn_layer_->GetIdx(); -+ if (all_reduce_sum_func_ != nullptr) { -+ all_reduce_sum_func_(output[0], output[0], h_token_num_ * hidden_size_, type, stream); -+ } -+ } -+} -+template -+void Encoder::AddBiasResidual(std::vector& inputs, const std::vector& output, cudaStream_t stream) -+{ -+ if (std::is_same::value || !is_ffn_fp16_) { -+ invokeAddBiasResidual(reinterpret_cast(output[0]), -+ reinterpret_cast(inputs[0]), -+ reinterpret_cast(inputs[1]), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ } -+ else { -+ if (layernorm_post_) { -+ invokeAddBiasResidualSameTypeCast(reinterpret_cast(output[0]), -+ reinterpret_cast(inputs[0]), -+ reinterpret_cast(output[1]), -+ reinterpret_cast(inputs[1]), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ } -+ else { -+ invokeAddBiasResidualCast(reinterpret_cast(output[0]), -+ reinterpret_cast(inputs[0]), -+ reinterpret_cast(output[1]), -+ reinterpret_cast(inputs[1]), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ } -+ } -+} -+template -+void Encoder::MulEmbeddingTable(std::vector& inputs, -+ const std::vector& output, -+ cublasHandle_t cublas_handle) -+{ -+ float alpha = 1.0f; -+ float beta = 0.0f; -+ int gemm_dims[] = {(int)embedding_size_, (int)batch_size_, (int)hidden_size_}; -+ int gemm_lds[] = {(int)hidden_size_, (int)hidden_size_, (int)embedding_size_}; -+ cublasOperation_t gemm_ops[] = {CUBLAS_OP_T, CUBLAS_OP_N}; -+ cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; -+ if (std::is_same::value) { -+ gemm_data_types[0] = CUDA_R_16F; -+ gemm_data_types[1] = CUDA_R_16F; -+ gemm_data_types[2] = CUDA_R_16F; -+ } -+ CublasGemmWrapper(reinterpret_cast(inputs[in_idx_++]), -+ reinterpret_cast(output[0]), -+ output[1], -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_data_types, -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+} -+template -+void Encoder::forward(std::vector& inputs, -+ const std::vector& output, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ ws = GetBuf(ws, ws_offset_); -+ int in_len = inputs.size(); -+ in_idx_ = 0; -+ std::vector encoder_in = inputs; -+ size_t h_token_num = batch_size_ * src_seq_len_; -+ SetHTokenNum(h_token_num, h_token_num); -+ T* input_tensor = reinterpret_cast(encoder_in[in_idx_++]); -+ -+ T* from_tensor = input_tensor; -+ int* d_sequence_lengths = reinterpret_cast(inputs[in_len - 1]); -+ int* d_sequence_lengths2 = d_sequence_lengths; -+ T* compress_buffer = GetBuf(ws, compress_buf_); -+ int* padding_offset = GetBuf(ws, padding_offset_buf_); -+ size_t* d_token_num = GetBuf(ws, d_token_num_buf_); -+ const int batch = batch_size_; -+ -+ bool is_T5 = attention_layer_->GetPositionBias(); -+ SetT5(is_T5); -+ k_cache_ = nullptr; -+ v_cache_ = nullptr; -+ attention_layer_->SetVslParam(nullptr, nullptr, nullptr, nullptr); -+ if (use_past_) { -+ InitUsePast(encoder_in, from_tensor, ws, stream); -+ } -+ else { -+ if (eft_) { -+ GetCompressBuffer( -+ compress_buffer, from_tensor, d_sequence_lengths, padding_offset, h_token_num, d_token_num, stream); -+ if (batch > 1) { -+ if (h_token_num * 2 <= batch_size_ * src_seq_len_) { -+ SetHTokenNum(h_token_num, h_token_num); -+ from_tensor = compress_buffer; -+ attention_layer_->SetVslParam( -+ padding_offset, padding_offset, d_sequence_lengths, d_sequence_lengths); -+ } -+ } -+ else { -+ SetHTokenNum(h_token_num, h_token_num); -+ } -+ } -+ } -+ h_token_num = h_token_num_; -+ bool is_ffn_write_to_output = !(is_ffn_fp16_ || is_layernorm_ || query_layer_); -+ T* attn_out = GetBuf(ws, attn_out_buf_); -+ T* normed_attn_out = GetBuf(ws, normed_attn_out_buf_); -+ T* tmp_out = reinterpret_cast(output[0]); -+ if (!is_ffn_write_to_output) { -+ // in the case that ffn not write directly to output, allocate a diffrent tensor -+ tmp_out = GetBuf(ws, tmp_out_buf_); -+ } -+ T* tmp_out1 = reinterpret_cast(output[0]); -+ T* tmp_out2 = reinterpret_cast(output[0]); -+ if (is_layernorm_ -+ && (attention_layer_->GetPaddingOffset() != nullptr || (std::is_same::value && is_ffn_fp16_)) -+ && !use_past_) { -+ tmp_out1 = GetBuf(ws, tmp_out1_buf_); -+ if (attention_layer_->GetPaddingOffset() != nullptr) { -+ tmp_out2 = compress_buffer; -+ } -+ } -+ else if (attention_layer_->GetPaddingOffset() != nullptr && is_ffn_fp16_) { -+ tmp_out1 = compress_buffer; -+ } -+ T* query_out = reinterpret_cast(output[0]); -+ if (attention_layer_->GetPaddingOffset() != nullptr && query_layer_ -+ && (!attention_layer_->GetIncrementalMode() || h_token_num_ < batch_size_)) { -+ query_out = GetBuf(ws, tmp_out1_buf_); -+ } -+ T* out_buf = tmp_out; -+ -+ // Step I - Do Pre Layer Norm -+ T* normed_from_tensor = from_tensor; -+ layer_norm1_->SetParams(h_token_num_, hidden_size_); -+ layer_norm2_->SetParams(h_token_num_, hidden_size_); -+ if (layernorm_post_ == false || is_T5) { -+ T* gamma1 = reinterpret_cast(encoder_in[in_idx_++]); -+ T* beta1 = (has_beta_) ? reinterpret_cast(encoder_in[in_idx_++]) : nullptr; -+ normed_from_tensor = GetBuf(ws, normed_from_tensor_buf_); -+ std::vector in = {from_tensor, gamma1, beta1}; -+ std::vector normed_out = {normed_from_tensor}; -+ layer_norm1_->forward(in, normed_out, ws, cublas_handle, stream); -+ } -+ -+ std::vector attn_from_vector{normed_from_tensor}; -+ std::vector attn_out_vector{attn_out}; -+ ForwardAttention(encoder_in, attn_from_vector, attn_out_vector, ws, cublas_handle, stream); -+ T* projection_bias = -+ (attention_layer_->GetProjectionBias()) ? reinterpret_cast(encoder_in[in_idx_++]) : nullptr; -+ T* gamma2 = reinterpret_cast(encoder_in[in_idx_++]); -+ T* beta2 = (has_beta_) ? reinterpret_cast(encoder_in[in_idx_++]) : nullptr; -+ if (layernorm_post_ == false || is_T5) { -+ // setup skip connection -+ from_tensor = (layernorm_post_) ? normed_from_tensor : from_tensor; -+ std::vector in2 = {from_tensor, gamma2, beta2, projection_bias}; -+ std::vector out2 = {attn_out, normed_attn_out}; -+ layer_norm2_->forward(in2, out2, ws, cublas_handle, stream); -+ } -+ else { -+ std::vector in2 = {from_tensor, gamma2, beta2, projection_bias}; -+ std::vector out2 = {attn_out, normed_attn_out}; -+ layer_norm1_->forward(in2, out2, ws, cublas_handle, stream); -+ if (!is_ffn_fp16_) { -+ normed_attn_out = attn_out; -+ } -+ } -+ -+ encoder_in[--in_idx_] = normed_attn_out; -+ T* ffn_out_tensor = tmp_out; -+ std::vector ffn_out_vector{ffn_out_tensor}; -+ ForwardFfn(encoder_in, ffn_out_vector, ws, cublas_handle, stream); -+ T* ffn_bias = (ffn_layer_->GetffnBias() && !is_moe_) ? reinterpret_cast(encoder_in[in_idx_++]) : nullptr; -+ T* ffn_norm_tensor = (is_ffn_fp16_) ? tmp_out1 : ffn_out_tensor; -+ if (layernorm_post_ == true && !is_T5) { -+ T* gamma3 = reinterpret_cast(encoder_in[in_idx_++]); -+ T* beta3 = (has_beta_) ? reinterpret_cast(encoder_in[in_idx_++]) : nullptr; -+ attn_out = (is_ffn_fp16_) ? normed_attn_out : attn_out; -+ std::vector in3 = {attn_out, gamma3, beta3, ffn_bias}; -+ std::vector out3 = {ffn_out_tensor, tmp_out1}; -+ layer_norm2_->forward(in3, out3, ws, cublas_handle, stream); -+ } -+ else { -+ attn_out = (layernorm_post_) ? normed_attn_out : attn_out; -+ std::vector add_residual_in_vector{attn_out, ffn_bias}; -+ std::vector add_residual_out_vector{ffn_out_tensor, tmp_out1}; -+ AddBiasResidual(add_residual_in_vector, add_residual_out_vector, stream); -+ } -+ T* last_norm_Tensor = ffn_norm_tensor; -+ if (is_layernorm_) { -+ T* gamma4 = reinterpret_cast(encoder_in[in_idx_++]); -+ T* beta4 = (has_beta_) ? reinterpret_cast(encoder_in[in_idx_++]) : nullptr; -+ std::vector in4 = {ffn_norm_tensor, gamma4, beta4}; -+ std::vector out4 = {tmp_out2}; -+ layer_norm3_->SetParams(h_token_num_, hidden_size_); -+ layer_norm3_->forward(in4, out4, ws, cublas_handle, stream); -+ last_norm_Tensor = tmp_out2; -+ } -+ if (query_layer_) { -+ if (attention_layer_->GetIncrementalMode() && h_token_num_ == batch_size_) { -+ query_out = last_norm_Tensor; -+ } -+ else { -+ invokeRebuildQuery(query_out, -+ last_norm_Tensor, -+ attention_layer_->GetSequenceLength(), -+ attention_layer_->GetIncrementalMode() ? h_token_num_ : batch_size_, -+ hidden_size_, -+ stream); -+ } -+ std::vector mul_out = {query_out, output[0]}; -+ MulEmbeddingTable(encoder_in, mul_out, cublas_handle); -+ } -+ if (attention_layer_->GetPaddingOffset() != nullptr && !use_past_) { -+ int size = batch_size_ * src_seq_len_ * hidden_size_; -+ cudaMemsetAsync(output[0], 0, size * sizeof(T), stream); -+ invokeRebuildPadding( -+ (T*)output[0], last_norm_Tensor, attention_layer_->GetPaddingOffset(), h_token_num_, hidden_size_, stream); -+ } -+ return; -+} -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/encoder.h b/src/fastertransformer/layers/ms_layers/encoder.h -new file mode 100644 -index 0000000..881fc48 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/encoder.h -@@ -0,0 +1,402 @@ -+#pragma once -+ -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/attention.h" -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+#include "src/fastertransformer/layers/ms_layers/layer_norm.h" -+#include "src/fastertransformer/layers/ms_layers/MoeFfnLayer.h" -+ -+#include -+#include -+ -+namespace fastertransformer { -+class EncoderBase: public BaseLayerMS { -+public: -+ EncoderBase(size_t batch_size, -+ size_t seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ int rank_num = 1, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP):BaseLayerMS(batch_size, seq_len, seq_len, head_num, head_size, hidden_size, rank_num, algo){} -+ -+ virtual void SetFfnParam(bool ffn_fp16, size_t ffn_hidden_size, FfnBase::ActType act_type = FfnBase::ActType::Gelu, bool ffn_bias = true) = 0; -+ virtual void SetMoeParam(bool is_moe, size_t expert_num = 0, size_t expert_offset = 0, size_t capacity_factor = 0, FfnBase::ActType act_type = FfnBase::ActType::Gelu) = 0; -+ virtual void SetParallelFunc(BaseLayerMS::allGatherFunc all_gather_func, BaseLayerMS::allReduceSumFunc all_reduce_sum_func) override = 0; -+ virtual void SetRankNum(int rank_num) override = 0; -+ virtual void SetRankId(int rank_id) override = 0; -+ virtual void SetScaleAttn(float scale = 1.0f) = 0; -+ virtual void SetIsLayerNorm(bool is_layernorm, float eps = 1e-6f) = 0; -+ virtual void SetLayerNormPost(bool layernorm_post) = 0; -+ virtual void SetVSL(bool eft) = 0; -+ virtual void SetT5(bool t5) = 0; -+ virtual void SetEps(float eps1, float eps2, float eps3) = 0; -+ virtual void SetUsePast(bool use_past) = 0; -+ virtual void SetAlgo(cublasGemmAlgo_t algo) override = 0; -+ virtual void SetHTokenNum(size_t h_token_num, size_t h_token_num2 = -1) = 0; -+ virtual void SetCache(void* k_cache, void* v_cache) = 0; -+ virtual void SetQueryLayer(bool query_layer) = 0; -+ virtual void SetEmmbedingSize(size_t embedding_size) = 0; -+ virtual void SetFirstLayer(bool first_layer) = 0; -+ virtual void SetRankParam(int rank_num = 0, int rank_id = 0) = 0; -+}; -+template -+class Encoder: public EncoderBase { -+private: -+ std::shared_ptr> attention_layer_; -+ std::shared_ptr ffn_layer_; -+ std::shared_ptr> layer_norm1_; -+ std::shared_ptr> layer_norm2_; -+ std::shared_ptr> layer_norm3_; -+void GetCompressBuffer(T* compress_buffer, -+ T* from_tensor, -+ int* d_sequence_lengths, -+ int* padding_offset, -+ size_t& h_token_num, -+ size_t* d_token_num, -+ cudaStream_t stream); -+void InitUsePast(std::vector &inputs, T* &from_tensor, void *ws, cudaStream_t stream); -+void ForwardAttention(std::vector &inputs, std::vector &from_tensor, std::vector &output, void *ws, cublasHandle_t cublas_handle, cudaStream_t stream); -+void ForwardFfn(std::vector &inputs, const std::vector &output, void *ws, cublasHandle_t cublas_handle, cudaStream_t stream); -+void AddBiasResidual(std::vector &inputs, const std::vector &output, cudaStream_t stream); -+void MulEmbeddingTable(std::vector &inputs, const std::vector &output, cublasHandle_t cublas_handle); -+ -+ -+ bool use_past_; // use past mode -+ bool query_layer_; // check if quary layer -+ -+ size_t data_parallel_{false}; -+ int cur_token_id_{0}; // current token id id -+ bool eft_{false}; -+ size_t h_token_num_; -+ size_t h_token_num2_; -+ void* k_cache_{nullptr}; -+ void* v_cache_{nullptr}; -+ int* seq_len_host_{nullptr}; -+ bool layernorm_post_; -+ bool has_beta_; -+ bool is_layernorm_; -+ bool ffn_fp16_; -+ int embedding_size_; -+ size_t normed_from_tensor_buf_{0}; -+ size_t attn_ws_buf_{0}; -+ size_t attn_out_buf_{0}; -+ size_t normed_attn_out_buf_{0}; -+ size_t ffn_ws_buf_{0}; -+ size_t tmp_out_buf_{0}; -+ size_t tmp_out1_buf_{0}; -+ size_t compress_buf_{0}; -+ size_t compress_buf2_{0}; -+ size_t d_token_num_buf_{0}; -+ size_t padding_offset_buf_{0}; -+ size_t padding_offset_buf2_{0}; -+ size_t d_sequence_lengths_offset_buf_{0}; -+ size_t d_sequence_lengths_offset_buf2_{0}; -+ -+ size_t norm_out_buf_{0}; -+ bool is_ffn_fp16_{false}; -+ bool is_moe_{0}; -+ std::shared_ptr moe_; -+ size_t expert_num_{0}; -+ size_t expert_offset_{0}; -+ size_t capacity_factor_{0}; -+ -+ bool first_layer_{0}; -+ -+public: -+ void printParam() -+ { -+ std::cout<<"batch_size = "< -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include "src/fastertransformer/utils/gemm_test/gemm_func.cc" -+#include "src/fastertransformer/layers/ms_layers/opt_allocator.h" -+namespace fastertransformer { -+ -+template -+size_t Ffn::GetWorkspaceSize() -+{ -+ size_t ffn_len = -+ batch_size_ * src_seq_len_ * ffn_hidden_size_; -+ OptAllocator allocator(ALIGN_SIZE); -+ allocator.Malloc(ffn_len * sizeof(T)); -+ return allocator.total_size(); -+} -+ -+template -+void Ffn::forward(std::vector &inputs, const std::vector &outputs, void *ws, cublasHandle_t cublas_handle, cudaStream_t stream) -+{ -+ in_idx_ = 0; -+ ws = GetBuf(ws, ws_offset_); -+ size_t inter_size = ffn_hidden_size_; -+ size_t h_token_num = h_token_num_; -+ cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N}; -+ cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; -+ if (std::is_same::value) { -+ gemm_data_types[0] = CUDA_R_16F; -+ gemm_data_types[1] = CUDA_R_16F; -+ gemm_data_types[2] = CUDA_R_16F; -+ } -+ float alpha = 1.0f; -+ float beta = 0.0f; -+ -+ int gemm_dims[] = {(int)inter_size, (int)h_token_num, (int)hidden_size_}; -+ int gemm_lds[] = {(int)inter_size, (int)hidden_size_, (int)inter_size}; -+ T* normed_attn_out = reinterpret_cast(inputs[in_idx_++]); -+ T* from = reinterpret_cast(inputs[in_idx_++]); -+ T* bias = (ffn_bias_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ -+ CublasGemmWrapper(from, -+ normed_attn_out, -+ ws, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_data_types, -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+ if (act_type_ == ActType::Gelu) { -+ invokeAddBiasGelu(reinterpret_cast(ws), bias, h_token_num, inter_size, stream); -+ } else if (act_type_ == ActType::FastGelu) { -+ invokeAddBiasFastGelu(reinterpret_cast(ws), bias, h_token_num, inter_size, stream); -+ } else if (act_type_ == ActType::Relu) { -+ invokeAddBiasRelu(reinterpret_cast(ws), bias, h_token_num, inter_size, stream); -+ } -+ else if (ffn_bias_ && act_type_ == ActType::No) { -+ invokeAddBias(reinterpret_cast(ws), bias, h_token_num, inter_size, stream); -+ } -+ gemm_dims[0] = hidden_size_; -+ gemm_dims[1] = h_token_num; -+ gemm_dims[2] = inter_size; -+ gemm_lds[0] = hidden_size_; -+ gemm_lds[1] = inter_size; -+ gemm_lds[2] = hidden_size_; -+ CublasGemmWrapper(reinterpret_cast(inputs[in_idx_++]), -+ ws, -+ outputs[0], -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_data_types, -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+ -+ -+} -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/ffn.h b/src/fastertransformer/layers/ms_layers/ffn.h -new file mode 100644 -index 0000000..10bf6dc ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/ffn.h -@@ -0,0 +1,132 @@ -+#pragma once -+ -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include -+#include -+ -+namespace fastertransformer { -+class FfnBase : public BaseLayerMS { -+public: -+ enum ActType { -+ No = 0, -+ Relu = 1, -+ Sigmoid = 2, -+ Relu6 = 3, -+ Elu = 4, -+ LeakyRelu = 5, -+ Abs = 6, -+ Relu1 = 7, -+ Softsign = 8, -+ Softplus = 9, -+ Tanh = 10, -+ Selu = 11, -+ HSwish = 12, -+ HSigmoid = 13, -+ ThresholdRelu = 14, -+ Linear = 15, -+ HardTanh = 16, -+ Sign = 17, -+ Swish = 18, -+ Gelu = 19, -+ FastGelu = 20, -+ Unknown = 21 -+ }; -+protected: -+ bool ffn_bias_; -+ size_t ffn_hidden_size_; -+ ActType act_type_; -+ size_t h_token_num_; -+ -+public: -+ void SetFfnHiddenSize(size_t ffn_hidden_size) -+ { -+ ffn_hidden_size_ = ffn_hidden_size; -+ } -+ void SetActType(ActType act_type) -+ { -+ act_type_ = act_type; -+ } -+ void SetffnBias(bool ffn_bias) -+ { -+ ffn_bias_ = ffn_bias; -+ } -+ size_t GetFfnHiddenSize() -+ { -+ return ffn_hidden_size_; -+ } -+ ActType GetActType() -+ { -+ return act_type_; -+ } -+ bool GetffnBias() -+ { -+ return ffn_bias_; -+ } -+ void SetHTokenNum(size_t h_token_num) -+ { -+ h_token_num_ = h_token_num; -+ } -+ void printParam() -+ { -+ std::cout<<"ffn param:\n"; -+ std::cout<<"batch_size = "< -+ -+namespace fastertransformer { -+ -+#define __ARCH__ 80 -+#ifdef __CUDA_ARCH__ -+#undef __CUDA_ARCH_HOST__ -+#define __CUDA_ARCH_HOST__ __CUDA_ARCH__ -+#endif -+#ifndef __CUDA_ARCH_HOST__ -+ #ifndef __CUDA_ARCH__ -+ #error "Need cuda arch at least 5.0" -+ #else -+ #define __CUDA_ARCH_HOST__ __CUDA_ARCH__ -+ #endif -+#endif -+#if __CUDA_ARCH_HOST__ < 750 -+ #undef __ARCH__ -+ #define __ARCH__ 70 -+#elif __CUDA_ARCH_HOST__ < 800 -+ #undef __ARCH__ -+ #define __ARCH__ 75 -+#endif -+ -+#define CONCAT(_x) cutlass::arch::Sm##_x -+ -+#define INSTANTIATE_ATTENTION_KERNEL_INSTANCE_ENABLE( \ -+ ARCH, SCALAR_T, IS_ALIGNED, QUERIES_PER_BLOCK, KEYS_PER_BLOCK, SINGLE_VALUE_ITER) \ -+ AttentionKernel -+ -+#define INSTANTIATE_ATTENTION_KERNEL_PARAM(__p__, __in0__, __in1__, __in2__, __in3__, __in4__, __out__) \ -+ { \ -+ p.query_ptr = __in0__; \ -+ p.key_ptr = __in1__; \ -+ p.value_ptr = __in2__; \ -+ p.attn_mask_ptr = __in3__; \ -+ p.attn_bias_ptr = __in4__; \ -+ p.cu_seqlens_q_ptr = d_sequence_length_; \ -+ p.cu_seqlens_k_ptr = d_sequence_length2_; \ -+ p.logsumexp_ptr = nullptr; \ -+ p.output_accum_ptr = nullptr; \ -+ p.output_ptr = __out__; \ -+ p.num_heads = head_num_; \ -+ p.num_batches = batch_size_; \ -+ p.head_dim = head_size_; \ -+ p.head_dim_value = head_size_; \ -+ p.num_queries =src_seq_len_; \ -+ p.num_keys = tgt_seq_len_; \ -+ p.scale = scale_; \ -+ p.causal = false; \ -+ p.no_bias_head_dim = (is_cross_ && position_bias_); \ -+ p.q_strideM = head_size_; \ -+ p.k_strideM = head_size_; \ -+ p.v_strideM = head_size_; \ -+ p.attn_mask_strideM = tgt_seq_len_; \ -+ p.attn_bias_strideM = tgt_seq_len_; \ -+ p.q_strideH = p.q_strideM *src_seq_len_; \ -+ p.k_strideH = p.k_strideM * tgt_seq_len_; \ -+ p.v_strideH = p.v_strideM * tgt_seq_len_; \ -+ p.o_strideH = head_size_; \ -+ p.attn_mask_strideH = p.attn_mask_strideM *src_seq_len_; \ -+ p.attn_bias_strideH = (p.no_bias_head_dim) ? 0 : p.attn_mask_strideH; \ -+ p.q_strideB = p.q_strideH * head_num_; \ -+ p.k_strideB = p.k_strideH * head_num_; \ -+ p.v_strideB = p.v_strideH * head_num_; \ -+ p.o_strideB = \ -+ src_seq_len_ * head_num_ * head_size_; \ -+ p.attn_mask_strideB = p.attn_mask_strideH; \ -+ p.attn_bias_strideB = (p.no_bias_head_dim) ? (p.attn_mask_strideH) : p.attn_bias_strideH * p.head_dim; \ -+ if (use_past_) { \ -+ p.num_queries = h_token_num_; \ -+ p.num_keys = cur_token_id_ + 1; \ -+ p.attn_mask_ptr = nullptr; \ -+ if (!incremental_mode_) { \ -+ p.causal = true; \ -+ } \ -+ } \ -+ p.use_past = use_past_; \ -+ } -+template -+template -+void FusedCutlassMha::forward_fmha__(const std::vector &inputs, const std::vector &output, void *ws, cudaStream_t stream) -+{ -+ ws = GetBuf(ws, ws_offset_); -+ const bool isAligned = true; -+ using Attention = INSTANTIATE_ATTENTION_KERNEL_INSTANCE_ENABLE( -+ __ARCH__, T, isAligned, kQueriesPerBlock, kKeysPerBlock, kSingleValueIteration); -+ typename Attention::Params p; -+ INSTANTIATE_ATTENTION_KERNEL_PARAM(p, inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], output[0]); -+ constexpr auto kernel_fn = attention_kernel_batched_impl; -+ int smem_bytes = sizeof(typename Attention::SharedStorage); -+ if (smem_bytes > 0xc000) { -+ cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); -+ } -+ if (Attention::kNeedsOutputAccumulatorBuffer) { -+ p.output_accum_ptr = reinterpret_cast(ws); -+ } -+ kernel_fn<<>>(p); -+} -+ -+template -+template -+bool FusedCutlassMha::is_fmha_support__() -+{ -+ const bool isAligned = true; -+ using Attention = INSTANTIATE_ATTENTION_KERNEL_INSTANCE_ENABLE( -+ __ARCH__, T, isAligned, kQueriesPerBlock, kKeysPerBlock, kSingleValueIteration); -+ typename Attention::Params p; -+ INSTANTIATE_ATTENTION_KERNEL_PARAM(p, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); -+ return Attention::check_supported(p); -+} -+template -+bool FusedCutlassMha::isSupport() -+{ -+ // Determine kernel configuration based on head size. -+ // If head size is less than or equal to 64, each block operates over 64 queries and -+ // 64 keys, and parital results can be stored in the register file. -+ // If head size is greater than 64, each block operates over 32 queries and 128 keys, -+ // and partial results are stored in shared memory. -+ -+ if (head_size_ > 64) { -+ static int const kQueriesPerBlock = 32; -+ static int const kKeysPerBlock = 128; -+ if (head_size_ <= kKeysPerBlock) { -+ return is_fmha_support__(); -+ } -+ else { -+ return is_fmha_support__(); -+ } -+ } -+ else { -+ static int const kQueriesPerBlock = 64; -+ static int const kKeysPerBlock = 64; -+ return is_fmha_support__(); -+ } -+} -+ -+template -+template -+size_t FusedCutlassMha::get_fmha_workspace__() -+{ -+ const bool isAligned = true; -+ -+ using Attention = INSTANTIATE_ATTENTION_KERNEL_INSTANCE_ENABLE( -+ __ARCH__, T, isAligned, kQueriesPerBlock, kKeysPerBlock, kSingleValueIteration); -+ typename Attention::Params p; -+ INSTANTIATE_ATTENTION_KERNEL_PARAM(p, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); -+ size_t size = 0; -+ if (Attention::kNeedsOutputAccumulatorBuffer) { -+ size += batch_size_ * src_seq_len_ * head_num_ * head_size_; -+ } -+ return size * sizeof(T); -+} -+ -+template -+size_t FusedCutlassMha::GetWorkspaceSize() -+{ -+ // Determine kernel configuration based on head size. -+ // If head size is less than or equal to 64, each block operates over 64 queries and -+ // 64 keys, and parital results can be stored in the register file. -+ // If head size is greater than 64, each block operates over 32 queries and 128 keys, -+ // and partial results are stored in shared memory. -+ if (head_size_ > 64) { -+ static int const kQueriesPerBlock = 32; -+ static int const kKeysPerBlock = 128; -+ if (head_size_ <= kKeysPerBlock) { -+ return get_fmha_workspace__(); -+ } -+ else { -+ return get_fmha_workspace__(); -+ } -+ } -+ else { -+ static int const kQueriesPerBlock = 64; -+ static int const kKeysPerBlock = 64; -+ return get_fmha_workspace__(); -+ } -+} -+ -+template -+void FusedCutlassMha::forward(std::vector &inputs, const std::vector &outputs, void *ws, cublasHandle_t cublas_handle, cudaStream_t stream) -+{ -+ // Determine kernel configuration based on head size. -+ // If head size is less than or equal to 64, each block operates over 64 queries and -+ // 64 keys, and parital results can be stored in the register file. -+ // If head size is greater than 64, each block operates over 32 queries and 128 keys, -+ // and partial results are stored in shared memory. -+ std::vectormha_in(inputs.size()); -+ std::transform(inputs.begin(), inputs.end(), mha_in.begin(), [](void *x) { return reinterpret_cast(x);}); -+ std::vectormha_out(outputs.size()); -+ std::transform(outputs.begin(), outputs.end(), mha_out.begin(), [](void *x) { return reinterpret_cast(x);}); -+ -+ if (head_size_ > 64) { -+ static int const kQueriesPerBlock = 32; -+ static int const kKeysPerBlock = 128; -+ if (head_size_ <= kKeysPerBlock) { -+ return forward_fmha__(mha_in,mha_out,ws, stream); -+ } -+ else { -+ return forward_fmha__(mha_in, mha_out, ws, stream); -+ } -+ } -+ else { -+ static int const kQueriesPerBlock = 64; -+ static int const kKeysPerBlock = 64; -+ return forward_fmha__(mha_in, mha_out ,ws, stream); -+ } -+} -+ -+} // namespace fastertransformer -\ No newline at end of file -diff --git a/src/fastertransformer/layers/ms_layers/fmha_cutlass.h b/src/fastertransformer/layers/ms_layers/fmha_cutlass.h -new file mode 100644 -index 0000000..1b7eb78 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/fmha_cutlass.h -@@ -0,0 +1,64 @@ -+/* -+ * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+ -+#include "param.h" -+#include "cutlass/half.h" -+#include "src/fastertransformer/layers/ms_layers/attention.h" -+namespace fastertransformer { -+template -+class FusedCutlassMha: public MhaDispatch { -+ template -+ bool is_fmha_support__(); -+ template -+ size_t get_fmha_workspace__(); -+ template -+ void forward_fmha__(const std::vector &inputs, const std::vector &output, void *ws, cudaStream_t stream); -+public: -+ FusedCutlassMha(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool is_cross = false, -+ bool position_bias = false, -+ float scale = 1.0f, -+ bool mask = true, -+ bool use_past = false, -+ int rank_num = 1, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) : -+ MhaDispatch(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ rank_num, -+ algo) {} -+ void forward(std::vector &inputs, const std::vector &output, void *ws, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0) override; -+ bool isSupport(); -+ size_t GetWorkspaceSize(); -+}; -+template class FusedCutlassMha; -+template class FusedCutlassMha; -+} -\ No newline at end of file -diff --git a/src/fastertransformer/layers/ms_layers/gemm.cc b/src/fastertransformer/layers/ms_layers/gemm.cc -new file mode 100644 -index 0000000..f249fa8 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/gemm.cc -@@ -0,0 +1,117 @@ -+ -+#include "src/fastertransformer/layers/ms_layers/gemm.h" -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/kernels/unfused_attention_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include -+namespace fastertransformer { -+ -+void CublasGemmWrapper(const void* a_addr, -+ const void* b_addr, -+ void* c_addr, -+ const int* params, -+ const int* lds, -+ const cublasOperation_t* operations, -+ const cudaDataType* data_types, -+ void* alpha, -+ void* beta, -+ cublasHandle_t cublas_handle, -+ cublasGemmAlgo_t algo) -+{ -+ const int m = params[0]; -+ const int n = params[1]; -+ const int k = params[2]; -+ cublasOperation_t trans_a = operations[0]; -+ cublasOperation_t trans_b = operations[1]; -+ const int lda = lds[0]; -+ const int ldb = lds[1]; -+ const int ldc = lds[2]; -+ cudaDataType type_a = data_types[0]; -+ cudaDataType type_b = data_types[1]; -+ cudaDataType type_c = data_types[2]; -+ cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; -+ -+ if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) { -+ compute_type = CUBLAS_COMPUTE_32F; -+ } -+ -+ cublasGemmEx(cublas_handle, -+ trans_a, -+ trans_b, -+ m, -+ n, -+ k, -+ alpha, -+ a_addr, -+ type_a, -+ lda, -+ b_addr, -+ type_b, -+ ldb, -+ beta, -+ c_addr, -+ type_c, -+ ldc, -+ compute_type, -+ algo); -+} -+ -+void CublasGemmStridedBatchedWrapper(const void* a_addr, -+ const void* b_addr, -+ void* c_addr, -+ const int* params, -+ const int* lds, -+ const cublasOperation_t* operations, -+ const int* strides, -+ const cudaDataType* data_types, -+ void* alpha, -+ void* beta, -+ int batch, -+ cublasHandle_t cublas_handle, -+ cublasGemmAlgo_t algo) -+{ -+ const int m = params[0]; -+ const int n = params[1]; -+ const int k = params[2]; -+ cublasOperation_t trans_a = operations[0]; -+ cublasOperation_t trans_b = operations[1]; -+ const int lda = lds[0]; -+ const int ldb = lds[1]; -+ const int ldc = lds[2]; -+ cudaDataType type_a = data_types[0]; -+ cudaDataType type_b = data_types[1]; -+ cudaDataType type_c = data_types[2]; -+ cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; -+ -+ if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) { -+ compute_type = CUBLAS_COMPUTE_32F; -+ } -+ const int stride_a = strides[0]; -+ const int stride_b = strides[1]; -+ const int stride_c = strides[2]; -+ cublasGemmStridedBatchedEx(cublas_handle, -+ trans_a, -+ trans_b, -+ m, -+ n, -+ k, -+ alpha, -+ a_addr, -+ type_a, -+ lda, -+ stride_a, -+ b_addr, -+ type_b, -+ ldb, -+ stride_b, -+ beta, -+ c_addr, -+ type_c, -+ ldc, -+ stride_c, -+ batch, -+ compute_type, -+ algo); -+} -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/gemm.h b/src/fastertransformer/layers/ms_layers/gemm.h -new file mode 100644 -index 0000000..8c25ea9 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/gemm.h -@@ -0,0 +1,14 @@ -+#pragma once -+ -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include -+#include -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+namespace fastertransformer { -+ -+void CublasGemmWrapper(const void* a_addr, const void* b_addr, void* c_addr, const int* params, const int* lds, const cublasOperation_t* operations, const cudaDataType* data_types, void* alpha, void* beta, cublasHandle_t cublas_handle, cublasGemmAlgo_t algo); -+void CublasGemmStridedBatchedWrapper(const void* a_addr, const void* b_addr, void* c_addr, const int* params, const int* lds, const cublasOperation_t* operations, const int* strides, const cudaDataType* data_types, void* alpha, void* beta, int batch, cublasHandle_t cublas_handle, cublasGemmAlgo_t algo); -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/layer_norm.cc b/src/fastertransformer/layers/ms_layers/layer_norm.cc -new file mode 100644 -index 0000000..dcb8385 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/layer_norm.cc -@@ -0,0 +1,91 @@ -+ -+#include "src/fastertransformer/kernels/layernorm_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include "src/fastertransformer/kernels/add_residual_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/layer_norm.h" -+#include -+namespace fastertransformer { -+ template -+ void LayerNorm::forward(std::vector &inputs, const std::vector &output, void *ws, cublasHandle_t cublas_handle, cudaStream_t stream) -+ { -+ in_idx_ = 0; -+ T* input = reinterpret_cast(inputs[in_idx_++]); -+ T* gamma = reinterpret_cast(inputs[in_idx_++]); -+ T* beta = (has_beta_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ T* bias = (has_bias_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ switch (type_) -+ { -+ case T5: -+ invokeGeneralT5LayerNorm(reinterpret_cast(output[0]), -+ input, -+ gamma, -+ beta, -+ m_, -+ n_, -+ stream, -+ eps_); -+ break; -+ case ADD_BIAS_RESIDUAL: -+ invokeAddBiasResidualLayerNorm(reinterpret_cast(output[0]), -+ input, -+ bias, -+ gamma, -+ beta, -+ m_, -+ n_, -+ stream, -+ eps_); -+ break; -+ case ADD_BIAS_RESIDUAL_CAST: -+ if(!std::is_same::value) invokeAddBiasResidualLayerNormCast(reinterpret_cast(output[0]), -+ reinterpret_cast(output[1]), -+ reinterpret_cast(input), -+ bias, -+ gamma, // gamma -+ beta, // beta -+ m_, -+ n_, -+ stream, -+ eps_); -+ break; -+ case ADD_BIAS_RESIDUAL_CAST_FFN: -+ if(!std::is_same::value) invokeAddBiasResidualLayerNormCast(reinterpret_cast(output[0]), -+ reinterpret_cast(output[1]), -+ reinterpret_cast(input), -+ bias, -+ gamma, // gamma -+ beta, // beta -+ m_, -+ n_, -+ stream, -+ eps_); -+ break; -+ case ADD_BIAS_RESIDUAL_T5_PRE: -+ invokeGeneralAddBiasResidualT5PreLayerNorm(reinterpret_cast(output[0]), -+ reinterpret_cast(output[1]), -+ input, -+ gamma, // gamma -+ beta, // beta -+ bias, -+ m_, -+ n_, -+ stream, -+ eps_); -+ break; -+ case ADD_BIAS_RESIDUAL_T5_PRE_CAST: -+ if(!std::is_same::value) invokeGeneralAddBiasResidualT5PreLayerNormCast(reinterpret_cast(output[0]), -+ reinterpret_cast(output[1]), -+ input, -+ gamma, // gamma -+ beta, // beta -+ bias, -+ m_, -+ n_, -+ stream, -+ eps_); -+ break; -+ default: -+ break; -+ } -+ } -+} -\ No newline at end of file -diff --git a/src/fastertransformer/layers/ms_layers/layer_norm.h b/src/fastertransformer/layers/ms_layers/layer_norm.h -new file mode 100644 -index 0000000..df88423 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/layer_norm.h -@@ -0,0 +1,60 @@ -+#pragma once -+ -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+namespace fastertransformer { -+ -+ template -+ class LayerNorm: public BaseLayerMS { -+ public: -+ enum Type { -+ T5, -+ ADD_BIAS_RESIDUAL, -+ ADD_BIAS_RESIDUAL_CAST, -+ ADD_BIAS_RESIDUAL_T5_PRE, -+ ADD_BIAS_RESIDUAL_T5_PRE_CAST, -+ ADD_BIAS_RESIDUAL_CAST_FFN -+ }; -+ static Type type; -+ private: -+ size_t m_; -+ size_t n_; -+ bool has_beta_; -+ bool has_bias_; -+ Type type_; -+ float eps_; -+ public: -+ void SetParams(size_t m, size_t n) -+ { -+ m_ = m; -+ n_ = n; -+ } -+ -+ void SetType(Type type) -+ { -+ type_ = type; -+ if (type != T5) has_bias_ = true; -+ else has_bias_ = false; -+ } -+ void SetEps(float eps) -+ { -+ eps_ = eps; -+ } -+ void SetBeta(bool has_beta) -+ { -+ has_beta_ = has_beta; -+ } -+ void forward(std::vector &inputs, const std::vector &outputs, void *ws, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0) override; -+ LayerNorm(bool has_beta = true, -+ bool has_bias = false, -+ Type type = 0, -+ float eps = 1e-6f, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP):BaseLayerMS(0, 0, 0, 0, 0, 0, 0, algo){ -+ has_beta_ = has_beta; -+ has_bias_ = has_bias; -+ type_ = type; -+ eps_ = eps; -+ } -+ }; -+ template class LayerNorm; -+ template class LayerNorm; -+} -\ No newline at end of file -diff --git a/src/fastertransformer/layers/ms_layers/opt_allocator.cc b/src/fastertransformer/layers/ms_layers/opt_allocator.cc -new file mode 100644 -index 0000000..560b5ba ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/opt_allocator.cc -@@ -0,0 +1,89 @@ -+/** -+ * Copyright 2021 Huawei Technologies Co., Ltd -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+#include "src/fastertransformer/layers/ms_layers/opt_allocator.h" -+#include -+#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) -+ -+namespace fastertransformer { -+size_t OptAllocator::FindFree(size_t size) { -+ size_t min_size = std::numeric_limits::max(); -+ size_t min_addr = std::numeric_limits::max(); -+ for (auto const &itr : arena_) { -+ // best fit -+ if (itr.second >= size) { -+ if (min_size > itr.second) { -+ min_size = itr.second; -+ min_addr = itr.first; -+ } -+ } -+ } -+ return min_addr; -+} -+ -+void OptAllocator::Reorder(size_t addr) { -+ size_t length = arena_[addr]; -+ size_t post = addr + length; -+ // connect to upper block -+ auto it = arena_.find(post); -+ if (it != arena_.end()) { -+ size_t post_size = it->second; -+ arena_[addr] = length + post_size; -+ arena_.erase(post); -+ } -+ // connect to lower block -+ auto itr = arena_.lower_bound(addr); -+ if (itr != arena_.begin()) { -+ itr--; -+ size_t last = itr->first; -+ if ((last + arena_[last]) == addr) { -+ arena_[last] = arena_[last] + arena_[addr]; -+ arena_.erase(addr); -+ } -+ } -+} -+ -+size_t OptAllocator::Malloc(size_t size) { -+ size = UP_DIV(size, align_size_) * align_size_; -+ size_t addr = FindFree(size); -+ // free block not found -+ if (addr == std::numeric_limits::max()) { -+ if (!arena_.empty()) { -+ addr = arena_.rbegin()->first; -+ if (addr + arena_[addr] < heap_) { -+ addr = heap_; -+ } else { -+ arena_.erase(addr); -+ } -+ } else { -+ addr = heap_; -+ } -+ heap_ = addr + size; -+ } else { -+ if (arena_[addr] > size) { -+ arena_[addr + size] = arena_[addr] - size; -+ } -+ arena_.erase(addr); -+ } -+ alloc_[addr] = size; -+ return addr; -+} -+ -+void OptAllocator::Free(size_t addr) { -+ arena_[addr] = alloc_[addr]; -+ alloc_.erase(addr); -+ Reorder(addr); -+} -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/opt_allocator.h b/src/fastertransformer/layers/ms_layers/opt_allocator.h -new file mode 100644 -index 0000000..13a539e ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/opt_allocator.h -@@ -0,0 +1,40 @@ -+/** -+ * Copyright 2020 Huawei Technologies Co., Ltd -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#ifndef FASTERTRANSFOMER_LITE_SRC_TRAIN_OPT_ALLOCATOR_H_ -+#define FASTERTRANSFOMER_LITE_SRC_TRAIN_OPT_ALLOCATOR_H_ -+ -+#include -+ -+namespace fastertransformer { -+class OptAllocator { -+ public: -+ explicit OptAllocator(size_t aligned_size = 32) : align_size_(aligned_size) {} -+ ~OptAllocator() {} -+ size_t Malloc(size_t size); -+ void Free(size_t offset); -+ size_t total_size() { return heap_; } -+ -+ private: -+ size_t FindFree(size_t size); -+ void Reorder(size_t addr); -+ std::map arena_; -+ std::map alloc_; -+ size_t heap_ = 0; -+ size_t align_size_; -+}; -+}; // namespace fastertransformer -+#endif // FASTERTRANSFOMER_LITE_SRC_TRAIN_OPT_ALLOCATOR_H_ -diff --git a/src/fastertransformer/layers/ms_layers/param.h b/src/fastertransformer/layers/ms_layers/param.h -new file mode 100644 -index 0000000..c8cb149 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/param.h -@@ -0,0 +1,183 @@ -+#pragma once -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include -+#include -+#include -+#include -+#include -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+ -+ -+ -+namespace fastertransformer { -+ -+ -+typedef enum ActType { -+ ActType_No = 0, -+ ActType_Relu = 1, -+ ActType_Sigmoid = 2, -+ ActType_Relu6 = 3, -+ ActType_Elu = 4, -+ ActType_LeakyRelu = 5, -+ ActType_Abs = 6, -+ ActType_Relu1 = 7, -+ ActType_Softsign = 8, -+ ActType_Softplus = 9, -+ ActType_Tanh = 10, -+ ActType_Selu = 11, -+ ActType_HSwish = 12, -+ ActType_HSigmoid = 13, -+ ActType_ThresholdRelu = 14, -+ ActType_Linear = 15, -+ ActType_HardTanh = 16, -+ ActType_Sign = 17, -+ ActType_Swish = 18, -+ ActType_Gelu = 19, -+ ActType_FastGelu = 20, -+ ActType_Unknown = 21 -+} ActType; -+typedef enum FmhaType { -+ FmhaType_UnFused, -+ FmhaType_CutlassFix -+} FmhaType; -+ -+typedef struct { -+ size_t batch_size; -+ size_t src_seq_len; -+ size_t tgt_seq_len; -+ size_t head_num; -+ size_t head_size; -+ size_t data_parallel; -+ size_t hidden_size; -+ int rank_num; -+ int rank_id; -+ size_t h_token_num; -+ size_t h_token_num2; -+ cublasGemmAlgo_t algo; -+ cublasHandle_t cublas_handle; -+ cudaStream_t stream; -+ int in_idx; -+ bool eft; -+ FfnBase::allGatherFunc all_gather_func; -+ FfnBase::allReduceSumFunc all_reduce_sum_func; -+ -+ int embedding_size; -+ bool use_past; // use past mode -+ bool query_layer; // check if quary layer -+ int cur_token_id; // current token id id -+ bool incremental_mode; // mode of inference -+ void* k_cache; -+ void* v_cache; -+} CommonParam; -+ -+typedef struct { -+ bool ffn_bias; -+ bool has_beta; -+ bool ffn_fp16; -+ size_t ffn_hidden_size; -+ ActType act_type; -+ size_t expert_num; -+ size_t expert_offset; -+ size_t capacity_factor; -+ bool load_weights; -+ size_t weight_mapping; -+ size_t weight_projection; -+} ffnParam; -+ -+typedef struct { -+ CommonParam* common_param; -+ ffnParam ffn_param; -+ bool is_moe; -+ std::shared_ptr moe; -+} ffnParamRun; -+ -+typedef struct { -+ bool qkv_bias; // ture -+ bool projection_bias; // ture -+ bool is_cross; // false -+ bool position_bias; -+ float scale; -+ size_t qkv_buf; -+ size_t q_buf_2; -+ size_t output1; -+ size_t output2; -+ size_t qk_buf; -+ size_t qkv_buf_2; -+ size_t qkv_buf_3; -+ size_t mha; -+ bool mask; -+ FmhaType fmha_type; -+ int* padding_offset; -+ int* d_sequence_length; -+ int* padding_offset2; -+ int* d_sequence_length2; -+} attentionParam; -+ -+typedef struct { -+ CommonParam* common_param; -+ attentionParam attn; -+} attentionParamRun; -+ -+typedef struct { -+ float eps1; -+ float eps2; -+ float eps3; -+ float eps4; -+ bool layernorm_post; -+ bool has_beta; -+ size_t normed_from_tensor_buf; -+ size_t attn_out_buf; -+ size_t attn_ws_buf; -+ size_t attn2_out_buf; -+ size_t attn2_ws_buf; -+ size_t tmp_out_buf; -+ size_t ffn_ws_buf; -+ size_t normed_attn_out_buf; -+ size_t normed_attn2_out_buf; -+ size_t compress_buf; -+ size_t d_token_num_buf; -+ size_t padding_offset_buf; -+ size_t d_sequence_lengths_offset_buf; -+ size_t compress_buf2; -+ size_t d_token_num_buf2; -+ size_t padding_offset_buf2; -+ size_t d_sequence_lengths_offset_buf2; -+ bool is_layernorm; -+} decoderParam; -+ -+typedef struct { -+ CommonParam common_param; -+ attentionParamRun attn1; -+ attentionParamRun attn2; -+ ffnParamRun ffn_param; -+ decoderParam decoder; -+} decoderParamRun; -+ -+typedef struct { -+ float eps1; -+ float eps2; -+ float eps3; -+ bool layernorm_post; -+ bool has_beta; -+ size_t normed_from_tensor_buf; -+ size_t attn_ws_buf; -+ size_t attn_out_buf; -+ size_t normed_attn_out_buf; -+ size_t ffn_ws_buf; -+ size_t tmp_out_buf; -+ size_t tmp_out1_buf; -+ size_t compress_buf; -+ size_t d_token_num_buf; -+ size_t padding_offset_buf; -+ size_t d_sequence_lengths_offset_buf; -+ size_t norm_out_buf; -+ bool is_layernorm; -+} encoderParam; -+ -+typedef struct { -+ CommonParam common_param; -+ ffnParamRun ffn_param; -+ attentionParamRun attn; -+ encoderParam encoder; -+} encoderParamRun; -+} // namespace fastertransformer -\ No newline at end of file -diff --git a/src/fastertransformer/models/CMakeLists.txt b/src/fastertransformer/models/CMakeLists.txt -index af33e76..97fc471 100644 ---- a/src/fastertransformer/models/CMakeLists.txt -+++ b/src/fastertransformer/models/CMakeLists.txt -@@ -21,8 +21,11 @@ add_subdirectory(xlnet) - - add_subdirectory(t5) - add_subdirectory(gptj) --add_subdirectory(multi_gpu_gpt) -+if(EXAMPLES) -+ add_subdirectory(multi_gpu_gpt) -+endif() - add_subdirectory(swin) - add_subdirectory(swin_int8) - add_subdirectory(vit) --add_subdirectory(vit_int8) -\ No newline at end of file -+add_subdirectory(vit_int8) -+add_subdirectory(ms) -\ No newline at end of file -diff --git a/src/fastertransformer/models/bert/Bert.cc b/src/fastertransformer/models/bert/Bert.cc -index ac727df..0682288 100644 ---- a/src/fastertransformer/models/bert/Bert.cc -+++ b/src/fastertransformer/models/bert/Bert.cc -@@ -255,7 +255,7 @@ void Bert::forward(std::vector* output_tensors, - switch (attention_type_) { - case AttentionType::UNFUSED_MHA: { - invokeBuildEncoderAttentionMask( -- attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); -+ attention_mask_, sequence_lengths, sequence_lengths, request_batch_size, request_seq_len, request_seq_len, 0, stream_); - sync_check_cuda_error(); - invokeGetPaddingOffset(&h_token_num, - token_num_, -@@ -281,7 +281,7 @@ void Bert::forward(std::vector* output_tensors, - } - case AttentionType::UNFUSED_PADDED_MHA: { - invokeBuildEncoderAttentionMask( -- attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); -+ attention_mask_, sequence_lengths, sequence_lengths, request_batch_size, request_seq_len, request_seq_len, 0, stream_); - sync_check_cuda_error(); - h_token_num = request_batch_size * request_seq_len; - bert_input_ptr = (T*)input_tensors->at(0).data; -diff --git a/src/fastertransformer/models/bert_int8/BertINT8.cc b/src/fastertransformer/models/bert_int8/BertINT8.cc -index 7c6347b..5f374ee 100644 ---- a/src/fastertransformer/models/bert_int8/BertINT8.cc -+++ b/src/fastertransformer/models/bert_int8/BertINT8.cc -@@ -180,7 +180,7 @@ void BertINT8::forward(std::vector* output_tensors, - switch (attention_type_) { - case AttentionType::UNFUSED_MHA: { - invokeBuildEncoderAttentionMask( -- attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); -+ attention_mask_, sequence_lengths, sequence_lengths, request_batch_size, request_seq_len, request_seq_len, 0, stream_); - sync_check_cuda_error(); - invokeGetPaddingOffset(&h_token_num, - token_num_, -@@ -206,7 +206,7 @@ void BertINT8::forward(std::vector* output_tensors, - } - case AttentionType::UNFUSED_PADDED_MHA: { - invokeBuildEncoderAttentionMask( -- attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); -+ attention_mask_, sequence_lengths, sequence_lengths, request_batch_size, request_seq_len, request_seq_len, 0,stream_); - sync_check_cuda_error(); - h_token_num = request_batch_size * request_seq_len; - bert_input_ptr = (T*)input_tensors->at(0).data; -diff --git a/src/fastertransformer/models/gptj/CMakeLists.txt b/src/fastertransformer/models/gptj/CMakeLists.txt -index d7d9d3e..e69a988 100644 ---- a/src/fastertransformer/models/gptj/CMakeLists.txt -+++ b/src/fastertransformer/models/gptj/CMakeLists.txt -@@ -19,6 +19,7 @@ set_property(TARGET GptJDecoderLayerWeight PROPERTY POSITION_INDEPENDENT_CODE O - set_property(TARGET GptJDecoderLayerWeight PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(GptJDecoderLayerWeight PUBLIC memory_utils) - -+if(off) - add_library(GptJDecoder STATIC GptJDecoder.cc) - set_property(TARGET GptJDecoder PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET GptJDecoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -@@ -40,12 +41,14 @@ target_link_libraries(GptJContextDecoder PUBLIC -lcudart cublasMMWrapper - add_residual_kernels - gpt_kernels - nccl_utils) -+endif() - - add_library(GptJWeight STATIC GptJWeight.cc) - set_property(TARGET GptJWeight PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET GptJWeight PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(GptJWeight PUBLIC GptJDecoderLayerWeight) - -+if(off) - add_library(GptJ STATIC GptJ.cc) - set_property(TARGET GptJ PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET GptJ PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -@@ -58,3 +61,4 @@ target_link_libraries(GptJ PUBLIC -lcudart - BaseBeamSearchLayer - bert_preprocess_kernels - GptJWeight) -+endif() -\ No newline at end of file -diff --git a/src/fastertransformer/models/gptj/GptJ.cc b/src/fastertransformer/models/gptj/GptJ.cc -index 0829e0d..fe41d4b 100644 ---- a/src/fastertransformer/models/gptj/GptJ.cc -+++ b/src/fastertransformer/models/gptj/GptJ.cc -@@ -665,7 +665,7 @@ void GptJ::forward(std::unordered_map* output_tensors, - logits_buf_ + vocab_size_units_offset, - CUDA_R_32F, - vocab_size_padded_, /* n */ -- CUDA_R_32F, -+ CUBLAS_COMPUTE_32F_FAST_TF32, - cublasGemmAlgo_t(-1)); - } - else { -@@ -691,7 +691,7 @@ void GptJ::forward(std::unordered_map* output_tensors, - + tensor_para_.rank_ * local_batch_size * beam_width * local_vocab_size, - CUDA_R_32F, - local_vocab_size, /* n */ -- CUDA_R_32F, -+ CUBLAS_COMPUTE_32F_FAST_TF32, - cublasGemmAlgo_t(-1)); - ftNcclAllGather(nccl_logits_buf_ + vocab_size_units_offset, - nccl_logits_buf_ + vocab_size_units_offset, -@@ -928,7 +928,7 @@ void GptJ::forward(std::unordered_map* output_tensors, - if (output_tensors->count("output_log_probs") > 0 - && output_tensors->at("output_log_probs").data != nullptr) { - ftNcclSend(output_tensors->at("output_log_probs").getPtr(), -- batch_size * beam_width * input_tensors->at("max_output_seq_len").getVal(), -+ output_tensors->at("output_log_probs").size(), - 0, - pipeline_para_, - stream_); -@@ -958,7 +958,7 @@ void GptJ::forward(std::unordered_map* output_tensors, - if (output_tensors->count("output_log_probs") > 0 - && output_tensors->at("output_log_probs").data != nullptr) { - ftNcclRecv(output_tensors->at("output_log_probs").getPtr(), -- batch_size * beam_width * input_tensors->at("max_output_seq_len").getVal(), -+ output_tensors->at("output_log_probs").size(), - pipeline_para_.world_size_ - 1, - pipeline_para_, - stream_); -diff --git a/src/fastertransformer/models/ms/CMakeLists.txt b/src/fastertransformer/models/ms/CMakeLists.txt -new file mode 100644 -index 0000000..8a99ce4 ---- /dev/null -+++ b/src/fastertransformer/models/ms/CMakeLists.txt -@@ -0,0 +1,19 @@ -+# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+cmake_minimum_required(VERSION 3.8) -+ -+add_executable(ms_gemm main.cc) -+# target_link_libraries(ms_gemm PUBLIC -lcudart encoder_gemm_func encoder_igemm_func memory_utils) -+target_link_libraries(ms_gemm PUBLIC -lcudart ms_gemm_func memory_utils) -diff --git a/src/fastertransformer/models/ms/main.cc b/src/fastertransformer/models/ms/main.cc -new file mode 100644 -index 0000000..cd5844f ---- /dev/null -+++ b/src/fastertransformer/models/ms/main.cc -@@ -0,0 +1,179 @@ -+/* -+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#include "src/fastertransformer/utils/gemm_test/ms_gemm_func.h" -+#include "src/fastertransformer/utils/memory_utils.h" -+ -+namespace ft = fastertransformer; -+ -+struct ms_opt_arg { -+ size_t batch_size; -+ size_t num_layers; -+ size_t seq_len; // source seq len -+ size_t tgt_seq_len; -+ size_t head_num; -+ size_t hidden_size; -+ size_t size_per_head; -+ bool is_remove_padding; -+ int m; -+ int n; -+ int k; -+ std::string model_name; -+ std::string compute_type; -+ std::string w_compute_type; -+ std::string s_compute_type; -+}; -+ -+void usage() { -+ std::cout << "Usage: ms_benchmark -b -l -t " -+ << "-s -H -S -p " -+ << "-T -W -F " -+ << "-m -c -M -N -K \n"; -+} -+ -+bool read_args(int argc, char* argv[], ms_opt_arg* opt_a) { -+ int opt; -+ while ((opt = getopt(argc, argv, "b:l:s:t:H:S:p:m:T:W:F:i:w:M:N:K:")) != -1) { -+ switch (opt) { -+ case 'b': -+ opt_a->batch_size = atoi(optarg); -+ break; -+ case 'l': -+ opt_a->num_layers = atoi(optarg); -+ break; -+ case 's': -+ opt_a->seq_len = atoi(optarg); -+ break; -+ case 't': -+ opt_a->tgt_seq_len = atoi(optarg); -+ break; -+ case 'H': -+ opt_a->head_num = atoi(optarg); -+ break; -+ case 'S': -+ opt_a->hidden_size = atoi(optarg); -+ break; -+ case 'p': -+ opt_a->is_remove_padding = static_cast(atoi(optarg)); -+ break; -+ case 'm': -+ opt_a->model_name = std::string(optarg); -+ break; -+ case 'T': -+ opt_a->compute_type = std::string(optarg); -+ break; -+ case 'W': -+ opt_a->w_compute_type = std::string(optarg); -+ break; -+ case 'F': -+ opt_a->s_compute_type = std::string(optarg); -+ break; -+ case 'M': -+ opt_a->m = atoi(optarg); -+ break; -+ case 'N': -+ opt_a->n = atoi(optarg); -+ break; -+ case 'K': -+ opt_a->k = atoi(optarg); -+ break; -+ case 'i': -+ case 'w': -+ break; -+ case 'h': -+ default: -+ usage(); -+ return false; -+ } -+ } -+ opt_a->size_per_head = opt_a->hidden_size / opt_a->head_num; -+ opt_a->tgt_seq_len = (opt_a->tgt_seq_len == -1) ? opt_a->seq_len : opt_a->tgt_seq_len; -+ return true; -+} -+ -+int main(int argc, char* argv[]) -+{ -+ ms_opt_arg opt_a; -+ opt_a.batch_size = 1; -+ opt_a.num_layers = 1; -+ opt_a.seq_len = 1; -+ opt_a.tgt_seq_len = -1; -+ opt_a.head_num = 1; -+ opt_a.hidden_size = 1; -+ opt_a.size_per_head = 1; -+ opt_a.is_remove_padding = false; -+ opt_a.m = 1; -+ opt_a.n = 1; -+ opt_a.k = 1; -+ opt_a.model_name = ""; -+ opt_a.compute_type = "fp32"; -+ opt_a.w_compute_type = "fp32"; -+ opt_a.s_compute_type = "fp32"; -+ -+ if (!read_args(argc, argv, &opt_a)) { -+ printf("[ERROR] Failed to read arguments. \n"); -+ usage(); -+ return 0; -+ } -+ -+ bool c_type_fp32 = (opt_a.compute_type.compare("fp32") == 0); -+ std::cout << "[INFO] arguments: " << std::endl; -+ std::cout << " batch_size: " << opt_a.batch_size << std::endl; -+ std::cout << " num of layers: " << opt_a.num_layers << std::endl; -+ std::cout << " seq len:" << opt_a.seq_len << std::endl; -+ std::cout << " target seq len: " << opt_a.tgt_seq_len << std::endl; -+ std::cout << " head_num: " << opt_a.head_num << std::endl; -+ std::cout << " size_per_head: " << opt_a.size_per_head << std::endl; -+ // std::cout << " compute_type: " << c_type_fp32 << std::endl; -+ -+ std::cout << std::endl; -+ -+ const int inter_size = 4 * opt_a.head_num * opt_a.size_per_head; -+ const ft::CublasDataType data_type = static_cast(0); // 0 FP32, 1 FP16, 2 BF 16 -+ void* gemm_test_buf; -+ size_t buf_size_in_byte = ft::calGemmTestBufSizeInByte(opt_a.batch_size, -+ opt_a.seq_len, -+ opt_a.head_num, -+ opt_a.size_per_head, -+ inter_size, -+ 0, // default -+ 0, // default -+ data_type); -+ -+ size_t total, free; -+ ft::check_cuda_error(cudaMemGetInfo(&free, &total)); -+ if (free < buf_size_in_byte + 10 * 1024 * 1024) { -+ printf("[ERROR] There is no enough device memory for gemm test!\n" -+ " %ld Bytes is needed, but only %ld Bytes is free.\n", -+ buf_size_in_byte, -+ free); -+ gemm_test_buf = NULL; -+ return -1; -+ } else { -+ ft::deviceMalloc(reinterpret_cast(&gemm_test_buf), buf_size_in_byte, false); -+ } -+ // int fast_algo = 0; -+ if (data_type == ft::FLOAT_DATATYPE) { -+ ft::generate_ms_gemm_config(opt_a.batch_size, opt_a.seq_len, opt_a.tgt_seq_len, opt_a.head_num, opt_a.size_per_head, gemm_test_buf, -+ false); -+ } else { -+ printf("[ERROR] data type only supports fp32(0). \n"); -+ return -1; -+ } -+ // std::cout << "main fast algo: " << fast_algo << std::endl; -+ ft::check_cuda_error(cudaFree(gemm_test_buf)); -+ return 0; -+} -\ No newline at end of file -diff --git a/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt b/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt -index 10b9e0b..86d733f 100644 ---- a/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt -+++ b/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt -@@ -37,7 +37,7 @@ set_property(TARGET ParallelGptDecoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(ParallelGptDecoder PUBLIC -lcudart TensorParallelGeluFfnLayer - TensorParallelDecoderSelfAttentionLayer layernorm_kernels - add_residual_kernels nccl_utils) -- -+ - add_library(ParallelGpt STATIC ParallelGpt.cc) - set_property(TARGET ParallelGpt PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET ParallelGpt PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc -index 17f9099..d171b4b 100644 ---- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc -+++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc -@@ -345,7 +345,7 @@ void ParallelGpt::computeContextCumLogProbs(float* cum_log_probs, - lp_logits_buf_, - CUDA_R_32F, - vocab_size_padded_, /* n */ -- CUDA_R_32F, -+ CUBLAS_COMPUTE_32F_FAST_TF32, - cublasGemmAlgo_t(-1)); - sync_check_cuda_error(); - } -@@ -370,7 +370,7 @@ void ParallelGpt::computeContextCumLogProbs(float* cum_log_probs, - lp_nccl_logits_buf_ + tensor_para_.rank_ * n_hidden_states * local_vocab_size, - CUDA_R_32F, - local_vocab_size, /* n */ -- CUDA_R_32F, -+ CUBLAS_COMPUTE_32F_FAST_TF32, - cublasGemmAlgo_t(-1)); - sync_check_cuda_error(); - ftNcclAllGather(lp_nccl_logits_buf_, -@@ -803,7 +803,7 @@ void ParallelGpt::forward(std::unordered_map* output_ten - logits_buf_ + vocab_size_units_offset, - CUDA_R_32F, - vocab_size_padded_, /* n */ -- CUDA_R_32F, -+ CUBLAS_COMPUTE_32F_FAST_TF32, - cublasGemmAlgo_t(-1)); - } - else { -@@ -829,7 +829,7 @@ void ParallelGpt::forward(std::unordered_map* output_ten - + tensor_para_.rank_ * local_batch_size * beam_width * local_vocab_size, - CUDA_R_32F, - local_vocab_size, /* n */ -- CUDA_R_32F, -+ CUBLAS_COMPUTE_32F_FAST_TF32, - cublasGemmAlgo_t(-1)); - ftNcclAllGather(nccl_logits_buf_ + vocab_size_units_offset, - nccl_logits_buf_ + vocab_size_units_offset, -@@ -1057,7 +1057,7 @@ void ParallelGpt::forward(std::unordered_map* output_ten - if (output_tensors->count("output_log_probs") > 0 - && output_tensors->at("output_log_probs").data != nullptr) { - ftNcclSend(output_tensors->at("output_log_probs").getPtr(), -- batch_size * beam_width * input_tensors->at("max_output_seq_len").getVal(), -+ output_tensors->at("output_log_probs").size(), - 0, - pipeline_para_, - stream_); -@@ -1087,7 +1087,7 @@ void ParallelGpt::forward(std::unordered_map* output_ten - if (output_tensors->count("output_log_probs") > 0 - && output_tensors->at("output_log_probs").data != nullptr) { - ftNcclRecv(output_tensors->at("output_log_probs").getPtr(), -- batch_size * beam_width * input_tensors->at("max_output_seq_len").getVal(), -+ output_tensors->at("output_log_probs").size(), - pipeline_para_.world_size_ - 1, - pipeline_para_, - stream_); -diff --git a/src/fastertransformer/models/t5/CMakeLists.txt b/src/fastertransformer/models/t5/CMakeLists.txt -index 9f3455d..e75bbbd 100644 ---- a/src/fastertransformer/models/t5/CMakeLists.txt -+++ b/src/fastertransformer/models/t5/CMakeLists.txt -@@ -14,6 +14,7 @@ - - cmake_minimum_required(VERSION 3.8) - -+if(False) - add_library(T5Decoder STATIC T5Decoder.cc T5DecoderLayerWeight.cc) - set_property(TARGET T5Decoder PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET T5Decoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -@@ -21,6 +22,7 @@ target_link_libraries(T5Decoder PUBLIC -lcudart cublasMMWrapper TensorParallelDe - TensorParallelDecoderCrossAttentionLayer TensorParallelReluFfnLayer - layernorm_kernels add_residual_kernels nccl_utils memory_utils) - -+ - add_library(T5Decoding STATIC T5Decoding.cc T5DecodingWeight.cc) - set_property(TARGET T5Decoding PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET T5Decoding PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -@@ -28,6 +30,8 @@ target_link_libraries(T5Decoding PUBLIC -lcudart cublasMMWrapper T5Decoder bert_ - decoding_kernels DynamicDecodeLayer BaseBeamSearchLayer - beam_search_topk_kernels gpt_kernels) - -+ -+ - add_library(T5Encoder STATIC T5Encoder.cc T5EncoderWeight.cc T5EncoderLayerWeight.cc) - set_property(TARGET T5Encoder PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET T5Encoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -@@ -36,4 +40,5 @@ target_link_libraries(T5Encoder PUBLIC -lcudart bert_preprocess_kernels cublasMM - TensorParallelGeluFfnLayer layernorm_kernels add_residual_kernels nccl_utils) - - add_executable(t5_gemm t5_gemm.cc) --target_link_libraries(t5_gemm PUBLIC -lcudart t5_gemm_func memory_utils) -\ No newline at end of file -+target_link_libraries(t5_gemm PUBLIC -lcudart t5_gemm_func memory_utils) -+endif() -\ No newline at end of file -diff --git a/src/fastertransformer/models/t5/T5Encoder.cc b/src/fastertransformer/models/t5/T5Encoder.cc -index 698e3d6..db989ff 100644 ---- a/src/fastertransformer/models/t5/T5Encoder.cc -+++ b/src/fastertransformer/models/t5/T5Encoder.cc -@@ -380,7 +380,7 @@ void T5Encoder::forward(std::unordered_map* output_tenso - request_seq_len, - request_seq_len, - local_batch_size, -- hidden_units_, -+ d_model_, - stream_); - } - else { -diff --git a/src/fastertransformer/models/vit/ViT.cc b/src/fastertransformer/models/vit/ViT.cc -index e785f2b..9a967e4 100644 ---- a/src/fastertransformer/models/vit/ViT.cc -+++ b/src/fastertransformer/models/vit/ViT.cc -@@ -415,7 +415,7 @@ bool ViTTransformer::setSeqLenVec(size_t batch_size) - template - void ViTTransformer::setDefaultMask(size_t batch_size) - { -- invokeBuildEncoderAttentionMask(mask_buf_, seq_len_vec_, batch_size, max_seq_len_, stream_); -+ invokeBuildEncoderAttentionMask(mask_buf_, seq_len_vec_, seq_len_vec_, batch_size, max_seq_len_, max_seq_len_, 0, stream_); - } - - template -diff --git a/src/fastertransformer/models/vit_int8/ViTINT8.cc b/src/fastertransformer/models/vit_int8/ViTINT8.cc -index f610785..44fc5fc 100644 ---- a/src/fastertransformer/models/vit_int8/ViTINT8.cc -+++ b/src/fastertransformer/models/vit_int8/ViTINT8.cc -@@ -462,7 +462,7 @@ bool ViTTransformerINT8::setSeqLenVec(size_t batch_size) - template - void ViTTransformerINT8::setDefaultMask(size_t batch_size) - { -- invokeBuildEncoderAttentionMask(mask_buf_, seq_len_vec_, batch_size, max_seq_len_, stream_); -+ invokeBuildEncoderAttentionMask(mask_buf_, seq_len_vec_, seq_len_vec_, batch_size, max_seq_len_, max_seq_len_, 0, stream_); - } - - template -diff --git a/src/fastertransformer/utils/CMakeLists.txt b/src/fastertransformer/utils/CMakeLists.txt -index 3d0f28a..3d2efbd 100644 ---- a/src/fastertransformer/utils/CMakeLists.txt -+++ b/src/fastertransformer/utils/CMakeLists.txt -@@ -44,10 +44,12 @@ set_property(TARGET memory_utils PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET memory_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(memory_utils PUBLIC -lnvToolsExt) - -+if(EXAMPLES) - add_library(nccl_utils STATIC nccl_utils.cc) - set_property(TARGET nccl_utils PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET nccl_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(nccl_utils PUBLIC -lnccl) -+endif() - - add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc) - set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) -diff --git a/src/fastertransformer/utils/cublasMMWrapper.cc b/src/fastertransformer/utils/cublasMMWrapper.cc -index e291151..e0c6d20 100644 ---- a/src/fastertransformer/utils/cublasMMWrapper.cc -+++ b/src/fastertransformer/utils/cublasMMWrapper.cc -@@ -99,7 +99,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, - void* C, - cudaDataType_t Ctype, - int ldc, -- cudaDataType_t computeType, -+ cublasComputeType_t computeType, - cublasGemmAlgo_t algo) - { - mu_->lock(); -@@ -160,7 +160,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, - - mu_->lock(); - // TODO: default cublas libs -- int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; -+ int is_fp16_computeType = computeType_ == CUBLAS_COMPUTE_16F ? 1 : 0; - bool using_cublasLt = (Atype_ == CUDA_R_16F) ? true : false; - int batch_count = 1; - // fp32 use cublas as default -@@ -187,14 +187,14 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, - #if (CUDART_VERSION >= 11000) - cublasComputeType_t computeType; - #else -- cudaDataType_t computeType; -+ cublasComputeType_t computeType; - #endif - - if (is_fp16_computeType) { - #if (CUDART_VERSION >= 11000) - computeType = CUBLAS_COMPUTE_16F; - #else -- computeType = CUDA_R_16F; -+ computeType = CUBLAS_COMPUTE_16F; - #endif - scaleType = CUDA_R_16F; - } -@@ -302,7 +302,7 @@ void cublasMMWrapper::setFP32GemmConfig() - Atype_ = CUDA_R_32F; - Btype_ = CUDA_R_32F; - Ctype_ = CUDA_R_32F; -- computeType_ = CUDA_R_32F; -+ computeType_ = CUBLAS_COMPUTE_32F_FAST_TF32; - } - - void cublasMMWrapper::setFP16GemmConfig() -@@ -310,7 +310,23 @@ void cublasMMWrapper::setFP16GemmConfig() - Atype_ = CUDA_R_16F; - Btype_ = CUDA_R_16F; - Ctype_ = CUDA_R_16F; -- computeType_ = CUDA_R_32F; -+ computeType_ = CUBLAS_COMPUTE_16F; -+} -+ -+void cublasMMWrapper::setFP32MixedGemmConfig() -+{ -+ Atype_ = CUDA_R_32F; -+ Btype_ = CUDA_R_16F; -+ Ctype_ = CUDA_R_32F; -+ computeType_ = CUBLAS_COMPUTE_32F_FAST_TF32; -+} -+ -+void cublasMMWrapper::setFP16MixedGemmConfig() -+{ -+ Atype_ = CUDA_R_16F; -+ Btype_ = CUDA_R_32F; -+ Ctype_ = CUDA_R_32F; -+ computeType_ = CUBLAS_COMPUTE_32F_FAST_TF32; - } - - #ifdef ENABLE_BF16 -@@ -319,14 +335,14 @@ void cublasMMWrapper::setBF16GemmConfig() - Atype_ = CUDA_R_16BF; - Btype_ = CUDA_R_16BF; - Ctype_ = CUDA_R_16BF; -- computeType_ = CUDA_R_32F; -+ computeType_ = CUBLAS_COMPUTE_16F; - } - #endif - - void cublasMMWrapper::setGemmConfig(cudaDataType_t aType, - cudaDataType_t bType, - cudaDataType_t cType, -- cudaDataType_t computeType) -+ cublasComputeType_t computeType) - { - Atype_ = aType; - Btype_ = bType; -@@ -451,7 +467,7 @@ void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, - half h_beta = (half)f_beta; - - mu_->lock(); -- int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; -+ int is_fp16_computeType = computeType_ == CUBLAS_COMPUTE_16F ? 1 : 0; - const void* alpha = - is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); -@@ -504,13 +520,13 @@ void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, - const int ldc, - const int64_t strideC, - const int batch_count, -- cudaDataType_t computeType) -+ cublasComputeType_t computeType) - { - half h_alpha = (half)f_alpha; - half h_beta = (half)f_beta; - - mu_->lock(); -- int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0; -+ int is_fp16_computeType = computeType == CUBLAS_COMPUTE_16F ? 1 : 0; - const void* alpha = - is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); -@@ -563,7 +579,7 @@ void cublasMMWrapper::batchedGemm(cublasOperation_t transa, - half h_beta = (half)0.0f; - - mu_->lock(); -- int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; -+ int is_fp16_computeType = computeType_ == CUBLAS_COMPUTE_16F ? 1 : 0; - const void* alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); - cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_)); -diff --git a/src/fastertransformer/utils/cublasMMWrapper.h b/src/fastertransformer/utils/cublasMMWrapper.h -index 6f410ab..21a8ea8 100644 ---- a/src/fastertransformer/utils/cublasMMWrapper.h -+++ b/src/fastertransformer/utils/cublasMMWrapper.h -@@ -41,7 +41,7 @@ private: - cudaDataType_t Atype_; - cudaDataType_t Btype_; - cudaDataType_t Ctype_; -- cudaDataType_t computeType_; -+ cublasComputeType_t computeType_; - - cudaStream_t stream_; - cublasAlgoMap* cublas_algo_map_; -@@ -90,7 +90,7 @@ public: - void* C, - cudaDataType_t Ctype, - int ldc, -- cudaDataType_t computeType, -+ cublasComputeType_t computeType, - cublasGemmAlgo_t algo); - - void Gemm(cublasOperation_t transa, -@@ -121,12 +121,14 @@ public: - - void setFP32GemmConfig(); - void setFP16GemmConfig(); -+ void setFP32MixedGemmConfig(); -+ void setFP16MixedGemmConfig(); - #ifdef ENABLE_BF16 - void setBF16GemmConfig(); - #endif - void setStream(cudaStream_t stream); - -- void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType); -+ void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cublasComputeType_t computeType); - - CublasDataType getCublasDataType(cudaDataType_t data_type); - -@@ -183,7 +185,7 @@ public: - const int ldc, - const int64_t strideC, - const int batch_count, -- cudaDataType_t computeType); -+ cublasComputeType_t computeType); - - void batchedGemm(cublasOperation_t transa, - cublasOperation_t transb, -diff --git a/src/fastertransformer/utils/cuda_utils.h b/src/fastertransformer/utils/cuda_utils.h -index 5d73c87..aef6ab9 100644 ---- a/src/fastertransformer/utils/cuda_utils.h -+++ b/src/fastertransformer/utils/cuda_utils.h -@@ -382,7 +382,7 @@ public: - - static double diffTime(timeval start, timeval end) - { -- return (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001; -+ return (end.tv_sec - start.tv_sec) * 1000000 + (end.tv_usec - start.tv_usec); - } - - /* ***************************** common utils ****************************** */ -diff --git a/src/fastertransformer/utils/custom_ar_comm.cc b/src/fastertransformer/utils/custom_ar_comm.cc -index ded1e58..159faaf 100644 ---- a/src/fastertransformer/utils/custom_ar_comm.cc -+++ b/src/fastertransformer/utils/custom_ar_comm.cc -@@ -54,6 +54,7 @@ void CustomAllReduceComm::customAllReduce(size_t elts, cudaStream_t stream) - output_tensor_->at(0).data = (const void*)tmp_tensor_data_; - } - -+ - template - void CustomAllReduceComm::allocateAndExchangePeerAccessPointer( - std::vector>* custom_all_reduce_comms) -diff --git a/src/fastertransformer/utils/gemm_test/CMakeLists.txt b/src/fastertransformer/utils/gemm_test/CMakeLists.txt -index 223b85d..ab48356 100644 ---- a/src/fastertransformer/utils/gemm_test/CMakeLists.txt -+++ b/src/fastertransformer/utils/gemm_test/CMakeLists.txt -@@ -49,6 +49,10 @@ set(swin_gemm_func_files - swin_gemm_func.cc - ) - -+set(ms_gemm_func_files -+ ms_gemm_func.cc -+) -+ - add_library(gemm_func STATIC ${gemm_func_files}) - target_link_libraries(gemm_func PUBLIC -lcublas -lcublasLt -lcudart) - set_property(TARGET gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) -@@ -109,3 +113,12 @@ add_library(swin_gemm_func STATIC ${swin_gemm_func_files}) - target_link_libraries(swin_gemm_func PUBLIC -lcublas -lcublasLt -lcudart gemm_func) - set_property(TARGET swin_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET swin_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -+ -+add_library(ms_gemm_func STATIC ${ms_gemm_func_files}) -+if (SPARSITY_SUPPORT) -+target_link_libraries(ms_gemm_func PUBLIC -lcublas -lcublasLt -lcudart gemm_func -lcusparse -lcusparseLt) -+else() -+target_link_libraries(ms_gemm_func PUBLIC -lcublas -lcublasLt -lcudart gemm_func) -+endif() -+set_property(TARGET ms_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) -+set_property(TARGET ms_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -\ No newline at end of file -diff --git a/src/fastertransformer/utils/gemm_test/encoder_gemm_func.cc b/src/fastertransformer/utils/gemm_test/encoder_gemm_func.cc -index 03c6947..00f8ca0 100644 ---- a/src/fastertransformer/utils/gemm_test/encoder_gemm_func.cc -+++ b/src/fastertransformer/utils/gemm_test/encoder_gemm_func.cc -@@ -26,11 +26,11 @@ void generate_encoder_gemm_config( - void* buffer; - int workSpaceSize; - --#ifdef ENABLE_BF16 -- if (std::is_same::value || std::is_same::value) { --#else -+// #ifdef ENABLE_BF16 -+// if (std::is_same::value || std::is_same::value) { -+// #else - if (std::is_same::value) { --#endif // ENABLE_BF16 -+// #endif // ENABLE_BF16 - // cublas_workspace_ should be the start pointer of cudaMalloc() - // to ensure 16B alignemnet - cublas_workspace = buffer_in; -diff --git a/src/fastertransformer/utils/gemm_test/encoder_gemm_func.h b/src/fastertransformer/utils/gemm_test/encoder_gemm_func.h -index fd067b9..4bf3d6c 100644 ---- a/src/fastertransformer/utils/gemm_test/encoder_gemm_func.h -+++ b/src/fastertransformer/utils/gemm_test/encoder_gemm_func.h -@@ -36,5 +36,4 @@ namespace fastertransformer { - template - void generate_encoder_gemm_config( - int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend = true); -- - } // namespace fastertransformer -diff --git a/src/fastertransformer/utils/gemm_test/gemm_func.cc b/src/fastertransformer/utils/gemm_test/gemm_func.cc -index edbfc40..6187d45 100644 ---- a/src/fastertransformer/utils/gemm_test/gemm_func.cc -+++ b/src/fastertransformer/utils/gemm_test/gemm_func.cc -@@ -534,7 +534,6 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, - } - } - } -- - // workspacesize==0 - printf("workspacesize==0, run %d algos\n", AlgoCountRestrict); - for (int i = 0; i < AlgoCountRestrict && i < (maxNumTraversal - nbAlgoIds); i++) { -@@ -594,7 +593,8 @@ CLEANUP: - if (stopEvent) { - cudaEventDestroy(stopEvent); - } -- return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -+ return AlgoCount; -+ // return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; - } - - template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, -@@ -634,7 +634,6 @@ template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, - FILE* fout, - customMatmulPerf_t perfResults[], - int AlgoCombinations); -- - #ifdef ENABLE_BF16 - template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, - int batch_size, -diff --git a/src/fastertransformer/utils/gemm_test/ms_gemm_func.cc b/src/fastertransformer/utils/gemm_test/ms_gemm_func.cc -new file mode 100644 -index 0000000..e8f88fe ---- /dev/null -+++ b/src/fastertransformer/utils/gemm_test/ms_gemm_func.cc -@@ -0,0 +1,364 @@ -+/* -+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#include "src/fastertransformer/utils/gemm_test/ms_gemm_func.h" -+ -+namespace fastertransformer { -+ -+template -+void generate_ms_gemm_config( -+ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer_in, bool isAppend) -+{ -+ void* cublas_workspace; -+ void* buffer; -+ int workSpaceSize; -+ -+#ifdef ENABLE_BF16 -+ if (std::is_same::value || std::is_same::value) { -+#else -+ if (std::is_same::value) { -+#endif // ENABLE_BF16 -+ // cublas_workspace_ should be the start pointer of cudaMalloc() -+ // to ensure 16B alignemnet -+ cublas_workspace = buffer_in; -+ buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); -+ workSpaceSize = CUBLAS_WORKSPACE_SIZE; -+ } -+ else { -+ cublas_workspace = nullptr; -+ buffer = buffer_in; -+ workSpaceSize = 0; -+ } -+ -+ struct cudaDeviceProp prop; -+ check_cuda_error(cudaGetDeviceProperties(&prop, 0)); -+ printf("Device %s\n", prop.name); -+ -+ // check config -+ FILE* fd; -+ int line_count = 0; -+ if (!isAppend) { -+ fd = fopen(GEMM_CONFIG, "w+"); -+ } -+ else { -+ fd = fopen(GEMM_CONFIG, "a+"); -+ std::vector config; -+ char line[1024]; -+ while (fgets(line, 1024, fd) != NULL) { -+ config.push_back(std::string(line)); -+ } -+ line_count = config.size(); -+ if (config.size() >= (MAX_CONFIG_NUM * GEMM_NUM + 1)) // 6 cublas/cublasLt, first row is not included -+ { -+ int startIdx = config.size() - ((MAX_CONFIG_NUM - 1) * GEMM_NUM); -+ fclose(fd); -+ fd = fopen(GEMM_CONFIG, "w+"); -+ fprintf(fd, "%s", config[0].c_str()); -+ for (uint i = startIdx; i < config.size(); i++) { -+ fprintf(fd, "%s", config[i].c_str()); -+ } -+ line_count = config.size() - (GEMM_NUM + 3); -+ } -+ } -+ -+ const int gemm_num = 4; -+ int M[gemm_num]; -+ int N[gemm_num]; -+ int K[gemm_num]; -+ int batchCount[gemm_num] = {1, 1, 1, 1}; -+ char mess[gemm_num][256]; -+ float exec_times[gemm_num]; -+ int gemm_lds[gemm_num][3]; // = {3 * hidden_size, hidden_size, 3 * hidden_size}; -+ cublasOperation_t gemm_ops[gemm_num][2]; // = {CUBLAS_OP_N, CUBLAS_OP_N}; -+ int gemm_strides[2][3]; -+ -+ // gemm1 -+ // int gemm_dims[] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size}; -+ int hidden_size = head_num * size_per_head; -+ M[0] = 3 * hidden_size; -+ N[0] = batch_size * seq_len; -+ K[0] = hidden_size; -+ gemm_lds[0][0] = 3 * hidden_size; -+ gemm_lds[0][1] = hidden_size; -+ gemm_lds[0][2] = 3 * hidden_size; -+ gemm_ops[0][0] = CUBLAS_OP_N; -+ gemm_ops[0][1] = CUBLAS_OP_N; -+ strcpy(mess[0], "cublasGemmEx "); -+ -+ // gemm2 -+ M[1] = tgt_seq_len; -+ N[1] = seq_len; -+ K[1] = size_per_head; -+ gemm_ops[1][0] = CUBLAS_OP_T; -+ gemm_ops[1][1] = CUBLAS_OP_N; -+ -+ gemm_lds[1][0] = size_per_head; -+ gemm_lds[1][1] = size_per_head; -+ gemm_lds[1][2] = tgt_seq_len; -+ -+ gemm_strides[0][0] = tgt_seq_len * size_per_head; -+ gemm_strides[0][1] = seq_len * size_per_head; -+ gemm_strides[0][2] = seq_len * tgt_seq_len; -+ strcpy(mess[1], "cublasGemmStridedBatchedEx"); -+ -+ // gemm3 -+ M[2] = size_per_head; -+ N[2] = seq_len; -+ K[2] = tgt_seq_len; -+ gemm_ops[2][0] = CUBLAS_OP_N; -+ gemm_ops[2][1] = CUBLAS_OP_N; -+ -+ gemm_lds[2][0] = size_per_head; -+ gemm_lds[2][1] = tgt_seq_len; -+ gemm_lds[2][2] = size_per_head; -+ -+ gemm_strides[1][0] = tgt_seq_len * size_per_head; -+ gemm_strides[1][1] = seq_len * tgt_seq_len; -+ gemm_strides[1][2] = seq_len * size_per_head; -+ strcpy(mess[2], "cublasGemmStridedBatchedEx"); -+ -+ // gemm4 -+ M[3] = hidden_size; -+ N[3] = batch_size * seq_len; -+ K[3] = hidden_size; -+ gemm_ops[3][0] = CUBLAS_OP_N; -+ gemm_ops[3][1] = CUBLAS_OP_N; -+ -+ gemm_lds[3][0] = hidden_size; -+ gemm_lds[3][1] = hidden_size; -+ gemm_lds[3][2] = hidden_size; -+ strcpy(mess[3], "cublasGemmEx"); -+ -+ cublasHandle_t cublas_handle; -+ check_cuda_error(cublasCreate(&cublas_handle)); -+ cublasLtHandle_t ltHandle; -+ check_cuda_error(cublasLtCreate(<Handle)); -+ -+ cudaDataType_t AType; -+ cudaDataType_t BType; -+ cudaDataType_t CType; -+ cublasComputeType_t computeType; -+ int startAlgo, endAlgo; -+ const int ites = 10000; -+ const int warmup_ites = 10000; -+ struct timeval start, end; -+ -+ CublasDataType data_type; -+ if (std::is_same::value) { -+ data_type = FLOAT_DATATYPE; -+ AType = CUDA_R_32F; -+ BType = CUDA_R_32F; -+ CType = CUDA_R_32F; -+ computeType = CUBLAS_COMPUTE_32F_FAST_TF32; -+ startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; -+ } -+ else if (std::is_same::value) { -+ data_type = HALF_DATATYPE; -+ AType = CUDA_R_16F; -+ BType = CUDA_R_16F; -+ CType = CUDA_R_16F; -+ computeType = CUBLAS_COMPUTE_16F; -+ startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; -+ } -+#ifdef ENABLE_BF16 -+ else if (std::is_same::value) { -+ data_type = BFLOAT16_DATATYPE; -+ AType = CUDA_R_16BF; -+ BType = CUDA_R_16BF; -+ CType = CUDA_R_16BF; -+ computeType = CUBLAS_COMPUTE_32F; -+ startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; -+ } -+#endif -+ using scaleT = typename ScaleTypeConverter::Type; -+ -+ scaleT alpha = (scaleT)1.0f; -+ scaleT beta = (scaleT)0.0f; -+ -+ printf("***Encoder Gemm Testing Begin***\n"); -+ printf("***Cublas Gemm Testing Begin***\n"); -+ if (line_count == 0) { -+ fprintf(fd, -+ "batch_size, seq_len, head_num, size_per_head dataType ### batchCount, n, m, k, algoId, " -+ "customOption, tile, numSplitsK, swizzle, reductionScheme, workspaceSize, stages, exec_time\n"); -+ } -+ for (int i = 0; i < gemm_num; ++i) { -+ // if(i != 0 && i != 5) continue; -+ -+ int m = M[i], n = N[i], k = K[i]; -+ printf("\n-----------------------------\n"); -+ printf("GEMM test %d: [M: %d, K: %d, N: %d] %s\n", i, m, k, n, mess[i]); -+ // printf("GEMM test %d: [M: %d, K: %d, N: %d] \n", i, m, k, n); -+ T* d_A = (T*)buffer; -+ T* d_B = d_A + m * k * batchCount[i]; -+ T* d_C = d_B + k * n * batchCount[i]; -+ -+ // array of pointer for batchedGemm -+ T* harray[12]; -+ harray[0] = (T*)buffer; -+ harray[1] = (T*)((char*)buffer + sizeof(T) * m * k); -+ harray[2] = (T*)((char*)buffer + 2 * sizeof(T) * m * k); -+ harray[4] = (T*)((char*)buffer + 3 * sizeof(T) * m * k); -+ harray[5] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + sizeof(T) * k * n); -+ harray[6] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 2 * sizeof(T) * k * n); -+ harray[8] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n); -+ harray[9] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + sizeof(T) * m * n); -+ harray[10] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + 2 * sizeof(T) * m * n); -+ -+ T** darray = 0; -+ check_cuda_error(cudaMalloc((void**)&darray, sizeof(T*) * 12)); -+ cudaMemcpy((void*)darray, (void*)harray, sizeof(T*) * 12, cudaMemcpyHostToDevice); -+ T** dAarray = darray; -+ T** dBarray = darray + 4; -+ T** dCarray = darray + 8; -+ -+ float exec_time = 99999.0f; -+ int fast_algo = 0; -+ -+ // warmup -+ // for (int j = 0; j < ites*10; j++) { -+ // cublasGemmEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], d_B, BType, -+ // gemm_lds[i][1], &beta, d_C, CType, gemm_lds[i][2], computeType, static_cast(0)); -+ // } -+ -+ for (int algo = startAlgo; algo <= endAlgo; algo++) { -+ cublasStatus_t status; -+ //warmup -+ for (int ite = 0; ite < warmup_ites; ++ite) { -+ if ((i == 0) || (i == 3)) { -+ status = cublasGemmEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], d_B, BType, -+ gemm_lds[i][1], &beta, d_C, CType, gemm_lds[i][2], computeType, static_cast(algo)); -+ } else { -+ status = cublasGemmStridedBatchedEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], -+ gemm_strides[i-1][0], d_B, BType, gemm_lds[i][1], gemm_strides[i-1][1], &beta, d_C, CType, -+ gemm_lds[i][2], gemm_strides[i-1][2], batch_size, computeType, static_cast(algo)); -+ } -+ } -+ cudaDeviceSynchronize(); -+ gettimeofday(&start, NULL); -+ if ((i == 0) || (i == 3)) { -+ for (int ite = 0; ite < ites; ++ite) { -+ status = cublasGemmEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], d_B, BType, -+ gemm_lds[i][1], &beta, d_C, CType, gemm_lds[i][2], computeType, static_cast(algo)); -+ } -+ } else { -+ for (int ite = 0; ite < ites; ++ite) { -+ status = cublasGemmStridedBatchedEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], -+ gemm_strides[i-1][0], d_B, BType, gemm_lds[i][1], gemm_strides[i-1][1], &beta, d_C, CType, -+ gemm_lds[i][2], gemm_strides[i-1][2], batch_size, computeType, static_cast(algo)); -+ } -+ } -+ -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ break; -+ } -+ // } -+ cudaDeviceSynchronize(); -+ gettimeofday(&end, NULL); -+ if (status == CUBLAS_STATUS_SUCCESS) { -+ printf("algo_%d costs %.6fms \n", algo, diffTime(start, end) / ites); -+ if (diffTime(start, end) / ites < exec_time) { -+ exec_time = diffTime(start, end) / ites; -+ fast_algo = algo; -+ } -+ } -+ } -+ printf("fast_algo %d costs %.6f ms \n", fast_algo, exec_time); -+ -+ // for fp16 and bf16, we compare cublasLt -+ if (i < 3 && data_type != FLOAT_DATATYPE) { -+ printf("***cublasLt Gemm Testing Beign***\n"); -+ // Let try a fixed number of combinations -+ int ALGO_COMBINATIONS = 5000; -+ customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; -+ LtHgemmCustomFind(ltHandle, -+ batch_size, -+ seq_len, -+ head_num, -+ size_per_head, -+ n, -+ m, -+ k, -+ &alpha, -+ d_B, -+ d_A, -+ &beta, -+ d_C, -+ cublas_workspace, -+ workSpaceSize, -+ fd, -+ perfResults, -+ ALGO_COMBINATIONS); -+ if (perfResults[0].time < exec_time) { -+ printPerfStructure( -+ batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0); -+ exec_time = perfResults[0].time; -+ } -+ else { -+ fprintf(fd, -+ "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 %f\n", -+ batch_size, -+ seq_len, -+ head_num, -+ size_per_head, -+ data_type, -+ batchCount[i], -+ n, -+ m, -+ k, -+ fast_algo, -+ exec_time); -+ } -+ printf("***cublasLt Gemm Testing End***\n"); -+ } -+ else { -+ fprintf(fd, -+ "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 %f\n", -+ batch_size, -+ seq_len, -+ head_num, -+ size_per_head, -+ data_type, -+ batchCount[i], -+ n, -+ m, -+ k, -+ fast_algo, -+ exec_time); -+ } -+ exec_times[i] = exec_time; -+ cudaFree(darray); -+ } -+ printf("***cublas Gemm Testing End***\n\n"); -+ fclose(fd); -+ printf("***Encoder Gemm Testing End***\n"); -+ -+ return; -+} -+ -+template void generate_ms_gemm_config( -+ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer, bool isAppend); -+template void generate_ms_gemm_config( -+ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer, bool isAppend); -+#ifdef ENABLE_BF16 -+template void generate_ms_gemm_config<__nv_bfloat16>( -+ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer, bool isAppend); -+#endif -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/utils/gemm_test/ms_gemm_func.h b/src/fastertransformer/utils/gemm_test/ms_gemm_func.h -new file mode 100644 -index 0000000..c6f68ca ---- /dev/null -+++ b/src/fastertransformer/utils/gemm_test/ms_gemm_func.h -@@ -0,0 +1,40 @@ -+/* -+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+ -+#include "src/fastertransformer/utils/cublasAlgoMap.h" -+#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -+#include "src/fastertransformer/utils/cuda_utils.h" -+#include "src/fastertransformer/utils/gemm_test/gemm_func.h" -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+namespace fastertransformer { -+ -+template -+void generate_ms_gemm_config( -+ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer, bool isAppend = true); -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/utils/logger.h b/src/fastertransformer/utils/logger.h -index bcdf8fa..e3e7007 100644 ---- a/src/fastertransformer/utils/logger.h -+++ b/src/fastertransformer/utils/logger.h -@@ -65,7 +65,7 @@ private: - #else - const Level DEFAULT_LOG_LEVEL = INFO; - #endif -- Level level_ = DEFAULT_LOG_LEVEL; -+ Level level_ = ERROR; // DEFAULT_LOG_LEVEL; - - Logger() - { -diff --git a/tests/unittests/test_gemm.cu b/tests/unittests/test_gemm.cu -index 13719f7..4ecf0bd 100644 ---- a/tests/unittests/test_gemm.cu -+++ b/tests/unittests/test_gemm.cu -@@ -157,7 +157,7 @@ void computeReference(GemmOp transa, - cudaDataType_t atype = (A.type == TYPE_FP16) ? CUDA_R_16F : CUDA_R_32F; - cudaDataType_t btype = (B.type == TYPE_FP16) ? CUDA_R_16F : CUDA_R_32F; - cudaDataType_t ctype = (C.type == TYPE_FP16) ? CUDA_R_16F : CUDA_R_32F; -- cudaDataType_t compute_type = (computeType == TYPE_FP16) ? CUDA_R_16F : CUDA_R_32F; -+ cublasComputeType_t compute_type = (computeType == TYPE_FP16) ? CUBLAS_COMPUTE_16F : CUBLAS_COMPUTE_32F_FAST_TF32; - - cublasHandle_t cublas_handle; - check_cuda_error(cublasCreate(&cublas_handle)); -@@ -391,7 +391,11 @@ void testGemmConsistencyMatmul(size_t m, size_t n, size_t k) { - - cudaDataType_t cuda_dtype = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; - cudaDataType_t cuda_ctype = (DataType::TYPE_FP32 == computeType) ? CUDA_R_32F : CUDA_R_16F; -- cublas_wrapper.setGemmConfig(cuda_dtype, cuda_dtype, cuda_dtype, cuda_ctype); -+ // add culab type -+ cublasComputeType_t cublasComputeType = (DataType::TYPE_FP32 == computeType) ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_16F; -+ cublas_wrapper.setGemmConfig(cuda_dtype, cuda_dtype, cuda_dtype, cublasComputeType); -+ //before change -+ // cublas_wrapper.setGemmConfig(cuda_dtype, cuda_dtype, cuda_dtype, cuda_ctype); - - std::shared_ptr gemm = createGemm(&allocator, stream, false, false); - gemm->setTypes(a_tensor.type, b_tensor.type, c_tensor.type, computeType); -@@ -506,8 +510,12 @@ void testGemmConsistencyBatchedMatmul(size_t m, size_t n, size_t k) { - &allocator); - - cudaDataType_t dtype = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -- cudaDataType_t ctype = (computeType == DataType::TYPE_FP32) ? CUDA_R_32F : CUDA_R_16F; -+ // cudaDataType_t ctype = (computeType == DataType::TYPE_FP32) ? CUDA_R_32F : CUDA_R_16F; -+ // add culab type -+ cublasComputeType_t ctype = (computeType == DataType::TYPE_FP32) ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_16F; - cublas_wrapper.setGemmConfig(dtype, dtype, dtype, ctype); -+ //before change -+ // cublas_wrapper.setGemmConfig(dtype, dtype, dtype, ctype); - - std::shared_ptr gemm = createGemm(&allocator, stream, false, false); - gemm->setTypes(a_type, b_type, c_type, computeType); -@@ -606,8 +614,12 @@ void testGemmConsistencyStridedBatchedMatmul(size_t batch_size, size_t m, size_t - &allocator); - - cudaDataType_t dtype = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -- cudaDataType_t ctype = (computeType == DataType::TYPE_FP32) ? CUDA_R_32F : CUDA_R_16F; -+ // add culab type -+ cublasComputeType_t ctype = (computeType == DataType::TYPE_FP32) ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_16F; - cublas_wrapper.setGemmConfig(dtype, dtype, dtype, ctype); -+ //before change -+ // cudaDataType_t ctype = (computeType == DataType::TYPE_FP32) ? CUDA_R_32F : CUDA_R_16F; -+ // cublas_wrapper.setGemmConfig(dtype, dtype, dtype, ctype); - - std::shared_ptr gemm = createGemm(&allocator, stream, false, false); - gemm->setTypes(a_tensor.type, b_tensor.type, c_tensor.type, computeType); -@@ -647,7 +659,7 @@ void testGemmConsistencyStridedBatchedMatmul(size_t batch_size, size_t m, size_t - ldc, - stridec, - batch_size, -- getCublasDataType(computeType)); -+ getCublasComputeType(computeType)); - - c_tensor.setInvalidValues(); // to guarantee C has invalid data - gemm->stridedBatchedGemm(op_pair.transa, op_pair.transb, m, n, k, diff --git a/third_party/patch/libevent/libevent.patch001 b/third_party/patch/libevent/libevent.patch001 deleted file mode 100644 index ad8a3105de9cf3f0f41e5cbffc463ee3a9d4d355..0000000000000000000000000000000000000000 --- a/third_party/patch/libevent/libevent.patch001 +++ /dev/null @@ -1,17 +0,0 @@ -diff -Npur libevent/CMakeLists.txt libevent-modify/CMakeLists.txt ---- libevent/CMakeLists.txt 2020-07-05 20:02:46.000000000 +0800 -+++ libevent-modify/CMakeLists.txt 2021-04-19 16:36:57.982307500 +0800 -@@ -852,7 +852,7 @@ if (NOT EVENT__DISABLE_OPENSSL) - - list(APPEND SRC_OPENSSL bufferevent_openssl.c) - list(APPEND HDR_PUBLIC include/event2/bufferevent_ssl.h) -- list(APPEND LIB_APPS ${OPENSSL_LIBRARIES}) -+ list(APPEND LIB_APPS ${OPENSSL_LIBRARIES} -ldl) - endif() - - if (NOT EVENT__DISABLE_THREAD_SUPPORT) -diff -Npur libevent/cmake/AddEventLibrary.cmake libevent-modify/cmake/AddEventLibrary.cmake ---- libevent/cmake/AddEventLibrary.cmake 2020-07-05 20:02:46.000000000 +0800 -+++ libevent-modify/cmake/AddEventLibrary.cmake 2021-04-19 16:36:57.982307500 +0800 -@@ -153,1 +153,0 @@ -- INSTALL_NAME_DIR "${CMAKE_INSTALL_PREFIX}/lib" diff --git a/third_party/patch/onednn/0001-fix-user-threadpool-bug.patch b/third_party/patch/onednn/0001-fix-user-threadpool-bug.patch deleted file mode 100644 index f52565177519742077d84e3641444576ba0b00db..0000000000000000000000000000000000000000 --- a/third_party/patch/onednn/0001-fix-user-threadpool-bug.patch +++ /dev/null @@ -1,20 +0,0 @@ -diff --git a/src/common/dnnl_thread.hpp b/src/common/dnnl_thread.hpp -index 342bc3b00..0b9190f9c 100644 ---- a/src/common/dnnl_thread.hpp -+++ b/src/common/dnnl_thread.hpp -@@ -104,10 +104,11 @@ inline int dnnl_get_max_threads() { - def_max_threads - = (int)dnnl::impl::cpu::platform::get_max_threads_to_use(); - assert(def_max_threads > 0); -- // Use the default value if the threadpool-provided is outside the range -- // [1, def_max_threads] -- return tp ? std::min(std::max(1, tp->get_num_threads()), def_max_threads) -- : def_max_threads; -+ -+ // Make user responsible for number of threads provided at execution time. -+ // This relates to the fact that the library may identify `def_max_threads` -+ // incorrectly for a platform. -+ return tp ? std::max(1, tp->get_num_threads()) : def_max_threads; - } - inline int dnnl_in_parallel() { - using namespace dnnl::impl::threadpool_utils; diff --git a/third_party/patch/onednn/0002-fix-pool-nthr-bug.patch b/third_party/patch/onednn/0002-fix-pool-nthr-bug.patch deleted file mode 100644 index d0ecb2f0cfe3bc8b267525ff77dd26e0b05a170d..0000000000000000000000000000000000000000 --- a/third_party/patch/onednn/0002-fix-pool-nthr-bug.patch +++ /dev/null @@ -1,334 +0,0 @@ -diff --git a/src/cpu/nchw_pooling.cpp b/src/cpu/nchw_pooling.cpp -index b678200a1..09736ccae 100644 ---- a/src/cpu/nchw_pooling.cpp -+++ b/src/cpu/nchw_pooling.cpp -@@ -609,10 +609,12 @@ status_t nchw_pooling_bwd_t::execute_backward( - int od_end = min(OD, 1 + (padF + ID - 1) / SD); - - dim_t c_blk = pd()->channel_block_size_; -- int c_blk_tail = C % c_blk; -+ dim_t c_blk_tail = C % c_blk; -+ const int nthr = pd()->nthr_; -+ - if (alg == alg_kind::pooling_max) { -- parallel_nd_ext(0, MB, utils::div_up(C, c_blk), -- [&](int ithr, int, int mb, int cb) { -+ parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk), -+ [&](int ithr, int, dim_t mb, dim_t cb) { - bool is_last_c_block - = c_blk_tail > 0 && (cb + 1) * c_blk > C; - int curr_c_block = is_last_c_block ? c_blk_tail : c_blk; -@@ -649,8 +651,8 @@ status_t nchw_pooling_bwd_t::execute_backward( - diff_src_fp32, src_sp_size * curr_c_block); - }); - } else { -- parallel_nd_ext(0, MB, utils::div_up(C, c_blk), -- [&](int ithr, int, int mb, int cb) { -+ parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk), -+ [&](int ithr, int, dim_t mb, dim_t cb) { - bool is_last_c_block - = c_blk_tail > 0 && (cb + 1) * c_blk > C; - int curr_c_block = is_last_c_block ? c_blk_tail : c_blk; -diff --git a/src/cpu/nchw_pooling.hpp b/src/cpu/nchw_pooling.hpp -index 9d649f3f5..2a73f6ae6 100644 ---- a/src/cpu/nchw_pooling.hpp -+++ b/src/cpu/nchw_pooling.hpp -@@ -139,6 +139,7 @@ struct nchw_pooling_bwd_t : public primitive_t { - ws_md_ = *hint_fwd_pd_->workspace_md(); - } - -+ nthr_ = dnnl_get_max_threads(); - calculate_channel_block_size(); - init_scratchpad(); - -@@ -146,6 +147,7 @@ struct nchw_pooling_bwd_t : public primitive_t { - } - - dim_t channel_block_size_; -+ int nthr_; // To not exceed the limit in execute used for set up. - - private: - void init_scratchpad() { -@@ -153,13 +155,12 @@ struct nchw_pooling_bwd_t : public primitive_t { - if (diff_dst_md()->data_type == data_type::bf16) { - size_t dst_sz_ = OD() * OH() * OW(); - size_t src_sz_ = ID() * IH() * IW(); -- size_t nthrs = dnnl_get_max_threads(); - auto scratchpad = scratchpad_registry().registrar(); - - scratchpad.template book(key_pool_src_bf16cvt, -- src_sz_ * nthrs * channel_block_size_); -+ src_sz_ * nthr_ * channel_block_size_); - scratchpad.template book(key_pool_dst_bf16cvt, -- dst_sz_ * nthrs * channel_block_size_); -+ dst_sz_ * nthr_ * channel_block_size_); - } - } - -@@ -169,8 +170,7 @@ struct nchw_pooling_bwd_t : public primitive_t { - // spatial - dim_t dst_sz_ = OD() * OH() * OW(); - dim_t src_sz_ = ID() * IH() * IW(); -- dim_t nthrs = dnnl_get_max_threads(); -- dim_t C_per_thr = nstl::min(MB() * C() / nthrs, C()); -+ dim_t C_per_thr = nstl::min(MB() * C() / nthr_, C()); - const dim_t max_block_size - = platform::get_per_core_cache_size(1) / 2; - dim_t data_size_per_ch = (dst_sz_ + src_sz_) * 6; // f32 + bf16 -diff --git a/src/cpu/nhwc_pooling.cpp b/src/cpu/nhwc_pooling.cpp -index 48d9e1240..efe3083f7 100644 ---- a/src/cpu/nhwc_pooling.cpp -+++ b/src/cpu/nhwc_pooling.cpp -@@ -378,8 +378,9 @@ status_t nhwc_pooling_fwd_t::execute_forward( - return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow; - }; - const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty()); -+ const int nthr = pd()->nthr_; - -- parallel_nd_ext(0, MB, OD, OH, OW, -+ parallel_nd_ext(nthr, MB, OD, OH, OW, - [&](int ithr, int, int mb, int od, int oh, int ow) { - const size_t dst_offset_init = strided_offset(mb, dst_n_stride, - od, dst_d_stride, oh, dst_h_stride, ow, dst_w_stride); -@@ -682,8 +683,9 @@ status_t nhwc_pooling_bwd_t::execute_backward( - auto apply_offset = [=](int index, int offset) { - return (index > offset) ? index - offset : 0; - }; -+ const int nthr = pd()->nthr_; - -- parallel_nd_ext(0, MB, ID, IH, IW, -+ parallel_nd_ext(nthr, MB, ID, IH, IW, - [&](int ithr, int, int mb, int id, int ih, int iw) { - size_t src_offset_init = strided_offset(mb, diff_src_n_stride, - id, diff_src_d_stride, ih, diff_src_h_stride, iw, -diff --git a/src/cpu/nhwc_pooling.hpp b/src/cpu/nhwc_pooling.hpp -index c65196a94..c16e840a2 100644 ---- a/src/cpu/nhwc_pooling.hpp -+++ b/src/cpu/nhwc_pooling.hpp -@@ -73,16 +73,19 @@ struct nhwc_pooling_fwd_t : public primitive_t { - init_default_ws(); - } - -+ nthr_ = dnnl_get_max_threads(); - init_scratchpad(); - - return status::success; - } - -+ int nthr_; // To not exceed the limit in execute used for set up. -+ - private: - void init_scratchpad() { - using namespace memory_tracking::names; - if (src_md()->data_type == data_type::bf16) { -- const size_t bf16cvt_sz_ = C() * dnnl_get_max_threads(); -+ const size_t bf16cvt_sz_ = C() * nthr_; - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.template book( - key_pool_src_bf16cvt, bf16cvt_sz_); -@@ -148,16 +151,19 @@ struct nhwc_pooling_bwd_t : public primitive_t { - if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; - } - -+ nthr_ = dnnl_get_max_threads(); - init_scratchpad(); - - return status::success; - } - -+ int nthr_; // To not exceed the limit in execute used for set up. -+ - private: - void init_scratchpad() { - using namespace memory_tracking::names; - if (diff_src_md()->data_type == data_type::bf16) { -- size_t bf16cvt_sz_ = C() * dnnl_get_max_threads(); -+ size_t bf16cvt_sz_ = C() * nthr_; - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.template book( - key_pool_src_bf16cvt, bf16cvt_sz_); -diff --git a/src/cpu/x64/jit_primitive_conf.hpp b/src/cpu/x64/jit_primitive_conf.hpp -index a2a181cfa..5befb81ac 100644 ---- a/src/cpu/x64/jit_primitive_conf.hpp -+++ b/src/cpu/x64/jit_primitive_conf.hpp -@@ -672,6 +672,7 @@ struct jit_pool_conf_t { - bool with_postops; - bool with_eltwise; - bool with_binary; -+ int nthr; - }; - - struct jit_pool_call_s { -diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp -index 36d129e6d..ebd4f3af1 100644 ---- a/src/cpu/x64/jit_uni_pool_kernel.cpp -+++ b/src/cpu/x64/jit_uni_pool_kernel.cpp -@@ -76,8 +76,7 @@ jit_uni_pool_kernel::jit_uni_pool_kernel( - - template - status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, -- memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd, -- int nthreads) { -+ memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd) { - - const auto &pd = *ppd->desc(); - const memory_desc_wrapper src_d( -@@ -87,6 +86,7 @@ status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, - - const int ndims = src_d.ndims(); - -+ jpp.nthr = dnnl_get_max_threads(); - jpp.is_training = pd.prop_kind == prop_kind::forward_training; - jpp.is_backward = pd.prop_kind == prop_kind::backward_data; - -@@ -248,7 +248,7 @@ status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, - ? (ndims == 5 && jpp.simple_alg ? jpp.od : 1) - : (ndims == 5 ? jpp.od : jpp.oh); - work *= jpp.mb * nb2_c; -- auto eff = (float)work / utils::rnd_up(work, nthreads); -+ auto eff = (float)work / utils::rnd_up(work, jpp.nthr); - if (eff > best_eff) { - - best_eff = eff; -diff --git a/src/cpu/x64/jit_uni_pool_kernel.hpp b/src/cpu/x64/jit_uni_pool_kernel.hpp -index d5d5f25a2..57ce6f43d 100644 ---- a/src/cpu/x64/jit_uni_pool_kernel.hpp -+++ b/src/cpu/x64/jit_uni_pool_kernel.hpp -@@ -46,8 +46,7 @@ struct jit_uni_pool_kernel : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel) - - static status_t init_conf(jit_pool_conf_t &jbp, -- memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd, -- int nthreads); -+ memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd); - - private: - using Xmm = Xbyak::Xmm; -diff --git a/src/cpu/x64/jit_uni_pooling.cpp b/src/cpu/x64/jit_uni_pooling.cpp -index b2055f2a9..29987f70c 100644 ---- a/src/cpu/x64/jit_uni_pooling.cpp -+++ b/src/cpu/x64/jit_uni_pooling.cpp -@@ -612,6 +612,8 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, - (*kernel_)(&arg); - }; - -+ const int nthr = jpp.nthr; -+ - if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { - const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - parallel_nd(jpp.mb, jpp.oh, nb2_c, [&](int n, int oh, int b2_c) { -@@ -622,7 +624,7 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, - } else { - if (trans_src || trans_dst) { - // ncsp format -- parallel_nd_ext(0, jpp.mb, jpp.nb_c, -+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, - [&](int ithr, int nthr, int n, int b_c) { - if (trans_src) - transpose_facade.execute_transpose_input( -@@ -635,7 +637,7 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, - }); - } else { - // nChw16c, nChw8c format -- parallel(0, [&](std::size_t ithr, std::size_t nthr) { -+ parallel(nthr, [&](int ithr, int nthr) { - const std::size_t work_amount - = static_cast(jpp.mb) * jpp.nb_c * jpp.oh; - if (ithr >= work_amount) return; -@@ -739,6 +741,8 @@ void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, - (*kernel_)(&arg); - }; - -+ const int nthr = jpp.nthr; -+ - if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { - const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - parallel_nd(jpp.mb, jpp.od, nb2_c, [&](int n, int od, int b2_c) { -@@ -757,7 +761,7 @@ void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, - }); - } else { - if (trans_src || trans_dst) { -- parallel_nd_ext(0, jpp.mb, jpp.nb_c, -+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, - [&](int ithr, int nthr, int n, int b_c) { - if (trans_src) - transpose_facade.execute_transpose_input( -@@ -948,7 +952,9 @@ void jit_uni_pooling_bwd_t::execute_backward( - transpose_facade.execute_transpose_output(ithr, n, b_c); - }; - -- parallel(0, [&](int ithr, int nthr) { -+ const int nthr = jpp.nthr; -+ -+ parallel(nthr, [&](int ithr, int nthr) { - const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - const std::size_t work_amount - = static_cast(jpp.mb) * nb2_c; -@@ -1098,6 +1104,8 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( - } - }; - -+ const int nthr = jpp.nthr; -+ - if (jpp.simple_alg) { - if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { - const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); -@@ -1109,7 +1117,7 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( - } else { - assert(jpp.ur_bc == 1); - if (trans_src || trans_dst) { -- parallel_nd_ext(0, jpp.mb, jpp.nb_c, -+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, - [&](int ithr, int nthr, int n, int b_c) { - if (trans_src) - transpose_facade.execute_transpose_input( -@@ -1142,7 +1150,7 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( - if (!trans_src) { - const size_t chunk_size - = (size_t)jpp.id * jpp.ih * jpp.iw * jpp.c_block; -- parallel_nd_ext(0, jpp.mb, jpp.nb_c, -+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, - [&](int ithr, int nthr, int n, int b_c) { - const size_t offset - = ((size_t)n * jpp.nb_c + b_c) * chunk_size; -@@ -1155,8 +1163,8 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( - - const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - if (trans_src || trans_dst) { -- parallel_nd_ext( -- 0, jpp.mb, nb2_c, [&](int ithr, int nthr, int n, int b2_c) { -+ parallel_nd_ext(nthr, jpp.mb, nb2_c, -+ [&](int ithr, int nthr, int n, int b2_c) { - const auto b_c = b2_c * jpp.ur_bc; - - if (trans_dst) { -diff --git a/src/cpu/x64/jit_uni_pooling.hpp b/src/cpu/x64/jit_uni_pooling.hpp -index ec4b04a2b..e25d9ce05 100644 ---- a/src/cpu/x64/jit_uni_pooling.hpp -+++ b/src/cpu/x64/jit_uni_pooling.hpp -@@ -66,8 +66,9 @@ struct jit_uni_pooling_fwd_t : public primitive_t { - init_default_ws(); - - auto scratchpad = scratchpad_registry().registrar(); -- return jit_uni_pool_kernel::init_conf( -- jpp_, scratchpad, this, dnnl_get_max_threads()); -+ CHECK(jit_uni_pool_kernel::init_conf(jpp_, scratchpad, this)); -+ -+ return status::success; - } - - jit_pool_conf_t jpp_; -@@ -130,9 +131,11 @@ struct jit_uni_pooling_bwd_t : public primitive_t { - init_default_ws(); - if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; - } -+ - auto scratchpad = scratchpad_registry().registrar(); -- return jit_uni_pool_kernel::init_conf( -- jpp_, scratchpad, this, dnnl_get_max_threads()); -+ CHECK(jit_uni_pool_kernel::init_conf(jpp_, scratchpad, this)); -+ -+ return status::success; - } - - jit_pool_conf_t jpp_; diff --git a/third_party/patch/onednn/0003-fix-zero-threads-identified-on-AMD.patch b/third_party/patch/onednn/0003-fix-zero-threads-identified-on-AMD.patch deleted file mode 100644 index 0c3b6a76ed2d2cd2ccb8570ba635edf32b76172b..0000000000000000000000000000000000000000 --- a/third_party/patch/onednn/0003-fix-zero-threads-identified-on-AMD.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/src/cpu/platform.cpp b/src/cpu/platform.cpp -index 1397073ba..041a3436f 100644 ---- a/src/cpu/platform.cpp -+++ b/src/cpu/platform.cpp -@@ -154,6 +154,8 @@ unsigned get_num_cores() { - // function supports process affinity. - unsigned get_max_threads_to_use() { - int num_cores_per_socket = (int)dnnl::impl::cpu::platform::get_num_cores(); -+ if (num_cores_per_socket <= 1) -+ num_cores_per_socket = std::thread::hardware_concurrency(); - #if defined(_WIN32) - DWORD_PTR proc_affinity_mask; - DWORD_PTR sys_affinity_mask; diff --git a/third_party/patch/onednn/0004-fix-dnnl-limits.patch b/third_party/patch/onednn/0004-fix-dnnl-limits.patch deleted file mode 100644 index 7638e4ae6516bfd0851b010b35c062bf8be367fd..0000000000000000000000000000000000000000 --- a/third_party/patch/onednn/0004-fix-dnnl-limits.patch +++ /dev/null @@ -1,10 +0,0 @@ ---- a/src/cpu/aarch64/xbyak_aarch64/xbyak_aarch64/xbyak_aarch64.h -+++ b/src/cpu/aarch64/xbyak_aarch64/xbyak_aarch64/xbyak_aarch64.h -@@ -28,6 +28,7 @@ - #include - #include - #include -+#include - #include - #include - #include diff --git a/third_party/patch/robin_hood_hashing/0001-fix-unused-var-warning.patch b/third_party/patch/robin_hood_hashing/0001-fix-unused-var-warning.patch deleted file mode 100644 index 54b7857e253be93be0cd24de67d7d82e2be05be7..0000000000000000000000000000000000000000 --- a/third_party/patch/robin_hood_hashing/0001-fix-unused-var-warning.patch +++ /dev/null @@ -1,60 +0,0 @@ -diff --git a/src/include/robin_hood.h b/src/include/robin_hood.h ---- a/src/include/robin_hood.h -+++ b/src/include/robin_hood.h -@@ -2541,4 +2541,56 @@ using unordered_set = detail::Table -+struct tuple_size> : std::integral_constant {}; -+ -+template -+struct tuple_element> { -+ typedef typename std::conditional::type type; -+}; -+} // namespace std -+ -+namespace robin_hood { -+template -+typename std::enable_if::type get(robin_hood::pair &p) { -+ return p.first; -+} -+ -+template -+typename std::enable_if::type get(robin_hood::pair &p) { -+ return p.second; -+} -+ -+template -+typename std::enable_if::type get(const robin_hood::pair &p) { -+ return p.first; -+} -+ -+template -+typename std::enable_if::type get(const robin_hood::pair &p) { -+ return p.second; -+} -+ -+template -+typename std::enable_if::type get(robin_hood::pair &&p) { -+ return std::move(p.first); -+} -+ -+template -+typename std::enable_if::type get(robin_hood::pair &&p) { -+ return std::move(p.second); -+} -+ -+template -+typename std::enable_if::type get(const robin_hood::pair &&p) { -+ return std::move(p.first); -+} -+ -+template -+typename std::enable_if::type get(const robin_hood::pair &&p) { -+ return std::move(p.second); -+} -+} // namespace robin_hood -+ - #endif diff --git a/third_party/patch/robin_hood_hashing/0002-fix-string-isflat-symbol.patch b/third_party/patch/robin_hood_hashing/0002-fix-string-isflat-symbol.patch deleted file mode 100644 index f2cb59f16bb9db7ba43c2aa7b4abe9b336ab0bf3..0000000000000000000000000000000000000000 --- a/third_party/patch/robin_hood_hashing/0002-fix-string-isflat-symbol.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/src/include/robin_hood.h b/src/include/robin_hood.h ---- a/src/include/robin_hood.h -+++ b/src/include/robin_hood.h -@@ -2519,7 +2519,8 @@ - using unordered_map = - detail::Table) <= sizeof(size_t) * 6 && - std::is_nothrow_move_constructible>::value && -- std::is_nothrow_move_assignable>::value, -+ std::is_nothrow_move_assignable>::value && -+ !std::is_same::value, - MaxLoadFactor100, Key, T, Hash, KeyEqual>; - - // set